├── LICENSE
├── README.md
├── assets
└── liger_framework.png
├── checkpoints
├── liger_gla_base
│ └── config.json
├── liger_gsa_base
│ └── config.json
├── liger_hgrn2_base
│ └── config.json
├── liger_mistral_gla_base
│ └── config.json
├── liger_qwen25_gla_base
│ └── config.json
└── lolcats_base
│ └── config.json
├── configs
├── liger_gla.yaml
├── liger_gsa.yaml
├── liger_hgrn2.yaml
├── liger_mistral_gla.yaml
├── liger_qwen2_gla.yaml
├── lolcats_ar.yaml
└── lolcats_at.yaml
├── eval
└── harness.py
├── liger
├── __init__.py
└── models
│ ├── __init__.py
│ ├── liger_gla
│ ├── __init__.py
│ ├── configuration_liger_gla.py
│ └── modeling_liger_gla.py
│ ├── liger_gsa
│ ├── __init__.py
│ ├── configuration_liger_gsa.py
│ └── modeling_liger_gsa.py
│ ├── liger_hgrn2
│ ├── __init__.py
│ ├── configuration_liger_hgrn2.py
│ └── modeling_liger_hgrn2.py
│ ├── liger_mistral_gla
│ ├── __init__.py
│ ├── configuration_liger_mistral_gla.py
│ └── modeling_liger_mistral_gla.py
│ └── liger_qwen2_gla
│ ├── __init__.py
│ ├── configuration_liger_qwen2_gla.py
│ └── modeling_liger_qwen2_gla.py
├── lolcats
├── __init__.py
└── models
│ ├── __init__.py
│ └── lolcats
│ ├── __init__.py
│ ├── configuration_lolcats.py
│ └── modeling_lolcats.py
├── requirements.txt
├── run.py
├── scripts
├── train_liger.sh
├── train_lolcats_stage1.sh
└── train_lolcats_stage2.sh
└── training
├── __init__.py
├── dataloader.py
├── train.py
├── trainer.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Liger: Linearizing Large Language Models to Gated Recurrent Structures
2 |
3 | [](https://arxiv.org/abs/2503.01496)
4 | [](https://huggingface.co/collections/linear-moe-hub/liger-67d904bffd7f9b77ade7747d)
5 |
6 | ## Framework
7 |
8 |
9 |
10 |
11 |
12 | Figure 1: Liger Framework
13 |
14 |
15 | ## Environment
16 |
17 | ```bash
18 | git clone --recurse-submodules https://github.com/OpenSparseLLMs/Linearization.git
19 | conda create -n liger python=3.10
20 | conda activate liger
21 | pip install -r requirements
22 | pip install flash-attn --no-build-isolation
23 | cd third_party/flash-linear-attention
24 | pip install -e .
25 | ```
26 |
27 | ## Linearization
28 |
29 | 1. Copy your pre-trained base model directory (e.g. Meta-Llama-3-8B) to `./checkpoints/`;
30 | 2. Modify the `config` file of the original Llama-3 base model to the `config` file of the Liger model (see `./checkpoints/liger_gla_base/config.json`);
31 | 3. Modify the linearization settings in `./configs/config.yaml ` file (e.g. liger_gla.yaml);
32 | 4. Run the linearization script:
33 |
34 | ```bash
35 | sh scripts/train_liger.sh
36 | ```
37 |
38 | ## Evaluation
39 |
40 | You need to install [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) for evaluation:
41 |
42 | ```
43 | cd third_party/lm-evaluation-harness
44 | pip install -e .
45 | ```
46 |
47 | ```bash
48 | python -m eval.harness --model hf \
49 | --model_args pretrained=/your/Liger/checkpoints/liger_base_model,peft=/your/Liger/checkpoints/lora_adapter_path \
50 | --tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
51 | --batch_size 64 \
52 | --device cuda \
53 | --seed 0
54 | ```
55 |
56 | ## Acknowledgements
57 |
58 | We use the triton-implemented linear attention kernels from [fla-org/flash-linear-attention](https://github.com/fla-org/flash-linear-attention). We refer to [HazyResearch/lolcats](https://github.com/HazyResearch/lolcats) to construct our linearization training processs. The evaluation is supported by [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). Sincerely thank their contributions!
59 |
60 | ## Citation
61 |
62 | If you find this repo useful, please cite and star our work:
63 |
64 | ```bibtex
65 | @article{lan2025liger,
66 | title={Liger: Linearizing Large Language Models to Gated Recurrent Structures},
67 | author={Lan, Disen and Sun, Weigao and Hu, Jiaxi and Du, Jusen and Cheng, Yu},
68 | journal={arXiv preprint arXiv:2503.01496},
69 | year={2025}
70 | }
71 | ```
--------------------------------------------------------------------------------
/assets/liger_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/Linearization/0e3cfae33a700fa5f644cf5752d8434c6afc2412/assets/liger_framework.png
--------------------------------------------------------------------------------
/checkpoints/liger_gla_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LigerGLAForCausalLM"
4 | ],
5 | "attention_bias": false,
6 | "attention_dropout": 0.0,
7 | "bos_token_id": 128000,
8 | "eos_token_id": 128001,
9 | "expand_k": 1,
10 | "expand_v": 1,
11 | "head_dim": 128,
12 | "hidden_act": "silu",
13 | "hidden_ratio": 4,
14 | "hidden_size": 4096,
15 | "initializer_range": 0.02,
16 | "intermediate_size": 14336,
17 | "max_position_embeddings": 8192,
18 | "mlp_bias": false,
19 | "model_type": "liger_gla",
20 | "num_attention_heads": 32,
21 | "num_hidden_layers": 32,
22 | "num_key_value_heads": 8,
23 | "pool_size": 128,
24 | "pretraining_tp": 1,
25 | "rms_norm_eps": 1e-05,
26 | "rope_scaling": null,
27 | "rope_theta": 500000.0,
28 | "tie_word_embeddings": false,
29 | "torch_dtype": "bfloat16",
30 | "transformers_version": "4.47.1",
31 | "use_cache": true,
32 | "vocab_size": 128256
33 | }
34 |
--------------------------------------------------------------------------------
/checkpoints/liger_gsa_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LigerGSAForCausalLM"
4 | ],
5 | "attention_bias": false,
6 | "attention_dropout": 0.0,
7 | "bos_token_id": 128000,
8 | "eos_token_id": 128001,
9 | "expand_k": 1,
10 | "expand_v": 1,
11 | "head_dim": 128,
12 | "hidden_act": "silu",
13 | "hidden_ratio": 4,
14 | "hidden_size": 4096,
15 | "initializer_range": 0.02,
16 | "intermediate_size": 14336,
17 | "max_position_embeddings": 8192,
18 | "mlp_bias": false,
19 | "model_type": "liger_gsa",
20 | "num_attention_heads": 32,
21 | "num_hidden_layers": 32,
22 | "num_key_value_heads": 8,
23 | "pool_size": 128,
24 | "pretraining_tp": 1,
25 | "rms_norm_eps": 1e-05,
26 | "rope_scaling": null,
27 | "rope_theta": 500000.0,
28 | "tie_word_embeddings": false,
29 | "torch_dtype": "bfloat16",
30 | "transformers_version": "4.47.1",
31 | "use_cache": true,
32 | "vocab_size": 128256
33 | }
34 |
--------------------------------------------------------------------------------
/checkpoints/liger_hgrn2_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LigerHGRN2ForCausalLM"
4 | ],
5 | "attention_bias": false,
6 | "attention_dropout": 0.0,
7 | "bos_token_id": 128000,
8 | "eos_token_id": 128001,
9 | "expand_k": 1,
10 | "expand_v": 1,
11 | "head_dim": 128,
12 | "hidden_act": "silu",
13 | "hidden_ratio": 4,
14 | "hidden_size": 4096,
15 | "initializer_range": 0.02,
16 | "intermediate_size": 14336,
17 | "max_position_embeddings": 8192,
18 | "mlp_bias": false,
19 | "model_type": "liger_hgrn2",
20 | "num_attention_heads": 32,
21 | "num_hidden_layers": 32,
22 | "num_key_value_heads": 8,
23 | "pool_size": 128,
24 | "pretraining_tp": 1,
25 | "rms_norm_eps": 1e-05,
26 | "rope_scaling": null,
27 | "rope_theta": 500000.0,
28 | "tie_word_embeddings": false,
29 | "torch_dtype": "bfloat16",
30 | "transformers_version": "4.47.1",
31 | "use_cache": true,
32 | "vocab_size": 128256
33 | }
34 |
--------------------------------------------------------------------------------
/checkpoints/liger_mistral_gla_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LigerMistralGLAForCausalLM"
4 | ],
5 | "bos_token_id": 1,
6 | "eos_token_id": 2,
7 | "hidden_act": "silu",
8 | "hidden_size": 4096,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 14336,
11 | "max_position_embeddings": 32768,
12 | "model_type": "liger_mistral_gla",
13 | "num_attention_heads": 32,
14 | "num_hidden_layers": 32,
15 | "num_key_value_heads": 8,
16 | "rms_norm_eps": 1e-05,
17 | "rope_theta": 10000.0,
18 | "sliding_window": 4096,
19 | "tie_word_embeddings": false,
20 | "torch_dtype": "bfloat16",
21 | "transformers_version": "4.34.0.dev0",
22 | "use_cache": true,
23 | "vocab_size": 32000
24 | }
25 |
--------------------------------------------------------------------------------
/checkpoints/liger_qwen25_gla_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LigerQwen2GLAForCausalLM"
4 | ],
5 | "attention_dropout": 0.0,
6 | "bos_token_id": 151643,
7 | "eos_token_id": 151643,
8 | "hidden_act": "silu",
9 | "hidden_size": 3584,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 18944,
12 | "max_position_embeddings": 131072,
13 | "max_window_layers": 28,
14 | "model_type": "liger_qwen2_gla",
15 | "num_attention_heads": 28,
16 | "num_hidden_layers": 28,
17 | "num_key_value_heads": 4,
18 | "rms_norm_eps": 1e-06,
19 | "rope_theta": 1000000.0,
20 | "sliding_window": 131072,
21 | "tie_word_embeddings": false,
22 | "torch_dtype": "bfloat16",
23 | "transformers_version": "4.40.1",
24 | "use_cache": true,
25 | "use_mrope": false,
26 | "use_sliding_window": false,
27 | "vocab_size": 152064
28 | }
29 |
--------------------------------------------------------------------------------
/checkpoints/lolcats_base/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LolcatsModelForCausalLM"
4 | ],
5 | "attention_bias": false,
6 | "attention_dropout": 0.0,
7 | "attn": null,
8 | "attn_mode": "fused_chunk",
9 | "bos_token_id": 128000,
10 | "elementwise_affine": true,
11 | "eos_token_id": 128001,
12 | "expand_k": 1,
13 | "expand_v": 1,
14 | "feature_map": "lolcats_hedgehog",
15 | "fuse_cross_entropy": true,
16 | "head_dim": 128,
17 | "hidden_act": "silu",
18 | "hidden_ratio": 4,
19 | "hidden_size": 4096,
20 | "initializer_range": 0.02,
21 | "intermediate_size": 14336,
22 | "max_position_embeddings": 8192,
23 | "mlp_bias": false,
24 | "model_type": "lolcats",
25 | "norm_eps": 1e-06,
26 | "norm_feature_map": false,
27 | "norm_k": false,
28 | "norm_q": false,
29 | "num_attention_heads": 32,
30 | "num_hidden_layers": 32,
31 | "num_key_value_heads": 8,
32 | "pretraining_tp": 1,
33 | "rms_norm_eps": 1e-05,
34 | "rope_scaling": null,
35 | "rope_theta": 500000.0,
36 | "tie_feature_map_qk": false,
37 | "tie_word_embeddings": false,
38 | "torch_dtype": "bfloat16",
39 | "transformers_version": "4.46.2",
40 | "use_cache": true,
41 | "vocab_size": 128256
42 | }
43 |
--------------------------------------------------------------------------------
/configs/liger_gla.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'liger_gla'
9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_gla_base'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.001
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: True
22 | train_qk_lora: True
23 | train_v: True
24 | train_v_lora: True
--------------------------------------------------------------------------------
/configs/liger_gsa.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'liger_gsa'
9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_gsa_base'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.001
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: True
22 | train_qk_lora: True
23 | train_v: True
24 | train_v_lora: True
--------------------------------------------------------------------------------
/configs/liger_hgrn2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'liger_hgrn2'
9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_hgrn2_base'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.001
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: True
22 | train_qk_lora: True
23 | train_v: True
24 | train_v_lora: True
--------------------------------------------------------------------------------
/configs/liger_mistral_gla.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'liger_mistral_gla'
9 | pretrained_model_name_or_path: '/your/checkpoints/liger_mistral_gla_base'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.001
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: True
22 | train_qk_lora: True
23 | train_v: True
24 | train_v_lora: True
--------------------------------------------------------------------------------
/configs/liger_qwen2_gla.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'liger_qwen2_gla'
9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/liger_qwen25_gla_base'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.001
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: True
22 | train_qk_lora: True
23 | train_v: True
24 | train_v_lora: True
--------------------------------------------------------------------------------
/configs/lolcats_ar.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'lolcats_ar'
9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/lolcats_at'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.0001
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: True
22 | train_qk_lora: True
23 | train_v: True
24 | train_v_lora: True
--------------------------------------------------------------------------------
/configs/lolcats_at.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | name: "alpaca_cleand"
3 | path: "yahma/alpaca-cleaned"
4 | batch_size: 8
5 | micro_batch_size: 1
6 | val_set_size: 200
7 | model:
8 | name: 'lolcats_at'
9 | pretrained_model_name_or_path: '/your/Liger/checkpoints/lolcats_base'
10 | max_length: 1024
11 | # tokenizer
12 | add_eos_token: False
13 |
14 | device_map: 'auto'
15 | train:
16 | optim: 'adamw_torch'
17 | lr: 0.01
18 | epochs: 2
19 | max_grad_norm: 1.0
20 | output_dir: 'checkpoints'
21 | train_qk: False
22 | train_qk_lora: False
23 | train_v: False
24 | train_v_lora: False
--------------------------------------------------------------------------------
/eval/harness.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import annotations
4 |
5 | import fla # noqa
6 | import liger
7 | import lolcats
8 | from lm_eval.__main__ import cli_evaluate
9 | from lm_eval.api.registry import register_model
10 | from lm_eval.models.huggingface import HFLM
11 |
12 |
13 | @register_model('fla')
14 | class FlashLinearAttentionLMWrapper(HFLM):
15 | def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper:
16 |
17 | # TODO: provide options for doing inference with different kernels
18 |
19 | super().__init__(**kwargs)
20 |
21 |
22 | if __name__ == "__main__":
23 | cli_evaluate()
24 |
--------------------------------------------------------------------------------
/liger/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from liger.models.liger_gla import LigerGLAConfig, LigerGLAForCausalLM, LigerGLAModel
4 | from liger.models.liger_gsa import LigerGSAConfig, LigerGSAForCausalLM, LigerGSAModel
5 | from liger.models.liger_hgrn2 import LigerHGRN2Config, LigerHGRN2ForCausalLM, LigerHGRN2Model
6 | from liger.models.liger_mistral_gla import LigerMistralGLAConfig, LigerMistralGLAForCausalLM, LigerMistralGLAModel
7 | from liger.models.liger_qwen2_gla import LigerQwen2GLAConfig, LigerQwen2GLAForCausalLM, LigerQwen2GLAModel
8 |
9 | __all__ = [
10 | 'LigerGLAConfig', 'LigerGLAForCausalLM', 'LigerGLAModel',
11 | 'LigerGSAConfig', 'LigerGSAForCausalLM', 'LigerGSAModel',
12 | 'LigerHGRN2Config', 'LigerHGRN2ForCausalLM', 'LigerHGRN2Model',
13 | 'LigerMistralGLAConfig', 'LigerMistralGLAForCausalLM', 'LigerMistralGLAModel',
14 | 'LigerQwen2GLAConfig', 'LigerQwen2GLAForCausalLM', 'LigerQwen2GLAModel',
15 | ]
--------------------------------------------------------------------------------
/liger/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from liger.models.liger_gla import LigerGLAConfig, LigerGLAForCausalLM, LigerGLAModel
4 | from liger.models.liger_gsa import LigerGSAConfig, LigerGSAForCausalLM, LigerGSAModel
5 | from liger.models.liger_hgrn2 import LigerHGRN2Config, LigerHGRN2ForCausalLM, LigerHGRN2Model
6 | from liger.models.liger_mistral_gla import LigerMistralGLAConfig, LigerMistralGLAForCausalLM, LigerMistralGLAModel
7 | from liger.models.liger_qwen2_gla import LigerQwen2GLAConfig, LigerQwen2GLAForCausalLM, LigerQwen2GLAModel
8 |
9 | __all__ = [
10 | 'LigerGLAConfig', 'LigerGLAForCausalLM', 'LigerGLAModel',
11 | 'LigerGSAConfig', 'LigerGSAForCausalLM', 'LigerGSAModel',
12 | 'LigerHGRN2Config', 'LigerHGRN2ForCausalLM', 'LigerHGRN2Model',
13 | 'LigerMistralGLAConfig', 'LigerMistralGLAForCausalLM', 'LigerMistralGLAModel',
14 | 'LigerQwen2GLAConfig', 'LigerQwen2GLAForCausalLM', 'LigerQwen2GLAModel',
15 | ]
--------------------------------------------------------------------------------
/liger/models/liger_gla/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 |
3 | from liger.models.liger_gla.configuration_liger_gla import LigerGLAConfig
4 | from liger.models.liger_gla.modeling_liger_gla import LigerGLAForCausalLM, LigerGLAModel
5 |
6 | AutoConfig.register(LigerGLAConfig.model_type, LigerGLAConfig)
7 | AutoModel.register(LigerGLAConfig, LigerGLAModel)
8 | AutoModelForCausalLM.register(LigerGLAConfig, LigerGLAForCausalLM)
9 |
10 |
11 | __all__ = ['LigerGLAConfig', 'LigerGLAForCausalLM', 'LigerGLAModel']
--------------------------------------------------------------------------------
/liger/models/liger_gla/configuration_liger_gla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 | from transformers.models.llama.configuration_llama import LlamaConfig
7 |
8 | class LigerGLAConfig(LlamaConfig, PretrainedConfig):
9 | model_type = 'liger_gla'
10 | keys_to_ignore_at_inference = ['past_key_values']
11 |
12 | def __init__(
13 | self,
14 | # llama config
15 | vocab_size=32000,
16 | hidden_size=4096,
17 | intermediate_size=11008,
18 | num_hidden_layers=32,
19 | num_attention_heads=32,
20 | num_key_value_heads=None,
21 | hidden_act="silu",
22 | max_position_embeddings=2048,
23 | initializer_range=0.02,
24 | rms_norm_eps=1e-6,
25 | use_cache=True,
26 | pad_token_id=None,
27 | bos_token_id=1,
28 | eos_token_id=2,
29 | pretraining_tp=1,
30 | tie_word_embeddings=False,
31 | rope_theta=10000.0,
32 | rope_scaling=None,
33 | attention_bias=False,
34 | attention_dropout=0.0,
35 | mlp_bias=False,
36 | head_dim=None,
37 | **kwargs,
38 | ):
39 | super().__init__(
40 | vocab_size=vocab_size,
41 | hidden_size=hidden_size,
42 | intermediate_size=intermediate_size,
43 | num_hidden_layers=num_hidden_layers,
44 | num_attention_heads=num_attention_heads,
45 | num_key_value_heads=num_key_value_heads,
46 | hidden_act=hidden_act,
47 | max_position_embeddings=max_position_embeddings,
48 | initializer_range=initializer_range,
49 | rms_norm_eps=rms_norm_eps,
50 | use_cache=use_cache,
51 | pad_token_id=pad_token_id,
52 | bos_token_id=bos_token_id,
53 | eos_token_id=eos_token_id,
54 | pretraining_tp=pretraining_tp,
55 | tie_word_embeddings=tie_word_embeddings,
56 | rope_theta=rope_theta,
57 | rope_scaling=rope_scaling,
58 | attention_bias=attention_bias,
59 | attention_dropout=attention_dropout,
60 | mlp_bias=mlp_bias,
61 | head_dim=head_dim,
62 | **kwargs,
63 | )
--------------------------------------------------------------------------------
/liger/models/liger_gla/modeling_liger_gla.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | import copy
4 | from typing import List, Optional, Tuple, Union
5 | from einops import rearrange, repeat
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint
11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
12 | from transformers.generation import GenerationMixin
13 | from transformers.modeling_outputs import (
14 | BaseModelOutputWithPast,
15 | CausalLMOutputWithPast,
16 | )
17 | from transformers.models.llama.modeling_llama import (
18 | LlamaRMSNorm,
19 | LlamaRotaryEmbedding,
20 | repeat_kv,
21 | apply_rotary_pos_emb,
22 | LlamaMLP,
23 | LlamaAttention,
24 | LlamaFlashAttention2,
25 | LlamaSdpaAttention,
26 | LlamaDecoderLayer,
27 | LlamaForCausalLM,
28 | LlamaModel,
29 | LlamaPreTrainedModel,
30 | LLAMA_INPUTS_DOCSTRING,
31 | )
32 |
33 | from transformers.utils import logging, add_start_docstrings_to_model_forward
34 | from transformers.utils import is_flash_attn_2_available
35 |
36 | if is_flash_attn_2_available():
37 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
38 | else:
39 | print("flash_attn_2 is not available")
40 |
41 | from fla.models.utils import Cache as FlaCache
42 | from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla
43 |
44 | from .configuration_liger_gla import LigerGLAConfig
45 |
46 | logger = logging.get_logger(__name__)
47 |
48 | class LigerGatedLinearAttention(nn.Module):
49 | def __init__(
50 | self,
51 | config: LigerGLAConfig,
52 | layer_idx: Optional[int] = None,
53 | ):
54 | super().__init__()
55 | self.config = config
56 | self.layer_idx = layer_idx
57 |
58 | self.attention_dropout = config.attention_dropout
59 | self.hidden_size = config.hidden_size
60 | self.num_heads = config.num_attention_heads
61 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
62 | self.num_key_value_heads = config.num_key_value_heads
63 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
64 |
65 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
66 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
67 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
68 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
69 |
70 |
71 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
72 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads)
73 |
74 | def forward(
75 | self,
76 | hidden_states: torch.Tensor,
77 | attention_mask: Optional[torch.Tensor] = None,
78 | position_ids: Optional[torch.LongTensor] = None,
79 | past_key_value: Optional[FlaCache] = None,
80 | output_attentions: bool = False,
81 | use_cache: bool = False,
82 | cache_position: Optional[torch.LongTensor] = None,
83 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
84 | **kwargs,
85 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86 | last_state = None
87 | if past_key_value is not None and len(past_key_value) > self.layer_idx:
88 | last_state = past_key_value[self.layer_idx]
89 |
90 | q = self.q_proj(hidden_states)
91 | k = self.k_proj(hidden_states)
92 | v = self.v_proj(hidden_states)
93 | g = self.pool_g(k)
94 |
95 | # dealing with left-padding
96 | if attention_mask is not None:
97 | v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
98 |
99 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
100 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_key_value_heads)
101 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_key_value_heads)
102 | g = rearrange(g, 'b n (h m) -> b h n m', h=self.num_key_value_heads)
103 |
104 | k = repeat_kv(k, self.num_key_value_groups)
105 | v = repeat_kv(v, self.num_key_value_groups)
106 | g = repeat_kv(g, self.num_key_value_groups)
107 |
108 | sq, sk, sv = q, k, v
109 |
110 | # norm
111 | q = F.softmax(q, dim=-1)
112 | k = F.softmax(k, dim=-1)
113 |
114 | gate_logit_normalizer = 16
115 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m)
116 |
117 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None
118 | offsets = kwargs.get('offsets', None)
119 | scale = 1
120 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g))
121 |
122 | if self.training or q.shape[-2] > 1:
123 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
124 | else:
125 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
126 |
127 | if past_key_value is not None:
128 | past_key_value.update(
129 | recurrent_state=recurrent_state,
130 | layer_idx=self.layer_idx,
131 | offset=q.shape[1]
132 | )
133 |
134 | q_len = hidden_states.size(-2)
135 |
136 | if position_embeddings is None:
137 | logger.warning_once(
138 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
139 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
140 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
141 | "removed and `position_embeddings` will be mandatory."
142 | )
143 | cos, sin = self.rotary_emb(sv, position_ids)
144 | else:
145 | cos, sin = position_embeddings
146 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin)
147 |
148 | input_dtype = sq.dtype
149 | if input_dtype == torch.float32:
150 | if torch.is_autocast_enabled():
151 | target_dtype = torch.get_autocast_gpu_dtype()
152 | # Handle the case where the model is quantized
153 | elif hasattr(self.config, "_pre_quantization_dtype"):
154 | target_dtype = self.config._pre_quantization_dtype
155 | else:
156 | target_dtype = self.q_proj.weight.dtype
157 |
158 | logger.warning_once(
159 | f"The input hidden states seems to be silently casted in float32, this might be related to"
160 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
161 | f" {target_dtype}."
162 | )
163 |
164 | sq = sq.to(target_dtype)
165 | sk = sk.to(target_dtype)
166 | sv = sv.to(target_dtype)
167 |
168 | window_size = 64
169 | if attention_mask is not None and 0.0 in attention_mask:
170 | pass
171 | else:
172 | attention_mask = None
173 |
174 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
175 | sq.transpose(1, 2),
176 | sk.transpose(1, 2),
177 | sv.transpose(1, 2),
178 | attention_mask,
179 | q_len,
180 | position_ids=position_ids,
181 | dropout=0.0,
182 | sliding_window=window_size,
183 | use_top_left_mask=False,
184 | is_causal=True,
185 | target_dtype=torch.float32,
186 | ).transpose(1, 2)
187 | o_ = 0.5 * y + 0.5 * o_
188 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)')
189 | o = self.o_proj(o)
190 |
191 | return o, None, past_key_value
192 |
193 | class LigerGLADecoderLayer(LlamaDecoderLayer):
194 | def __init__(self, config: LigerGLAConfig, layer_idx: int):
195 | super().__init__(config, layer_idx)
196 | self.hidden_size = config.hidden_size
197 | self.self_attn = LigerGatedLinearAttention(config=config, layer_idx=layer_idx)
198 | self.mlp = LlamaMLP(config)
199 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
200 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
201 |
202 | class LigerGLAPreTrainedModel(LlamaPreTrainedModel):
203 |
204 | config_class = LigerGLAConfig
205 | base_model_prefix = "model"
206 | supports_gradient_checkpointing = True
207 | _no_split_modules = ['LigerGLADecoderLayer']
208 | _skip_keys_device_placement = "past_key_values"
209 |
210 | class LigerGLAModel(LlamaModel, LigerGLAPreTrainedModel):
211 |
212 | def __init__(self, config: LigerGLAConfig):
213 | super().__init__(config)
214 | self.padding_idx = config.pad_token_id
215 | self.vocab_size = config.vocab_size
216 |
217 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
218 | self.layers = nn.ModuleList(
219 | [LigerGLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
220 | )
221 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
222 | self.rotary_emb = LlamaRotaryEmbedding(config=config)
223 | self.gradient_checkpointing = False
224 |
225 | # Initialize weights and apply final processing
226 | self.post_init()
227 |
228 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
229 | def forward(
230 | self,
231 | input_ids: torch.LongTensor = None,
232 | attention_mask: Optional[torch.Tensor] = None,
233 | position_ids: Optional[torch.LongTensor] = None,
234 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None,
235 | inputs_embeds: Optional[torch.FloatTensor] = None,
236 | use_cache: Optional[bool] = None,
237 | output_attentions: Optional[bool] = None,
238 | output_hidden_states: Optional[bool] = None,
239 | return_dict: Optional[bool] = None,
240 | cache_position: Optional[torch.LongTensor] = None,
241 | **kwargs,
242 | ) -> Union[Tuple, BaseModelOutputWithPast]:
243 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
244 | output_hidden_states = (
245 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
246 | )
247 | use_cache = use_cache if use_cache is not None else self.config.use_cache
248 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
249 |
250 | if (input_ids is None) ^ (inputs_embeds is not None):
251 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
252 |
253 | if self.gradient_checkpointing and self.training and use_cache:
254 | logger.warning_once(
255 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
256 | )
257 | use_cache = False
258 |
259 | if inputs_embeds is None:
260 | inputs_embeds = self.embed_tokens(input_ids)
261 |
262 |
263 | # kept for BC (non `Cache` `past_key_values` inputs)
264 | return_legacy_cache = False
265 | if use_cache and not isinstance(past_key_values, FlaCache):
266 | past_key_values = FlaCache.from_legacy_cache(past_key_values)
267 |
268 | if cache_position is None:
269 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
270 | cache_position = torch.arange(
271 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
272 | )
273 |
274 | if position_ids is None:
275 | position_ids = cache_position.unsqueeze(0)
276 |
277 | causal_mask = attention_mask
278 |
279 | hidden_states = inputs_embeds
280 |
281 | # create position embeddings to be shared across the decoder layers
282 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
283 |
284 | # decoder layers
285 | all_hidden_states = () if output_hidden_states else None
286 | all_self_attns = () if output_attentions else None
287 | next_decoder_cache = None
288 |
289 | if output_attentions:
290 | all_softmax_hidden_states = ()
291 |
292 | for decoder_layer in self.layers:
293 | if output_hidden_states:
294 | all_hidden_states += (hidden_states,)
295 | if all_softmax_hidden_states is not None:
296 | all_softmax_hidden_states += (hidden_states,)
297 |
298 | if self.gradient_checkpointing and self.training:
299 | layer_outputs = self._gradient_checkpointing_func(
300 | decoder_layer.__call__,
301 | hidden_states,
302 | causal_mask,
303 | position_ids,
304 | past_key_values,
305 | output_attentions,
306 | use_cache,
307 | cache_position,
308 | position_embeddings,
309 | )
310 |
311 | else:
312 | layer_outputs = decoder_layer(
313 | hidden_states,
314 | attention_mask=causal_mask,
315 | position_ids=position_ids,
316 | past_key_value=past_key_values,
317 | output_attentions=output_attentions,
318 | use_cache=use_cache,
319 | cache_position=cache_position,
320 | position_embeddings=position_embeddings,
321 | )
322 | hidden_states = layer_outputs[0]
323 |
324 | if use_cache:
325 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
326 |
327 | if output_attentions:
328 | all_self_attns += (layer_outputs[1],)
329 |
330 | hidden_states = self.norm(hidden_states)
331 |
332 | # add hidden states from the last decoder layer
333 | if output_hidden_states:
334 | all_hidden_states += (hidden_states,)
335 |
336 | next_cache = next_decoder_cache if use_cache else None
337 |
338 | if not return_dict:
339 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
340 | return BaseModelOutputWithPast(
341 | last_hidden_state=hidden_states,
342 | past_key_values=next_cache,
343 | hidden_states=all_hidden_states,
344 | attentions=all_self_attns,
345 | )
346 |
347 | class LigerGLAForCausalLM(LlamaForCausalLM, LigerGLAPreTrainedModel, GenerationMixin):
348 | _tied_weights_keys = ["lm_head.weight"]
349 |
350 | def __init__(self, config):
351 | super().__init__(config)
352 | self.model = LigerGLAModel(config)
353 | self.vocab_size = config.vocab_size
354 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
355 |
356 | # Initialize weights and apply final processing
357 | self.post_init()
--------------------------------------------------------------------------------
/liger/models/liger_gsa/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 |
3 | from liger.models.liger_gsa.configuration_liger_gsa import LigerGSAConfig
4 | from liger.models.liger_gsa.modeling_liger_gsa import LigerGSAForCausalLM, LigerGSAModel
5 |
6 | AutoConfig.register(LigerGSAConfig.model_type, LigerGSAConfig)
7 | AutoModel.register(LigerGSAConfig, LigerGSAModel)
8 | AutoModelForCausalLM.register(LigerGSAConfig, LigerGSAForCausalLM)
9 |
10 |
11 | __all__ = ['LigerGSAConfig', 'LigerGSAForCausalLM', 'LigerGSAModel']
--------------------------------------------------------------------------------
/liger/models/liger_gsa/configuration_liger_gsa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 | from transformers.models.llama.configuration_llama import LlamaConfig
7 |
8 | class LigerGSAConfig(LlamaConfig, PretrainedConfig):
9 | model_type = 'liger_gsa'
10 | keys_to_ignore_at_inference = ['past_key_values']
11 |
12 | def __init__(
13 | self,
14 | # llama config
15 | vocab_size=32000,
16 | hidden_size=4096,
17 | intermediate_size=11008,
18 | num_hidden_layers=32,
19 | num_attention_heads=32,
20 | num_key_value_heads=None,
21 | hidden_act="silu",
22 | max_position_embeddings=2048,
23 | initializer_range=0.02,
24 | rms_norm_eps=1e-6,
25 | use_cache=True,
26 | pad_token_id=None,
27 | bos_token_id=1,
28 | eos_token_id=2,
29 | pretraining_tp=1,
30 | tie_word_embeddings=False,
31 | rope_theta=10000.0,
32 | rope_scaling=None,
33 | attention_bias=False,
34 | attention_dropout=0.0,
35 | mlp_bias=False,
36 | head_dim=None,
37 | pool_size: int = 64, # pooling
38 | **kwargs,
39 | ):
40 | self.pool_size = pool_size
41 |
42 | super().__init__(
43 | vocab_size=vocab_size,
44 | hidden_size=hidden_size,
45 | intermediate_size=intermediate_size,
46 | num_hidden_layers=num_hidden_layers,
47 | num_attention_heads=num_attention_heads,
48 | num_key_value_heads=num_key_value_heads,
49 | hidden_act=hidden_act,
50 | max_position_embeddings=max_position_embeddings,
51 | initializer_range=initializer_range,
52 | rms_norm_eps=rms_norm_eps,
53 | use_cache=use_cache,
54 | pad_token_id=pad_token_id,
55 | bos_token_id=bos_token_id,
56 | eos_token_id=eos_token_id,
57 | pretraining_tp=pretraining_tp,
58 | tie_word_embeddings=tie_word_embeddings,
59 | rope_theta=rope_theta,
60 | rope_scaling=rope_scaling,
61 | attention_bias=attention_bias,
62 | attention_dropout=attention_dropout,
63 | mlp_bias=mlp_bias,
64 | head_dim=head_dim,
65 | **kwargs,
66 | )
--------------------------------------------------------------------------------
/liger/models/liger_gsa/modeling_liger_gsa.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 | import warnings
5 | from typing import List, Optional, Tuple, Union
6 | from einops import rearrange, repeat
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.utils.checkpoint
12 | from transformers.activations import ACT2FN
13 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
14 | from transformers.modeling_outputs import (
15 | BaseModelOutputWithPast,
16 | CausalLMOutputWithPast,
17 | )
18 | from transformers.models.llama.modeling_llama import (
19 | LlamaRMSNorm,
20 | LlamaRotaryEmbedding,
21 | apply_rotary_pos_emb,
22 | repeat_kv,
23 | LlamaMLP,
24 | LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention,
25 | LlamaDecoderLayer,
26 | LlamaForCausalLM,
27 | LlamaModel,
28 | LlamaPreTrainedModel,
29 | LLAMA_INPUTS_DOCSTRING,
30 | )
31 | from transformers.utils import logging, add_start_docstrings_to_model_forward
32 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
33 |
34 | if is_flash_attn_2_available():
35 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
36 |
37 | from fla.models.utils import Cache as FlaCache
38 | from fla.modules.activations import swish
39 | from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
40 |
41 | from .configuration_liger_gsa import LigerGSAConfig
42 |
43 | logger = logging.get_logger(__name__)
44 |
45 |
46 | class LigerGatedSlotAttention(nn.Module):
47 | def __init__(
48 | self,
49 | config: LigerGSAConfig,
50 | layer_idx: Optional[int] = None,
51 | ):
52 | super().__init__()
53 | self.config = config
54 | self.layer_idx = layer_idx
55 |
56 | self.attention_dropout = config.attention_dropout
57 | self.hidden_size = config.hidden_size
58 | self.num_heads = config.num_attention_heads
59 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
60 | self.num_key_value_heads = config.num_key_value_heads
61 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
62 |
63 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
64 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
65 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
66 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
67 |
68 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
69 |
70 | self.pool_size = config.pool_size
71 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.pool_size * self.num_key_value_heads)
72 |
73 | def forward(
74 | self,
75 | hidden_states: torch.Tensor,
76 | attention_mask: Optional[torch.Tensor] = None,
77 | position_ids: Optional[torch.LongTensor] = None,
78 | past_key_value: Optional[FlaCache] = None,
79 | output_attentions: bool = False,
80 | use_cache: bool = False,
81 | cache_position: Optional[torch.LongTensor] = None,
82 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
83 | **kwargs,
84 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
85 | last_state = None
86 | if past_key_value is not None and len(past_key_value) > self.layer_idx:
87 | last_state = past_key_value[self.layer_idx]
88 |
89 | q = self.q_proj(hidden_states)
90 | k = self.k_proj(hidden_states)
91 | v = self.v_proj(hidden_states)
92 | g = self.pool_g(k)
93 |
94 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
95 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_key_value_heads)
96 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_key_value_heads)
97 | g = rearrange(g, 'b n (h m) -> b h n m', h=self.num_key_value_heads)
98 |
99 | k = repeat_kv(k, self.num_key_value_groups)
100 | v = repeat_kv(v, self.num_key_value_groups)
101 | g = repeat_kv(g, self.num_key_value_groups)
102 |
103 | sq, sk, sv = q, k, v
104 |
105 | gate_logit_normalizer = 16
106 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m)
107 | s = 1 - torch.exp(g)
108 | # dealing with left-padding
109 | if attention_mask is not None:
110 | s = s.mul_(attention_mask[:, None, -s.shape[2]:, None])
111 | v = v.mul_(attention_mask[:, None, -v.shape[2]:, None])
112 |
113 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None
114 | scale = 1
115 |
116 | q, k, v, s, g = (x.to(torch.float32).contiguous() for x in (q, k, v, s, g))
117 |
118 | if self.training or q.shape[-2] > 1:
119 | o_, recurrent_state = chunk_gsa(q, k, v, s, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
120 | else:
121 | o_, recurrent_state = fused_recurrent_gsa(q, k, v, s, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
122 |
123 | if past_key_value is not None:
124 | past_key_value.update(
125 | recurrent_state=recurrent_state,
126 | layer_idx=self.layer_idx,
127 | offset=q.shape[1]
128 | )
129 |
130 | q_len = hidden_states.size(-2)
131 |
132 | if position_embeddings is None:
133 | logger.warning_once(
134 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
135 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
136 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
137 | "removed and `position_embeddings` will be mandatory."
138 | )
139 | cos, sin = self.rotary_emb(sv, position_ids)
140 | else:
141 | cos, sin = position_embeddings
142 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin)
143 |
144 |
145 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
146 | # therefore the input hidden states gets silently casted in float32. Hence, we need
147 | # cast them back in the correct dtype just to be sure everything works as expected.
148 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
149 | # in fp32. (LlamaRMSNorm handles it correctly)
150 |
151 | input_dtype = sq.dtype
152 | if input_dtype == torch.float32:
153 | if torch.is_autocast_enabled():
154 | target_dtype = torch.get_autocast_gpu_dtype()
155 | # Handle the case where the model is quantized
156 | elif hasattr(self.config, "_pre_quantization_dtype"):
157 | target_dtype = self.config._pre_quantization_dtype
158 | else:
159 | target_dtype = self.q_proj.weight.dtype
160 |
161 | logger.warning_once(
162 | f"The input hidden states seems to be silently casted in float32, this might be related to"
163 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
164 | f" {target_dtype}."
165 | )
166 |
167 | sq = sq.to(target_dtype)
168 | sk = sk.to(target_dtype)
169 | sv = sv.to(target_dtype)
170 |
171 | window_size = 64
172 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
173 | sq.transpose(1, 2),
174 | sk.transpose(1, 2),
175 | sv.transpose(1, 2),
176 | attention_mask,
177 | q_len,
178 | position_ids=position_ids,
179 | dropout=0.0,
180 | sliding_window=window_size,
181 | use_top_left_mask=not is_flash_attn_greater_or_equal_2_10(),
182 | is_causal=True,
183 | target_dtype=torch.float32,
184 | **kwargs,
185 | ).transpose(1, 2)
186 |
187 | o_ = 0.5 * y + 0.5 * o_ # 0.5 is important
188 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)')
189 | o = self.o_proj(o)
190 |
191 | return o, None, past_key_value
192 |
193 |
194 |
195 | class LigerGSADecoderLayer(LlamaDecoderLayer):
196 | def __init__(self, config: LigerGSAConfig, layer_idx: int):
197 | super().__init__(config, layer_idx) # layer_idx
198 | self.hidden_size = config.hidden_size
199 | self.self_attn = LigerGatedSlotAttention(config=config, layer_idx=layer_idx)
200 | self.mlp = LlamaMLP(config)
201 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
202 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
203 |
204 | def forward(
205 | self,
206 | hidden_states: torch.Tensor,
207 | attention_mask: Optional[torch.Tensor] = None,
208 | position_ids: Optional[torch.LongTensor] = None,
209 | past_key_value: Optional[Union[FlaCache, Tuple]] = None,
210 | output_attentions: Optional[bool] = False,
211 | use_cache: Optional[bool] = False,
212 | cache_position: Optional[torch.LongTensor] = None,
213 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
214 | **kwargs,
215 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
216 | outputs = super().forward(
217 | hidden_states,
218 | attention_mask,
219 | position_ids,
220 | past_key_value,
221 | output_attentions,
222 | use_cache,
223 | cache_position,
224 | position_embeddings,
225 | **kwargs
226 | )
227 | return outputs
228 |
229 |
230 | class LigerGSAPreTrainedModel(LlamaPreTrainedModel):
231 |
232 | config_class = LigerGSAConfig
233 | base_model_prefix = "model"
234 | supports_gradient_checkpointing = True
235 | _no_split_modules = ['LigerGSADecoderLayer']
236 | _skip_keys_device_placement = "past_key_values"
237 |
238 | def _init_weights(
239 | self,
240 | module,
241 | ):
242 | std = self.config.initializer_range
243 | if isinstance(module, nn.Linear):
244 | module.weight.data.normal_(mean=0.0, std=std)
245 | if module.bias is not None:
246 | module.bias.data.zero_()
247 | elif isinstance(module, nn.Embedding):
248 | module.weight.data.normal_(mean=0.0, std=std)
249 | if module.padding_idx is not None:
250 | module.weight.data[module.padding_idx].zero_()
251 |
252 | class LigerGSAModel(LlamaModel, LigerGSAPreTrainedModel):
253 |
254 | def __init__(self, config: LigerGSAConfig):
255 | super().__init__(config)
256 | self.padding_idx = config.pad_token_id
257 | self.vocab_size = config.vocab_size
258 |
259 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
260 | self.layers = nn.ModuleList(
261 | [LigerGSADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
262 | )
263 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
264 | self.rotary_emb = LlamaRotaryEmbedding(config=config)
265 | self.gradient_checkpointing = False
266 |
267 | # Initialize weights and apply final processing
268 | self.post_init()
269 |
270 | def get_input_embeddings(self):
271 | return self.embed_tokens
272 |
273 | def set_input_embeddings(self, value):
274 | self.embed_tokens = value
275 |
276 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
277 | def forward(
278 | self,
279 | input_ids: torch.LongTensor = None,
280 | attention_mask: Optional[torch.Tensor] = None,
281 | position_ids: Optional[torch.LongTensor] = None,
282 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None,
283 | inputs_embeds: Optional[torch.FloatTensor] = None,
284 | use_cache: Optional[bool] = None,
285 | output_attentions: Optional[bool] = None,
286 | output_hidden_states: Optional[bool] = None,
287 | return_dict: Optional[bool] = None,
288 | cache_position: Optional[torch.LongTensor] = None,
289 | ) -> Union[Tuple, BaseModelOutputWithPast]:
290 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
291 | output_hidden_states = (
292 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
293 | )
294 | use_cache = use_cache if use_cache is not None else self.config.use_cache
295 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
296 |
297 | if (input_ids is None) ^ (inputs_embeds is not None):
298 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
299 |
300 | if self.gradient_checkpointing and self.training and use_cache:
301 | logger.warning_once(
302 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
303 | )
304 | use_cache = False
305 |
306 | if inputs_embeds is None:
307 | inputs_embeds = self.embed_tokens(input_ids)
308 |
309 | # kept for BC (non `Cache` `past_key_values` inputs)
310 | return_legacy_cache = False
311 | if use_cache:
312 | if output_attentions:
313 | # LigerGSA kv
314 | LigerGSA_past_key_values, softmax_past_key_values = None, None
315 | if past_key_values is None:
316 | LigerGSA_past_key_values = FlaCache.from_legacy_cache(past_key_values)
317 | softmax_past_key_values = DynamicCache()
318 | else:
319 | if not isinstance(past_key_values[0], FlaCache):
320 | LigerGSA_past_key_values = FlaCache.from_legacy_cache(past_key_values[0])
321 | # softmax kv
322 | if not isinstance(past_key_values[1], Cache):
323 | return_legacy_cache = True
324 | if past_key_values[1] is None:
325 | softmax_past_key_values = DynamicCache()
326 | else:
327 | softmax_past_key_values = DynamicCache.from_legacy_cache(past_key_values[1])
328 | logger.warning_once(
329 | "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
330 | "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
331 | "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
332 | )
333 |
334 | past_key_values = (LigerGSA_past_key_values, softmax_past_key_values)
335 | else:
336 | # only LigerGSA kv
337 | if not isinstance(past_key_values, FlaCache):
338 | past_key_values = FlaCache.from_legacy_cache(past_key_values)
339 |
340 | if cache_position is None:
341 | if output_attentions:
342 | past_seen_tokens = past_key_values[1].get_seq_length() if past_key_values is not None else 0
343 | else:
344 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
345 | cache_position = torch.arange(
346 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
347 | )
348 | if position_ids is None:
349 | position_ids = cache_position.unsqueeze(0)
350 |
351 | if output_attentions:
352 | causal_mask = self._update_causal_mask(
353 | attention_mask, inputs_embeds, cache_position, past_key_values[1], output_attentions
354 | )
355 | causal_mask = (attention_mask, causal_mask)
356 | else:
357 | causal_mask = attention_mask
358 | hidden_states = inputs_embeds
359 |
360 | # create position embeddings to be shared across the decoder layers
361 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
362 |
363 | # decoder layers
364 | all_hidden_states = () if output_hidden_states else None
365 | all_self_attns = () if output_attentions else None
366 | next_decoder_cache = None
367 |
368 | if output_attentions:
369 | all_softmax_hidden_states = ()
370 |
371 | for decoder_layer in self.layers:
372 | if output_hidden_states:
373 | all_hidden_states += (hidden_states,)
374 | if all_softmax_hidden_states is not None:
375 | all_softmax_hidden_states += (hidden_states,)
376 |
377 | if self.gradient_checkpointing and self.training:
378 | layer_outputs = self._gradient_checkpointing_func(
379 | decoder_layer.__call__,
380 | hidden_states,
381 | causal_mask,
382 | position_ids,
383 | past_key_values,
384 | output_attentions,
385 | use_cache,
386 | cache_position,
387 | position_embeddings,
388 | )
389 |
390 | else:
391 | layer_outputs = decoder_layer(
392 | hidden_states,
393 | attention_mask=causal_mask,
394 | position_ids=position_ids,
395 | past_key_value=past_key_values,
396 | output_attentions=output_attentions,
397 | use_cache=use_cache,
398 | cache_position=cache_position,
399 | position_embeddings=position_embeddings,
400 | )
401 | hidden_states = layer_outputs[0]
402 |
403 | if use_cache:
404 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
405 |
406 | if output_attentions:
407 | all_self_attns += (layer_outputs[1],)
408 |
409 | hidden_states = self.norm(hidden_states)
410 |
411 | # add hidden states from the last decoder layer
412 | if output_hidden_states:
413 | all_hidden_states += (hidden_states,)
414 |
415 | if output_attentions:
416 | next_cache = next_decoder_cache[1] if use_cache else None
417 | if return_legacy_cache:
418 | next_cache = next_cache.to_legacy_cache()
419 |
420 | next_cache = (next_decoder_cache[0], next_cache)
421 | else:
422 | next_cache = next_decoder_cache if use_cache else None
423 |
424 | if not return_dict:
425 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
426 | return BaseModelOutputWithPast(
427 | last_hidden_state=hidden_states,
428 | past_key_values=next_cache,
429 | hidden_states=all_hidden_states,
430 | attentions=all_self_attns,
431 | )
432 |
433 | class LigerGSAForCausalLM(LlamaForCausalLM, LigerGSAPreTrainedModel):
434 | _tied_weights_keys = ["lm_head.weight"]
435 |
436 | def __init__(self, config):
437 | super().__init__(config)
438 | self.model = LigerGSAModel(config)
439 | self.vocab_size = config.vocab_size
440 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
441 |
442 | # Initialize weights and apply final processing
443 | self.post_init()
--------------------------------------------------------------------------------
/liger/models/liger_hgrn2/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 |
3 | from liger.models.liger_hgrn2.configuration_liger_hgrn2 import LigerHGRN2Config
4 | from liger.models.liger_hgrn2.modeling_liger_hgrn2 import LigerHGRN2ForCausalLM, LigerHGRN2Model
5 |
6 | AutoConfig.register(LigerHGRN2Config.model_type, LigerHGRN2Config)
7 | AutoModel.register(LigerHGRN2Config, LigerHGRN2Model)
8 | AutoModelForCausalLM.register(LigerHGRN2Config, LigerHGRN2ForCausalLM)
9 |
10 |
11 | __all__ = ['LigerHGRN2Config', 'LigerHGRN2ForCausalLM', 'LigerHGRN2Model']
--------------------------------------------------------------------------------
/liger/models/liger_hgrn2/configuration_liger_hgrn2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 | from transformers.models.llama.configuration_llama import LlamaConfig
7 |
8 | class LigerHGRN2Config(LlamaConfig, PretrainedConfig):
9 | model_type = 'liger_hgrn2'
10 | keys_to_ignore_at_inference = ['past_key_values']
11 |
12 | def __init__(
13 | self,
14 | # llama config
15 | vocab_size=32000,
16 | hidden_size=4096,
17 | intermediate_size=11008,
18 | num_hidden_layers=32,
19 | num_attention_heads=32,
20 | num_key_value_heads=None,
21 | hidden_act="silu",
22 | max_position_embeddings=2048,
23 | initializer_range=0.02,
24 | rms_norm_eps=1e-6,
25 | use_cache=True,
26 | pad_token_id=None,
27 | bos_token_id=1,
28 | eos_token_id=2,
29 | pretraining_tp=1,
30 | tie_word_embeddings=False,
31 | rope_theta=10000.0,
32 | rope_scaling=None,
33 | attention_bias=False,
34 | attention_dropout=0.0,
35 | mlp_bias=False,
36 | head_dim=None,
37 | pool_size: int = 128, # pooling
38 | **kwargs,
39 | ):
40 | self.pool_size = pool_size
41 |
42 | super().__init__(
43 | vocab_size=vocab_size,
44 | hidden_size=hidden_size,
45 | intermediate_size=intermediate_size,
46 | num_hidden_layers=num_hidden_layers,
47 | num_attention_heads=num_attention_heads,
48 | num_key_value_heads=num_key_value_heads,
49 | hidden_act=hidden_act,
50 | max_position_embeddings=max_position_embeddings,
51 | initializer_range=initializer_range,
52 | rms_norm_eps=rms_norm_eps,
53 | use_cache=use_cache,
54 | pad_token_id=pad_token_id,
55 | bos_token_id=bos_token_id,
56 | eos_token_id=eos_token_id,
57 | pretraining_tp=pretraining_tp,
58 | tie_word_embeddings=tie_word_embeddings,
59 | rope_theta=rope_theta,
60 | rope_scaling=rope_scaling,
61 | attention_bias=attention_bias,
62 | attention_dropout=attention_dropout,
63 | mlp_bias=mlp_bias,
64 | head_dim=head_dim,
65 | **kwargs,
66 | )
--------------------------------------------------------------------------------
/liger/models/liger_hgrn2/modeling_liger_hgrn2.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from typing import List, Optional, Tuple, Union
4 | from einops import rearrange, repeat
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.utils.checkpoint
10 | from transformers.activations import ACT2FN
11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
12 | from transformers.modeling_outputs import (
13 | BaseModelOutputWithPast,
14 | CausalLMOutputWithPast,
15 | )
16 | from transformers.models.llama.modeling_llama import (
17 | LlamaRMSNorm,
18 | LlamaRotaryEmbedding,
19 | apply_rotary_pos_emb,
20 | repeat_kv,
21 | LlamaMLP,
22 | LlamaAttention, # LlamaFlashAttention2, LlamaSdpaAttention,
23 | LlamaDecoderLayer,
24 | LlamaForCausalLM,
25 | LlamaModel,
26 | LlamaPreTrainedModel,
27 | LLAMA_INPUTS_DOCSTRING,
28 | )
29 |
30 |
31 | from transformers.utils import logging, add_start_docstrings_to_model_forward
32 | from transformers.utils import is_flash_attn_2_available
33 |
34 | if is_flash_attn_2_available():
35 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
36 |
37 | from fla.modules.activations import swish
38 | from fla.models.utils import Cache as FlaCache
39 | from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
40 |
41 | from liger.models.liger_hgrn2 import LigerHGRN2Config
42 |
43 | logger = logging.get_logger(__name__)
44 |
45 |
46 | class LigerHGRN2Attention(nn.Module):
47 | def __init__(
48 | self,
49 | config: LigerHGRN2Config,
50 | layer_idx: Optional[int] = None,
51 | ):
52 | super().__init__()
53 | self.config = config
54 | self.layer_idx = layer_idx
55 |
56 | self.attention_dropout = config.attention_dropout
57 | self.hidden_size = config.hidden_size
58 | self.num_heads = config.num_attention_heads
59 | self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
60 | self.num_key_value_heads = config.num_key_value_heads
61 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
62 |
63 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
64 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
65 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
66 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
67 |
68 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
69 |
70 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads)
71 |
72 | def forward(
73 | self,
74 | hidden_states: torch.Tensor,
75 | attention_mask: Optional[torch.Tensor] = None,
76 | position_ids: Optional[torch.LongTensor] = None,
77 | past_key_value: Optional[FlaCache] = None,
78 | output_attentions: bool = False,
79 | use_cache: bool = False,
80 | cache_position: Optional[torch.LongTensor] = None,
81 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
82 | **kwargs,
83 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
84 | last_state = None
85 | if past_key_value is not None and len(past_key_value) > self.layer_idx:
86 | last_state = past_key_value[self.layer_idx]
87 |
88 | q = self.q_proj(hidden_states)
89 | k = self.k_proj(hidden_states)
90 | v = self.v_proj(hidden_states)
91 | f = self.pool_g(k)
92 |
93 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
94 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_key_value_heads)
95 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_key_value_heads)
96 | f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_key_value_heads)
97 |
98 | k = repeat_kv(k, self.num_key_value_groups)
99 | v = repeat_kv(v, self.num_key_value_groups)
100 | f = repeat_kv(f, self.num_key_value_groups)
101 |
102 | sq, sk, sv = q, k, v
103 |
104 | # the lower bound for the first layer is zero
105 | lower_bound = None
106 | if lower_bound is None or self.layer_idx == 0:
107 | k, g = 1 - f.sigmoid(), F.logsigmoid(f)
108 | else:
109 | g = lower_bound + (1 - lower_bound) * f.sigmoid()
110 | k, g = 1 - g, g.log()
111 |
112 | # norm
113 | q = F.softmax(q, dim=-1)
114 | # k = F.softmax(k, dim=-1)
115 |
116 | # dealing with left-padding
117 | if attention_mask is not None:
118 | v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
119 |
120 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None
121 | offsets = kwargs.get('offsets', None)
122 | scale = 1 # default
123 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g))
124 |
125 | if self.training or q.shape[-2] > 1:
126 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
127 | else:
128 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True, offsets=offsets)
129 |
130 | if past_key_value is not None:
131 | past_key_value.update(
132 | recurrent_state=recurrent_state,
133 | layer_idx=self.layer_idx,
134 | offset=q.shape[1]
135 | )
136 |
137 | q_len = hidden_states.size(-2)
138 |
139 | if position_embeddings is None:
140 | logger.warning_once(
141 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
142 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
143 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
144 | "removed and `position_embeddings` will be mandatory."
145 | )
146 | cos, sin = self.rotary_emb(sv, position_ids)
147 | else:
148 | cos, sin = position_embeddings
149 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin)
150 |
151 | # if past_key_value is not None:
152 | # # sin and cos are specific to RoPE models; cache_position needed for the static cache
153 | # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
154 | # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
155 |
156 | input_dtype = sq.dtype
157 | if input_dtype == torch.float32:
158 | if torch.is_autocast_enabled():
159 | target_dtype = torch.get_autocast_gpu_dtype()
160 | # Handle the case where the model is quantized
161 | elif hasattr(self.config, "_pre_quantization_dtype"):
162 | target_dtype = self.config._pre_quantization_dtype
163 | else:
164 | target_dtype = self.q_proj.weight.dtype
165 |
166 | logger.warning_once(
167 | f"The input hidden states seems to be silently casted in float32, this might be related to"
168 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
169 | f" {target_dtype}."
170 | )
171 |
172 | sq = sq.to(target_dtype)
173 | sk = sk.to(target_dtype)
174 | sv = sv.to(target_dtype)
175 |
176 | window_size = 64
177 |
178 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
179 | sq.transpose(1, 2),
180 | sk.transpose(1, 2),
181 | sv.transpose(1, 2),
182 | attention_mask,
183 | q_len,
184 | position_ids=position_ids,
185 | dropout=0.0,
186 | sliding_window=window_size,
187 | use_top_left_mask=False,
188 | is_causal=True,
189 | target_dtype=torch.float32,
190 | ).transpose(1, 2)
191 |
192 | o_ = 0.5 * y + 0.5 * o_
193 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)')
194 | o = self.o_proj(o)
195 |
196 | return o, o_, past_key_value
197 |
198 | class LigerHGRN2DecoderLayer(LlamaDecoderLayer):
199 | def __init__(self, config: LigerHGRN2Config, layer_idx: int):
200 | super().__init__(config, layer_idx)
201 | self.hidden_size = config.hidden_size
202 | self.self_attn = LigerHGRN2Attention(config=config, layer_idx=layer_idx)
203 | self.mlp = LlamaMLP(config)
204 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
205 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
206 |
207 |
208 | class LigerHGRN2PreTrainedModel(LlamaPreTrainedModel):
209 |
210 | config_class = LigerHGRN2Config
211 | base_model_prefix = "model"
212 | supports_gradient_checkpointing = True
213 | _no_split_modules = ['LigerHGRN2DecoderLayer']
214 | _skip_keys_device_placement = "past_key_values"
215 |
216 | class LigerHGRN2Model(LlamaModel, LigerHGRN2PreTrainedModel):
217 |
218 | def __init__(self, config: LigerHGRN2Config):
219 | super().__init__(config)
220 | self.padding_idx = config.pad_token_id
221 | self.vocab_size = config.vocab_size
222 |
223 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
224 | self.layers = nn.ModuleList(
225 | [LigerHGRN2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
226 | )
227 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
228 | self.rotary_emb = LlamaRotaryEmbedding(config=config)
229 | self.gradient_checkpointing = False
230 |
231 | # Initialize weights and apply final processing
232 | self.post_init()
233 |
234 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
235 | def forward(
236 | self,
237 | input_ids: torch.LongTensor = None,
238 | attention_mask: Optional[torch.Tensor] = None,
239 | position_ids: Optional[torch.LongTensor] = None,
240 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None,
241 | inputs_embeds: Optional[torch.FloatTensor] = None,
242 | use_cache: Optional[bool] = None,
243 | output_attentions: Optional[bool] = None,
244 | output_hidden_states: Optional[bool] = None,
245 | return_dict: Optional[bool] = None,
246 | cache_position: Optional[torch.LongTensor] = None,
247 | **kwargs,
248 | ) -> Union[Tuple, BaseModelOutputWithPast]:
249 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
250 | output_hidden_states = (
251 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
252 | )
253 | use_cache = use_cache if use_cache is not None else self.config.use_cache
254 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
255 |
256 | if (input_ids is None) ^ (inputs_embeds is not None):
257 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
258 |
259 | if self.gradient_checkpointing and self.training and use_cache:
260 | logger.warning_once(
261 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
262 | )
263 | use_cache = False
264 |
265 | if inputs_embeds is None:
266 | inputs_embeds = self.embed_tokens(input_ids)
267 |
268 |
269 | # kept for BC (non `Cache` `past_key_values` inputs)
270 | return_legacy_cache = False
271 | if use_cache and not isinstance(past_key_values, FlaCache):
272 | past_key_values = FlaCache.from_legacy_cache(past_key_values)
273 |
274 | if cache_position is None:
275 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
276 | cache_position = torch.arange(
277 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
278 | )
279 |
280 | if position_ids is None:
281 | position_ids = cache_position.unsqueeze(0)
282 |
283 | causal_mask = attention_mask
284 |
285 | hidden_states = inputs_embeds
286 |
287 | # create position embeddings to be shared across the decoder layers
288 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
289 |
290 | # decoder layers
291 | all_hidden_states = () if output_hidden_states else None
292 | all_self_attns = () if output_attentions else None
293 | next_decoder_cache = None
294 |
295 | if output_attentions:
296 | all_softmax_hidden_states = ()
297 |
298 | for decoder_layer in self.layers:
299 | if output_hidden_states:
300 | all_hidden_states += (hidden_states,)
301 | if all_softmax_hidden_states is not None:
302 | all_softmax_hidden_states += (hidden_states,)
303 |
304 | if self.gradient_checkpointing and self.training:
305 | layer_outputs = self._gradient_checkpointing_func(
306 | decoder_layer.__call__,
307 | hidden_states,
308 | causal_mask,
309 | position_ids,
310 | past_key_values,
311 | output_attentions,
312 | use_cache,
313 | cache_position,
314 | position_embeddings,
315 | )
316 |
317 | else:
318 | layer_outputs = decoder_layer(
319 | hidden_states,
320 | attention_mask=causal_mask,
321 | position_ids=position_ids,
322 | past_key_value=past_key_values,
323 | output_attentions=output_attentions,
324 | use_cache=use_cache,
325 | cache_position=cache_position,
326 | position_embeddings=position_embeddings,
327 | )
328 | hidden_states = layer_outputs[0]
329 |
330 | if use_cache:
331 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
332 |
333 | if output_attentions:
334 | all_self_attns += (layer_outputs[1],)
335 |
336 | hidden_states = self.norm(hidden_states)
337 |
338 | # add hidden states from the last decoder layer
339 | if output_hidden_states:
340 | all_hidden_states += (hidden_states,)
341 |
342 | next_cache = next_decoder_cache if use_cache else None
343 |
344 | if not return_dict:
345 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
346 | return BaseModelOutputWithPast(
347 | last_hidden_state=hidden_states,
348 | past_key_values=next_cache,
349 | hidden_states=all_hidden_states,
350 | attentions=all_self_attns,
351 | )
352 |
353 | class LigerHGRN2ForCausalLM(LlamaForCausalLM, LigerHGRN2PreTrainedModel):
354 | _tied_weights_keys = ["lm_head.weight"]
355 |
356 | def __init__(self, config):
357 | super().__init__(config)
358 | self.model = LigerHGRN2Model(config)
359 | self.vocab_size = config.vocab_size
360 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361 |
362 | # Initialize weights and apply final processing
363 | self.post_init()
--------------------------------------------------------------------------------
/liger/models/liger_mistral_gla/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 |
3 | from liger.models.liger_mistral_gla.configuration_liger_mistral_gla import LigerMistralGLAConfig
4 | from liger.models.liger_mistral_gla.modeling_liger_mistral_gla import LigerMistralGLAForCausalLM, LigerMistralGLAModel
5 |
6 | AutoConfig.register(LigerMistralGLAConfig.model_type, LigerMistralGLAConfig)
7 | AutoModel.register(LigerMistralGLAConfig, LigerMistralGLAModel)
8 | AutoModelForCausalLM.register(LigerMistralGLAConfig, LigerMistralGLAForCausalLM)
9 |
10 |
11 | __all__ = ['LigerMistralGLAConfig', 'LigerMistralGLAForCausalLM', 'LigerMistralGLAModel']
--------------------------------------------------------------------------------
/liger/models/liger_mistral_gla/configuration_liger_mistral_gla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 | from transformers.utils import logging
7 | from transformers.models.mistral.configuration_mistral import MistralConfig
8 |
9 |
10 | logger = logging.get_logger(__name__)
11 |
12 | class LigerMistralGLAConfig(MistralConfig, PretrainedConfig):
13 | model_type = "liger_mistral_gla"
14 | keys_to_ignore_at_inference = ["past_key_values"]
15 |
16 | base_model_tp_plan = {
17 | "layers.*.self_attn.q_proj": "colwise",
18 | "layers.*.self_attn.k_proj": "colwise",
19 | "layers.*.self_attn.v_proj": "colwise",
20 | "layers.*.self_attn.o_proj": "rowwise",
21 | "layers.*.mlp.gate_proj": "colwise",
22 | "layers.*.mlp.up_proj": "colwise",
23 | "layers.*.mlp.down_proj": "rowwise",
24 | }
25 | def __init__(
26 | self,
27 | vocab_size=32000,
28 | hidden_size=4096,
29 | intermediate_size=14336,
30 | num_hidden_layers=32,
31 | num_attention_heads=32,
32 | num_key_value_heads=8,
33 | head_dim=None,
34 | hidden_act="silu",
35 | max_position_embeddings=4096 * 32,
36 | initializer_range=0.02,
37 | rms_norm_eps=1e-6,
38 | use_cache=True,
39 | pad_token_id=None,
40 | bos_token_id=1,
41 | eos_token_id=2,
42 | tie_word_embeddings=False,
43 | rope_theta=10000.0,
44 | sliding_window=4096,
45 | attention_dropout=0.0,
46 | # linear attention
47 | expand_k: int = 1,
48 | expand_v: int = 1,
49 | hidden_ratio: Optional[int] = 4,
50 | **kwargs,
51 | ):
52 | self.expand_k = expand_k
53 | self.expand_v = expand_v
54 | self.hidden_ratio = hidden_ratio
55 |
56 | super().__init__(
57 | vocab_size=vocab_size,
58 | hidden_size=hidden_size,
59 | intermediate_size=intermediate_size,
60 | num_hidden_layers=num_hidden_layers,
61 | num_attention_heads=num_attention_heads,
62 | num_key_value_heads=num_key_value_heads,
63 | head_dim=head_dim,
64 | hidden_act=hidden_act,
65 | max_position_embeddings=max_position_embeddings,
66 | initializer_range=initializer_range,
67 | rms_norm_eps=rms_norm_eps,
68 | use_cache=use_cache,
69 | pad_token_id=pad_token_id,
70 | bos_token_id=bos_token_id,
71 | eos_token_id=eos_token_id,
72 | tie_word_embeddings=tie_word_embeddings,
73 | rope_theta=rope_theta,
74 | sliding_window=sliding_window,
75 | attention_dropout=attention_dropout,
76 | **kwargs,
77 | )
78 |
--------------------------------------------------------------------------------
/liger/models/liger_mistral_gla/modeling_liger_mistral_gla.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | import copy
4 | from typing import List, Optional, Tuple, Union
5 | from einops import rearrange, repeat
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint
11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
12 | from transformers.modeling_outputs import (
13 | BaseModelOutputWithPast,
14 | CausalLMOutputWithPast,
15 | )
16 |
17 | from transformers.models.mistral.modeling_mistral import (
18 | MistralRMSNorm,
19 | MistralMLP,
20 | repeat_kv,
21 | apply_rotary_pos_emb,
22 | MistralRotaryEmbedding,
23 | MistralDecoderLayer,
24 | MistralForCausalLM,
25 | MistralModel,
26 | MistralPreTrainedModel,
27 | )
28 | from transformers.utils import logging, add_start_docstrings_to_model_forward
29 | from transformers.utils import is_flash_attn_2_available
30 |
31 | if is_flash_attn_2_available():
32 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
33 |
34 | from fla.modules.activations import swish
35 | from fla.models.utils import Cache as FlaCache
36 | from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla
37 |
38 | from liger.models.liger_mistral_gla import LigerMistralGLAConfig
39 |
40 | logger = logging.get_logger(__name__)
41 |
42 |
43 | class LigerMistralGatedLinearAttention(nn.Module):
44 | def __init__(
45 | self,
46 | config: LigerMistralGLAConfig,
47 | layer_idx: Optional[int] = None,
48 | ):
49 | super().__init__()
50 | self.config = config
51 | self.layer_idx = layer_idx
52 | if layer_idx is None:
53 | logger.warning_once(
54 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
55 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
56 | "when creating this class."
57 | )
58 |
59 | self.attention_dropout = config.attention_dropout
60 | self.hidden_size = config.hidden_size
61 | self.num_heads = config.num_attention_heads
62 | self.head_dim = config.head_dim
63 | self.num_key_value_heads = config.num_key_value_heads
64 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
65 | self.max_position_embeddings = config.max_position_embeddings
66 | self.rope_theta = config.rope_theta
67 | self.is_causal = True
68 |
69 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
70 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
71 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
72 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
73 |
74 | self.rotary_emb = MistralRotaryEmbedding(
75 | self.head_dim,
76 | max_position_embeddings=self.max_position_embeddings,
77 | base=self.rope_theta,
78 | )
79 |
80 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads)
81 |
82 | def forward(
83 | self,
84 | hidden_states: torch.Tensor,
85 | attention_mask: Optional[torch.Tensor] = None,
86 | position_ids: Optional[torch.LongTensor] = None,
87 | past_key_value: Optional[FlaCache] = None,
88 | output_attentions: bool = False,
89 | use_cache: bool = False,
90 | cache_position: Optional[torch.LongTensor] = None,
91 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
92 | **kwargs,
93 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
94 | last_state = None
95 | if past_key_value is not None and len(past_key_value) > self.layer_idx:
96 | last_state = past_key_value[self.layer_idx]
97 |
98 | bsz, q_len, _ = hidden_states.size()
99 |
100 | q = self.q_proj(hidden_states)
101 | k = self.k_proj(hidden_states)
102 | v = self.v_proj(hidden_states)
103 | g = self.pool_g(k)
104 |
105 | q = q.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
106 | k = k.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
107 | v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
108 | g = g.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
109 |
110 | sv = v
111 | cos, sin = self.rotary_emb(sv, position_ids)
112 | sq, sk = apply_rotary_pos_emb(q, k, cos, sin)
113 |
114 | k = repeat_kv(k, self.num_key_value_groups)
115 | v = repeat_kv(v, self.num_key_value_groups)
116 | g = repeat_kv(g, self.num_key_value_groups)
117 |
118 | sk = repeat_kv(sk, self.num_key_value_groups)
119 | sv = repeat_kv(sv, self.num_key_value_groups)
120 |
121 | # norm
122 | q = F.softmax(q, dim=-1)
123 | k = F.softmax(k, dim=-1)
124 |
125 | # dealing with left-padding
126 | if attention_mask is not None:
127 | v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
128 |
129 | gate_logit_normalizer = 16
130 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m)
131 |
132 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None
133 | offsets = kwargs.get('offsets', None)
134 | scale = 1
135 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g))
136 |
137 | if self.training or q.shape[-2] > 1:
138 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
139 | else:
140 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True, offsets=offsets)
141 |
142 | if past_key_value is not None:
143 | past_key_value.update(
144 | recurrent_state=recurrent_state,
145 | layer_idx=self.layer_idx,
146 | offset=q.shape[1]
147 | )
148 |
149 | # if past_key_value is not None:
150 | # # sin and cos are specific to RoPE models; cache_position needed for the static cache
151 | # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
152 | # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
153 |
154 | input_dtype = sq.dtype
155 | if input_dtype == torch.float32:
156 | if torch.is_autocast_enabled():
157 | target_dtype = torch.get_autocast_gpu_dtype()
158 | # Handle the case where the model is quantized
159 | elif hasattr(self.config, "_pre_quantization_dtype"):
160 | target_dtype = self.config._pre_quantization_dtype
161 | else:
162 | target_dtype = self.q_proj.weight.dtype
163 |
164 | logger.warning_once(
165 | f"The input hidden states seems to be silently casted in float32, this might be related to"
166 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
167 | f" {target_dtype}."
168 | )
169 |
170 | sq = sq.to(target_dtype)
171 | sk = sk.to(target_dtype)
172 | sv = sv.to(target_dtype)
173 |
174 | window_size = 64
175 |
176 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
177 | sq.transpose(1, 2),
178 | sk.transpose(1, 2),
179 | sv.transpose(1, 2),
180 | attention_mask,
181 | q_len,
182 | position_ids=position_ids,
183 | dropout=0.0,
184 | sliding_window=window_size,
185 | use_top_left_mask=False,
186 | is_causal=True,
187 | target_dtype=torch.float32,
188 | ).transpose(1, 2)
189 |
190 | o_ = 0.5 * y + 0.5 * o_
191 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)')
192 | o = self.o_proj(o)
193 |
194 | return o, o_, past_key_value
195 |
196 |
197 | class LigerMistralGLADecoderLayer(MistralDecoderLayer):
198 | def __init__(self, config: LigerMistralGLAConfig, layer_idx: int):
199 | super().__init__(config, layer_idx)
200 | self.hidden_size = config.hidden_size
201 | self.self_attn = LigerMistralGatedLinearAttention(config=config, layer_idx=layer_idx)
202 | self.mlp = MistralMLP(config)
203 | self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
204 | self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
205 |
206 |
207 | class LigerMistralPreTrainedModel(MistralPreTrainedModel):
208 | config_class = LigerMistralGLAConfig
209 | base_model_prefix = "model"
210 | supports_gradient_checkpointing = True
211 | _no_split_modules = ["MistralDecoderLayer"]
212 | _skip_keys_device_placement = "past_key_values"
213 |
214 |
215 | class LigerMistralGLAModel(MistralModel, LigerMistralPreTrainedModel):
216 | def __init__(self, config: LigerMistralGLAConfig):
217 | super().__init__(config)
218 | self.layers = nn.ModuleList(
219 | [LigerMistralGLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
220 | )
221 |
222 | def forward(
223 | self,
224 | input_ids: torch.LongTensor = None,
225 | attention_mask: Optional[torch.Tensor] = None,
226 | position_ids: Optional[torch.LongTensor] = None,
227 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None,
228 | inputs_embeds: Optional[torch.FloatTensor] = None,
229 | use_cache: Optional[bool] = None,
230 | output_attentions: Optional[bool] = None,
231 | output_hidden_states: Optional[bool] = None,
232 | return_dict: Optional[bool] = None,
233 | cache_position: Optional[torch.LongTensor] = None,
234 | ) -> Union[Tuple, BaseModelOutputWithPast]:
235 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
236 | output_hidden_states = (
237 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
238 | )
239 | use_cache = use_cache if use_cache is not None else self.config.use_cache
240 |
241 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
242 |
243 | # retrieve input_ids and inputs_embeds
244 | if (input_ids is None) ^ (inputs_embeds is not None):
245 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
246 |
247 | if self.gradient_checkpointing and self.training and use_cache:
248 | logger.warning_once(
249 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
250 | )
251 | use_cache = False
252 |
253 | if inputs_embeds is None:
254 | inputs_embeds = self.embed_tokens(input_ids)
255 |
256 | # kept for BC (non `Cache` `past_key_values` inputs)
257 | return_legacy_cache = False
258 | if use_cache and not isinstance(past_key_values, FlaCache):
259 | past_key_values = FlaCache.from_legacy_cache(past_key_values)
260 |
261 | if cache_position is None:
262 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
263 | cache_position = torch.arange(
264 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
265 | )
266 |
267 | if position_ids is None:
268 | position_ids = cache_position.unsqueeze(0)
269 |
270 | causal_mask = self._update_causal_mask(
271 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
272 | )
273 |
274 | hidden_states = inputs_embeds
275 |
276 | # decoder layers
277 | all_hidden_states = () if output_hidden_states else None
278 | all_self_attns = () if output_attentions else None
279 | next_decoder_cache = None
280 |
281 | for decoder_layer in self.layers:
282 | if output_hidden_states:
283 | all_hidden_states += (hidden_states,)
284 |
285 | if self.gradient_checkpointing and self.training:
286 | layer_outputs = self._gradient_checkpointing_func(
287 | decoder_layer.__call__,
288 | hidden_states,
289 | causal_mask,
290 | position_ids,
291 | past_key_values,
292 | output_attentions,
293 | use_cache,
294 | cache_position,
295 | )
296 | else:
297 | layer_outputs = decoder_layer(
298 | hidden_states,
299 | attention_mask=causal_mask,
300 | position_ids=position_ids,
301 | past_key_value=past_key_values,
302 | output_attentions=output_attentions,
303 | use_cache=use_cache,
304 | cache_position=cache_position,
305 | )
306 |
307 | hidden_states = layer_outputs[0]
308 |
309 | if use_cache:
310 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
311 |
312 | if output_attentions:
313 | all_self_attns += (layer_outputs[1],)
314 |
315 | hidden_states = self.norm(hidden_states)
316 |
317 | # add hidden states from the last decoder layer
318 | if output_hidden_states:
319 | all_hidden_states += (hidden_states,)
320 |
321 | # next_cache = next_decoder_cache if use_cache else None
322 | # if return_legacy_cache:
323 | # next_cache = next_cache.to_legacy_cache()
324 | next_cache = next_decoder_cache if use_cache else None
325 |
326 | if not return_dict:
327 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
328 | return BaseModelOutputWithPast(
329 | last_hidden_state=hidden_states,
330 | past_key_values=next_cache,
331 | hidden_states=all_hidden_states,
332 | attentions=all_self_attns,
333 | )
334 |
335 | class LigerMistralGLAForCausalLM(MistralForCausalLM, LigerMistralPreTrainedModel):
336 | _tied_weights_keys = ["lm_head.weight"]
337 | _tp_plan = {"lm_head": "colwise_rep"}
338 |
339 | def __init__(self, config: LigerMistralGLAConfig):
340 | super().__init__(config)
341 | self.model = LigerMistralGLAModel(config)
342 | self.vocab_size = config.vocab_size
343 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
344 |
345 | # Initialize weights and apply final processing
346 | self.post_init()
--------------------------------------------------------------------------------
/liger/models/liger_qwen2_gla/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 |
3 | from liger.models.liger_qwen2_gla.configuration_liger_qwen2_gla import LigerQwen2GLAConfig
4 | from liger.models.liger_qwen2_gla.modeling_liger_qwen2_gla import LigerQwen2GLAForCausalLM, LigerQwen2GLAModel
5 |
6 | AutoConfig.register(LigerQwen2GLAConfig.model_type, LigerQwen2GLAConfig)
7 | AutoModel.register(LigerQwen2GLAConfig, LigerQwen2GLAModel)
8 | AutoModelForCausalLM.register(LigerQwen2GLAConfig, LigerQwen2GLAForCausalLM)
9 |
10 |
11 | __all__ = ['LigerQwen2GLAConfig', 'LigerQwen2GLAForCausalLM', 'LigerQwen2GLAModel']
--------------------------------------------------------------------------------
/liger/models/liger_qwen2_gla/configuration_liger_qwen2_gla.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 | from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
7 |
8 | class LigerQwen2GLAConfig(Qwen2Config, PretrainedConfig):
9 | model_type = "liger_qwen2_gla"
10 | keys_to_ignore_at_inference = ["past_key_values"]
11 |
12 | def __init__(
13 | self,
14 | vocab_size=151936,
15 | hidden_size=4096,
16 | intermediate_size=22016,
17 | num_hidden_layers=32,
18 | num_attention_heads=32,
19 | num_key_value_heads=32,
20 | hidden_act="silu",
21 | max_position_embeddings=32768,
22 | initializer_range=0.02,
23 | rms_norm_eps=1e-6,
24 | use_cache=True,
25 | tie_word_embeddings=False,
26 | rope_theta=10000.0,
27 | rope_scaling=None,
28 | use_sliding_window=False,
29 | sliding_window=4096,
30 | max_window_layers=28,
31 | attention_dropout=0.0,
32 | **kwargs,
33 | ):
34 | super().__init__(
35 | vocab_size=vocab_size,
36 | hidden_size=hidden_size,
37 | intermediate_size=intermediate_size,
38 | num_hidden_layers=num_hidden_layers,
39 | num_attention_heads=num_attention_heads,
40 | num_key_value_heads=num_key_value_heads,
41 | hidden_act=hidden_act,
42 | max_position_embeddings=max_position_embeddings,
43 | initializer_range=initializer_range,
44 | rms_norm_eps=rms_norm_eps,
45 | use_cache=use_cache,
46 | tie_word_embeddings=tie_word_embeddings,
47 | rope_theta=rope_theta,
48 | rope_scaling=rope_scaling,
49 | use_sliding_window=use_sliding_window,
50 | sliding_window=sliding_window,
51 | max_window_layers=max_window_layers,
52 | attention_dropout=attention_dropout,
53 | **kwargs,
54 | )
--------------------------------------------------------------------------------
/liger/models/liger_qwen2_gla/modeling_liger_qwen2_gla.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | import copy
4 | from typing import List, Optional, Tuple, Union
5 | from einops import rearrange, repeat
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint
11 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
12 | from transformers.generation import GenerationMixin
13 | from transformers.modeling_outputs import (
14 | BaseModelOutputWithPast,
15 | CausalLMOutputWithPast,
16 | )
17 | from transformers.models.qwen2.modeling_qwen2 import (
18 | repeat_kv,
19 | apply_rotary_pos_emb,
20 | Qwen2RotaryEmbedding,
21 | Qwen2RMSNorm,
22 | Qwen2MLP,
23 | Qwen2Attention,
24 | Qwen2DecoderLayer,
25 | Qwen2PreTrainedModel,
26 | Qwen2Model,
27 | Qwen2ForCausalLM,
28 | )
29 | from transformers.utils import (
30 | LossKwargs,
31 | add_code_sample_docstrings,
32 | add_start_docstrings,
33 | add_start_docstrings_to_model_forward,
34 | logging,
35 | replace_return_docstrings,
36 | )
37 |
38 | from transformers.utils import is_flash_attn_2_available
39 |
40 | if is_flash_attn_2_available():
41 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
42 | else:
43 | print("flash_attn_2 is not available")
44 |
45 | from fla.models.utils import Cache as FlaCache
46 | from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla
47 |
48 | from .configuration_liger_qwen2_gla import LigerQwen2GLAConfig
49 |
50 | logger = logging.get_logger(__name__)
51 |
52 | class LigerQwen2GatedLinearAttention(nn.Module):
53 | def __init__(
54 | self,
55 | config: LigerQwen2GLAConfig,
56 | layer_idx: Optional[int] = None,
57 | ):
58 | super().__init__()
59 | self.config = config
60 | self.layer_idx = layer_idx
61 | if layer_idx is None:
62 | logger.warning_once(
63 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
64 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
65 | "when creating this class."
66 | )
67 |
68 | self.hidden_size = config.hidden_size
69 | self.num_heads = config.num_attention_heads
70 | self.head_dim = self.hidden_size // self.num_heads
71 | self.num_key_value_heads = config.num_key_value_heads
72 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
73 | self.max_position_embeddings = config.max_position_embeddings
74 | self.rope_theta = config.rope_theta
75 | self.is_causal = True
76 | self.attention_dropout = config.attention_dropout
77 |
78 | if (self.head_dim * self.num_heads) != self.hidden_size:
79 | raise ValueError(
80 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
81 | f" and `num_heads`: {self.num_heads})."
82 | )
83 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
84 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
85 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
86 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
87 |
88 | self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
89 | self.pool_g = nn.AdaptiveAvgPool1d(output_size=self.head_dim * self.num_key_value_heads)
90 |
91 |
92 | def forward(
93 | self,
94 | hidden_states: torch.Tensor,
95 | attention_mask: Optional[torch.Tensor] = None,
96 | position_ids: Optional[torch.LongTensor] = None,
97 | past_key_value: Optional[FlaCache] = None,
98 | output_attentions: bool = False,
99 | use_cache: bool = False,
100 | cache_position: Optional[torch.LongTensor] = None,
101 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
102 | **kwargs,
103 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
104 | last_state = None
105 | if past_key_value is not None and len(past_key_value) > self.layer_idx:
106 | last_state = past_key_value[self.layer_idx]
107 |
108 | bsz, q_len, _ = hidden_states.size()
109 |
110 | query_states = self.q_proj(hidden_states)
111 | key_states = self.k_proj(hidden_states)
112 | value_states = self.v_proj(hidden_states)
113 | g = self.pool_g(key_states)
114 |
115 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
116 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
117 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
118 | g = g.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
119 |
120 | # if position_embeddings is None:
121 | # logger.warning_once(
122 | # "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
123 | # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
124 | # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
125 | # "removed and `position_embeddings` will be mandatory."
126 | # )
127 | # cos, sin = self.rotary_emb(value_states, position_ids)
128 | # else:
129 | # cos, sin = position_embeddings
130 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
131 |
132 | # if past_key_value is not None:
133 | # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
134 | # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs)
135 |
136 | # repeat k/v heads if n_kv_heads < n_heads
137 | q = query_states
138 | k = repeat_kv(key_states, self.num_key_value_groups)
139 | v = repeat_kv(value_states, self.num_key_value_groups)
140 | g = repeat_kv(g, self.num_key_value_groups)
141 |
142 | sq, sk, sv = q, k, v
143 |
144 | # norm
145 | q = F.softmax(q, dim=-1)
146 | k = F.softmax(k, dim=-1)
147 |
148 | gate_logit_normalizer = 16
149 | g = F.logsigmoid(g) / gate_logit_normalizer # (b, h, n, m)
150 |
151 | recurrent_state = last_state['recurrent_state'] if last_state is not None else None
152 | offsets = kwargs.get('offsets', None)
153 | scale = 1
154 | q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g))
155 |
156 | if self.training or q.shape[-2] > 1:
157 | o_, recurrent_state = fused_chunk_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
158 | else:
159 | o_, recurrent_state = fused_recurrent_gla(q, k, v, g, scale=scale, initial_state=recurrent_state, output_final_state=True)
160 |
161 | if past_key_value is not None:
162 | past_key_value.update(
163 | recurrent_state=recurrent_state,
164 | layer_idx=self.layer_idx,
165 | offset=q.shape[1]
166 | )
167 |
168 | if position_embeddings is None:
169 | logger.warning_once(
170 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
171 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
172 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
173 | "removed and `position_embeddings` will be mandatory."
174 | )
175 | cos, sin = self.rotary_emb(sv, position_ids)
176 | else:
177 | cos, sin = position_embeddings
178 | sq, sk = apply_rotary_pos_emb(sq, sk, cos, sin)
179 |
180 | input_dtype = sq.dtype
181 | if input_dtype == torch.float32:
182 | if torch.is_autocast_enabled():
183 | target_dtype = torch.get_autocast_gpu_dtype()
184 | # Handle the case where the model is quantized
185 | elif hasattr(self.config, "_pre_quantization_dtype"):
186 | target_dtype = self.config._pre_quantization_dtype
187 | else:
188 | target_dtype = self.q_proj.weight.dtype
189 |
190 | logger.warning_once(
191 | f"The input hidden states seems to be silently casted in float32, this might be related to"
192 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
193 | f" {target_dtype}."
194 | )
195 |
196 | sq = sq.to(target_dtype)
197 | sk = sk.to(target_dtype)
198 | sv = sv.to(target_dtype)
199 |
200 | window_size = 64
201 | if attention_mask is not None and 0.0 in attention_mask:
202 | pass
203 | else:
204 | attention_mask = None
205 |
206 | y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
207 | sq.transpose(1, 2),
208 | sk.transpose(1, 2),
209 | sv.transpose(1, 2),
210 | attention_mask,
211 | q_len,
212 | position_ids=position_ids,
213 | dropout=0.0,
214 | sliding_window=window_size,
215 | use_top_left_mask=False,
216 | is_causal=True,
217 | target_dtype=torch.float32,
218 | ).transpose(1, 2)
219 | o_ = 0.5 * y + 0.5 * o_
220 | o = rearrange(o_.bfloat16(), 'b h n d -> b n (h d)')
221 | o = self.o_proj(o)
222 |
223 | return o, None, past_key_value
224 |
225 | class LigerQwen2DecoderLayer(Qwen2DecoderLayer):
226 | def __init__(self, config: LigerQwen2GLAConfig, layer_idx: int):
227 | super().__init__(config, layer_idx)
228 | self.hidden_size = config.hidden_size
229 |
230 | if config.sliding_window and config._attn_implementation != "flash_attention_2":
231 | logger.warning_once(
232 | f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
233 | "unexpected results may be encountered."
234 | )
235 | self.self_attn = LigerQwen2GatedLinearAttention(config, layer_idx)
236 | self.mlp = Qwen2MLP(config)
237 | self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238 | self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239 |
240 | class LigerQwen2PreTrainedModel(Qwen2PreTrainedModel):
241 |
242 | config_class = LigerQwen2GLAConfig
243 | base_model_prefix = "model"
244 | supports_gradient_checkpointing = True
245 | _no_split_modules = ["Qwen2DecoderLayer"]
246 | _skip_keys_device_placement = "past_key_values"
247 |
248 | class LigerQwen2GLAModel(Qwen2Model, LigerQwen2PreTrainedModel):
249 |
250 | def __init__(self, config: LigerQwen2GLAConfig):
251 | super().__init__(config)
252 | self.padding_idx = config.pad_token_id
253 | self.vocab_size = config.vocab_size
254 |
255 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
256 | self.layers = nn.ModuleList(
257 | [LigerQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
258 | )
259 | self._attn_implementation = config._attn_implementation
260 | self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
261 | self.rotary_emb = Qwen2RotaryEmbedding(config=config)
262 |
263 | self.gradient_checkpointing = False
264 | # Initialize weights and apply final processing
265 | self.post_init()
266 |
267 | def forward(
268 | self,
269 | input_ids: torch.LongTensor = None,
270 | attention_mask: Optional[torch.Tensor] = None,
271 | position_ids: Optional[torch.LongTensor] = None,
272 | past_key_values: Optional[Union[Tuple, FlaCache, List[torch.FloatTensor]]] = None,
273 | inputs_embeds: Optional[torch.FloatTensor] = None,
274 | use_cache: Optional[bool] = None,
275 | output_attentions: Optional[bool] = None,
276 | output_hidden_states: Optional[bool] = None,
277 | return_dict: Optional[bool] = None,
278 | cache_position: Optional[torch.LongTensor] = None,
279 | **kwargs,
280 | ) -> Union[Tuple, BaseModelOutputWithPast]:
281 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282 | output_hidden_states = (
283 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
284 | )
285 | use_cache = use_cache if use_cache is not None else self.config.use_cache
286 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
287 |
288 | if (input_ids is None) ^ (inputs_embeds is not None):
289 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
290 |
291 | if self.gradient_checkpointing and self.training and use_cache:
292 | logger.warning_once(
293 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
294 | )
295 | use_cache = False
296 |
297 |
298 | # kept for BC (non `Cache` `past_key_values` inputs)
299 | return_legacy_cache = False
300 | if use_cache and not isinstance(past_key_values, FlaCache):
301 | past_key_values = FlaCache.from_legacy_cache(past_key_values)
302 |
303 | if inputs_embeds is None:
304 | inputs_embeds = self.embed_tokens(input_ids)
305 |
306 | if cache_position is None:
307 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
308 | cache_position = torch.arange(
309 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
310 | )
311 |
312 | if position_ids is None:
313 | position_ids = cache_position.unsqueeze(0)
314 |
315 | causal_mask = attention_mask
316 |
317 | hidden_states = inputs_embeds
318 |
319 | # create position embeddings to be shared across the decoder layers
320 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
321 |
322 | # decoder layers
323 | all_hidden_states = () if output_hidden_states else None
324 | all_self_attns = () if output_attentions else None
325 | next_decoder_cache = None
326 |
327 | if output_attentions:
328 | all_softmax_hidden_states = ()
329 |
330 | for decoder_layer in self.layers:
331 | if output_hidden_states:
332 | all_hidden_states += (hidden_states,)
333 | if all_softmax_hidden_states is not None:
334 | all_softmax_hidden_states += (hidden_states,)
335 |
336 | if self.gradient_checkpointing and self.training:
337 | layer_outputs = self._gradient_checkpointing_func(
338 | decoder_layer.__call__,
339 | hidden_states,
340 | causal_mask,
341 | position_ids,
342 | past_key_values,
343 | output_attentions,
344 | use_cache,
345 | cache_position,
346 | position_embeddings,
347 | )
348 |
349 | else:
350 | layer_outputs = decoder_layer(
351 | hidden_states,
352 | attention_mask=causal_mask,
353 | position_ids=position_ids,
354 | past_key_value=past_key_values,
355 | output_attentions=output_attentions,
356 | use_cache=use_cache,
357 | cache_position=cache_position,
358 | position_embeddings=position_embeddings,
359 | )
360 | hidden_states = layer_outputs[0]
361 |
362 | if use_cache:
363 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
364 |
365 | if output_attentions:
366 | all_self_attns += (layer_outputs[1],)
367 |
368 | hidden_states = self.norm(hidden_states)
369 |
370 | # add hidden states from the last decoder layer
371 | if output_hidden_states:
372 | all_hidden_states += (hidden_states,)
373 |
374 | next_cache = next_decoder_cache if use_cache else None
375 |
376 | if not return_dict:
377 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
378 | return BaseModelOutputWithPast(
379 | last_hidden_state=hidden_states,
380 | past_key_values=next_cache,
381 | hidden_states=all_hidden_states,
382 | attentions=all_self_attns,
383 | )
384 |
385 | class LigerQwen2GLAForCausalLM(LigerQwen2PreTrainedModel, Qwen2ForCausalLM, GenerationMixin):
386 | _tied_weights_keys = ["lm_head.weight"]
387 | _tp_plan = {"lm_head": "colwise_rep"}
388 |
389 | def __init__(self, config):
390 | super().__init__(config)
391 | self.model = LigerQwen2GLAModel(config)
392 | self.vocab_size = config.vocab_size
393 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
394 |
395 | # Initialize weights and apply final processing
396 | self.post_init()
--------------------------------------------------------------------------------
/lolcats/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from lolcats.models.lolcats import LolcatsConfig, LolcatsModel, LolcatsModelForCausalLM
4 |
5 | __all__ = [
6 | 'LolcatsConfig', 'LolcatsModel', 'LolcatsModelForCausalLM'
7 | ]
--------------------------------------------------------------------------------
/lolcats/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from lolcats.models.lolcats import LolcatsConfig, LolcatsModel, LolcatsModelForCausalLM
4 |
5 | __all__ = [
6 | 'LolcatsConfig', 'LolcatsModel', 'LolcatsModelForCausalLM'
7 | ]
--------------------------------------------------------------------------------
/lolcats/models/lolcats/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 |
3 | from lolcats.models.lolcats.configuration_lolcats import LolcatsConfig
4 | from lolcats.models.lolcats.modeling_lolcats import LolcatsModel, LolcatsModelForCausalLM
5 |
6 | AutoConfig.register(LolcatsConfig.model_type, LolcatsConfig)
7 | AutoModel.register(LolcatsConfig, LolcatsModel)
8 | AutoModelForCausalLM.register(LolcatsConfig, LolcatsModelForCausalLM)
9 |
10 |
11 | __all__ = ['LolcatsConfig', 'LolcatsModelForCausalLM', 'LolcatsModel']
--------------------------------------------------------------------------------
/lolcats/models/lolcats/configuration_lolcats.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from typing import Dict, Optional
4 |
5 | from transformers.configuration_utils import PretrainedConfig
6 | from transformers.modeling_rope_utils import rope_config_validation
7 | from transformers.models.llama.configuration_llama import LlamaConfig
8 |
9 | class LolcatsConfig(LlamaConfig, PretrainedConfig):
10 |
11 | model_type = 'lolcats'
12 | keys_to_ignore_at_inference = ['past_key_values']
13 |
14 | def __init__(
15 | self,
16 | # llama config
17 | vocab_size=32000,
18 | hidden_size=4096,
19 | intermediate_size=11008,
20 | num_hidden_layers=32,
21 | num_attention_heads=32,
22 | num_key_value_heads=None,
23 | hidden_act="silu",
24 | max_position_embeddings=2048,
25 | initializer_range=0.02,
26 | rms_norm_eps=1e-6,
27 | use_cache=True,
28 | pad_token_id=None,
29 | bos_token_id=1,
30 | eos_token_id=2,
31 | pretraining_tp=1,
32 | tie_word_embeddings=False,
33 | rope_theta=10000.0,
34 | rope_scaling=None,
35 | attention_bias=False,
36 | attention_dropout=0.0,
37 | mlp_bias=False,
38 | head_dim=None,
39 | # linear attention
40 | attn_mode: str = "fused_chunk",
41 | expand_k: int = 1,
42 | expand_v: int = 1,
43 | hidden_ratio: Optional[int] = 4,
44 | # num_heads: int = 4,
45 | # num_kv_heads: Optional[int] = None,
46 | feature_map: str = "lolcats_t2r",
47 | tie_feature_map_qk: bool = False,
48 | norm_q: bool = True,
49 | norm_k: bool = True,
50 | norm_feature_map: bool = False,
51 | elementwise_affine: Optional[bool] = True,
52 | norm_eps: float = 1e-6,
53 | attn: Optional[Dict] = None,
54 | fuse_cross_entropy: bool = True,
55 | **kwargs
56 | ):
57 |
58 | # linear attention settings
59 | self.attn_mode = attn_mode
60 | self.expand_k = expand_k
61 | self.expand_v = expand_v
62 | self.hidden_ratio = hidden_ratio
63 | # self.num_heads = num_heads
64 | # self.num_kv_heads = num_kv_heads
65 | self.feature_map = feature_map
66 | self.tie_feature_map_qk = tie_feature_map_qk
67 | self.norm_q = norm_q
68 | self.norm_k = norm_k
69 | self.norm_feature_map = norm_feature_map
70 | self.max_position_embeddings = max_position_embeddings
71 | self.elementwise_affine = elementwise_affine
72 | self.norm_eps = norm_eps
73 | self.attn = attn
74 | self.fuse_cross_entropy = fuse_cross_entropy
75 |
76 | if attn is not None:
77 | if not isinstance(attn, Dict):
78 | raise ValueError("attn must be a dictionary")
79 | if 'layers' not in attn:
80 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
81 | if 'num_heads' not in attn:
82 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
83 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
84 | attn['window_size'] = attn.get('window_size', None)
85 |
86 | super().__init__(
87 | vocab_size=vocab_size,
88 | hidden_size=hidden_size,
89 | intermediate_size=intermediate_size,
90 | num_hidden_layers=num_hidden_layers,
91 | num_attention_heads=num_attention_heads,
92 | num_key_value_heads=num_key_value_heads,
93 | hidden_act=hidden_act,
94 | max_position_embeddings=max_position_embeddings,
95 | initializer_range=initializer_range,
96 | rms_norm_eps=rms_norm_eps,
97 | use_cache=use_cache,
98 | pad_token_id=pad_token_id,
99 | bos_token_id=bos_token_id,
100 | eos_token_id=eos_token_id,
101 | pretraining_tp=pretraining_tp,
102 | tie_word_embeddings=tie_word_embeddings,
103 | rope_theta=rope_theta,
104 | rope_scaling=rope_scaling,
105 | attention_bias=attention_bias,
106 | attention_dropout=attention_dropout,
107 | mlp_bias=mlp_bias,
108 | head_dim=head_dim,
109 | **kwargs,
110 | )
111 |
112 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.1
2 | triton==3.1.0
3 | transformers==4.47.1
4 | datasets
5 | peft
6 | evaluate
7 | omegaconf
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import random
5 | import numpy as np
6 | from omegaconf import OmegaConf
7 | import torch
8 |
9 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
10 |
11 | from training.train import train
12 |
13 |
14 | def set_random_seed(seed=0):
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | random.seed(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | torch.backends.cudnn.benchmark = False # if benchmark=True, deterministic will be False
22 | torch.backends.cudnn.deterministic = True # choose a deterministic algorithm
23 |
24 | def get_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--cfg", type=str, default="configs/liger.yaml")
27 | args = parser.parse_args()
28 | return args
29 |
30 | def main():
31 | set_random_seed(seed=0)
32 | args = get_args()
33 | config = OmegaConf.load(args.cfg)
34 | output_dir = args.cfg.split('/')[-1].split('.')[0]
35 | config.train.output_dir = os.path.join(config.train.output_dir, output_dir) # 'checkpoints/${filename}'
36 | train(config)
37 |
38 | if __name__ == "__main__":
39 | main()
--------------------------------------------------------------------------------
/scripts/train_liger.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | python run.py --cfg configs/liger_gla.yaml
--------------------------------------------------------------------------------
/scripts/train_lolcats_stage1.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=3
2 |
3 | python run.py --cfg configs/lolcats_at.yaml
--------------------------------------------------------------------------------
/scripts/train_lolcats_stage2.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=7
2 |
3 | python run.py --cfg configs/lolcats_ar.yaml
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenSparseLLMs/Linearization/0e3cfae33a700fa5f644cf5752d8434c6afc2412/training/__init__.py
--------------------------------------------------------------------------------
/training/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import random
4 | from tqdm import tqdm
5 | from functools import partial
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | from datasets import Dataset as HFDataset
9 | from datasets import load_dataset, load_from_disk
10 | import evaluate
11 | from huggingface_hub import hf_hub_download
12 |
13 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
14 | from transformers import DataCollatorForSeq2Seq
15 |
16 |
17 | PROMPT_DICT = {
18 | "prompt_input": (
19 | "Below is an instruction that describes a task, paired with an input that provides further context. "
20 | "Write a response that appropriately completes the request.\n\n"
21 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
22 | ),
23 | "prompt_no_input": (
24 | "Below is an instruction that describes a task. "
25 | "Write a response that appropriately completes the request.\n\n"
26 | "### Instruction:\n{instruction}\n\n### Response:\n"
27 | ),
28 | }
29 |
30 | def encode_response(response: str, tokenizer) -> list[int]:
31 | tokens = tokenizer.encode(response.strip(), add_special_tokens=False)
32 | # For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
33 | tokens.append(tokenizer.eos_token_id)
34 | try: # Llama 3 Instruct
35 | tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
36 | except KeyError:
37 | pass
38 | return tokens
39 |
40 | def load_data(config):
41 | cache_dir = "/root/.cache"
42 | input_len = config.model.max_length
43 | concat_data = True
44 |
45 | tokenizer_path = config.model.pretrained_model_name_or_path
46 | tokenizer_name = tokenizer_path.split('/')[-1]
47 |
48 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
49 | if tokenizer.pad_token is None:
50 | tokenizer.pad_token = tokenizer.eos_token
51 | print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
52 |
53 | tokenizer.padding_side = 'left' # for decoder-only generation
54 | # Get initial data
55 | ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs']
56 | dataset_config = {
57 | "name": "default",
58 | "path": "yahma/alpaca-cleaned",
59 | "chunk_size": input_len,
60 | "concat_data": concat_data,
61 | "cache_dir": cache_dir,
62 | }
63 | dataset = load_dataset(
64 | **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}
65 | )
66 | dataset = dataset['train']
67 | train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir)
68 | val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
69 | test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
70 |
71 | # Convert to dicts of {input_ids, attention_mask, labels}
72 | train_set = train_set.map(
73 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
74 | remove_columns=list(dataset.features),)
75 | val_set = val_set.map(
76 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
77 | remove_columns=list(dataset.features),)
78 | test_set = test_set.map(
79 | partial(template_and_tokenize, tokenizer=tokenizer, include_label=False),
80 | remove_columns=list(dataset.features),)
81 |
82 | # Chunk together train and val sets
83 | if concat_data:
84 | train_set = ConcatDataset(train_set, chunk_size=input_len)
85 | val_set = ConcatDataset(val_set, chunk_size=input_len)
86 |
87 | loader_kwargs = {
88 | "batch_size": config.data.micro_batch_size,
89 | "num_workers": 0,
90 | "drop_last": False,
91 | "pin_memory": True,
92 | }
93 |
94 | # Get dataloaders
95 | dataloaders = {
96 | 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
97 | 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
98 | 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
99 | }
100 | # Evaluation metric
101 | try:
102 | # metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
103 | metric = evaluate.load(download_metric(), 'gov_report')
104 | except Exception as e:
105 | print(f'Error loading metric: {e}')
106 | metric = None
107 |
108 | # Finishing touches
109 | for k, v in dataloaders.items(): # Make tokenizer accessible
110 | dataloaders[k].dataset.tokenizer = tokenizer
111 | dataloaders[k].dataset.metric = metric
112 | return dataloaders
113 |
114 |
115 | def convert_to_hf_dataset(dataset, cache_dir: str):
116 | """
117 | Convert iterable dataset to HuggingFace HFDataset object
118 | """
119 | def gen():
120 | for _, sample in enumerate(dataset):
121 | yield sample # dataset[idx]
122 | return HFDataset.from_generator(gen, cache_dir=cache_dir)
123 |
124 | def template_and_tokenize(sample, tokenizer, include_label: bool = True):
125 | """
126 | Format dataset context and answers into single-sequence prompts
127 | """
128 | if sample.get('input', '') == '':
129 | prompt = PROMPT_DICT["prompt_no_input"].format_map(sample)
130 | else:
131 | prompt = PROMPT_DICT["prompt_input"].format_map(sample)
132 |
133 | prompt = tokenizer.encode(prompt, add_special_tokens=True)
134 | if include_label:
135 | answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
136 | add_special_tokens=False)
137 | target = None
138 | else:
139 | answer = []
140 | target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
141 | add_special_tokens=False)
142 | input_ids = prompt + answer
143 | attn_mask = [1] * len(input_ids)
144 |
145 | sample = {
146 | "input_ids": input_ids,
147 | "attention_mask" : attn_mask,
148 | "labels": [-100] * len(prompt) + answer if include_label else target,
149 | }
150 | return sample
151 |
152 | def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer,
153 | split: str, max_length: int = None, **loader_kwargs: any):
154 | """
155 | Get dataloader for language modeling (training)
156 | -> Currently this ends up being the same as get_seq2seq_loader
157 | """
158 | # collate_fn = DefaultDataCollator(return_tensors='pt')
159 | # collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,
160 | # max_length=max_length, return_tensors='pt')
161 | collate_fn = DataCollatorForSeq2Seq(
162 | tokenizer, label_pad_token_id=-100, return_tensors='pt')
163 | return DataLoader(
164 | dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
165 |
166 | def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer,
167 | split: str, **loader_kwargs: any):
168 | """
169 | Get dataloader for seq2seq tasks (evaluation)
170 | """
171 | tokenizer.padding_side = 'right'
172 | collate_fn = DataCollatorForSeq2Seq(
173 | tokenizer, label_pad_token_id=-100, return_tensors='pt')
174 | return DataLoader(
175 | dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
176 |
177 | def download_metric():
178 | """
179 | Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset
180 | """
181 | scrolls_metric_path = hf_hub_download(
182 | repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset"
183 | )
184 | updated_scrolls_metric_path = (
185 | os.path.dirname(scrolls_metric_path) +
186 | os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
187 | )
188 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
189 | return updated_scrolls_metric_path
190 |
191 | class ConcatDataset(Dataset):
192 | """
193 | Concatenates or packs samples of a dataset into chunks of size `chunk_size`
194 | """
195 | def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None:
196 | self.dataset = dataset
197 | self.chunk_size = chunk_size
198 | self.samples = []
199 | buffer = {
200 | "input_ids": [],
201 | "attention_mask": [],
202 | "labels": [],
203 | }
204 | random.seed(seed)
205 | for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
206 | buffer = {k: v + sample[k] for k,v in buffer.items()}
207 |
208 | while len(next(iter(buffer.values()))) > self.chunk_size:
209 | self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
210 | buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
211 | # Slow hack, but filter out any samples without valid labels (all -100)
212 | self.filtered_samples = []
213 | for s in self.samples:
214 | if sum(s['labels']) != chunk_size * -100:
215 | self.filtered_samples.append(s)
216 | if len(self.filtered_samples) < len(self.samples):
217 | print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}')
218 | print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples')
219 |
220 | def __getitem__(self, idx):
221 | return self.filtered_samples[idx]
222 |
223 | def __len__(self):
224 | return len(self.filtered_samples)
225 |
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import fla
10 | import liger
11 | import lolcats
12 | import torch.utils
13 | import torch.utils.data
14 | import torch.utils.data.dataloader
15 | from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
16 | from transformers import Trainer, TrainingArguments
17 | from peft import LoraConfig, TaskType, PeftModel, get_peft_model
18 |
19 | from training.trainer import DefaultTrainer, FinetuneTrainer
20 | from training.utils import get_optimizer_and_scheduler, count_model_params
21 | from training.dataloader import load_data
22 |
23 |
24 | def train(config):
25 |
26 | trainer = FinetuneTrainer
27 | model_config = AutoConfig.from_pretrained(config.model.pretrained_model_name_or_path)
28 | if config.model.name == "liger_gla":
29 | from liger.models.liger_gla import LigerGLAConfig
30 | liger_model_config = LigerGLAConfig()
31 | elif config.model.name == "liger_gsa":
32 | from liger.models.liger_gsa import LigerGSAConfig
33 | liger_model_config = LigerGSAConfig()
34 | elif config.model.name == "lolcats_at":
35 | # first stage: attention transfer
36 | from lolcats.models.lolcats import LolcatsConfig
37 | liger_model_config = LolcatsConfig()
38 | trainer = DefaultTrainer
39 | elif config.model.name == "lolcats_ar":
40 | # second stage
41 | from lolcats.models.lolcats import LolcatsConfig
42 | liger_model_config = LolcatsConfig()
43 | else:
44 | raise NotImplementedError(config.model.name)
45 |
46 | liger_model_config.__dict__.update(model_config.__dict__)
47 | model_config = liger_model_config
48 | model = AutoModelForCausalLM.from_pretrained(
49 | config.model.pretrained_model_name_or_path,
50 | config=model_config,
51 | device_map="cuda"
52 | ).to(torch.bfloat16)
53 |
54 |
55 | print("Model config:")
56 | print(model_config)
57 | print("Model:")
58 | print(model)
59 |
60 | tokenizer = AutoTokenizer.from_pretrained(config.model.pretrained_model_name_or_path)
61 | tokenizer.pad_token_id = tokenizer.eos_token_id
62 | tokenizer.padding_side = "left" # Allow batched inference
63 |
64 | for name, param in model.named_parameters():
65 | param.requires_grad = False
66 | if "train_qk" in config.train and config.train.train_qk:
67 | if "self_attn.q_proj" in name:
68 | param.requires_grad = True
69 | elif "self_attn.k_proj" in name:
70 | param.requires_grad = True
71 | if "train_v" in config.train and config.train.train_v and "self_attn.v_proj" in name:
72 | param.requires_grad = True
73 | if "train_o" in config.train and config.train.train_o and "self_attn.o_proj" in name:
74 | param.requires_grad = True
75 |
76 | # LoRA finetune
77 | target_modules = []
78 | if "train_qk" in config.train and config.train.train_qk and config.train.train_qk_lora:
79 | target_modules.append("self_attn.q_proj")
80 | target_modules.append("self_attn.k_proj")
81 | if "train_v" in config.train and config.train.train_v and config.train.train_v_lora:
82 | target_modules.append("self_attn.v_proj")
83 | if "train_o" in config.train and config.train.train_o and config.train.train_o_lora:
84 | target_modules.append("self_attn.o_proj")
85 | # lolcats attention transfer
86 | if config.model.name == "lolcats_at":
87 | for name, param in model.named_parameters():
88 | if "feature_map" in name:
89 | param.requires_grad = True
90 | else:
91 | param.requires_grad = False
92 |
93 |
94 | if len(target_modules) != 0:
95 | lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, target_modules=target_modules)
96 | model = get_peft_model(model, peft_config=lora_config)
97 |
98 | # print trainable params count
99 | trainable_params = count_model_params(model, requires_grad=True)
100 | total_params = count_model_params(model, requires_grad=False)
101 | print(f"Model trainable params: {trainable_params}")
102 | print(f"Model total params: {total_params}")
103 | print(f"trainable%: {trainable_params / total_params}")
104 |
105 | gradient_accumulation_steps = config.data.batch_size // config.data.micro_batch_size
106 |
107 | print("Preparing data...")
108 |
109 | dataloaders = load_data(config)
110 | train_loader = dataloaders["train"]
111 | eval_loader = dataloaders["validation"]
112 |
113 | print("Building trainer...")
114 |
115 | training_args = TrainingArguments(
116 | per_device_train_batch_size=config.data.micro_batch_size,
117 | gradient_accumulation_steps=gradient_accumulation_steps,
118 | warmup_steps=0,
119 | num_train_epochs=config.train.epochs,
120 | learning_rate=config.train.lr,
121 | bf16=True,
122 | max_grad_norm=config.train.max_grad_norm,
123 | logging_steps=1,
124 | optim=config.train.optim,
125 | evaluation_strategy="steps" if config.data.val_set_size > 0 else "no",
126 | save_strategy="steps",
127 | eval_steps=200 if config.data.val_set_size > 0 else None,
128 | save_steps=1000,
129 | logging_dir=config.train.output_dir,
130 | output_dir=config.train.output_dir,
131 | save_total_limit=3,
132 | load_best_model_at_end=True if config.data.val_set_size > 0 else False,
133 | # default trainer args
134 | greater_is_better = False,
135 | metric_for_best_model = 'eval/loss',
136 | # wandb
137 | report_to="none" # wandb off "wandb"
138 | )
139 |
140 | trainer = trainer(
141 | model=model,
142 | train_loader=train_loader,
143 | eval_loader=eval_loader,
144 | args=training_args,
145 | optimizers=get_optimizer_and_scheduler(model, config),
146 | tokenizer=tokenizer,
147 | config=config
148 | )
149 |
150 | print("Train start")
151 | best_model = trainer.train()
152 | save_path = trainer.save_path + '/best'
153 | best_model.save_pretrained(save_path)
154 | tokenizer.save_pretrained(save_path)
155 | print(f'\n-> Saved best model checkpoint to: {save_path}!')
156 |
157 | print("Train over")
158 |
--------------------------------------------------------------------------------
/training/trainer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import pandas as pd
4 |
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 |
12 | import sys
13 | import os
14 | import pandas as pd
15 |
16 | from tqdm import tqdm
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.utils.data import DataLoader
22 |
23 | from transformers import AutoModel, AutoModelForCausalLM
24 | from transformers import Trainer, TrainingArguments
25 |
26 | from peft import PeftModel
27 |
28 | class DefaultTrainer():
29 | # code is modified from: https://github.com/HazyResearch/lolcats/blob/main/src/trainer/default_lm.py
30 | def __init__(self, model, train_loader, eval_loader, args, optimizers, tokenizer, config):
31 | super().__init__()
32 | self.model = model
33 | self.args = args
34 | self.tokenizer = tokenizer
35 | self.config = config
36 | self.type = 'default'
37 |
38 | self.step = 0 # Total steps taken
39 | self.grad_step = 0 # Total gradient updates
40 | self.compute_loss_backprop = False # Whether we backprop in self.compute_loss
41 |
42 | self.optimizer, self.scheduler = optimizers
43 | self.scheduler_step_after_epoch = True
44 | # Dataloaders
45 | self.train_loader = train_loader
46 | self.eval_loader = eval_loader
47 |
48 | self.device = model.device
49 | wandb = None
50 | self.wandb = wandb
51 |
52 | # args
53 | self.metric_for_best_model = self.args.metric_for_best_model
54 | self.num_train_epochs = self.args.num_train_epochs
55 | self.gradient_accumulation_steps = self.args.gradient_accumulation_steps
56 | self.evaluation_strategy = self.args.evaluation_strategy
57 | self.greater_is_better = self.args.greater_is_better
58 | self.is_better = (lambda x, y: x > y if self.args.greater_is_better else x < y)
59 | self.load_best_model_at_end = self.args.load_best_model_at_end
60 | self.logging_steps = self.args.logging_steps
61 | self.max_steps = self.args.max_steps
62 | self.eval_steps = self.args.eval_steps
63 |
64 | max_eval_batches = -1
65 | print_samples = False
66 | initial_eval = True
67 | self.max_eval_batches = max_eval_batches
68 | self.print_samples = print_samples
69 | self.initial_eval = initial_eval
70 | self.save_total_limit = self.args.save_total_limit
71 | self.save_steps = self.args.save_steps # num_save_ckpt_steps
72 |
73 | # Saving metrics
74 | self.train_metrics = {'train/loss': None,
75 | 'train/epoch': None,
76 | 'train/step': None}
77 | self.eval_metrics = {self.metric_for_best_model: None}
78 | self.eval_metrics_by_step = {'eval_step': []} # save all eval metrics
79 | self.criterion = nn.CrossEntropyLoss(reduction='mean')
80 |
81 | save_results = True
82 | save_checkpoints = True
83 |
84 | self.save_results = save_results
85 | self.results_path = None
86 | self.best_val_metric = 0 if self.greater_is_better else 1e10
87 | self.best_val_metric_epoch = 0
88 | self.best_val_metric_step = 0
89 | if save_checkpoints: # Also initializes best_val_metrics
90 | self.init_checkpointing(config=config)
91 |
92 | def train(self) -> nn.Module:
93 | """
94 | Entire training run
95 | """
96 | model = self.model
97 | pbar = tqdm(range(self.num_train_epochs), leave=False, colour='white', desc='Training')
98 | for ix, epoch in enumerate(pbar):
99 | model, early_stopping = self.train_step(model, epoch)
100 | if self.evaluation_strategy == 'epoch':
101 | _eval_metrics = self.eval_step(model, step=self.grad_step)
102 | print(f'Epoch {ix} metrics:', _eval_metrics)
103 | if early_stopping:
104 | break
105 |
106 | if self.load_best_model_at_end: # Return best checkpoint
107 | try:
108 | model.from_pretrained(self.best_val_checkpoint_path)
109 | print(f'-> Loading best checkpoint from {self.best_val_checkpoint_path}')
110 | except FileNotFoundError as e:
111 | print(e)
112 | print('-> Returning most recent model instead')
113 | return model
114 |
115 | def train_step(self, model, epoch) -> nn.Module:
116 | if self.gradient_accumulation_steps is None:
117 | accum_iter = 1
118 | else:
119 | accum_iter = self.gradient_accumulation_steps
120 |
121 | model.train()
122 | model.zero_grad()
123 | pbar = tqdm(self.train_loader, leave=False, colour='blue', desc=f'-> Training (epoch {epoch} / {self.args.num_train_epochs})')
124 | total_loss = 0
125 | eval_for_step = False
126 |
127 | # Initial eval
128 | if self.initial_eval:
129 | print('')
130 | print('-> Initial eval')
131 | self.compute_eval_metrics(model, step=self.grad_step)
132 |
133 | # model.to(self.device)
134 | for ix, data in enumerate(pbar):
135 | loss, train_metrics = self.compute_loss(model, data, return_outputs=True)
136 | loss /= accum_iter
137 | if not self.compute_loss_backprop:
138 | # loss.backward() did not occur in compute_loss
139 | try:
140 | with torch.autograd.set_detect_anomaly(True):
141 | loss.backward()
142 | except Exception as e:
143 | breakpoint()
144 | if (self.step + 1) % accum_iter == 0: # and self.step != 0:
145 | self.optimizer.step()
146 | if not self.scheduler_step_after_epoch and self.scheduler is not None:
147 | self.scheduler.step()
148 | self.optimizer.zero_grad()
149 | self.grad_step += 1
150 | if not self.compute_loss_backprop:
151 | loss = loss.detach().cpu().item()
152 |
153 | self.step += 1
154 | if not isinstance(loss, float):
155 | total_loss += loss.item()
156 | else:
157 | total_loss += loss
158 | desc = f"Training epoch {epoch} | loss: {total_loss / (ix + 1):.3f} | lr: {self.optimizer.param_groups[0]['lr']:.5f}"
159 | desc += f' | gradient step: {self.grad_step}'
160 | for k, v in train_metrics.items():
161 | desc += f' | {k}: {v:.3f}'
162 | pbar.set_description(desc)
163 |
164 | # Logging
165 | if (self.grad_step) % (self.logging_steps):
166 | self.train_metrics['train/loss'] = loss.item() if not isinstance(loss, float) else loss
167 | self.train_metrics['train/epoch'] = epoch
168 | self.train_metrics['train/step'] = self.grad_step
169 | self.train_metrics['train/lr'] = self.optimizer.param_groups[0]['lr']
170 | for k, v in train_metrics.items():
171 | self.train_metrics[f'train/{k}'] = v
172 |
173 | if self.wandb is not None:
174 | self.wandb.log(self.train_metrics, step=self.grad_step)
175 |
176 | if self.evaluation_strategy == 'steps':
177 | if (self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and not eval_for_step):
178 | _eval_metrics = self.eval_step(model, step=self.grad_step)
179 | print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics)
180 | eval_for_step = True
181 | model.train() # Need to set back to train mode
182 | elif self.grad_step == 0 and self.save_steps < 1000 and not eval_for_step: # hack for micros
183 | _eval_metrics = self.eval_step(model, step=self.grad_step)
184 | print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics)
185 | eval_for_step = True
186 | model.train() # Need to set back to train mode
187 |
188 | elif self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and eval_for_step:
189 | pass
190 | else:
191 | if self.grad_step > 0:
192 | eval_for_step = False
193 | if self.grad_step == self.max_steps:
194 | early_stopping = True
195 | return model, early_stopping
196 |
197 | early_stopping = False
198 | return model, early_stopping
199 |
200 |
201 | def eval_step(self, model: nn.Module, step: int = None, **kwargs: any) -> dict[any]:
202 | """
203 | Evaluation loop over one epoch
204 | """
205 | with torch.no_grad():
206 | self.eval_metrics = self.compute_eval_metrics(model, step=step, **kwargs)
207 | val_metric = self.eval_metrics[self.metric_for_best_model]
208 |
209 | # Save results
210 | if self.wandb is not None: # log to WandB
211 | self.wandb.log(self.eval_metrics, step=self.grad_step)
212 |
213 | if self.results_path is not None: # log to local file
214 | self.eval_metrics_by_step['eval_step'].append(step)
215 | for k, v in self.eval_metrics.items():
216 | if k not in self.eval_metrics_by_step:
217 | self.eval_metrics_by_step[k] = [v]
218 | else:
219 | self.eval_metrics_by_step[k].append(v)
220 | # Inefficient, but log for experiments results
221 | pd.DataFrame(self.eval_metrics_by_step).to_csv(self.results_path)
222 |
223 | # Save best metric and checkpoint
224 | if self.grad_step % self.eval_steps == 0 and step > 0:
225 | if self.is_better(val_metric, self.best_val_metric):
226 | self.best_val_metric = val_metric
227 | self.best_val_metric_step = self.grad_step
228 |
229 | save_path = self.save_path + '/iter' + '_' + str(step)
230 | self.best_val_checkpoint_path = save_path
231 | model.save_pretrained(save_path)
232 | self.tokenizer.save_pretrained(save_path)
233 | print(f'\n-> Saved best model checkpoint to: {save_path}!')
234 |
235 | if self.grad_step % self.save_steps == 0 and step > 0:
236 |
237 | save_path = self.save_path + '/' + self.type + '_' + str(step)
238 | self.best_val_checkpoint_path = save_path
239 | model.save_pretrained(save_path)
240 | self.tokenizer.save_pretrained(save_path)
241 | print(f'\n-> Saved model checkpoint to: {save_path}!')
242 |
243 | if self.scheduler_step_after_epoch and self.scheduler is not None:
244 | self.scheduler.step(val_metric)
245 | return self.eval_metrics
246 |
247 | def compute_eval_metrics(self,
248 | model: nn.Module, step: int,
249 | max_batches: int = None,
250 | dataloader: DataLoader = None,
251 | **kwargs: any) -> dict[any]:
252 | """
253 | One evaluation loop over a validation dataset
254 | """
255 | max_batches = (self.max_eval_batches if max_batches is None else max_batches)
256 | dataloader = self.eval_loader if dataloader is None else dataloader
257 | pbar = tqdm(dataloader, leave=False, colour='green', desc=f'Evaluating at step {step}')
258 |
259 | model.eval()
260 | step_loss = 0
261 | step_eval_metrics = {}
262 | with torch.no_grad():
263 | for ix, data in enumerate(pbar):
264 | loss, eval_metrics = self.compute_loss(model, data, return_outputs=True)
265 | if not self.compute_loss_backprop:
266 | loss = loss.item() # otherwise already float
267 | if ix == 0:
268 | step_eval_metrics[self.metric_for_best_model] = [loss]
269 | for k, v in eval_metrics.items():
270 | step_eval_metrics[f'eval/{k}'] = [v]
271 | else:
272 | step_eval_metrics[self.metric_for_best_model].append(loss)
273 | for k, v in eval_metrics.items():
274 | step_eval_metrics[f'eval/{k}'].append(v)
275 |
276 | step_loss += loss
277 | desc = f"Evaluating at step {step} | loss: {step_loss / (ix + 1):.3f}"
278 | if self.optimizer is not None:
279 | desc += f" | lr: {self.optimizer.param_groups[0]['lr']:.5f}"
280 | pbar.set_description(desc)
281 | if ix == max_batches:
282 | break
283 |
284 | # Average over batches
285 | for k, v in step_eval_metrics.items():
286 | step_eval_metrics[k] = sum(v) / len(v)
287 | print(f'Eval step {step}:', step_eval_metrics)
288 | del loss
289 | torch.cuda.empty_cache()
290 | return step_eval_metrics
291 |
292 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
293 | inputs = {k: v.to(model.device) for k, v in inputs.items() if k != 'labels'}
294 | outputs = model(**inputs, output_attentions=True)
295 |
296 | outputs = outputs.attentions # tuple [num_decoder_layers, 2, B, H, L, L]
297 | loss_mse = 0
298 | self.mse_factor = 1000
299 | self.criterion_mse = nn.MSELoss(reduction='mean')
300 | n_layers = 0 # Number of layers to distill
301 |
302 | for layer_idx, attns in enumerate(outputs):
303 | if attns is not None:
304 | loss_mse += self.criterion_mse(attns[0], attns[1])
305 | n_layers += 1
306 |
307 | if n_layers > 0:
308 | loss_mse = loss_mse / n_layers * self.mse_factor
309 | loss = loss_mse
310 | outputs = {'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0,
311 | 'mse_factor': self.mse_factor}
312 |
313 | return (loss, outputs) if return_outputs else loss
314 |
315 | def init_checkpointing(self, config) -> None:
316 | self.save_path = config.train.output_dir
317 | self.best_val_checkpoint_path = config.train.output_dir
318 |
319 | # Best metric setup
320 | self.best_val_metric = 0 if self.greater_is_better else 1e10
321 | self.best_val_metric_epoch = 0
322 | self.best_val_metric_step = 0
323 | self.best_train_metric = 0 if self.greater_is_better else 1e10
324 | self.best_train_metric_epoch = 0
325 | self.best_train_metric_step = 0
326 | self.metric_for_best_model = self.metric_for_best_model
327 | if self.metric_for_best_model is not None:
328 | if 'eval' not in self.metric_for_best_model:
329 | self.metric_for_best_model = f'eval/{self.metric_for_best_model}'
330 |
331 | class FinetuneTrainer(DefaultTrainer):
332 | def __init__(self, model, train_loader, eval_loader, args, optimizers, tokenizer, config):
333 | super().__init__(model, train_loader, eval_loader, args, optimizers, tokenizer, config)
334 | self.type = 'finetune'
335 |
336 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
337 | input_keys = {'input_ids', 'attention_mask'}
338 | data = {k: v.to(model.device) for k, v in inputs.items() if k in input_keys}
339 | outputs = model(**data, output_attentions=False)
340 | outputs = outputs.get('logits')[..., :-1, :].contiguous()
341 | targets = inputs.get('labels')[..., 1:].contiguous()
342 | # Flatten and compute cross-entropy loss
343 | outputs = outputs.view(-1, outputs.shape[-1])
344 | targets = targets.view(-1).to(outputs.device)
345 | loss = self.criterion(outputs, targets)
346 |
347 | targets = targets.cpu()
348 | outputs = outputs.cpu()
349 | outputs = {'ppl': torch.exp(loss).item(), 'seq_len': targets.shape[-1] + 1}
350 | return (loss, outputs) if return_outputs else loss
351 |
--------------------------------------------------------------------------------
/training/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.optim
4 |
5 | def get_optimizer_and_scheduler(model, config):
6 | params = [p for p in model.parameters() if p.requires_grad]
7 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr, fused=True)
8 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
9 | optimizer=optimizer,
10 | mode='min',
11 | factor=0.1,
12 | patience=10,
13 | min_lr=0.00001
14 | )
15 | return optimizer, scheduler
16 |
17 | def count_model_params(model, requires_grad: bool = True):
18 | # code form lolcats
19 | """
20 | Return total # of trainable parameters
21 | """
22 | if requires_grad:
23 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
24 | else:
25 | model_parameters = model.parameters()
26 | try:
27 | return sum([np.prod(p.size()) for p in model_parameters]).item()
28 | except:
29 | return sum([np.prod(p.size()) for p in model_parameters])
--------------------------------------------------------------------------------