├── .gitignore
├── LICENSE
├── README.md
├── SVDLLM.py
├── component
├── svd_llama.py
├── svd_llama_kvcache.py
├── svd_mistral.py
└── svd_opt.py
├── compress_llama.sh
├── evaluater.py
├── figures
├── framework_v1.jpg
├── framework_v2.jpg
└── logo.png
├── gptq
├── gptq.py
└── quant.py
├── quant_llama.py
├── requirements.txt
├── svdllm_gptq.sh
└── utils
├── LoRA.py
├── Prompter.py
├── data_utils.py
├── model_utils.py
└── peft
├── __init__.py
├── import_utils.py
├── mapping.py
├── peft_model.py
├── tuners
├── __init__.py
├── adalora.py
├── lora.py
├── p_tuning.py
├── prefix_tuning.py
└── prompt_tuning.py
└── utils
├── __init__.py
├── config.py
├── other.py
└── save_and_load.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 | cache/*
3 | **/__pycache__
4 | utils/c4-*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
SVD-LLM: Singular Value Decomposition for Large Language Model Compression
7 |
30 |
31 |
32 | ## ✨Roadmap
33 | We are working on the following tasks, please stay tuned!
34 |
35 | - [X] Release the code of SVD-LLM.
36 | - [ ] Release the code of SVD-LLM V2.
37 | - [ ] Upgrade the transformers package to the latest version.
38 | - [ ] Update the efficiency evaluation code.
39 | - [ ] Optimize the compression of GQA-based LLMs.
40 |
41 |
42 | ## Introduction
43 |
44 | > **[SVD-LLM: Truncation-aware Singular Value Decomposition for Large Language Model Compression](https://openreview.net/forum?id=LNYIUouhdt&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DICLR.cc%2F2025%2FConference%2FAuthors%23your-submissions))**
45 | >
46 | > *Xin Wang1, Yu Zheng2, Zhongwei Wan1, Mi Zhang1*
47 | > *1The Ohio State University, 2Michigan State University*
48 | >
49 | > International Conference on Learning Representations (ICLR) 2025
50 |
51 |
52 | > **[SVD-LLM V2: Optimizing Singular Value Truncation for Large Language Model Compression](https://arxiv.org/abs/2503.12340)**
53 | >
54 | > *Xin Wang, Samiul Alam, Zhongwei Wan, Hui Shen, Mi Zhang*
55 | > *The Ohio State University*
56 | >
57 | > Annual Conference of the Nations of the Americas Chapter of the Association for Computational Linguistics (NAACL) 2025
58 |
59 |
60 | ## Quick Start
61 |
62 | ### Installation
63 | Please keep the version of the transformers package exactly equal to 4.35.2 since the svd-compressed version of LLM has a slight change of model structure (in the `component/.` folder).
64 | Create and set up a conda environment with python version 3.9 (newer versions break some dependencies)
65 | ```
66 | conda create -n compress python=3.9
67 | conda activate compress
68 | ```
69 | Clone and navigate to the repository
70 | ```
71 | git clone https://github.com/AIoT-MLSys-Lab/SVD-LLM.git
72 | ```
73 | Install requirements.txt
74 | ```
75 | pip install -r requirements.txt
76 | ```
77 |
78 | ### Quick Example
79 | ```
80 | bash compress_llama.sh
81 | ```
82 | This script would compress the LLaMA-7B model under 20\% compression ratio and automatically run the evaluation code, including both perplexity and efficiency of the compressed model.
83 |
84 |
85 | ## Step-by-Step Instructions of SVD-LLM
86 |
87 | ### 1. Truncation-Aware Data Whitening + SVD Compression
88 | Under the low compression ratio (recommended ratio <= 0.3), we first run the data whitening of the LLM and saved the weight along with the whitening information.
89 | ```
90 | python SVDLLM.py \
91 | --step 1 \
92 | --ratio COMPRESSION_RATIO \
93 | --model HUGGINGFACE_MODEL_REPO \
94 | --whitening_nsamples WHITENING_SAMPLE_NUMBER \
95 | --dataset WHITENING_DATASET \
96 | --seed SAMPLING_SEED \
97 | --model_seq_len MODEL_SEQ_LEN \
98 | --save_path WHITENING_INFO_SAVING_PATH
99 | ```
100 |
101 |
102 |
103 |
104 | ### 2. Parameter Update with Sequential Low-rank Approximation
105 | We first update the compressed weight matrix U and then V with LoRA fine-tuning.
106 | ```
107 | python LoRA.py \
108 | --prune_model COMPRESSED_MODEL_PATH \
109 | --data_path yahma/alpaca-cleaned \
110 | --output_dir LORA_OUTPUT_PATH \
111 | --lora_r 8 \
112 | --num_epochs 2 \
113 | --learning_rate 1e-4 \
114 | --batch_size 64
115 | ```
116 |
117 | ### 3. SVD-LLM + GPTQ
118 | SVD-LLM can also be integrated with quantization methods to achieve a better compression. Here is the example of how to integrate SVD-LLM (20% compression ratio) with GPTQ-4bit to compress LLaMA-7B
119 | ```
120 | bash svdllm_gptq.sh
121 | ```
122 |
123 | ### 4. Evaluation
124 | - Perplexity Evaluation:
125 | ```
126 | python SVDLLM.py \
127 | --step 4 \
128 | --model_path COMPRESSD_MODEL_SAVING_PATH \
129 | ```
130 | We use the same c4 dataset as in [SparseGPT](https://github.com/IST-DASLab/sparsegpt). Since the original dowload link is invalid, please directly download it from this [link](https://drive.google.com/drive/folders/123Id1MkZVsKySGy_sMO4RgiJKrtPcvUp?usp=sharing) and add the two json files under the `utils/.` folder.
131 | - Efficiency Evaluation:
132 | ```
133 | python SVDLLM.py \
134 | --step 5 \
135 | --model_path COMPRESSD_MODEL_SAVING_PATH \
136 | ```
137 | ## Citation
138 | If you find this work useful, please cite
139 | ```
140 | @inproceedings{wang2025svdllm,
141 | title={{SVD}-{LLM}: Truncation-aware Singular Value Decomposition for Large Language Model Compression},
142 | author={Xin Wang and Yu Zheng and Zhongwei Wan and Mi Zhang},
143 | booktitle={International Conference on Learning Representations (ICLR)},
144 | year={2025},
145 | url={https://openreview.net/forum?id=LNYIUouhdt}
146 | }
147 | ```
148 |
--------------------------------------------------------------------------------
/component/svd_llama.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import torch.utils.checkpoint
6 | from torch import nn
7 |
8 | from transformers.activations import ACT2FN
9 | from transformers.utils import logging
10 | from transformers import LlamaConfig
11 |
12 | logger = logging.get_logger(__name__)
13 |
14 | _CONFIG_FOR_DOC = "LlamaConfig"
15 |
16 | class LlamaRMSNorm(nn.Module):
17 | def __init__(self, hidden_size, eps=1e-6):
18 | """
19 | LlamaRMSNorm is equivalent to T5LayerNorm
20 | """
21 | super().__init__()
22 | self.weight = nn.Parameter(torch.ones(hidden_size))
23 | self.variance_epsilon = eps
24 |
25 | def forward(self, hidden_states):
26 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
27 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
28 |
29 | # convert into half-precision if necessary
30 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
31 | hidden_states = hidden_states.to(self.weight.dtype)
32 |
33 | return self.weight * hidden_states
34 |
35 |
36 | class LlamaRotaryEmbedding(torch.nn.Module):
37 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
38 | super().__init__()
39 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
40 | self.register_buffer("inv_freq", inv_freq)
41 |
42 | # Build here to make `torch.jit.trace` work.
43 | self.max_seq_len_cached = max_position_embeddings
44 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
45 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
46 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
47 | emb = torch.cat((freqs, freqs), dim=-1)
48 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
49 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
50 |
51 | def forward(self, x, seq_len=None):
52 | # x: [bs, num_attention_heads, seq_len, head_size]
53 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
54 | if seq_len > self.max_seq_len_cached:
55 | self.max_seq_len_cached = seq_len
56 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
57 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
58 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
59 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
60 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
61 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
62 | return (
63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
65 | )
66 |
67 |
68 | def rotate_half(x):
69 | """Rotates half the hidden dims of the input."""
70 | x1 = x[..., : x.shape[-1] // 2]
71 | x2 = x[..., x.shape[-1] // 2 :]
72 | return torch.cat((-x2, x1), dim=-1)
73 |
74 |
75 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
76 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
77 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
78 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
79 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
80 |
81 | q_embed = (q * cos) + (rotate_half(q) * sin)
82 | k_embed = (k * cos) + (rotate_half(k) * sin)
83 | return q_embed, k_embed
84 |
85 |
86 | class SVD_LlamaMLP(nn.Module):
87 | def __init__(
88 | self,
89 | hidden_size: int,
90 | intermediate_size: int,
91 | hidden_act: str,
92 | ratio=1
93 | ):
94 | super().__init__()
95 | self.ratio = ratio
96 | low_rank = int(intermediate_size * hidden_size * self.ratio / (intermediate_size + hidden_size))
97 | self.gate_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
98 | self.gate_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
99 |
100 | self.down_u_proj = nn.Linear(low_rank, hidden_size, bias=False)
101 | self.down_v_proj = nn.Linear(intermediate_size, low_rank, bias=False)
102 |
103 | self.up_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
104 | self.up_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
105 | self.act_fn = ACT2FN[hidden_act]
106 |
107 | def forward(self, x):
108 | up = self.up_u_proj(self.up_v_proj(x))
109 | gate = self.gate_u_proj(self.gate_v_proj(x))
110 | return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))
111 |
112 |
113 | class SVD_LlamaAttention(nn.Module):
114 | """Multi-headed attention from 'Attention Is All You Need' paper"""
115 |
116 | def __init__(self, config: LlamaConfig, ratio=1):
117 | super().__init__()
118 | self.config = config
119 | self.hidden_size = config.hidden_size
120 | self.num_heads = config.num_attention_heads
121 | self.head_dim = self.hidden_size // self.num_heads
122 | self.max_position_embeddings = config.max_position_embeddings
123 | self.ratio = ratio # 1 means no truncate, just keep normal attn
124 |
125 | if (self.head_dim * self.num_heads) != self.hidden_size:
126 | raise ValueError(
127 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
128 | f" and `num_heads`: {self.num_heads})."
129 | )
130 | low_rank = int(self.hidden_size * self.ratio/2)
131 | self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
132 | self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
133 |
134 | self.k_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
135 | self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
136 |
137 | self.v_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
138 | self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
139 |
140 | self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False)
141 | self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False)
142 |
143 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
144 |
145 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
146 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
147 |
148 | def forward(
149 | self,
150 | hidden_states: torch.Tensor,
151 | attention_mask: Optional[torch.Tensor] = None,
152 | position_ids: Optional[torch.LongTensor] = None,
153 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
154 | output_attentions: bool = False,
155 | use_cache: bool = False,
156 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
157 | bsz, q_len, _ = hidden_states.size()
158 |
159 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
160 |
161 | key_states = self.k_u_proj(self.k_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
162 |
163 | value_states = self.v_u_proj(self.v_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
164 |
165 | kv_seq_len = key_states.shape[-2]
166 | if past_key_value is not None:
167 | kv_seq_len += past_key_value[0].shape[-2]
168 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
169 |
170 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
171 | # [bsz, nh, t, hd]
172 |
173 | if past_key_value is not None:
174 | # reuse k, v, self_attention
175 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
176 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
177 |
178 | past_key_value = (key_states, value_states) if use_cache else None
179 |
180 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
181 |
182 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
183 | raise ValueError(
184 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
185 | f" {attn_weights.size()}"
186 | )
187 |
188 | if attention_mask is not None:
189 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
190 | raise ValueError(
191 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
192 | )
193 | attn_weights = attn_weights + attention_mask
194 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device))
195 |
196 | # upcast attention to fp32
197 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
198 | attn_output = torch.matmul(attn_weights, value_states)
199 |
200 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
201 | raise ValueError(
202 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
203 | f" {attn_output.size()}"
204 | )
205 |
206 | attn_output = attn_output.transpose(1, 2)
207 | attn_output = attn_output.reshape(bsz, q_len, -1)
208 |
209 | attn_output = self.o_u_proj(self.o_v_proj(attn_output))
210 |
211 | if not output_attentions:
212 | attn_weights = None
213 |
214 | return attn_output, attn_weights, past_key_value
215 |
--------------------------------------------------------------------------------
/component/svd_llama_kvcache.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import torch.utils.checkpoint
6 | from torch import nn
7 |
8 | from transformers.activations import ACT2FN
9 | from transformers.utils import logging
10 | from transformers import LlamaConfig
11 |
12 | logger = logging.get_logger(__name__)
13 |
14 | _CONFIG_FOR_DOC = "LlamaConfig"
15 |
16 | class LlamaRMSNorm(nn.Module):
17 | def __init__(self, hidden_size, eps=1e-6):
18 | """
19 | LlamaRMSNorm is equivalent to T5LayerNorm
20 | """
21 | super().__init__()
22 | self.weight = nn.Parameter(torch.ones(hidden_size))
23 | self.variance_epsilon = eps
24 |
25 | def forward(self, hidden_states):
26 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
27 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
28 |
29 | # convert into half-precision if necessary
30 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
31 | hidden_states = hidden_states.to(self.weight.dtype)
32 |
33 | return self.weight * hidden_states
34 |
35 |
36 | class LlamaRotaryEmbedding(torch.nn.Module):
37 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
38 | super().__init__()
39 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
40 | self.register_buffer("inv_freq", inv_freq)
41 |
42 | # Build here to make `torch.jit.trace` work.
43 | self.max_seq_len_cached = max_position_embeddings
44 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
45 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
46 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
47 | emb = torch.cat((freqs, freqs), dim=-1)
48 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
49 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
50 |
51 | def forward(self, x, seq_len=None):
52 | # x: [bs, num_attention_heads, seq_len, head_size]
53 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
54 | if seq_len > self.max_seq_len_cached:
55 | self.max_seq_len_cached = seq_len
56 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
57 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
58 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
59 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
60 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
61 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
62 | return (
63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
65 | )
66 |
67 |
68 | def rotate_half(x):
69 | """Rotates half the hidden dims of the input."""
70 | x1 = x[..., : x.shape[-1] // 2]
71 | x2 = x[..., x.shape[-1] // 2 :]
72 | return torch.cat((-x2, x1), dim=-1)
73 |
74 |
75 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
76 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
77 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
78 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
79 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
80 |
81 | q_embed = (q * cos) + (rotate_half(q) * sin)
82 | k_embed = (k * cos) + (rotate_half(k) * sin)
83 | return q_embed, k_embed
84 |
85 |
86 | class SVD_LlamaMLP(nn.Module):
87 | def __init__(
88 | self,
89 | hidden_size: int,
90 | intermediate_size: int,
91 | hidden_act: str,
92 | compression_ratio=1,
93 | ):
94 | super().__init__()
95 | self.compression_ratio = compression_ratio
96 | if self.compression_ratio != 1:
97 | low_rank = int(intermediate_size * hidden_size * self.compression_ratio / (intermediate_size + hidden_size))
98 | self.gate_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
99 | self.gate_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
100 | else:
101 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
102 | if self.compression_ratio != 1:
103 | low_rank = int(intermediate_size * hidden_size * self.compression_ratio / (intermediate_size + hidden_size))
104 | self.down_u_proj = nn.Linear(low_rank, hidden_size, bias=False)
105 | self.down_v_proj = nn.Linear(intermediate_size, low_rank, bias=False)
106 | else:
107 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
108 | if self.compression_ratio != 1:
109 | low_rank = int(intermediate_size * hidden_size * self.compression_ratio / (intermediate_size + hidden_size))
110 | self.up_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
111 | self.up_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
112 | else:
113 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
114 | self.act_fn = ACT2FN[hidden_act]
115 |
116 | def forward(self, x):
117 | if self.compression_ratio != 1:
118 | up = self.up_u_proj(self.up_v_proj(x))
119 | else:
120 | up = self.up_proj(x)
121 | if self.compression_ratio != 1:
122 | gate = self.gate_u_proj(self.gate_v_proj(x))
123 | else:
124 | gate = self.gate_proj(x)
125 | if self.compression_ratio != 1:
126 | return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))
127 | else:
128 | return self.down_proj(self.act_fn(gate) * up)
129 |
130 |
131 | class SVD_LlamaAttention(nn.Module):
132 | """Multi-headed attention from 'Attention Is All You Need' paper"""
133 |
134 | def __init__(self, config: LlamaConfig, compression_ratio=1):
135 | super().__init__()
136 | self.config = config
137 | self.hidden_size = config.hidden_size
138 | self.num_heads = config.num_attention_heads
139 | self.head_dim = self.hidden_size // self.num_heads
140 | self.max_position_embeddings = config.max_position_embeddings
141 | self.compression_ratio = compression_ratio
142 | if (self.head_dim * self.num_heads) != self.hidden_size:
143 | raise ValueError(
144 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
145 | f" and `num_heads`: {self.num_heads})."
146 | )
147 | if self.compression_ratio != 1:
148 | low_rank = int(self.hidden_size * self.compression_ratio/2)
149 | self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
150 | self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
151 | else:
152 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
153 | if self.compression_ratio != 1:
154 | low_rank = int(self.hidden_size * self.compression_ratio/2)
155 | self.k_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
156 | self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
157 | else:
158 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
159 | if self.compression_ratio != 1:
160 | low_rank = int(self.hidden_size * self.compression_ratio/2)
161 | self.v_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
162 | self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
163 | else:
164 | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
165 | if self.compression_ratio != 1:
166 | low_rank = int(self.hidden_size * self.compression_ratio/2)
167 | self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False)
168 | self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False)
169 | else:
170 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
171 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
172 |
173 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
174 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
175 |
176 | def forward(
177 | self,
178 | hidden_states: torch.Tensor,
179 | attention_mask: Optional[torch.Tensor] = None,
180 | position_ids: Optional[torch.LongTensor] = None,
181 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
182 | output_attentions: bool = False,
183 | use_cache: bool = False,
184 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
185 | bsz, q_len, _ = hidden_states.size()
186 |
187 | if self.compression_ratio != 1:
188 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
189 | else:
190 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
191 | if self.compression_ratio != 1:
192 | # TODO
193 | if use_cache:
194 | key_v_x_states = self.k_v_proj(hidden_states)
195 | key_states = self.k_u_proj(key_v_x_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
196 | else:
197 | key_states = self.k_u_proj(self.k_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198 | else:
199 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200 | if self.compression_ratio != 1:
201 | # TODO
202 | if use_cache:
203 | value_v_x_states = self.v_v_proj(hidden_states)
204 | value_states = self.v_u_proj(value_v_x_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
205 | else:
206 | value_states = self.v_u_proj(self.v_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
207 | else:
208 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
209 |
210 | # key_states.shape = query_states = value_states[8,32,2048,128]
211 | # cos.shape = sin.shape = [1,1,2048,128]
212 | kv_seq_len = key_states.shape[-2]
213 | if past_key_value is not None:
214 | kv_seq_len += past_key_value[0].shape[-2]
215 | # TODO: pass the shared cos / sin cache from top module (original code only require dev, increasing kv_seq_len, value datatype)
216 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
217 | query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids)
218 | key_states = apply_rotary_pos_emb(key_states, cos, sin, position_ids)
219 | ################################################################################################################################
220 | # [bsz, nh, t, hd]
221 | # TODO: store the key_v_x_states, value_v_x_states in past_key_value rather than the key,value states
222 | if past_key_value is not None:
223 | # reuse k, v, self_attention
224 |
225 | key_v_x_states = torch.cat([past_key_value[0], key_v_x_states], dim=(1 if self.compression_ratio != 1 else 2))
226 | value_v_x_states = torch.cat([past_key_value[1], value_v_x_states], dim=(1 if self.compression_ratio != 1 else 2))
227 |
228 | past_key_value = (key_v_x_states, value_v_x_states) if use_cache else None
229 | ##################################################################################
230 | # TODO: restore the k and v
231 | if self.compression_ratio != 1:
232 | key_states = self.k_u_proj(key_v_x_states).view(bsz, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
233 | value_states = self.v_u_proj(value_v_x_states).view(bsz, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
234 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
235 |
236 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
237 | raise ValueError(
238 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
239 | f" {attn_weights.size()}"
240 | )
241 |
242 | if attention_mask is not None:
243 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
244 | raise ValueError(
245 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
246 | )
247 | attn_weights = attn_weights + attention_mask
248 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device))
249 |
250 | # upcast attention to fp32
251 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
252 | attn_output = torch.matmul(attn_weights, value_states)
253 | # # TODO: clean
254 | # del key_states, value_states
255 | # key_states = None
256 | # value_states = None
257 | # torch.cuda.empty_cache()
258 |
259 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
260 | raise ValueError(
261 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
262 | f" {attn_output.size()}"
263 | )
264 |
265 | attn_output = attn_output.transpose(1, 2)
266 | attn_output = attn_output.reshape(bsz, q_len, -1)
267 |
268 | if self.compression_ratio != 1:
269 | attn_output = self.o_u_proj(self.o_v_proj(attn_output))
270 | else:
271 | attn_output = self.o_proj(attn_output)
272 |
273 | if not output_attentions:
274 | attn_weights = None
275 |
276 | return attn_output, attn_weights, past_key_value
--------------------------------------------------------------------------------
/component/svd_mistral.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch Mistral model."""
21 | import inspect
22 | import math
23 | import warnings
24 | from typing import List, Optional, Tuple, Union
25 |
26 | import torch
27 | import torch.nn.functional as F
28 | import torch.utils.checkpoint
29 | from torch import nn
30 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31 |
32 | from transformers.activations import ACT2FN
33 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
34 | # from transormers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
35 | from transformers.modeling_utils import PreTrainedModel
36 | from transformers.utils import (
37 | add_start_docstrings,
38 | add_start_docstrings_to_model_forward,
39 | is_flash_attn_2_available,
40 | logging,
41 | replace_return_docstrings,
42 | )
43 | from transformers.models.mistral import MistralConfig
44 |
45 |
46 | if is_flash_attn_2_available():
47 | from flash_attn import flash_attn_func, flash_attn_varlen_func
48 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
49 |
50 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
51 |
52 |
53 | logger = logging.get_logger(__name__)
54 |
55 | _CONFIG_FOR_DOC = "MistralConfig"
56 |
57 |
58 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data
59 | def _get_unpad_data(attention_mask):
60 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
61 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
62 | max_seqlen_in_batch = seqlens_in_batch.max().item()
63 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
64 | return (
65 | indices,
66 | cu_seqlens,
67 | max_seqlen_in_batch,
68 | )
69 |
70 |
71 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
72 | class MistralRMSNorm(nn.Module):
73 | def __init__(self, hidden_size, eps=1e-6):
74 | """
75 | MistralRMSNorm is equivalent to T5LayerNorm
76 | """
77 | super().__init__()
78 | self.weight = nn.Parameter(torch.ones(hidden_size))
79 | self.variance_epsilon = eps
80 |
81 | def forward(self, hidden_states):
82 | input_dtype = hidden_states.dtype
83 | hidden_states = hidden_states.to(torch.float32)
84 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
85 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86 | return self.weight * hidden_states.to(input_dtype)
87 |
88 |
89 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
90 | class MistralRotaryEmbedding(nn.Module):
91 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
92 | super().__init__()
93 |
94 | self.dim = dim
95 | self.max_position_embeddings = max_position_embeddings
96 | self.base = base
97 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
98 | self.register_buffer("inv_freq", inv_freq, persistent=False)
99 |
100 | # Build here to make `torch.jit.trace` work.
101 | self._set_cos_sin_cache(
102 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
103 | )
104 |
105 | def _set_cos_sin_cache(self, seq_len, device, dtype):
106 | self.max_seq_len_cached = seq_len
107 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
108 |
109 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
110 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
111 | emb = torch.cat((freqs, freqs), dim=-1)
112 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
113 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
114 |
115 | def forward(self, x, seq_len=None):
116 | # x: [bs, num_attention_heads, seq_len, head_size]
117 | if seq_len > self.max_seq_len_cached:
118 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
119 |
120 | return (
121 | self.cos_cached[:seq_len].to(dtype=x.dtype),
122 | self.sin_cached[:seq_len].to(dtype=x.dtype),
123 | )
124 |
125 |
126 | # Copied from transformers.models.llama.modeling_llama.rotate_half
127 | def rotate_half(x):
128 | """Rotates half the hidden dims of the input."""
129 | x1 = x[..., : x.shape[-1] // 2]
130 | x2 = x[..., x.shape[-1] // 2 :]
131 | return torch.cat((-x2, x1), dim=-1)
132 |
133 |
134 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
135 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
136 | """Applies Rotary Position Embedding to the query and key tensors.
137 |
138 | Args:
139 | q (`torch.Tensor`): The query tensor.
140 | k (`torch.Tensor`): The key tensor.
141 | cos (`torch.Tensor`): The cosine part of the rotary embedding.
142 | sin (`torch.Tensor`): The sine part of the rotary embedding.
143 | position_ids (`torch.Tensor`):
144 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be
145 | used to pass offsetted position ids when working with a KV-cache.
146 | unsqueeze_dim (`int`, *optional*, defaults to 1):
147 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
148 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
149 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
150 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
151 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
152 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
153 | Returns:
154 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
155 | """
156 | cos = cos[position_ids].unsqueeze(unsqueeze_dim)
157 | sin = sin[position_ids].unsqueeze(unsqueeze_dim)
158 | q_embed = (q * cos) + (rotate_half(q) * sin)
159 | k_embed = (k * cos) + (rotate_half(k) * sin)
160 | return q_embed, k_embed
161 |
162 |
163 | class SVD_MistralMLP(nn.Module):
164 | def __init__(self, config,
165 | ratio=1 # 1 means no truncate, just keep normal MLP
166 | ):
167 | super().__init__()
168 | self.config = config
169 | self.hidden_size = config.hidden_size
170 | self.intermediate_size = config.intermediate_size
171 | self.ratio = ratio
172 | low_rank = int(self.intermediate_size * self.hidden_size * self.ratio / (self.intermediate_size + self.hidden_size))
173 | self.gate_u_proj = nn.Linear(low_rank, self.intermediate_size, bias=False)
174 | self.gate_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
175 |
176 | self.down_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False)
177 | self.down_v_proj = nn.Linear(self.intermediate_size, low_rank, bias=False)
178 |
179 | self.up_u_proj = nn.Linear(low_rank, self.intermediate_size, bias=False)
180 | self.up_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
181 | self.act_fn = ACT2FN[config.hidden_act]
182 |
183 | def forward(self, x):
184 | up = self.up_u_proj(self.up_v_proj(x))
185 | gate = self.gate_u_proj(self.gate_v_proj(x))
186 | return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))
187 |
188 |
189 | # Copied from transformers.models.llama.modeling_llama.repeat_kv
190 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
191 | """
192 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
193 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
194 | """
195 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
196 | if n_rep == 1:
197 | return hidden_states
198 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
199 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
200 |
201 |
202 | class SVD_MistralAttention(nn.Module):
203 | """
204 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
205 | and "Generating Long Sequences with Sparse Transformers".
206 | """
207 |
208 | def __init__(self, config: MistralConfig,
209 | ratio=1):
210 | super().__init__()
211 | self.config = config
212 | self.hidden_size = config.hidden_size
213 | self.num_heads = config.num_attention_heads
214 | self.head_dim = self.hidden_size // self.num_heads
215 | self.num_key_value_heads = config.num_key_value_heads
216 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
217 | self.max_position_embeddings = config.max_position_embeddings
218 | self.rope_theta = config.rope_theta
219 | self.is_causal = True
220 | self.ratio = ratio # 1 means no truncate, just keep normal attn
221 |
222 | if (self.head_dim * self.num_heads) != self.hidden_size:
223 | raise ValueError(
224 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
225 | f" and `num_heads`: {self.num_heads})."
226 | )
227 | low_rank = int(self.hidden_size * self.ratio/2)
228 | self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
229 | self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
230 | self.k_u_proj = nn.Linear(low_rank, self.num_key_value_heads * self.head_dim, bias=False)
231 | self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
232 | self.v_u_proj = nn.Linear(low_rank, self.num_key_value_heads * self.head_dim, bias=False)
233 | self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)
234 | self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False)
235 | self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False)
236 |
237 | self.rotary_emb = MistralRotaryEmbedding(
238 | self.head_dim,
239 | max_position_embeddings=self.max_position_embeddings,
240 | base=self.rope_theta,
241 | )
242 |
243 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
244 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
245 |
246 | def forward(
247 | self,
248 | hidden_states: torch.Tensor,
249 | attention_mask: Optional[torch.Tensor] = None,
250 | position_ids: Optional[torch.LongTensor] = None,
251 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
252 | output_attentions: bool = False,
253 | use_cache: bool = False,
254 | **kwargs,
255 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
256 | if "padding_mask" in kwargs:
257 | warnings.warn(
258 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
259 | )
260 | bsz, q_len, _ = hidden_states.size()
261 | query_states = self.q_u_proj(self.q_v_proj(hidden_states))
262 | key_states = self.k_u_proj(self.k_v_proj(hidden_states))
263 | value_states = self.v_u_proj(self.v_v_proj(hidden_states))
264 |
265 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
266 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
267 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
268 |
269 | kv_seq_len = key_states.shape[-2]
270 | if past_key_value is not None:
271 | kv_seq_len += past_key_value[0].shape[-2]
272 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
273 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
274 |
275 | if past_key_value is not None:
276 | # reuse k, v, self_attention
277 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
278 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
279 |
280 | past_key_value = (key_states, value_states) if use_cache else None
281 |
282 | # repeat k/v heads if n_kv_heads < n_heads
283 | key_states = repeat_kv(key_states, self.num_key_value_groups)
284 | value_states = repeat_kv(value_states, self.num_key_value_groups)
285 |
286 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
287 |
288 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
289 | raise ValueError(
290 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
291 | f" {attn_weights.size()}"
292 | )
293 |
294 | if attention_mask is not None:
295 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
296 | raise ValueError(
297 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
298 | )
299 |
300 | attn_weights = attn_weights + attention_mask
301 |
302 | # upcast attention to fp32
303 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
304 | attn_output = torch.matmul(attn_weights, value_states)
305 |
306 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
307 | raise ValueError(
308 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
309 | f" {attn_output.size()}"
310 | )
311 |
312 | attn_output = attn_output.transpose(1, 2).contiguous()
313 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
314 |
315 | attn_output = self.o_u_proj(self.o_v_proj(attn_output))
316 |
317 | if not output_attentions:
318 | attn_weights = None
319 |
320 | return attn_output, attn_weights, past_key_value
321 |
--------------------------------------------------------------------------------
/component/svd_opt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch OPT model."""
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn.functional as F
20 | import torch.utils.checkpoint
21 | from torch import nn
22 |
23 | from transformers.activations import ACT2FN
24 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
25 |
26 | from transformers.models.opt.configuration_opt import OPTConfig
27 |
28 |
29 | from transformers.utils import logging
30 |
31 | _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
32 | _CONFIG_FOR_DOC = "OPTConfig"
33 |
34 | # Base model docstring
35 | _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
36 |
37 | # SequenceClassification docstring
38 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
39 | _SEQ_CLASS_EXPECTED_LOSS = 1.71
40 | _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
41 |
42 | OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43 | "facebook/opt-125m",
44 | "facebook/opt-350m",
45 | "facebook/opt-1.3b",
46 | "facebook/opt-2.7b",
47 | "facebook/opt-6.7b",
48 | "facebook/opt-13b",
49 | "facebook/opt-30b",
50 | # See all OPT models at https://huggingface.co/models?filter=opt
51 | ]
52 |
53 |
54 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55 | def _get_unpad_data(attention_mask):
56 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
57 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
58 | max_seqlen_in_batch = seqlens_in_batch.max().item()
59 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
60 | return (
61 | indices,
62 | cu_seqlens,
63 | max_seqlen_in_batch,
64 | )
65 |
66 |
67 | class OPTLearnedPositionalEmbedding(nn.Embedding):
68 | """
69 | This module learns positional embeddings up to a fixed maximum size.
70 | """
71 |
72 | def __init__(self, num_embeddings: int, embedding_dim: int):
73 | # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
74 | # and adjust num_embeddings appropriately. Other models don't have this hack
75 | self.offset = 2
76 | super().__init__(num_embeddings + self.offset, embedding_dim)
77 |
78 | def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
79 | """`input_ids_shape` is expected to be [bsz x seqlen]."""
80 | attention_mask = attention_mask.long()
81 |
82 | # create positions depending on attention_mask
83 | positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
84 |
85 | # cut positions if `past_key_values_length` is > 0
86 | positions = positions[:, past_key_values_length:]
87 |
88 | return super().forward(positions + self.offset)
89 |
90 |
91 | class SVDOPTAttention(nn.Module):
92 | """Multi-headed attention from 'Attention Is All You Need' paper"""
93 |
94 | def __init__(
95 | self,
96 | config: OPTConfig,
97 | is_decoder: bool = False,
98 | ratio=1,
99 | **kwargs,
100 | ):
101 | super().__init__()
102 | self.config = config
103 |
104 | def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
105 | """
106 | If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
107 | warning and return that value, otherwise take the equivalent config.config_arg_name
108 | """
109 | val = None
110 | if fn_arg_name in kwargs:
111 | logging.warning(
112 | "Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
113 | " Please set it in the config instead"
114 | )
115 | val = kwargs.pop(fn_arg_name)
116 | else:
117 | val = getattr(config, config_arg_name)
118 | return val
119 |
120 | self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
121 | self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
122 | self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
123 | self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)
124 |
125 | self.head_dim = self.embed_dim // self.num_heads
126 | self.is_causal = True
127 | self.ratio = ratio
128 |
129 | if (self.head_dim * self.num_heads) != self.embed_dim:
130 | raise ValueError(
131 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
132 | f" and `num_heads`: {self.num_heads})."
133 | )
134 | self.scaling = self.head_dim**-0.5
135 | self.is_decoder = is_decoder
136 |
137 | if self.ratio != 1:
138 | low_rank = int(self.embed_dim * self.ratio/2)
139 | self.q_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias)
140 | self.q_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False)
141 | else:
142 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
143 | if self.ratio != 1:
144 | low_rank = int(self.embed_dim * self.ratio/2)
145 | self.k_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias)
146 | self.k_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False)
147 | else:
148 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
149 | if self.ratio != 1:
150 | low_rank = int(self.embed_dim * self.ratio/2)
151 | self.v_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias)
152 | self.v_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False)
153 | else:
154 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
155 | if self.ratio != 1:
156 | low_rank = int(self.embed_dim * self.ratio/2)
157 | self.out_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias)
158 | self.out_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False)
159 | else:
160 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
161 |
162 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
163 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
164 |
165 | def forward(
166 | self,
167 | hidden_states: torch.Tensor,
168 | key_value_states: Optional[torch.Tensor] = None,
169 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
170 | attention_mask: Optional[torch.Tensor] = None,
171 | layer_head_mask: Optional[torch.Tensor] = None,
172 | output_attentions: bool = False,
173 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
174 | """Input shape: Batch x Time x Channel"""
175 |
176 | # if key_value_states are provided this layer is used as a cross-attention layer
177 | # for the decoder
178 | is_cross_attention = key_value_states is not None
179 |
180 | bsz, tgt_len, _ = hidden_states.size()
181 |
182 | # get query proj
183 | if self.ratio != 1:
184 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)) * self.scaling
185 | else:
186 | query_states = self.q_proj(hidden_states) * self.scaling
187 | # get key, value proj
188 | if is_cross_attention and past_key_value is not None:
189 | # reuse k,v, cross_attentions
190 | key_states = past_key_value[0]
191 | value_states = past_key_value[1]
192 | elif is_cross_attention:
193 | # cross_attentions
194 | if self.ratio != 1:
195 | key_states = self._shape(self.k_u_proj(self.k_v_proj(key_value_states)), -1, bsz)
196 | else:
197 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
198 | if self.ratio != 1:
199 | value_states = self._shape(self.v_u_proj(self.v_v_proj(key_value_states)), -1, bsz)
200 | else:
201 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
202 | elif past_key_value is not None:
203 | # reuse k, v, self_attention
204 | if self.ratio != 1:
205 | key_states = self._shape(self.k_u_proj(self.k_v_proj(hidden_states)), -1, bsz)
206 | else:
207 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
208 | if self.ratio != 1:
209 | value_states = self._shape(self.v_u_proj(self.v_v_proj(hidden_states)), -1, bsz)
210 | else:
211 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
212 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
213 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
214 | else:
215 | # self_attention
216 | if self.ratio != 1:
217 | key_states = self._shape(self.k_u_proj(self.k_v_proj(hidden_states)), -1, bsz)
218 | else:
219 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
220 | if self.ratio != 1:
221 | value_states = self._shape(self.v_u_proj(self.v_v_proj(hidden_states)), -1, bsz)
222 | else:
223 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
224 |
225 | if self.is_decoder:
226 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
227 | # Further calls to cross_attention layer can then reuse all cross-attention
228 | # key/value_states (first "if" case)
229 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
230 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
231 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
232 | # if encoder bi-directional self-attention `past_key_value` is always `None`
233 | past_key_value = (key_states, value_states)
234 |
235 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
236 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
237 | key_states = key_states.view(*proj_shape)
238 | value_states = value_states.view(*proj_shape)
239 |
240 | src_len = key_states.size(1)
241 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
242 |
243 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
244 | raise ValueError(
245 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
246 | f" {attn_weights.size()}"
247 | )
248 |
249 | if attention_mask is not None:
250 | if attention_mask.size() != (bsz, 1, tgt_len, src_len):
251 | raise ValueError(
252 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
253 | )
254 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
255 | attn_weights = torch.max(
256 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
257 | )
258 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
259 |
260 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
261 | if attn_weights.dtype == torch.float16:
262 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
263 | else:
264 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
265 |
266 | if layer_head_mask is not None:
267 | if layer_head_mask.size() != (self.num_heads,):
268 | raise ValueError(
269 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
270 | f" {layer_head_mask.size()}"
271 | )
272 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
273 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
274 |
275 | if output_attentions:
276 | # this operation is a bit awkward, but it's required to
277 | # make sure that attn_weights keeps its gradient.
278 | # In order to do so, attn_weights have to be reshaped
279 | # twice and have to be reused in the following
280 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
281 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
282 | else:
283 | attn_weights_reshaped = None
284 |
285 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
286 |
287 | attn_output = torch.bmm(attn_probs, value_states)
288 |
289 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
290 | raise ValueError(
291 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
292 | f" {attn_output.size()}"
293 | )
294 |
295 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
296 | attn_output = attn_output.transpose(1, 2)
297 |
298 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
299 | # partitioned aross GPUs when using tensor-parallelism.
300 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
301 |
302 | if self.ratio != 1:
303 | attn_output = self.out_u_proj(self.out_v_proj(attn_output))
304 | else:
305 | attn_output = self.out_proj(attn_output)
306 |
307 | return attn_output, attn_weights_reshaped, past_key_value
308 |
309 |
310 |
311 | class SVDOPTDecoderLayer(nn.Module):
312 | def __init__(self, config: OPTConfig, ratio = 1):
313 | super().__init__()
314 | self.embed_dim = config.hidden_size
315 |
316 | self.self_attn = SVDOPTAttention(config=config, ratio=ratio, is_decoder=True)
317 |
318 | self.do_layer_norm_before = config.do_layer_norm_before
319 | self.dropout = config.dropout
320 | self.activation_fn = ACT2FN[config.activation_function]
321 |
322 | self.self_attn_layer_norm = nn.LayerNorm(
323 | self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
324 | )
325 | self.ratio = ratio
326 | if self.ratio != 1:
327 | low_rank = int(config.ffn_dim * self.embed_dim * self.ratio / (config.ffn_dim + self.embed_dim))
328 | self.fc1_u_proj = nn.Linear(low_rank, config.ffn_dim, bias=config.enable_bias)
329 | self.fc1_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False)
330 | else:
331 | self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
332 | if self.ratio != 1:
333 | low_rank = int(config.ffn_dim * self.embed_dim * self.ratio / (config.ffn_dim + self.embed_dim))
334 | self.fc2_u_proj = nn.Linear(low_rank, self.embed_dim, bias=config.enable_bias)
335 | self.fc2_v_proj = nn.Linear(config.ffn_dim, low_rank, bias=False)
336 | else:
337 | self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
338 | self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
339 |
340 | def forward(
341 | self,
342 | hidden_states: torch.Tensor,
343 | attention_mask: Optional[torch.Tensor] = None,
344 | layer_head_mask: Optional[torch.Tensor] = None,
345 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
346 | output_attentions: Optional[bool] = False,
347 | use_cache: Optional[bool] = False,
348 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
349 | """
350 | Args:
351 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
352 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
353 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
354 | layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
355 | `(encoder_attention_heads,)`.
356 | output_attentions (`bool`, *optional*):
357 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
358 | returned tensors for more detail.
359 | use_cache (`bool`, *optional*):
360 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
361 | (see `past_key_values`).
362 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
363 | """
364 |
365 | residual = hidden_states
366 |
367 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
368 | if self.do_layer_norm_before:
369 | hidden_states = self.self_attn_layer_norm(hidden_states)
370 |
371 | # Self Attention
372 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
373 | hidden_states=hidden_states,
374 | past_key_value=past_key_value,
375 | attention_mask=attention_mask,
376 | layer_head_mask=layer_head_mask,
377 | output_attentions=output_attentions,
378 | )
379 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
380 | hidden_states = residual + hidden_states
381 |
382 | # 350m applies layer norm AFTER attention
383 | if not self.do_layer_norm_before:
384 | hidden_states = self.self_attn_layer_norm(hidden_states)
385 |
386 | # Fully Connected
387 | hidden_states_shape = hidden_states.shape
388 | hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
389 | residual = hidden_states
390 |
391 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
392 | if self.do_layer_norm_before:
393 | hidden_states = self.final_layer_norm(hidden_states)
394 |
395 | if self.ratio != 1:
396 | hidden_states = self.fc1_u_proj(self.fc1_v_proj(hidden_states))
397 | else:
398 | hidden_states = self.fc1(hidden_states)
399 | hidden_states = self.activation_fn(hidden_states)
400 |
401 | if self.ratio != 1:
402 | hidden_states = self.fc2_u_proj(self.fc2_v_proj(hidden_states))
403 | else:
404 | hidden_states = self.fc2(hidden_states)
405 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
406 |
407 | hidden_states = (residual + hidden_states).view(hidden_states_shape)
408 |
409 | # 350m applies layer norm AFTER attention
410 | if not self.do_layer_norm_before:
411 | hidden_states = self.final_layer_norm(hidden_states)
412 |
413 | outputs = (hidden_states,)
414 |
415 | if output_attentions:
416 | outputs += (self_attn_weights,)
417 |
418 | if use_cache:
419 | outputs += (present_key_value,)
420 |
421 | return outputs
422 |
423 |
--------------------------------------------------------------------------------
/compress_llama.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # example of compressing LLaMA-7B with SVDLLM
4 | FINE_TUNE_PATH="."
5 | # run data whitening with 20% compression ratio
6 | python SVDLLM.py --model jeffwan/llama-7b-hf --step 1 --ratio 0.2 --whitening_nsamples 256 --dataset wikitext2 --seed 3 --model_seq_len 2048 --save_path .
7 | ## you can also run the following command for low-resource gpu (ex. llama 7b will only need 15G gpu memory to compress) or to compress large-scale llm (ex. llama 65b)
8 | # python SVDLLM.py --model jeffwan/llama-7b-hf --step 1 --ratio 0.2 --whitening_nsamples 256 --dataset wikitext2 --model_seq_len 2048 --save_path ./ --run_low_resource
9 | python SVDLLM.py --step 4 --model_path jeffwan_llama_7b_hf_whitening_only_0.8.pt
10 | # finetune the compressed model with lora
11 | python utils/LoRA.py --prune_model --data_path yahma/alpaca-cleaned --output_dir $FINE_TUNE_PATH/first_half --lora_target_modules q_u_proj,k_u_proj,v_u_proj,o_u_proj,gate_u_proj,down_u_proj,up_u_proj --lora_r 8 --num_epochs 3 --learning_rate 1e-4 --batch_size 64
12 | python SVDLLM.py --model_path jeffwan_llama_7b_hf_whitening_only_0.8.pt --lora $FINE_TUNE_PATH/first_half /first_half --step 4
13 | python utils/LoRA.py --prune_model $FINE_TUNE_PATH/first_half/merge.pt --data_path yahma/alpaca-cleaned --output_dir $FINE_TUNE_PATH/second_half --lora_target_modules q_v_proj,k_v_proj,v_v_proj,o_v_proj,gate_v_proj,down_v_proj,up_v_proj --lora_r 8 --num_epochs 3 --learning_rate 1e-4 --batch_size 64
14 | python SVDLLM.py --model_path jeffwan_llama_7b_hf_whitening_only_0.8.pt --lora $FINE_TUNE_PATH/first_half /first_half --step 4
15 | python SVDLLM.py --model_path $FINE_TUNE_PATH/first_half/merge.pt --lora $FINE_TUNE_PATH/second_half --step 4
--------------------------------------------------------------------------------
/evaluater.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 | import time
5 | import itertools
6 | from utils.data_utils import get_test_data
7 | import os
8 | import sys
9 |
10 | current_path = os.path.dirname(os.path.abspath(__file__))
11 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12 | sys.path.append(current_path)
13 |
14 | @torch.no_grad()
15 | def ppl_eval(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], model_seq_len=2048, batch_size=32, device="cuda"):
16 | model.to(device)
17 | model.eval()
18 | ppls = {}
19 | for dataset in datasets:
20 | test_loader = get_test_data(dataset, tokenizer, seq_len=model_seq_len, batch_size = batch_size)
21 | nlls = []
22 | for batch in tqdm(test_loader):
23 | batch = batch.to(device)
24 | output = model(batch, use_cache=False)
25 | lm_logits = output.logits
26 | if torch.isfinite(lm_logits).all():
27 | shift_logits = lm_logits[:, :-1, :].contiguous()
28 | shift_labels = batch[:, 1:].contiguous()
29 |
30 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
31 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
32 | nlls.append(loss)
33 | ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
34 | ppls[dataset] = ppl
35 | print("PPL after pruning: {}".format(ppls))
36 | print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024))
37 |
38 | # only call this function when for 65b or more model
39 | @torch.no_grad()
40 | def ppl_eval_large(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], seq_len=2048, batch_size=32, device="cuda"):
41 | import torch.nn as nn
42 | class LlamaRMSNorm(nn.Module):
43 | def __init__(self, hidden_size=model.config.hidden_size, eps=model.config.rms_norm_eps):
44 | """
45 | LlamaRMSNorm is equivalent to T5LayerNorm
46 | """
47 | super().__init__()
48 | self.weight = nn.Parameter(torch.ones(hidden_size))
49 | self.variance_epsilon = eps
50 |
51 | def forward(self, hidden_states):
52 | input_dtype = hidden_states.dtype
53 | hidden_states = hidden_states.to(torch.float32)
54 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
55 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
56 | return self.weight * hidden_states.to(input_dtype)
57 | norm = LlamaRMSNorm().half().cuda()
58 | lm_head = model.lm_head.cuda()
59 | model.eval()
60 | ppls = {}
61 | layers = model.model.layers
62 | for dataset in datasets:
63 | test_loader = get_test_data(dataset, tokenizer, seq_len=seq_len, batch_size = batch_size)
64 | nlls = []
65 | for batch in tqdm(test_loader):
66 | model.model.embed_tokens = model.model.embed_tokens.cuda()
67 | model.model.norm = model.model.norm.cuda()
68 | layers[0] = layers[0].cuda()
69 |
70 | dtype = next(iter(model.parameters())).dtype
71 | inps = torch.zeros(
72 | (batch.shape[0], model.seqlen, model.config.hidden_size), dtype=dtype, device="cuda"
73 | )
74 | cache = {'i': 0, 'attention_mask': None, "position_ids": None}
75 | class Catcher(nn.Module):
76 | def __init__(self, module):
77 | super().__init__()
78 | self.module = module
79 | def forward(self, inp, **kwargs):
80 | inps[cache['i']] = inp
81 | cache['i'] += 1
82 | if cache['attention_mask'] is None:
83 | cache['attention_mask'] = kwargs['attention_mask']
84 | cache['position_ids'] = kwargs['position_ids']
85 | else:
86 | cache['attention_mask'] = torch.cat((cache['attention_mask'], kwargs['attention_mask']), dim=0)
87 | cache['position_ids'] = torch.cat((cache['position_ids'], kwargs['position_ids']), dim=0)
88 | raise ValueError
89 | layers[0] = Catcher(layers[0])
90 | for j in range(batch.shape[0]):
91 | try:
92 | model(batch[j].unsqueeze(0).cuda())
93 | except ValueError:
94 | pass
95 | layers[0] = layers[0].module
96 | layers[0] = layers[0].cpu()
97 | model.model.embed_tokens = model.model.embed_tokens.cpu()
98 | model.model.norm = model.model.norm.cpu()
99 | torch.cuda.empty_cache()
100 | attention_masks = cache['attention_mask']
101 | position_ids = cache['position_ids']
102 | for i in range(len(layers)):
103 | layer = layers[i].cuda()
104 | outs = layer(inps, attention_mask=attention_masks, position_ids=position_ids)[0]
105 | layers[i] = layer.cpu()
106 | inps = outs
107 | torch.cuda.empty_cache()
108 | hidden_states = norm(outs)
109 | lm_logits = lm_head(hidden_states)
110 | if torch.isfinite(lm_logits).all():
111 | shift_logits = lm_logits[:, :-1, :].contiguous()
112 | shift_labels = batch[:, 1:].contiguous().cuda()
113 |
114 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
115 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
116 | nlls.append(loss)
117 | else:
118 | print("warning: nan or inf in lm_logits")
119 | ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
120 | ppls[dataset] = ppl
121 | print("PPL after pruning: {}".format(ppls))
122 | print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024))
123 |
124 | @torch.no_grad()
125 | def eff_eval(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"):
126 | model.eval()
127 | throughput = 0
128 | token_num = 0
129 | end_memory = 0
130 | num_batches_to_fetch = 10
131 | test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size = batch_size)
132 | weight_memory = torch.cuda.memory_allocated()
133 | for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)):
134 | batch = batch_data.to(device)
135 | token_num += batch.shape[0] * generated_len
136 | torch.cuda.empty_cache()
137 | start_memory = torch.cuda.memory_allocated()
138 | torch.cuda.reset_peak_memory_stats(0)
139 | torch.cuda.synchronize()
140 | start_time = time.time()
141 | generation_output = model.generate(
142 | input_ids=batch,
143 | pad_token_id=tokenizer.eos_token_id,
144 | do_sample=True,
145 | use_cache=True,
146 | top_k=50,
147 | max_length=original_len+generated_len,
148 | top_p=0.95,
149 | temperature=1,
150 | )
151 | torch.cuda.synchronize()
152 | end_time = time.time()
153 | end_memory = max(torch.cuda.max_memory_allocated(0), end_memory)
154 | if torch.isfinite(generation_output[0]).all(): # check if the generation is successful since fp16 may cause nan
155 | throughput += end_time - start_time
156 | print("time: {}".format(end_time - start_time))
157 | print("Total Memory: {} GB".format(end_memory/(1024 ** 3)))
158 | print("Weight Memory: {} GB".format(weight_memory/(1024 ** 3)))
159 | print("Activation Memory: {} GB".format((end_memory - start_memory)/(1024 ** 3)))
160 | print("Throughput: {} tokens/sec".format(token_num / throughput))
161 |
162 |
--------------------------------------------------------------------------------
/figures/framework_v1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/SVD-LLM/1c1009a99bea54c22c27a4dd5d9fba7116dc456b/figures/framework_v1.jpg
--------------------------------------------------------------------------------
/figures/framework_v2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/SVD-LLM/1c1009a99bea54c22c27a4dd5d9fba7116dc456b/figures/framework_v2.jpg
--------------------------------------------------------------------------------
/figures/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/SVD-LLM/1c1009a99bea54c22c27a4dd5d9fba7116dc456b/figures/logo.png
--------------------------------------------------------------------------------
/gptq/gptq.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | import sys
9 | import os
10 | current_path = os.path.dirname(os.path.abspath(__file__))
11 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12 | sys.path.append(current_path)
13 | from quant import *
14 |
15 |
16 | DEBUG = False
17 |
18 | torch.backends.cuda.matmul.allow_tf32 = False
19 | torch.backends.cudnn.allow_tf32 = False
20 |
21 |
22 | class GPTQ:
23 |
24 | def __init__(self, layer):
25 | self.layer = layer
26 | self.dev = self.layer.weight.device
27 | W = layer.weight.data.clone()
28 | if isinstance(self.layer, nn.Conv2d):
29 | W = W.flatten(1)
30 | if isinstance(self.layer, transformers.Conv1D):
31 | W = W.t()
32 | self.rows = W.shape[0]
33 | self.columns = W.shape[1]
34 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
35 | self.nsamples = 0
36 |
37 | def add_batch(self, inp, out):
38 | if DEBUG:
39 | self.inp1 = inp
40 | self.out1 = out
41 | if len(inp.shape) == 2:
42 | inp = inp.unsqueeze(0)
43 | tmp = inp.shape[0]
44 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
45 | if len(inp.shape) == 3:
46 | inp = inp.reshape((-1, inp.shape[-1]))
47 | inp = inp.t()
48 | if isinstance(self.layer, nn.Conv2d):
49 | unfold = nn.Unfold(
50 | self.layer.kernel_size,
51 | dilation=self.layer.dilation,
52 | padding=self.layer.padding,
53 | stride=self.layer.stride
54 | )
55 | inp = unfold(inp)
56 | inp = inp.permute([1, 0, 2])
57 | inp = inp.flatten(1)
58 | self.H *= self.nsamples / (self.nsamples + tmp)
59 | self.nsamples += tmp
60 | # inp = inp.float()
61 | inp = math.sqrt(2 / self.nsamples) * inp.float()
62 | # self.H += 2 / self.nsamples * inp.matmul(inp.t())
63 | self.H += inp.matmul(inp.t())
64 |
65 | def fasterquant(
66 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
67 | ):
68 | W = self.layer.weight.data.clone()
69 | if isinstance(self.layer, nn.Conv2d):
70 | W = W.flatten(1)
71 | if isinstance(self.layer, transformers.Conv1D):
72 | W = W.t()
73 | W = W.float()
74 |
75 | tick = time.time()
76 |
77 | if not self.quantizer.ready():
78 | self.quantizer.find_params(W, weight=True)
79 |
80 | H = self.H
81 | del self.H
82 | dead = torch.diag(H) == 0
83 | H[dead, dead] = 1
84 | W[:, dead] = 0
85 |
86 | if static_groups:
87 | import copy
88 | groups = []
89 | for i in range(0, self.columns, groupsize):
90 | quantizer = copy.deepcopy(self.quantizer)
91 | quantizer.find_params(W[:, i:(i + groupsize)], weight=True)
92 | groups.append(quantizer)
93 |
94 | if actorder:
95 | perm = torch.argsort(torch.diag(H), descending=True)
96 | W = W[:, perm]
97 | H = H[perm][:, perm]
98 | invperm = torch.argsort(perm)
99 |
100 | Losses = torch.zeros_like(W)
101 | Q = torch.zeros_like(W)
102 |
103 | damp = percdamp * torch.mean(torch.diag(H))
104 | diag = torch.arange(self.columns, device=self.dev)
105 | H[diag, diag] += damp
106 | H = torch.linalg.cholesky(H)
107 | H = torch.cholesky_inverse(H)
108 | H = torch.linalg.cholesky(H, upper=True)
109 | Hinv = H
110 |
111 | for i1 in range(0, self.columns, blocksize):
112 | i2 = min(i1 + blocksize, self.columns)
113 | count = i2 - i1
114 |
115 | W1 = W[:, i1:i2].clone()
116 | Q1 = torch.zeros_like(W1)
117 | Err1 = torch.zeros_like(W1)
118 | Losses1 = torch.zeros_like(W1)
119 | Hinv1 = Hinv[i1:i2, i1:i2]
120 |
121 | for i in range(count):
122 | w = W1[:, i]
123 | d = Hinv1[i, i]
124 |
125 | if groupsize != -1:
126 | if not static_groups:
127 | if (i1 + i) % groupsize == 0:
128 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
129 | else:
130 | idx = i1 + i
131 | if actorder:
132 | idx = perm[idx]
133 | self.quantizer = groups[idx // groupsize]
134 |
135 | q = quantize(
136 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
137 | ).flatten()
138 | Q1[:, i] = q
139 | Losses1[:, i] = (w - q) ** 2 / d ** 2
140 |
141 | err1 = (w - q) / d
142 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
143 | Err1[:, i] = err1
144 |
145 | Q[:, i1:i2] = Q1
146 | Losses[:, i1:i2] = Losses1 / 2
147 |
148 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
149 |
150 | if DEBUG:
151 | self.layer.weight.data[:, :i2] = Q[:, :i2]
152 | self.layer.weight.data[:, i2:] = W[:, i2:]
153 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
154 | print(torch.sum(Losses))
155 |
156 | torch.cuda.synchronize()
157 | print('time %.2f' % (time.time() - tick))
158 | print('error', torch.sum(Losses).item())
159 |
160 | if actorder:
161 | Q = Q[:, invperm]
162 |
163 | if isinstance(self.layer, transformers.Conv1D):
164 | Q = Q.t()
165 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
166 | if DEBUG:
167 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
168 |
169 | def free(self):
170 | if DEBUG:
171 | self.inp1 = None
172 | self.out1 = None
173 | self.H = None
174 | self.Losses = None
175 | self.Trace = None
176 | torch.cuda.empty_cache()
177 |
--------------------------------------------------------------------------------
/gptq/quant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import sys
5 | import os
6 | current_path = os.path.dirname(os.path.abspath(__file__))
7 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8 | sys.path.append(current_path)
9 |
10 |
11 | def quantize(x, scale, zero, maxq):
12 | if maxq < 0:
13 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
14 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
15 | return scale * (q - zero)
16 |
17 | class Quantizer(nn.Module):
18 |
19 | def __init__(self, shape=1):
20 | super(Quantizer, self).__init__()
21 | self.register_buffer('maxq', torch.tensor(0))
22 | self.register_buffer('scale', torch.zeros(shape))
23 | self.register_buffer('zero', torch.zeros(shape))
24 |
25 | def configure(
26 | self,
27 | bits, perchannel=False, sym=True,
28 | mse=False, norm=2.4, grid=100, maxshrink=.8,
29 | trits=False
30 | ):
31 | self.maxq = torch.tensor(2 ** bits - 1)
32 | self.perchannel = perchannel
33 | self.sym = sym
34 | self.mse = mse
35 | self.norm = norm
36 | self.grid = grid
37 | self.maxshrink = maxshrink
38 | if trits:
39 | self.maxq = torch.tensor(-1)
40 |
41 | def find_params(self, x, weight=False):
42 | dev = x.device
43 | self.maxq = self.maxq.to(dev)
44 |
45 | shape = x.shape
46 | if self.perchannel:
47 | if weight:
48 | x = x.flatten(1)
49 | else:
50 | if len(shape) == 4:
51 | x = x.permute([1, 0, 2, 3])
52 | x = x.flatten(1)
53 | if len(shape) == 3:
54 | x = x.reshape((-1, shape[-1])).t()
55 | if len(shape) == 2:
56 | x = x.t()
57 | else:
58 | x = x.flatten().unsqueeze(0)
59 |
60 | tmp = torch.zeros(x.shape[0], device=dev)
61 | xmin = torch.minimum(x.min(1)[0], tmp)
62 | xmax = torch.maximum(x.max(1)[0], tmp)
63 |
64 | if self.sym:
65 | xmax = torch.maximum(torch.abs(xmin), xmax)
66 | tmp = xmin < 0
67 | if torch.any(tmp):
68 | xmin[tmp] = -xmax[tmp]
69 | tmp = (xmin == 0) & (xmax == 0)
70 | xmin[tmp] = -1
71 | xmax[tmp] = +1
72 |
73 | if self.maxq < 0:
74 | self.scale = xmax
75 | self.zero = xmin
76 | else:
77 | self.scale = (xmax - xmin) / self.maxq
78 | if self.sym:
79 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
80 | else:
81 | self.zero = torch.round(-xmin / self.scale)
82 |
83 | if self.mse:
84 | best = torch.full([x.shape[0]], float('inf'), device=dev)
85 | for i in range(int(self.maxshrink * self.grid)):
86 | p = 1 - i / self.grid
87 | xmin1 = p * xmin
88 | xmax1 = p * xmax
89 | scale1 = (xmax1 - xmin1) / self.maxq
90 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
91 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
92 | q -= x
93 | q.abs_()
94 | q.pow_(self.norm)
95 | err = torch.sum(q, 1)
96 | tmp = err < best
97 | if torch.any(tmp):
98 | best[tmp] = err[tmp]
99 | self.scale[tmp] = scale1[tmp]
100 | self.zero[tmp] = zero1[tmp]
101 | if not self.perchannel:
102 | if weight:
103 | tmp = shape[0]
104 | else:
105 | tmp = shape[1] if len(shape) != 3 else shape[2]
106 | self.scale = self.scale.repeat(tmp)
107 | self.zero = self.zero.repeat(tmp)
108 |
109 | if weight:
110 | shape = [-1] + [1] * (len(shape) - 1)
111 | self.scale = self.scale.reshape(shape)
112 | self.zero = self.zero.reshape(shape)
113 | return
114 | if len(shape) == 4:
115 | self.scale = self.scale.reshape((1, -1, 1, 1))
116 | self.zero = self.zero.reshape((1, -1, 1, 1))
117 | if len(shape) == 3:
118 | self.scale = self.scale.reshape((1, 1, -1))
119 | self.zero = self.zero.reshape((1, 1, -1))
120 | if len(shape) == 2:
121 | self.scale = self.scale.unsqueeze(0)
122 | self.zero = self.zero.unsqueeze(0)
123 |
124 | def quantize(self, x):
125 | if self.ready():
126 | return quantize(x, self.scale, self.zero, self.maxq)
127 | return x
128 |
129 | def enabled(self):
130 | return self.maxq > 0
131 |
132 | def ready(self):
133 | return torch.all(self.scale != 0)
134 |
135 |
136 | try:
137 | import quant_cuda
138 | except:
139 | print('CUDA extension not installed.')
140 |
141 | # Assumes layer is perfectly divisible into 1024 * 1024 blocks
142 | class Quant3Linear(nn.Module):
143 |
144 | def __init__(self, infeatures, outfeatures, faster=False):
145 | super().__init__()
146 | self.register_buffer('zeros', torch.zeros((outfeatures, 1)))
147 | self.register_buffer('scales', torch.zeros((outfeatures, 1)))
148 | self.register_buffer('bias', torch.zeros(outfeatures))
149 | self.register_buffer(
150 | 'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int)
151 | )
152 | self.faster = faster
153 |
154 | def pack(self, linear, scales, zeros):
155 | self.zeros = zeros * scales
156 | self.scales = scales.clone()
157 | if linear.bias is not None:
158 | self.bias = linear.bias.clone()
159 |
160 | intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int)
161 | intweight = intweight.t().contiguous()
162 | intweight = intweight.numpy().astype(np.uint32)
163 | qweight = np.zeros(
164 | (intweight.shape[0] // 32 * 3, intweight.shape[1]), dtype=np.uint32
165 | )
166 | i = 0
167 | row = 0
168 | while row < qweight.shape[0]:
169 | for j in range(i, i + 10):
170 | qweight[row] |= intweight[j] << (3 * (j - i))
171 | i += 10
172 | qweight[row] |= intweight[i] << 30
173 | row += 1
174 | qweight[row] |= (intweight[i] >> 2) & 1
175 | i += 1
176 | for j in range(i, i + 10):
177 | qweight[row] |= intweight[j] << (3 * (j - i) + 1)
178 | i += 10
179 | qweight[row] |= intweight[i] << 31
180 | row += 1
181 | qweight[row] |= (intweight[i] >> 1) & 0x3
182 | i += 1
183 | for j in range(i, i + 10):
184 | qweight[row] |= intweight[j] << (3 * (j - i) + 2)
185 | i += 10
186 | row += 1
187 |
188 | qweight = qweight.astype(np.int32)
189 | self.qweight = torch.from_numpy(qweight)
190 |
191 | def forward(self, x):
192 | if x.shape[-1] == x.numel():
193 | outshape = list(x.shape)
194 | y = self.bias.clone()
195 | outshape[-1] = self.bias.numel()
196 | dtype = x.dtype
197 | if self.faster:
198 | x = x.half()
199 | quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros)
200 | else:
201 | x = x.float()
202 | quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros)
203 | y = y.to(dtype)
204 | return y.reshape(outshape)
205 | raise ValueError('Only supports a single token currently.')
206 |
207 | def make_quant3(module, names, name='', faster=False):
208 | if isinstance(module, Quant3Linear):
209 | return
210 | for attr in dir(module):
211 | tmp = getattr(module, attr)
212 | name1 = name + '.' + attr if name != '' else attr
213 | if name1 in names:
214 | setattr(
215 | module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster)
216 | )
217 | for name1, child in module.named_children():
218 | make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster)
219 |
--------------------------------------------------------------------------------
/quant_llama.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from gptq.gptq import *
7 | from utils.model_utils import *
8 | from gptq.quant import *
9 | from evaluater import ppl_eval
10 |
11 | current_path = os.path.dirname(os.path.abspath(__file__))
12 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13 | sys.path.append(current_path)
14 |
15 | @torch.no_grad()
16 | def llama_sequential(model, dataloader, dev):
17 | print('Starting ...')
18 |
19 | use_cache = model.config.use_cache
20 | model.config.use_cache = False
21 | layers = model.model.layers
22 |
23 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
24 | model.model.norm = model.model.norm.to(dev)
25 | layers[0] = layers[0].to(dev)
26 |
27 | dtype = next(iter(model.parameters())).dtype
28 | inps = torch.zeros(
29 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
30 | )
31 | cache = {'i': 0, 'attention_mask': None}
32 |
33 | class Catcher(nn.Module):
34 | def __init__(self, module):
35 | super().__init__()
36 | self.module = module
37 | def forward(self, inp, **kwargs):
38 | inps[cache['i']] = inp
39 | cache['i'] += 1
40 | cache['attention_mask'] = kwargs['attention_mask']
41 | cache['position_ids'] = kwargs['position_ids']
42 | raise ValueError
43 | layers[0] = Catcher(layers[0])
44 | for batch in dataloader:
45 | try:
46 | model(batch[0].to(dev))
47 | except ValueError:
48 | pass
49 | layers[0] = layers[0].module
50 |
51 | layers[0] = layers[0].cpu()
52 | model.model.embed_tokens = model.model.embed_tokens.cpu()
53 | model.model.norm = model.model.norm.cpu()
54 | torch.cuda.empty_cache()
55 |
56 | outs = torch.zeros_like(inps)
57 | attention_mask = cache['attention_mask']
58 | position_ids = cache['position_ids']
59 |
60 | print('Ready.')
61 |
62 | quantizers = {}
63 | for i in range(len(layers)):
64 | layer = layers[i].to(dev)
65 | full = find_layers(layer)
66 |
67 | if args.true_sequential:
68 | sequential = [
69 | ['self_attn.k_u_proj','self_attn.k_v_proj', 'self_attn.v_u_proj', 'self_attn.v_v_proj', 'self_attn.q_u_proj', 'self_attn.q_v_proj'],
70 | ['self_attn.o_u_proj', 'self_attn.o_v_proj'],
71 | ['mlp.up_u_proj', 'mlp.up_v_proj', 'mlp.gate_u_proj', 'mlp.gate_v_proj'],
72 | ['mlp.down_u_proj', 'mlp.down_v_proj']
73 | ]
74 | else:
75 | sequential = [list(full.keys())]
76 |
77 | for names in sequential:
78 | subset = {n: full[n] for n in names}
79 |
80 | gptq = {}
81 | for name in subset:
82 | gptq[name] = GPTQ(subset[name])
83 | gptq[name].quantizer = Quantizer()
84 | gptq[name].quantizer.configure(
85 | args.wbits, perchannel=True, sym=args.sym, mse=False
86 | )
87 |
88 | def add_batch(name):
89 | def tmp(_, inp, out):
90 | gptq[name].add_batch(inp[0].data, out.data)
91 | return tmp
92 | handles = []
93 | for name in subset:
94 | handles.append(subset[name].register_forward_hook(add_batch(name)))
95 | for j in range(args.nsamples):
96 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
97 | for h in handles:
98 | h.remove()
99 |
100 | for name in subset:
101 | print(i, name)
102 | print('Quantizing ...')
103 | gptq[name].fasterquant(
104 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups
105 | )
106 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
107 | gptq[name].free()
108 |
109 | for j in range(args.nsamples):
110 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
111 |
112 | layers[i] = layer.cpu()
113 | del layer
114 | del gptq
115 | torch.cuda.empty_cache()
116 |
117 | inps, outs = outs, inps
118 |
119 | model.config.use_cache = use_cache
120 |
121 | return quantizers
122 |
123 | @torch.no_grad()
124 | def llama_eval(model, testenc, dev):
125 | print('Evaluating ...')
126 |
127 | testenc = testenc.input_ids
128 | nsamples = testenc.numel() // model.seqlen
129 |
130 | use_cache = model.config.use_cache
131 | model.config.use_cache = False
132 | layers = model.model.layers
133 |
134 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
135 | layers[0] = layers[0].to(dev)
136 |
137 | dtype = next(iter(model.parameters())).dtype
138 | inps = torch.zeros(
139 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
140 | )
141 | cache = {'i': 0, 'attention_mask': None}
142 |
143 | class Catcher(nn.Module):
144 | def __init__(self, module):
145 | super().__init__()
146 | self.module = module
147 | def forward(self, inp, **kwargs):
148 | inps[cache['i']] = inp
149 | cache['i'] += 1
150 | cache['attention_mask'] = kwargs['attention_mask']
151 | cache['position_ids'] = kwargs['position_ids']
152 | raise ValueError
153 | layers[0] = Catcher(layers[0])
154 | for i in range(nsamples):
155 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
156 | try:
157 | model(batch)
158 | except ValueError:
159 | pass
160 | layers[0] = layers[0].module
161 |
162 | layers[0] = layers[0].cpu()
163 | model.model.embed_tokens = model.model.embed_tokens.cpu()
164 | torch.cuda.empty_cache()
165 |
166 | outs = torch.zeros_like(inps)
167 | attention_mask = cache['attention_mask']
168 | position_ids = cache['position_ids']
169 |
170 | for i in range(len(layers)):
171 | print(i)
172 | layer = layers[i].to(dev)
173 |
174 | if args.nearest:
175 | subset = find_layers(layer)
176 | for name in subset:
177 | quantizer = Quantizer()
178 | quantizer.configure(
179 | args.wbits, perchannel=True, sym=False, mse=False
180 | )
181 | W = subset[name].weight.data
182 | quantizer.find_params(W, weight=True)
183 | subset[name].weight.data = quantize(
184 | W, quantizer.scale, quantizer.zero, quantizer.maxq
185 | ).to(next(iter(layer.parameters())).dtype)
186 |
187 | for j in range(nsamples):
188 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
189 | layers[i] = layer.cpu()
190 | del layer
191 | torch.cuda.empty_cache()
192 | inps, outs = outs, inps
193 |
194 | if model.model.norm is not None:
195 | model.model.norm = model.model.norm.to(dev)
196 | model.lm_head = model.lm_head.to(dev)
197 |
198 | testenc = testenc.to(dev)
199 | nlls = []
200 | for i in range(nsamples):
201 | hidden_states = inps[i].unsqueeze(0)
202 | if model.model.norm is not None:
203 | hidden_states = model.model.norm(hidden_states)
204 | lm_logits = model.lm_head(hidden_states)
205 | shift_logits = lm_logits[:, :-1, :].contiguous()
206 | shift_labels = testenc[
207 | :, (i * model.seqlen):((i + 1) * model.seqlen)
208 | ][:, 1:]
209 | loss_fct = nn.CrossEntropyLoss()
210 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
211 | neg_log_likelihood = loss.float() * model.seqlen
212 | nlls.append(neg_log_likelihood)
213 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
214 | print(ppl.item())
215 |
216 | model.config.use_cache = use_cache
217 |
218 | def llama_pack3(model, quantizers):
219 | layers = find_layers(model)
220 | layers = {n: layers[n] for n in quantizers}
221 | make_quant3(model, quantizers)
222 | qlayers = find_layers(model, [Quant3Linear])
223 | print('Packing ...')
224 | for name in qlayers:
225 | print(name)
226 | quantizers[name] = quantizers[name].cpu()
227 | qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero)
228 | print('Done.')
229 | return model
230 |
231 |
232 | if __name__ == '__main__':
233 | import argparse
234 | from utils.data_utils import *
235 |
236 | parser = argparse.ArgumentParser()
237 |
238 | parser.add_argument(
239 | '--model_path', type=str,
240 | help='path of the compressed model.'
241 | )
242 | parser.add_argument(
243 | '--dataset', type=str, choices=['wikitext2', 'ptb', 'c4'],
244 | help='Where to extract calibration data from.'
245 | )
246 | parser.add_argument(
247 | '--seed',
248 | type=int, default=0, help='Seed for sampling the calibration data.'
249 | )
250 | parser.add_argument(
251 | '--nsamples', type=int, default=128,
252 | help='Number of calibration data samples.'
253 | )
254 | parser.add_argument(
255 | '--percdamp', type=float, default=.01,
256 | help='Percent of the average Hessian diagonal to use for dampening.'
257 | )
258 | parser.add_argument(
259 | '--nearest', action='store_true',
260 | help='Whether to run the RTN baseline.'
261 | )
262 | parser.add_argument(
263 | '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],
264 | help='#bits to use for quantization; use 16 for evaluating base model.'
265 | )
266 | parser.add_argument(
267 | '--groupsize', type=int, default=-1,
268 | help='Groupsize to use for quantization; default uses full row.'
269 | )
270 | parser.add_argument(
271 | '--sym', action='store_true',
272 | help='Whether to perform symmetric quantization.'
273 | )
274 | parser.add_argument(
275 | '--save', type=str, default='',
276 | help='Save quantized checkpoint under this name.'
277 | )
278 | parser.add_argument(
279 | '--new-eval', action='store_true',
280 | help='Whether to use the new PTB and C4 eval.'
281 | )
282 | parser.add_argument(
283 | '--act-order', action='store_true',
284 | help='Whether to apply the activation order GPTQ heuristic'
285 | )
286 | parser.add_argument(
287 | '--true-sequential', action='store_true',
288 | help='Whether to run in true sequential model.'
289 | )
290 | parser.add_argument(
291 | '--static-groups', action='store_true',
292 | help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.'
293 | )
294 | parser.add_argument('--DEV', type=str, default="cuda", help='device')
295 |
296 | args = parser.parse_args()
297 |
298 | model, tokenizer = get_model_from_local(args.model_path)
299 | model.eval()
300 |
301 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, tokenizer=tokenizer)
302 |
303 | if args.wbits < 16 and not args.nearest:
304 | tick = time.time()
305 | quantizers = llama_sequential(model, dataloader, args.DEV)
306 | print(time.time() - tick)
307 | # if args.save:
308 | # llama_pack3(model, quantizers)
309 | # torch.save(model.state_dict(), args.save)
310 | ppl_eval(model, tokenizer, datasets=['wikitext2'], model_seq_len=2048, batch_size=16, device=args.DEV)
311 | torch.save({
312 | 'model': model,
313 | 'tokenizer': tokenizer
314 | }, args.save)
315 |
316 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==2.16.1
2 | numpy==1.26.3
3 | torch>=2.0.1
4 | tqdm==4.65.0
5 | transformers==4.35.2
6 | sentencepiece==0.1.99
7 | matplotlib==3.4.3
8 | evaluate
9 | accelerate
--------------------------------------------------------------------------------
/svdllm_gptq.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run data whitening with 20% compression ratio
4 | python SVDLLM.py --model jeffwan/llama-7b-hf --step 1 --ratio 0.2 --whitening_nsamples 256 --dataset wikitext2 --seed 3 --model_seq_len 2048 --save_path .
5 |
6 | # further compress the model with GPTQ-4bit
7 | python quant_llama.py --model_path whitening/jeffwan_llama_7b_hf_whitening_0.2.pt --dataset c4 --wbits 4 --true-sequential --act-order --new-eval --save svdllm_gptq_4.pt
--------------------------------------------------------------------------------
/utils/LoRA.py:
--------------------------------------------------------------------------------
1 | '''
2 | Refer to
3 | https://github.com/tloen/alpaca-lora/blob/main/finetune.py
4 | '''
5 |
6 | import os
7 | import sys
8 | import argparse
9 | from typing import List
10 |
11 | import torch
12 | import transformers
13 | from datasets import load_dataset
14 |
15 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16 | sys.path.append(parent_path)
17 | from peft import (
18 | LoraConfig,
19 | get_peft_model,
20 | get_peft_model_state_dict,
21 | prepare_model_for_int8_training,
22 | set_peft_model_state_dict,
23 | )
24 | from Prompter import Prompter, ZeroPrompter
25 |
26 | device = "cuda" if torch.cuda.is_available() else "cpu"
27 |
28 | def wikitext2():
29 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
30 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
31 | return traindata, testdata
32 |
33 | def ptb():
34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
35 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
36 | return traindata, valdata
37 |
38 | def apply_lora(model, tokenizer, batch_size=64, micro_batch_size=4, cutoff_len=256, add_eos_token=False,
39 | lora_r=2, lora_alpha=16, lora_target_modules="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj",
40 | lora_dropout=0.05, val_set_size=2000, data_path="yahma/alpaca-cleaned",num_epochs=2, learning_rate=1e-4,
41 | output_dir="Checkpoints/tune", group_by_length=False, extra_val_dataset=None):
42 |
43 | gradient_accumulation_steps = batch_size // micro_batch_size
44 | prompter = ZeroPrompter()
45 |
46 | if device == 'cuda':
47 | model.half()
48 |
49 | tokenizer.pad_token_id = 0
50 | tokenizer.padding_side = "left"
51 |
52 | def tokenize(prompt, add_eos_token=True):
53 | result = tokenizer(
54 | prompt,
55 | truncation=True,
56 | max_length=cutoff_len,
57 | padding=False,
58 | return_tensors=None,
59 | )
60 | if (
61 | result["input_ids"][-1] != tokenizer.eos_token_id
62 | and len(result["input_ids"]) < cutoff_len
63 | and add_eos_token
64 | ):
65 | result["input_ids"].append(tokenizer.eos_token_id)
66 | result["attention_mask"].append(1)
67 |
68 | result["labels"] = result["input_ids"].copy()
69 |
70 | return result
71 |
72 | def generate_and_tokenize_prompt(data_point):
73 | full_prompt = prompter.generate_prompt(
74 | data_point["instruction"],
75 | data_point["input"],
76 | data_point["output"],
77 | )
78 | tokenized_full_prompt = tokenize(full_prompt)
79 | user_prompt = prompter.generate_prompt(
80 | data_point["instruction"], data_point["input"]
81 | )
82 | tokenized_user_prompt = tokenize(
83 | user_prompt, add_eos_token=add_eos_token
84 | )
85 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
86 |
87 | if add_eos_token:
88 | user_prompt_len -= 1
89 |
90 | tokenized_full_prompt["labels"] = [
91 | -100
92 | ] * user_prompt_len + tokenized_full_prompt["labels"][
93 | user_prompt_len:
94 | ] # could be sped up, probably
95 | return tokenized_full_prompt
96 |
97 | def split_and_tokenizer(test_data, tokenizer, seq_len, field_name):
98 | test_ids = tokenizer("\n\n".join(test_data[field_name]), return_tensors='pt').input_ids[0]
99 | test_ids_batch = []
100 | nsamples = test_ids.numel() // seq_len
101 |
102 | test_set = []
103 | for i in range(nsamples):
104 | batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
105 | test_set.append({
106 | 'input_ids': batch,
107 | 'labels': batch
108 | })
109 | return test_set
110 |
111 | # Prepare For LoRA
112 | model = prepare_model_for_int8_training(model)
113 | config = LoraConfig(
114 | r=lora_r,
115 | lora_alpha=lora_alpha,
116 | target_modules=lora_target_modules.split(","),
117 | lora_dropout=lora_dropout,
118 | bias="none",
119 | task_type="CAUSAL_LM",
120 | )
121 | model = get_peft_model(model, config)
122 | model.print_trainable_parameters()
123 |
124 | # Load Train Dataset
125 | data = load_dataset(data_path)
126 | train_val = data["train"].train_test_split(
127 | test_size=val_set_size, shuffle=True, seed=42
128 | )
129 | train_data = (
130 | train_val["train"].shuffle().map(generate_and_tokenize_prompt)
131 | )
132 | val_data = {
133 | data_path: train_val["test"].shuffle().map(generate_and_tokenize_prompt),
134 | }
135 |
136 | # Load Extra Validation Dataset
137 | if extra_val_dataset:
138 | seq_len = 128
139 | for extra_dataset in extra_val_dataset.split(','):
140 | if 'wikitext2' in extra_dataset:
141 | _, test_data = wikitext2()
142 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='text')
143 | if 'ptb' in extra_dataset:
144 | _, test_data = ptb()
145 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='sentence')
146 | val_data[extra_dataset] = test_data
147 |
148 | trainer = transformers.Trainer(
149 | model=model,
150 | train_dataset=train_data,
151 | eval_dataset=val_data,
152 | args=transformers.TrainingArguments(
153 | per_device_train_batch_size=micro_batch_size,
154 | gradient_accumulation_steps=gradient_accumulation_steps,
155 | warmup_steps=100,
156 | num_train_epochs=num_epochs,
157 | learning_rate=learning_rate,
158 | fp16=True,
159 | logging_steps=10,
160 | logging_first_step=True,
161 | optim="adamw_torch",
162 | evaluation_strategy="steps",
163 | save_strategy="steps",
164 | eval_steps=100,
165 | save_steps=200,
166 | output_dir=output_dir,
167 | save_total_limit=30,
168 | load_best_model_at_end=True,
169 | ddp_find_unused_parameters=None,
170 | group_by_length=group_by_length,
171 | report_to="none",
172 | run_name="none",
173 | metric_for_best_model="{}_loss".format(data_path),
174 | ),
175 | data_collator=transformers.DataCollatorForSeq2Seq(
176 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
177 | ),
178 | )
179 | model.config.use_cache = False
180 | old_state_dict = model.state_dict
181 | model.state_dict = (
182 | lambda self, *_, **__: get_peft_model_state_dict(
183 | self, old_state_dict()
184 | )
185 | ).__get__(model, type(model))
186 |
187 | trainer.train()
188 | model.state_dict = old_state_dict
189 | return model
190 |
191 | def main(args):
192 | # Set WanDB
193 | os.environ["WANDB_PROJECT"] = args.wandb_project
194 |
195 | # Load Pruned Model
196 | pruned_dict = torch.load(args.prune_model, map_location='cpu')
197 | tokenizer, model = pruned_dict['tokenizer'], pruned_dict['model']
198 | gradient_accumulation_steps = args.batch_size // args.micro_batch_size
199 | if not args.no_instruction:
200 | prompter = Prompter(args.prompt_template_name)
201 | else:
202 | prompter = ZeroPrompter()
203 |
204 | if device == 'cuda':
205 | model.half()
206 |
207 | tokenizer.pad_token_id = 0
208 | tokenizer.padding_side = "left"
209 |
210 | def tokenize(prompt, add_eos_token=True):
211 | result = tokenizer(
212 | prompt,
213 | truncation=True,
214 | max_length=args.cutoff_len,
215 | padding=False,
216 | return_tensors=None,
217 | )
218 | if (
219 | result["input_ids"][-1] != tokenizer.eos_token_id
220 | and len(result["input_ids"]) < args.cutoff_len
221 | and add_eos_token
222 | ):
223 | result["input_ids"].append(tokenizer.eos_token_id)
224 | result["attention_mask"].append(1)
225 |
226 | result["labels"] = result["input_ids"].copy()
227 |
228 | return result
229 |
230 | def generate_and_tokenize_prompt(data_point):
231 | full_prompt = prompter.generate_prompt(
232 | data_point["instruction"],
233 | data_point["input"],
234 | data_point["output"],
235 | )
236 | tokenized_full_prompt = tokenize(full_prompt)
237 | if not args.train_on_inputs:
238 | user_prompt = prompter.generate_prompt(
239 | data_point["instruction"], data_point["input"]
240 | )
241 | tokenized_user_prompt = tokenize(
242 | user_prompt, add_eos_token=args.add_eos_token
243 | )
244 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
245 |
246 | if args.add_eos_token:
247 | user_prompt_len -= 1
248 |
249 | tokenized_full_prompt["labels"] = [
250 | -100
251 | ] * user_prompt_len + tokenized_full_prompt["labels"][
252 | user_prompt_len:
253 | ] # could be sped up, probably
254 | return tokenized_full_prompt
255 |
256 | def split_and_tokenizer(test_data, tokenizer, seq_len, field_name):
257 | test_ids = tokenizer("\n\n".join(test_data[field_name]), return_tensors='pt').input_ids[0]
258 | test_ids_batch = []
259 | nsamples = test_ids.numel() // seq_len
260 |
261 | test_set = []
262 | for i in range(nsamples):
263 | batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
264 | test_set.append({
265 | 'input_ids': batch,
266 | 'labels': batch
267 | })
268 | return test_set
269 |
270 | # Prepare For LoRA
271 | model = prepare_model_for_int8_training(model)
272 | config = LoraConfig(
273 | r=args.lora_r,
274 | lora_alpha=args.lora_alpha,
275 | target_modules=args.lora_target_modules.split(","),
276 | lora_dropout=args.lora_dropout,
277 | bias="none",
278 | task_type="CAUSAL_LM",
279 | )
280 | model = get_peft_model(model, config)
281 | model.print_trainable_parameters()
282 |
283 | # Load Train Dataset
284 | data = load_dataset(args.data_path)
285 | train_val = data["train"].train_test_split(
286 | test_size=args.val_set_size, shuffle=True, seed=42
287 | )
288 | train_data = (
289 | train_val["train"].shuffle().map(generate_and_tokenize_prompt)
290 | )
291 | val_data = {
292 | args.data_path: train_val["test"].shuffle().map(generate_and_tokenize_prompt),
293 | }
294 |
295 | # Load Extra Validation Dataset
296 | if args.extra_val_dataset:
297 | seq_len = 128
298 | for extra_dataset in args.extra_val_dataset.split(','):
299 | if 'wikitext2' in extra_dataset:
300 | _, test_data = wikitext2()
301 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='text')
302 | if 'ptb' in extra_dataset:
303 | _, test_data = ptb()
304 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='sentence')
305 | val_data[extra_dataset] = test_data
306 |
307 | trainer = transformers.Trainer(
308 | model=model,
309 | train_dataset=train_data,
310 | eval_dataset=val_data,
311 | args=transformers.TrainingArguments(
312 | per_device_train_batch_size=args.micro_batch_size,
313 | gradient_accumulation_steps=gradient_accumulation_steps,
314 | warmup_steps=100,
315 | num_train_epochs=args.num_epochs,
316 | learning_rate=args.learning_rate,
317 | fp16=True,
318 | logging_steps=10,
319 | logging_first_step=True,
320 | optim="adamw_torch",
321 | evaluation_strategy="steps",
322 | save_strategy="steps",
323 | save_safetensors=False,
324 | eval_steps=100,
325 | save_steps=200,
326 | output_dir=args.output_dir,
327 | save_total_limit=20,
328 | load_best_model_at_end=True,
329 | ddp_find_unused_parameters=None,
330 | group_by_length=args.group_by_length,
331 | report_to="none",
332 | run_name="none",
333 | metric_for_best_model="{}_loss".format(args.data_path),
334 | ),
335 | data_collator=transformers.DataCollatorForSeq2Seq(
336 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
337 | ),
338 | )
339 | model.config.use_cache = False
340 | old_state_dict = model.state_dict
341 | model.state_dict = (
342 | lambda self, *_, **__: get_peft_model_state_dict(
343 | self, old_state_dict()
344 | )
345 | ).__get__(model, type(model))
346 |
347 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
348 |
349 | model.state_dict = old_state_dict
350 | model.save_pretrained(args.output_dir, safe_serialization=False)
351 |
352 |
353 | if __name__ == "__main__":
354 | parser = argparse.ArgumentParser(description='Tuning Pruned LLM')
355 |
356 | # Model Type&Path
357 | parser.add_argument('--base_model', type=str, default="decapoda-research/llama-7b-hf", help='base model name')
358 | parser.add_argument('--prune_model', type=str, help='prune model name')
359 | parser.add_argument('--data_path', type=str, default="yahma/alpaca-cleaned", help='data path')
360 | # parser.add_argument('--extra_val_dataset', type=str, default='wikitext2,ptb', help='validation datasets. Split with ","')
361 | parser.add_argument('--extra_val_dataset', type=str, default=None, help='validation datasets. Split with ","')
362 | parser.add_argument('--output_dir', type=str, default="./lora-alpaca", help='output directory')
363 |
364 | # Training Hyperparameters
365 | parser.add_argument('--batch_size', type=int, default=128, help='batch size')
366 | parser.add_argument('--micro_batch_size', type=int, default=4, help='micro batch size')
367 | parser.add_argument('--num_epochs', type=int, default=5, help='number of epochs')
368 | parser.add_argument('--learning_rate', type=float, default=3e-4, help='learning rate')
369 | parser.add_argument('--cutoff_len', type=int, default=256, help='cutoff length')
370 | parser.add_argument('--val_set_size', type=int, default=2000, help='validation set size')
371 | parser.add_argument('--prompt_template_name', type=str, default="alpaca", help="The prompt template to use, will default to alpaca.")
372 | parser.add_argument('--no_instruction', action='store_true', default=False, help="Whether to use the instruction template or not.")
373 |
374 | # Lora Configuration
375 | parser.add_argument('--lora_r', type=int, default=8, help='lora r')
376 | parser.add_argument('--lora_alpha', type=int, default=16, help='lora alpha')
377 | parser.add_argument('--lora_dropout', type=float, default=0.05, help='lora dropout')
378 | parser.add_argument('--lora_target_modules', type=str, default="q_v_proj,q_u_proj,k_v_proj,k_u_proj,v_u_proj,v_v_proj,o_u_proj,o_v_proj,gate_u_proj,gate_v_proj,down_u_proj,down_v_proj,up_u_proj,up_v_proj", help='lora target modules')
379 |
380 | # llm hyperparameters
381 | parser.add_argument('--train_on_inputs', default=False, action="store_true", help='Train on inputs. If False, masks out inputs in loss')
382 | parser.add_argument('--add_eos_token', default=False, action="store_true")
383 | parser.add_argument('--group_by_length', default=False, action="store_true", help="faster, but produces an odd training loss curve")
384 |
385 | # wandb params
386 | parser.add_argument('--wandb_project', type=str, default="")
387 | parser.add_argument('--resume_from_checkpoint', type=str, help="either training checkpoint or final adapter")
388 |
389 | args = parser.parse_args()
390 | torch_version = int(torch.__version__.split('.')[1])
391 | args.torch_version = torch_version
392 |
393 | main(args)
--------------------------------------------------------------------------------
/utils/Prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | A dedicated helper to manage templates and prompt building.
3 | """
4 |
5 | import json
6 | import os.path as osp
7 | from typing import Union
8 |
9 | alpaca_template = {
10 | "description": "Template used by Alpaca-LoRA.",
11 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
12 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
13 | "response_split": "### Response:"
14 | }
15 |
16 | class Prompter(object):
17 | __slots__ = ("template", "_verbose")
18 |
19 | def __init__(self, template_name: str = "", verbose: bool = False):
20 | self._verbose = verbose
21 | if not template_name or template_name == 'alpaca':
22 | self.template = alpaca_template
23 | if self._verbose:
24 | print(
25 | f"Using prompt template {template_name}: {self.template['description']}"
26 | )
27 |
28 | def generate_prompt(
29 | self,
30 | instruction: str,
31 | input: Union[None, str] = None,
32 | label: Union[None, str] = None,
33 | ) -> str:
34 | # returns the full prompt from instruction and optional input
35 | # if a label (=response, =output) is provided, it's also appended.
36 | if input:
37 | res = self.template["prompt_input"].format(
38 | instruction=instruction, input=input
39 | )
40 | else:
41 | res = self.template["prompt_no_input"].format(
42 | instruction=instruction
43 | )
44 | if label:
45 | res = f"{res}{label}"
46 | if self._verbose:
47 | print(res)
48 | return res
49 |
50 | def get_response(self, output: str) -> str:
51 | return output.split(self.template["response_split"])[1].strip()
52 |
53 |
54 | class ZeroPrompter(object):
55 | __slots__ = ("_verbose")
56 |
57 | def __init__(self, verbose: bool = False):
58 | self._verbose = verbose
59 |
60 | if self._verbose:
61 | print(
62 | f"Without using prompt template!"
63 | )
64 |
65 | def generate_prompt(
66 | self,
67 | instruction: str,
68 | input: Union[None, str] = None,
69 | label: Union[None, str] = None,
70 | ) -> str:
71 | # returns the full prompt from instruction and optional input
72 | # if a label (=response, =output) is provided, it's also appended.
73 | if instruction[-1] == '.':
74 | instruction = instruction[:-1] + ':'
75 | if instruction[-1] not in ['.', ':', '?', '!']:
76 | instruction = instruction + ':'
77 | instruction += ' '
78 |
79 | if input:
80 | if input[-1] not in ['.', ':', '?', '!']:
81 | input = input + '.'
82 | res = instruction + input
83 | else:
84 | res = instruction
85 | if label:
86 | res = f"{res} {label}"
87 | if self._verbose:
88 | print(res)
89 | return res
90 |
91 | def get_response(self, output: str) -> str:
92 | return output.strip()
93 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import sys
5 | from datasets import load_dataset
6 | from torch.utils.data.dataset import Dataset
7 |
8 | current_path = os.path.dirname(os.path.abspath(__file__))
9 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10 | sys.path.append(current_path)
11 |
12 | def get_calib_train_data(name, tokenizer, nsamples, seqlen=2048, seed=3, batch_size=1, dataset_cache_dir=None):
13 | import random
14 | random.seed(seed)
15 | cache_file = (
16 | f"cache/{name}_{nsamples}_{seqlen}_{seed}_{batch_size}.pt"
17 | )
18 | nsamples += 1 #############################
19 | if not os.path.exists("cache"):
20 | os.makedirs("cache")
21 | if os.path.exists(cache_file):
22 | traindataset = torch.load(cache_file)
23 | return traindataset
24 | if name == "c4":
25 | traindata = load_dataset("json", data_files="utils/c4-train.json")['train']
26 | tot_text = "\n\n".join(traindata["text"])
27 | elif name == "ptb":
28 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir)
29 | tot_text = "\n\n".join(traindata["sentence"])
30 | elif name == "wikitext2":
31 | traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", cache_dir=dataset_cache_dir)
32 | tot_text = "\n\n".join(traindata["text"])
33 | else:
34 | raise NotImplementedError
35 | traindataset = []
36 | for s in range(nsamples):
37 | i = random.randint(0, len(tot_text) - seqlen - 1)
38 | j = i + seqlen * 10
39 | trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
40 | if trainenc.input_ids.shape[1] < seqlen:
41 | s = s - 1
42 | continue
43 | if s % batch_size == 0:
44 | if s != 0:
45 | attention_mask = torch.ones_like(inp)
46 | traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
47 | inp = trainenc.input_ids[:, :seqlen]
48 | else:
49 | inp = torch.cat((inp, trainenc.input_ids[:, :seqlen]), dim=0)
50 | torch.save(traindataset, cache_file)
51 | return traindataset
52 |
53 |
54 |
55 | def get_wikitext2(nsamples, seed, seqlen, tokenizer, dataset_cache_dir=None):
56 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', cache_dir=dataset_cache_dir)
57 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', cache_dir=dataset_cache_dir)
58 |
59 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
60 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
61 |
62 | import random
63 | random.seed(seed)
64 | trainloader = []
65 | for _ in range(nsamples):
66 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
67 | j = i + seqlen
68 | inp = trainenc.input_ids[:, i:j]
69 | tar = inp.clone()
70 | tar[:, :-1] = -100
71 | trainloader.append((inp, tar))
72 | return trainloader, testenc
73 |
74 | def get_ptb(nsamples, seed, seqlen, tokenizer, dataset_cache_dir=None):
75 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir)
76 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation', cache_dir=dataset_cache_dir)
77 |
78 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
79 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
80 |
81 | import random
82 | random.seed(seed)
83 | trainloader = []
84 | for _ in range(nsamples):
85 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
86 | j = i + seqlen
87 | inp = trainenc.input_ids[:, i:j]
88 | tar = inp.clone()
89 | tar[:, :-1] = -100
90 | trainloader.append((inp, tar))
91 | return trainloader, testenc
92 |
93 | def get_c4(nsamples, seed, seqlen, tokenizer):
94 | traindata = load_dataset("json", data_files="utils/c4-train.json")['train']
95 | valdata = load_dataset("json", data_files="utils/c4-validation.json")['train']
96 |
97 | import random
98 | random.seed(seed)
99 | trainloader = []
100 | for _ in range(nsamples):
101 | while True:
102 | i = random.randint(0, len(traindata) - 1)
103 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
104 | if trainenc.input_ids.shape[1] >= seqlen:
105 | break
106 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
107 | j = i + seqlen
108 | inp = trainenc.input_ids[:, i:j]
109 | tar = inp.clone()
110 | tar[:, :-1] = -100
111 | trainloader.append((inp, tar))
112 |
113 | import random
114 | random.seed(0)
115 | valenc = []
116 | for _ in range(256):
117 | while True:
118 | i = random.randint(0, len(valdata) - 1)
119 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
120 | if tmp.input_ids.shape[1] >= seqlen:
121 | break
122 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
123 | j = i + seqlen
124 | valenc.append(tmp.input_ids[:, i:j])
125 | valenc = torch.hstack(valenc)
126 | class TokenizerWrapper:
127 | def __init__(self, input_ids):
128 | self.input_ids = input_ids
129 | valenc = TokenizerWrapper(valenc)
130 |
131 | return trainloader, valenc
132 |
133 |
134 |
135 | def get_ptb_new(nsamples, seed, seqlen, tokenizer, dataset_cache_dir=None):
136 | from datasets import load_dataset
137 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir)
138 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test', cache_dir=dataset_cache_dir)
139 |
140 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
141 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
142 |
143 | import random
144 | random.seed(seed)
145 | trainloader = []
146 | for _ in range(nsamples):
147 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
148 | j = i + seqlen
149 | inp = trainenc.input_ids[:, i:j]
150 | tar = inp.clone()
151 | tar[:, :-1] = -100
152 | trainloader.append((inp, tar))
153 | return trainloader, testenc
154 |
155 | def get_c4_new(nsamples, seed, seqlen, tokenizer):
156 | traindata = load_dataset("json", data_files="utils/c4-train.json")['train']
157 | valdata = load_dataset("json", data_files="utils/c4-validation.json")['train']
158 |
159 | import random
160 | random.seed(seed)
161 | trainloader = []
162 | for _ in range(nsamples):
163 | while True:
164 | i = random.randint(0, len(traindata) - 1)
165 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
166 | if trainenc.input_ids.shape[1] >= seqlen:
167 | break
168 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
169 | j = i + seqlen
170 | inp = trainenc.input_ids[:, i:j]
171 | tar = inp.clone()
172 | tar[:, :-1] = -100
173 | trainloader.append((inp, tar))
174 |
175 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
176 | valenc = valenc.input_ids[:, :(256 * seqlen)]
177 |
178 | class TokenizerWrapper:
179 | def __init__(self, input_ids):
180 | self.input_ids = input_ids
181 | valenc = TokenizerWrapper(valenc)
182 |
183 | return trainloader, valenc
184 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
185 | if 'wikitext2' in name:
186 | return get_wikitext2(nsamples, seed, seqlen, tokenizer)
187 | if 'ptb' in name:
188 | if 'new' in name:
189 | return get_ptb_new(nsamples, seed, seqlen, tokenizer)
190 | return get_ptb(nsamples, seed, seqlen, tokenizer)
191 | if 'c4' in name:
192 | if 'new' in name:
193 | return get_c4_new(nsamples, seed, seqlen, tokenizer)
194 | return get_c4(nsamples, seed, seqlen, tokenizer)
195 |
196 |
197 |
198 | def get_test_data(name, tokenizer, seq_len=2048, batch_size = 4):
199 | class IndexDataset(Dataset):
200 | def __init__(self, tensors):
201 | self.tensors = tensors
202 |
203 | def __getitem__(self, index):
204 | return self.tensors[index]
205 |
206 | def __len__(self):
207 | return len(self.tensors)
208 | ####
209 | def process_data(samples, tokenizer, seq_len, field_name):
210 | test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
211 | test_ids_batch = []
212 | nsamples = test_ids.numel() // seq_len
213 |
214 | for i in range(nsamples):
215 | batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
216 | test_ids_batch.append(batch)
217 | test_ids_batch = torch.stack(test_ids_batch)
218 | return IndexDataset(tensors=test_ids_batch)
219 | ####
220 | if 'wikitext2' in name:
221 | test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
222 | test_dataset = process_data(test_data, tokenizer, seq_len, 'text')
223 | if 'ptb' in name:
224 | test_data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
225 | test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence')
226 | elif 'c4' in name:
227 | test_data = load_dataset("json", data_files="utils/c4-validation.json")['train']
228 | test_dataset = process_data(test_data[0:2000], tokenizer, seq_len, 'text')
229 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
230 | return test_loader
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | import os
3 | import sys
4 | import torch
5 | import torch.nn as nn
6 |
7 | current_path = os.path.dirname(os.path.abspath(__file__))
8 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9 | sys.path.append(current_path)
10 |
11 | # bandaid fix
12 | dev = torch.device("cuda")
13 |
14 | def get_model_from_huggingface(model_id):
15 | from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer, LlamaForCausalLM
16 | if "opt" in model_id or "mistral" in model_id:
17 | tokenizer = AutoTokenizer.from_pretrained(model_id, device_map="cpu", trust_remote_code=True)
18 | else:
19 | tokenizer = LlamaTokenizer.from_pretrained(model_id, device_map="cpu", trust_remote_code=True)
20 | model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype=torch.float16, trust_remote_code=True, cache_dir=None)
21 | model.seqlen = 2048
22 | return model, tokenizer
23 |
24 | def get_model_from_local(model_id):
25 | pruned_dict = torch.load(model_id, weights_only=False, map_location='cpu')
26 | tokenizer, model = pruned_dict['tokenizer'], pruned_dict['model']
27 | return model, tokenizer
28 |
29 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
30 | if type(module) in layers:
31 | return {name: module}
32 | res = {}
33 | for name1, child in module.named_children():
34 | res.update(find_layers(
35 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
36 | ))
37 | return res
38 |
--------------------------------------------------------------------------------
/utils/peft/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | __version__ = "0.3.0.dev0"
21 |
22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model
23 | from .peft_model import (
24 | PeftModel,
25 | PeftModelForCausalLM,
26 | PeftModelForSeq2SeqLM,
27 | PeftModelForSequenceClassification,
28 | PeftModelForTokenClassification,
29 | )
30 | from .tuners import (
31 | LoraConfig,
32 | LoraModel,
33 | AdaLoraConfig,
34 | AdaLoraModel,
35 | PrefixEncoder,
36 | PrefixTuningConfig,
37 | PromptEmbedding,
38 | PromptEncoder,
39 | PromptEncoderConfig,
40 | PromptEncoderReparameterizationType,
41 | PromptTuningConfig,
42 | PromptTuningInit,
43 | )
44 | from .utils import (
45 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
46 | PeftConfig,
47 | PeftType,
48 | PromptLearningConfig,
49 | TaskType,
50 | bloom_model_postprocess_past_key_value,
51 | get_peft_model_state_dict,
52 | prepare_model_for_int8_training,
53 | set_peft_model_state_dict,
54 | shift_tokens_right,
55 | )
56 |
--------------------------------------------------------------------------------
/utils/peft/import_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import importlib
16 |
17 |
18 | def is_bnb_available():
19 | return importlib.util.find_spec("bitsandbytes") is not None
20 |
--------------------------------------------------------------------------------
/utils/peft/mapping.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .peft_model import (
17 | PeftModel,
18 | PeftModelForCausalLM,
19 | PeftModelForSeq2SeqLM,
20 | PeftModelForSequenceClassification,
21 | PeftModelForTokenClassification,
22 | )
23 | from .tuners import AdaLoraConfig, LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig
24 | from .utils import PromptLearningConfig
25 |
26 |
27 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
28 | "SEQ_CLS": PeftModelForSequenceClassification,
29 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
30 | "CAUSAL_LM": PeftModelForCausalLM,
31 | "TOKEN_CLS": PeftModelForTokenClassification,
32 | }
33 |
34 | PEFT_TYPE_TO_CONFIG_MAPPING = {
35 | "PROMPT_TUNING": PromptTuningConfig,
36 | "PREFIX_TUNING": PrefixTuningConfig,
37 | "P_TUNING": PromptEncoderConfig,
38 | "LORA": LoraConfig,
39 | "ADALORA": AdaLoraConfig,
40 | }
41 |
42 |
43 | def get_peft_config(config_dict):
44 | """
45 | Returns a Peft config object from a dictionary.
46 |
47 | Args:
48 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
49 | """
50 |
51 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
52 |
53 |
54 | def _prepare_prompt_learning_config(peft_config, model_config):
55 | if peft_config.num_layers is None:
56 | if "num_hidden_layers" in model_config:
57 | num_layers = model_config["num_hidden_layers"]
58 | elif "num_layers" in model_config:
59 | num_layers = model_config["num_layers"]
60 | elif "n_layer" in model_config:
61 | num_layers = model_config["n_layer"]
62 | else:
63 | raise ValueError("Please specify `num_layers` in `peft_config`")
64 | peft_config.num_layers = num_layers
65 |
66 | if peft_config.token_dim is None:
67 | if "hidden_size" in model_config:
68 | token_dim = model_config["hidden_size"]
69 | elif "n_embd" in model_config:
70 | token_dim = model_config["n_embd"]
71 | elif "d_model" in model_config:
72 | token_dim = model_config["d_model"]
73 | else:
74 | raise ValueError("Please specify `token_dim` in `peft_config`")
75 | peft_config.token_dim = token_dim
76 |
77 | if peft_config.num_attention_heads is None:
78 | if "num_attention_heads" in model_config:
79 | num_attention_heads = model_config["num_attention_heads"]
80 | elif "n_head" in model_config:
81 | num_attention_heads = model_config["n_head"]
82 | elif "num_heads" in model_config:
83 | num_attention_heads = model_config["num_heads"]
84 | elif "encoder_attention_heads" in model_config:
85 | num_attention_heads = model_config["encoder_attention_heads"]
86 | else:
87 | raise ValueError("Please specify `num_attention_heads` in `peft_config`")
88 | peft_config.num_attention_heads = num_attention_heads
89 |
90 | if getattr(peft_config, "encoder_hidden_size", None) is None:
91 | setattr(peft_config, "encoder_hidden_size", token_dim)
92 |
93 | return peft_config
94 |
95 |
96 | def get_peft_model(model, peft_config):
97 | """
98 | Returns a Peft model object from a model and a config.
99 |
100 | Args:
101 | model ([`transformers.PreTrainedModel`]): Model to be wrapped.
102 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
103 | """
104 | model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config
105 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
106 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
107 | peft_config, PromptLearningConfig
108 | ):
109 | return PeftModel(model, peft_config)
110 | if isinstance(peft_config, PromptLearningConfig):
111 | peft_config = _prepare_prompt_learning_config(peft_config, model_config)
112 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config)
113 |
--------------------------------------------------------------------------------
/utils/peft/tuners/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from .lora import LoraConfig, LoraModel
21 | from .adalora import AdaLoraConfig, AdaLoraModel
22 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
23 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
24 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
25 |
--------------------------------------------------------------------------------
/utils/peft/tuners/lora.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import math
16 | import re
17 | import warnings
18 | from dataclasses import asdict, dataclass, field
19 | from enum import Enum
20 | from typing import List, Optional, Union
21 |
22 | import torch
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 | from transformers.pytorch_utils import Conv1D
26 |
27 | from ..import_utils import is_bnb_available
28 | from ..utils import (
29 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
30 | ModulesToSaveWrapper,
31 | PeftConfig,
32 | PeftType,
33 | _freeze_adapter,
34 | _get_submodules,
35 | transpose,
36 | )
37 |
38 |
39 | if is_bnb_available():
40 | import bitsandbytes as bnb
41 |
42 |
43 | @dataclass
44 | class LoraConfig(PeftConfig):
45 | """
46 | This is the configuration class to store the configuration of a [`LoraModel`].
47 |
48 | Args:
49 | r (`int`): Lora attention dimension.
50 | target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
51 | lora_alpha (`float`): The alpha parameter for Lora scaling.
52 | lora_dropout (`float`): The dropout probability for Lora layers.
53 | fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
54 | For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.:
55 | bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
56 | modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
57 | and saved in the final checkpoint.
58 | """
59 |
60 | r: int = field(default=8, metadata={"help": "Lora attention dimension"})
61 | target_modules: Optional[Union[List[str], str]] = field(
62 | default=None,
63 | metadata={
64 | "help": "List of module names or regex expression of the module names to replace with Lora."
65 | "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
66 | },
67 | )
68 | lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
69 | lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
70 | fan_in_fan_out: bool = field(
71 | default=False,
72 | metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
73 | )
74 | bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
75 | modules_to_save: Optional[List[str]] = field(
76 | default=None,
77 | metadata={
78 | "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
79 | "For example, in Sequence Classification or Token Classification tasks, "
80 | "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
81 | },
82 | )
83 | init_lora_weights: bool = field(
84 | default=True,
85 | metadata={"help": "Whether to initialize the weights of the Lora layers."},
86 | )
87 |
88 | def __post_init__(self):
89 | self.peft_type = PeftType.LORA
90 |
91 |
92 | class LoraModel(torch.nn.Module):
93 | """
94 | Creates Low Rank Adapter (Lora) model from a pretrained transformers model.
95 |
96 | Args:
97 | model ([`~transformers.PreTrainedModel`]): The model to be adapted.
98 | config ([`LoraConfig`]): The configuration of the Lora model.
99 |
100 | Returns:
101 | `torch.nn.Module`: The Lora model.
102 |
103 | Example:
104 |
105 | ```py
106 | >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig
107 | >>> from peft import LoraModel, LoraConfig
108 |
109 | >>> config = LoraConfig(
110 | ... peft_type="LORA",
111 | ... task_type="SEQ_2_SEQ_LM",
112 | ... r=8,
113 | ... lora_alpha=32,
114 | ... target_modules=["q", "v"],
115 | ... lora_dropout=0.01,
116 | ... )
117 |
118 | >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
119 | >>> lora_model = LoraModel(config, model)
120 | ```
121 |
122 | **Attributes**:
123 | - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
124 | - **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
125 | """
126 |
127 | def __init__(self, model, config, adapter_name):
128 | super().__init__()
129 | self.model = model
130 | self.forward = self.model.forward
131 | self.peft_config = config
132 | self.add_adapter(adapter_name, self.peft_config[adapter_name])
133 |
134 | def add_adapter(self, adapter_name, config=None):
135 | if config is not None:
136 | model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config
137 | config = self._prepare_lora_config(config, model_config)
138 | self.peft_config[adapter_name] = config
139 | self._find_and_replace(adapter_name)
140 | if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
141 | raise ValueError(
142 | "LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
143 | )
144 | mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
145 | if self.peft_config[adapter_name].inference_mode:
146 | _freeze_adapter(self.model, adapter_name)
147 |
148 | def _find_and_replace(self, adapter_name):
149 | lora_config = self.peft_config[adapter_name]
150 | loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
151 | if loaded_in_8bit and not is_bnb_available():
152 | raise ImportError(
153 | "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
154 | "You can install it with `pip install bitsandbytes`."
155 | )
156 | is_target_modules_in_base_model = False
157 | kwargs = {
158 | "r": lora_config.r,
159 | "lora_alpha": lora_config.lora_alpha,
160 | "lora_dropout": lora_config.lora_dropout,
161 | "fan_in_fan_out": lora_config.fan_in_fan_out,
162 | "init_lora_weights": lora_config.init_lora_weights,
163 | }
164 | key_list = [key for key, _ in self.model.named_modules()]
165 | for key in key_list:
166 | if isinstance(lora_config.target_modules, str):
167 | target_module_found = re.fullmatch(lora_config.target_modules, key)
168 | else:
169 | target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
170 | if target_module_found:
171 | if not is_target_modules_in_base_model:
172 | is_target_modules_in_base_model = True
173 | parent, target, target_name = _get_submodules(self.model, key)
174 | bias = target.bias is not None
175 | if isinstance(target, LoraLayer):
176 | target.update_layer(
177 | adapter_name,
178 | lora_config.r,
179 | lora_config.lora_alpha,
180 | lora_config.lora_dropout,
181 | lora_config.init_lora_weights,
182 | )
183 | else:
184 | if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
185 | eightbit_kwargs = kwargs.copy()
186 | eightbit_kwargs.update(
187 | {
188 | "has_fp16_weights": target.state.has_fp16_weights,
189 | "memory_efficient_backward": target.state.memory_efficient_backward,
190 | "threshold": target.state.threshold,
191 | "index": target.index,
192 | }
193 | )
194 | new_module = Linear8bitLt(
195 | adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
196 | )
197 | else:
198 | if isinstance(target, torch.nn.Linear):
199 | in_features, out_features = target.in_features, target.out_features
200 | if kwargs["fan_in_fan_out"]:
201 | warnings.warn(
202 | "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
203 | "Setting fan_in_fan_out to False."
204 | )
205 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
206 | elif isinstance(target, Conv1D):
207 | in_features, out_features = (
208 | target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
209 | )
210 | if not kwargs["fan_in_fan_out"]:
211 | warnings.warn(
212 | "fan_in_fan_out is set to False but the target module is `Conv1D`. "
213 | "Setting fan_in_fan_out to True."
214 | )
215 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
216 | else:
217 | raise ValueError(
218 | f"Target module {target} is not supported. "
219 | f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
220 | )
221 | new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
222 |
223 | self._replace_module(parent, target_name, new_module, target)
224 | if not is_target_modules_in_base_model:
225 | raise ValueError(
226 | f"Target modules {lora_config.target_modules} not found in the base model. "
227 | f"Please check the target modules and try again."
228 | )
229 |
230 | def _replace_module(self, parent_module, child_name, new_module, old_module):
231 | setattr(parent_module, child_name, new_module)
232 | new_module.weight = old_module.weight
233 | if old_module.bias is not None:
234 | new_module.bias = old_module.bias
235 | if getattr(old_module, "state", None) is not None:
236 | new_module.state = old_module.state
237 | new_module.to(old_module.weight.device)
238 |
239 | # dispatch to correct device
240 | for name, module in new_module.named_modules():
241 | if "lora_" in name:
242 | module.to(old_module.weight.device)
243 |
244 | def __getattr__(self, name: str):
245 | """Forward missing attributes to the wrapped module."""
246 | try:
247 | return super().__getattr__(name) # defer to nn.Module's logic
248 | except AttributeError:
249 | return getattr(self.model, name)
250 |
251 | def get_peft_config_as_dict(self, inference: bool = False):
252 | config_dict = {}
253 | for key, value in self.peft_config.items():
254 | config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
255 | if inference:
256 | config["inference_mode"] = True
257 | config_dict[key] = config
258 | return config
259 |
260 | def _set_adapter_layers(self, enabled=True):
261 | for module in self.model.modules():
262 | if isinstance(module, LoraLayer):
263 | module.disable_adapters = False if enabled else True
264 |
265 | def enable_adapter_layers(self):
266 | self._set_adapter_layers(enabled=True)
267 |
268 | def disable_adapter_layers(self):
269 | self._set_adapter_layers(enabled=False)
270 |
271 | def set_adapter(self, adapter_name):
272 | for module in self.model.modules():
273 | if isinstance(module, LoraLayer):
274 | if module.merged:
275 | warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
276 | module.unmerge()
277 | module.active_adapter = adapter_name
278 |
279 | def merge_adapter(self):
280 | for module in self.model.modules():
281 | if isinstance(module, LoraLayer):
282 | module.merge()
283 |
284 | def unmerge_adapter(self):
285 | for module in self.model.modules():
286 | if isinstance(module, LoraLayer):
287 | module.unmerge()
288 |
289 | @staticmethod
290 | def _prepare_lora_config(peft_config, model_config):
291 | if peft_config.target_modules is None:
292 | if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
293 | raise ValueError("Please specify `target_modules` in `peft_config`")
294 | peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
295 | if peft_config.inference_mode:
296 | peft_config.merge_weights = True
297 | return peft_config
298 |
299 | def merge_and_unload(self):
300 | r"""
301 | This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
302 | as a standalone model.
303 | """
304 | if getattr(self.config, "model_type", None) == "gpt2":
305 | raise ValueError("GPT2 models are not supported for merging LORA layers")
306 |
307 | if getattr(self.model, "is_loaded_in_8bit", False):
308 | raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
309 |
310 | key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
311 | for key in key_list:
312 | try:
313 | parent, target, target_name = _get_submodules(self.model, key)
314 | except AttributeError:
315 | continue
316 | if isinstance(target, LoraLayer):
317 | bias = target.bias is not None
318 | new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
319 | target.merge()
320 | self._replace_module(parent, target_name, new_module, target)
321 |
322 | # save any additional trainable modules part of `modules_to_save`
323 | if isinstance(target, ModulesToSaveWrapper):
324 | setattr(parent, target_name, target.modules_to_save[target.active_adapter])
325 |
326 | return self.model
327 |
328 | def add_weighted_adapter(self, adapters, weights, adapter_name):
329 | if len({self.peft_config[adapter].r for adapter in adapters}) != 1:
330 | raise ValueError("All adapters must have the same r value")
331 | self.peft_config[adapter_name] = self.peft_config[adapters[0]]
332 | self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r
333 | self._find_and_replace(adapter_name)
334 | mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
335 | _freeze_adapter(self.model, adapter_name)
336 | key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
337 | for key in key_list:
338 | _, target, _ = _get_submodules(self.model, key)
339 | if isinstance(target, LoraLayer):
340 | target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0
341 | target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0
342 | for adapter, weight in zip(adapters, weights):
343 | if adapter not in target.lora_A:
344 | continue
345 | target.lora_A[adapter_name].weight.data += (
346 | target.lora_A[adapter].weight.data * weight * target.scaling[adapter]
347 | )
348 | target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight
349 |
350 |
351 | # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
352 | # and modified to work with PyTorch FSDP
353 |
354 |
355 | # ------------------------------------------------------------------------------------------
356 | # Copyright (c) Microsoft Corporation. All rights reserved.
357 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
358 | # ------------------------------------------------------------------------------------------
359 |
360 |
361 | # had to adapt it for `lora_only` to work
362 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
363 | for n, p in model.named_parameters():
364 | if "lora_" not in n:
365 | p.requires_grad = False
366 | if bias == "none":
367 | return
368 | elif bias == "all":
369 | for n, p in model.named_parameters():
370 | if "bias" in n:
371 | p.requires_grad = True
372 | elif bias == "lora_only":
373 | for m in model.modules():
374 | if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
375 | m.bias.requires_grad = True
376 | else:
377 | raise NotImplementedError
378 |
379 |
380 | class LoraLayer:
381 | def __init__(
382 | self,
383 | in_features: int,
384 | out_features: int,
385 | ):
386 | self.r = {}
387 | self.lora_alpha = {}
388 | self.scaling = {}
389 | self.lora_dropout = nn.ModuleDict({})
390 | self.lora_A = nn.ModuleDict({})
391 | self.lora_B = nn.ModuleDict({})
392 | # Mark the weight as unmerged
393 | self.merged = False
394 | self.disable_adapters = False
395 | self.in_features = in_features
396 | self.out_features = out_features
397 |
398 | def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
399 | self.r[adapter_name] = r
400 | self.lora_alpha[adapter_name] = lora_alpha
401 | if lora_dropout > 0.0:
402 | lora_dropout_layer = nn.Dropout(p=lora_dropout)
403 | else:
404 | lora_dropout_layer = nn.Identity()
405 |
406 | self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
407 | # Actual trainable parameters
408 | if r > 0:
409 | self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
410 | self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
411 | self.scaling[adapter_name] = lora_alpha / r
412 | if init_lora_weights:
413 | self.reset_lora_parameters(adapter_name)
414 | self.to(self.weight.device)
415 |
416 | def reset_lora_parameters(self, adapter_name):
417 | if adapter_name in self.lora_A.keys():
418 | # initialize A the same way as the default for nn.Linear and B to zero
419 | nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
420 | nn.init.zeros_(self.lora_B[adapter_name].weight)
421 |
422 |
423 | class Linear(nn.Linear, LoraLayer):
424 | # Lora implemented in a dense layer
425 | def __init__(
426 | self,
427 | adapter_name: str,
428 | in_features: int,
429 | out_features: int,
430 | r: int = 0,
431 | lora_alpha: int = 1,
432 | lora_dropout: float = 0.0,
433 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
434 | **kwargs,
435 | ):
436 | init_lora_weights = kwargs.pop("init_lora_weights", True)
437 |
438 | nn.Linear.__init__(self, in_features, out_features, **kwargs)
439 | LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
440 | # Freezing the pre-trained weight matrix
441 | self.weight.requires_grad = False
442 |
443 | self.fan_in_fan_out = fan_in_fan_out
444 | if fan_in_fan_out:
445 | self.weight.data = self.weight.data.T
446 |
447 | nn.Linear.reset_parameters(self)
448 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
449 | self.active_adapter = adapter_name
450 |
451 | def merge(self):
452 | if self.active_adapter not in self.lora_A.keys():
453 | return
454 | if self.merged:
455 | warnings.warn("Already merged. Nothing to do.")
456 | return
457 | if self.r[self.active_adapter] > 0:
458 | self.weight.data += (
459 | transpose(
460 | self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
461 | self.fan_in_fan_out,
462 | )
463 | * self.scaling[self.active_adapter]
464 | )
465 | self.merged = True
466 |
467 | def unmerge(self):
468 | if self.active_adapter not in self.lora_A.keys():
469 | return
470 | if not self.merged:
471 | warnings.warn("Already unmerged. Nothing to do.")
472 | return
473 | if self.r[self.active_adapter] > 0:
474 | self.weight.data -= (
475 | transpose(
476 | self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
477 | self.fan_in_fan_out,
478 | )
479 | * self.scaling[self.active_adapter]
480 | )
481 | self.merged = False
482 |
483 | def forward(self, x: torch.Tensor):
484 | previous_dtype = x.dtype
485 |
486 | if self.active_adapter not in self.lora_A.keys():
487 | return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
488 | if self.disable_adapters:
489 | if self.r[self.active_adapter] > 0 and self.merged:
490 | self.unmerge()
491 | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
492 | elif self.r[self.active_adapter] > 0 and not self.merged:
493 | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
494 |
495 | x = x.to(self.lora_A[self.active_adapter].weight.dtype)
496 |
497 | result += (
498 | self.lora_B[self.active_adapter](
499 | self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
500 | )
501 | * self.scaling[self.active_adapter]
502 | )
503 | else:
504 | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
505 |
506 | result = result.to(previous_dtype)
507 |
508 | return result
509 |
510 |
511 | if is_bnb_available():
512 |
513 | class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
514 | # Lora implemented in a dense layer
515 | def __init__(
516 | self,
517 | adapter_name,
518 | in_features,
519 | out_features,
520 | r: int = 0,
521 | lora_alpha: int = 1,
522 | lora_dropout: float = 0.0,
523 | **kwargs,
524 | ):
525 | bnb.nn.Linear8bitLt.__init__(
526 | self,
527 | in_features,
528 | out_features,
529 | bias=kwargs.get("bias", True),
530 | has_fp16_weights=kwargs.get("has_fp16_weights", True),
531 | memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
532 | threshold=kwargs.get("threshold", 0.0),
533 | index=kwargs.get("index", None),
534 | )
535 | LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
536 |
537 | # Freezing the pre-trained weight matrix
538 | self.weight.requires_grad = False
539 |
540 | init_lora_weights = kwargs.pop("init_lora_weights", True)
541 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
542 | self.active_adapter = adapter_name
543 |
544 | def forward(self, x: torch.Tensor):
545 | result = super().forward(x)
546 |
547 | if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
548 | return result
549 | elif self.r[self.active_adapter] > 0:
550 | if not torch.is_autocast_enabled():
551 | expected_dtype = result.dtype
552 |
553 | if x.dtype != torch.float32:
554 | x = x.float()
555 | output = (
556 | self.lora_B[self.active_adapter](
557 | self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
558 | ).to(expected_dtype)
559 | * self.scaling[self.active_adapter]
560 | )
561 | else:
562 | output = (
563 | self.lora_B[self.active_adapter](
564 | self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
565 | )
566 | * self.scaling[self.active_adapter]
567 | )
568 | result += output
569 | return result
570 |
--------------------------------------------------------------------------------
/utils/peft/tuners/p_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import enum
17 | import warnings
18 | from dataclasses import dataclass, field
19 | from typing import Union
20 |
21 | import torch
22 |
23 | from ..utils import PeftType, PromptLearningConfig
24 |
25 |
26 | class PromptEncoderReparameterizationType(str, enum.Enum):
27 | MLP = "MLP"
28 | LSTM = "LSTM"
29 |
30 |
31 | @dataclass
32 | class PromptEncoderConfig(PromptLearningConfig):
33 | """
34 | This is the configuration class to store the configuration of a [`PromptEncoder`].
35 |
36 | Args:
37 | encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]):
38 | The type of reparameterization to use.
39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder.
40 | encoder_num_layers (`int`): The number of layers of the prompt encoder.
41 | encoder_dropout (`float`): The dropout probability of the prompt encoder.
42 | """
43 |
44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field(
45 | default=PromptEncoderReparameterizationType.MLP,
46 | metadata={"help": "How to reparameterize the prompt encoder"},
47 | )
48 | encoder_hidden_size: int = field(
49 | default=None,
50 | metadata={"help": "The hidden size of the prompt encoder"},
51 | )
52 | encoder_num_layers: int = field(
53 | default=2,
54 | metadata={"help": "The number of layers of the prompt encoder"},
55 | )
56 | encoder_dropout: float = field(
57 | default=0.0,
58 | metadata={"help": "The dropout of the prompt encoder"},
59 | )
60 |
61 | def __post_init__(self):
62 | self.peft_type = PeftType.P_TUNING
63 |
64 |
65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py
66 | # with some refactor
67 | class PromptEncoder(torch.nn.Module):
68 | """
69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
70 |
71 | Args:
72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder.
73 |
74 | Example:
75 |
76 | ```py
77 | >>> from peft import PromptEncoder, PromptEncoderConfig
78 |
79 | >>> config = PromptEncoderConfig(
80 | ... peft_type="P_TUNING",
81 | ... task_type="SEQ_2_SEQ_LM",
82 | ... num_virtual_tokens=20,
83 | ... token_dim=768,
84 | ... num_transformer_submodules=1,
85 | ... num_attention_heads=12,
86 | ... num_layers=12,
87 | ... encoder_reparameterization_type="MLP",
88 | ... encoder_hidden_size=768,
89 | ... )
90 |
91 | >>> prompt_encoder = PromptEncoder(config)
92 | ```
93 |
94 | **Attributes**:
95 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder.
96 | - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`.
97 | - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and
98 | `encoder_reparameterization_type="LSTM"`.
99 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model.
100 | - **input_size** (`int`) -- The input size of the prompt encoder.
101 | - **output_size** (`int`) -- The output size of the prompt encoder.
102 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder.
103 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the
104 | prompt encoder.
105 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt
106 | encoder.
107 |
108 |
109 | Input shape: (`batch_size`, `total_virtual_tokens`)
110 |
111 | Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
112 | """
113 |
114 | def __init__(self, config):
115 | super().__init__()
116 | self.token_dim = config.token_dim
117 | self.input_size = self.token_dim
118 | self.output_size = self.token_dim
119 | self.hidden_size = config.encoder_hidden_size
120 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
121 | self.encoder_type = config.encoder_reparameterization_type
122 |
123 | # embedding
124 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim)
125 | if not config.inference_mode:
126 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
127 | lstm_dropout = config.encoder_dropout
128 | num_layers = config.encoder_num_layers
129 | # LSTM
130 | self.lstm_head = torch.nn.LSTM(
131 | input_size=self.input_size,
132 | hidden_size=self.hidden_size,
133 | num_layers=num_layers,
134 | dropout=lstm_dropout,
135 | bidirectional=True,
136 | batch_first=True,
137 | )
138 |
139 | self.mlp_head = torch.nn.Sequential(
140 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
141 | torch.nn.ReLU(),
142 | torch.nn.Linear(self.hidden_size * 2, self.output_size),
143 | )
144 |
145 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
146 | warnings.warn(
147 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
148 | )
149 | layers = [
150 | torch.nn.Linear(self.input_size, self.hidden_size),
151 | torch.nn.ReLU(),
152 | torch.nn.Linear(self.hidden_size, self.hidden_size),
153 | torch.nn.ReLU(),
154 | torch.nn.Linear(self.hidden_size, self.output_size),
155 | ]
156 | self.mlp_head = torch.nn.Sequential(*layers)
157 |
158 | else:
159 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
160 |
161 | def forward(self, indices):
162 | input_embeds = self.embedding(indices)
163 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
164 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0])
165 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
166 | output_embeds = self.mlp_head(input_embeds)
167 | else:
168 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
169 |
170 | return output_embeds
171 |
--------------------------------------------------------------------------------
/utils/peft/tuners/prefix_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from dataclasses import dataclass, field
18 |
19 | import torch
20 |
21 | from ..utils import PeftType, PromptLearningConfig
22 |
23 |
24 | @dataclass
25 | class PrefixTuningConfig(PromptLearningConfig):
26 | """
27 | This is the configuration class to store the configuration of a [`PrefixEncoder`].
28 |
29 | Args:
30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder.
31 | prefix_projection (`bool`): Whether to project the prefix embeddings.
32 | """
33 |
34 | encoder_hidden_size: int = field(
35 | default=None,
36 | metadata={"help": "The hidden size of the encoder"},
37 | )
38 | prefix_projection: bool = field(
39 | default=False,
40 | metadata={"help": "Whether to project the prefix tokens"},
41 | )
42 |
43 | def __post_init__(self):
44 | self.peft_type = PeftType.PREFIX_TUNING
45 |
46 |
47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
48 | # with some refactor
49 | class PrefixEncoder(torch.nn.Module):
50 | r"""
51 | The `torch.nn` model to encode the prefix.
52 |
53 | Args:
54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder.
55 |
56 | Example:
57 |
58 | ```py
59 | >>> from peft import PrefixEncoder, PrefixTuningConfig
60 |
61 | >>> config = PrefixTuningConfig(
62 | ... peft_type="PREFIX_TUNING",
63 | ... task_type="SEQ_2_SEQ_LM",
64 | ... num_virtual_tokens=20,
65 | ... token_dim=768,
66 | ... num_transformer_submodules=1,
67 | ... num_attention_heads=12,
68 | ... num_layers=12,
69 | ... encoder_hidden_size=768,
70 | ... )
71 | >>> prefix_encoder = PrefixEncoder(config)
72 | ```
73 |
74 | **Attributes**:
75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder.
76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if
77 | `prefix_projection` is `True`.
78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings.
79 |
80 | Input shape: (`batch_size`, `num_virtual_tokens`)
81 |
82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`)
83 | """
84 |
85 | def __init__(self, config):
86 | super().__init__()
87 | self.prefix_projection = config.prefix_projection
88 | token_dim = config.token_dim
89 | num_layers = config.num_layers
90 | encoder_hidden_size = config.encoder_hidden_size
91 | num_virtual_tokens = config.num_virtual_tokens
92 | if self.prefix_projection and not config.inference_mode:
93 | # Use a two-layer MLP to encode the prefix
94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
95 | self.transform = torch.nn.Sequential(
96 | torch.nn.Linear(token_dim, encoder_hidden_size),
97 | torch.nn.Tanh(),
98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
99 | )
100 | else:
101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
102 |
103 | def forward(self, prefix: torch.Tensor):
104 | if self.prefix_projection:
105 | prefix_tokens = self.embedding(prefix)
106 | past_key_values = self.transform(prefix_tokens)
107 | else:
108 | past_key_values = self.embedding(prefix)
109 | return past_key_values
110 |
--------------------------------------------------------------------------------
/utils/peft/tuners/prompt_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import enum
17 | import math
18 | from dataclasses import dataclass, field
19 | from typing import Optional, Union
20 |
21 | import torch
22 |
23 | from ..utils import PeftType, PromptLearningConfig
24 |
25 |
26 | class PromptTuningInit(str, enum.Enum):
27 | TEXT = "TEXT"
28 | RANDOM = "RANDOM"
29 |
30 |
31 | @dataclass
32 | class PromptTuningConfig(PromptLearningConfig):
33 | """
34 | This is the configuration class to store the configuration of a [`PromptEmbedding`].
35 |
36 | Args:
37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding.
38 | prompt_tuning_init_text (`str`, *optional*):
39 | The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`.
40 | tokenizer_name_or_path (`str`, *optional*):
41 | The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`.
42 | """
43 |
44 | prompt_tuning_init: Union[PromptTuningInit, str] = field(
45 | default=PromptTuningInit.RANDOM,
46 | metadata={"help": "How to initialize the prompt tuning parameters"},
47 | )
48 | prompt_tuning_init_text: Optional[str] = field(
49 | default=None,
50 | metadata={
51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
52 | },
53 | )
54 | tokenizer_name_or_path: Optional[str] = field(
55 | default=None,
56 | metadata={
57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
58 | },
59 | )
60 |
61 | def __post_init__(self):
62 | self.peft_type = PeftType.PROMPT_TUNING
63 |
64 |
65 | class PromptEmbedding(torch.nn.Module):
66 | """
67 | The model to encode virtual tokens into prompt embeddings.
68 |
69 | Args:
70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding.
71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model.
72 |
73 | **Attributes**:
74 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding.
75 |
76 | Example:
77 |
78 | ```py
79 | >>> from peft import PromptEmbedding, PromptTuningConfig
80 |
81 | >>> config = PromptTuningConfig(
82 | ... peft_type="PROMPT_TUNING",
83 | ... task_type="SEQ_2_SEQ_LM",
84 | ... num_virtual_tokens=20,
85 | ... token_dim=768,
86 | ... num_transformer_submodules=1,
87 | ... num_attention_heads=12,
88 | ... num_layers=12,
89 | ... prompt_tuning_init="TEXT",
90 | ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
91 | ... tokenizer_name_or_path="t5-base",
92 | ... )
93 |
94 | >>> # t5_model.shared is the word embeddings of the base model
95 | >>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
96 | ```
97 |
98 | Input Shape: (`batch_size`, `total_virtual_tokens`)
99 |
100 | Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
101 | """
102 |
103 | def __init__(self, config, word_embeddings):
104 | super().__init__()
105 |
106 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
107 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
108 | if config.prompt_tuning_init == PromptTuningInit.TEXT:
109 | from transformers import AutoTokenizer
110 |
111 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
112 | init_text = config.prompt_tuning_init_text
113 | init_token_ids = tokenizer(init_text)["input_ids"]
114 | # Trim or iterate until num_text_tokens matches total_virtual_tokens
115 | num_text_tokens = len(init_token_ids)
116 | if num_text_tokens > total_virtual_tokens:
117 | init_token_ids = init_token_ids[:total_virtual_tokens]
118 | elif num_text_tokens < total_virtual_tokens:
119 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
120 | init_token_ids = init_token_ids * num_reps
121 | init_token_ids = init_token_ids[:total_virtual_tokens]
122 |
123 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
124 | word_embedding_weights = word_embedding_weights.to(torch.float32)
125 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
126 |
127 | def forward(self, indices):
128 | # Just get embeddings
129 | prompt_embeddings = self.embedding(indices)
130 | return prompt_embeddings
131 |
--------------------------------------------------------------------------------
/utils/peft/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
21 | from .other import (
22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
25 | CONFIG_NAME,
26 | WEIGHTS_NAME,
27 | _set_trainable,
28 | bloom_model_postprocess_past_key_value,
29 | prepare_model_for_int8_training,
30 | shift_tokens_right,
31 | transpose,
32 | _get_submodules,
33 | _set_adapter,
34 | _freeze_adapter,
35 | ModulesToSaveWrapper,
36 | )
37 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
38 |
--------------------------------------------------------------------------------
/utils/peft/utils/config.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import enum
16 | import json
17 | import os
18 | from dataclasses import asdict, dataclass, field
19 | from typing import Optional, Union
20 |
21 | from huggingface_hub import hf_hub_download
22 | from transformers.utils import PushToHubMixin
23 |
24 | from .other import CONFIG_NAME
25 |
26 |
27 | class PeftType(str, enum.Enum):
28 | PROMPT_TUNING = "PROMPT_TUNING"
29 | P_TUNING = "P_TUNING"
30 | PREFIX_TUNING = "PREFIX_TUNING"
31 | LORA = "LORA"
32 | ADALORA = "ADALORA"
33 |
34 |
35 | class TaskType(str, enum.Enum):
36 | SEQ_CLS = "SEQ_CLS"
37 | SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
38 | CAUSAL_LM = "CAUSAL_LM"
39 | TOKEN_CLS = "TOKEN_CLS"
40 |
41 |
42 | @dataclass
43 | class PeftConfigMixin(PushToHubMixin):
44 | r"""
45 | This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
46 | PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to
47 | push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
48 | directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
49 |
50 | Args:
51 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
52 | """
53 | peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
54 |
55 | @property
56 | def __dict__(self):
57 | return asdict(self)
58 |
59 | def to_dict(self):
60 | return self.__dict__
61 |
62 | def save_pretrained(self, save_directory, **kwargs):
63 | r"""
64 | This method saves the configuration of your adapter model in a directory.
65 |
66 | Args:
67 | save_directory (`str`):
68 | The directory where the configuration will be saved.
69 | kwargs (additional keyword arguments, *optional*):
70 | Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`]
71 | method.
72 | """
73 | if os.path.isfile(save_directory):
74 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
75 |
76 | os.makedirs(save_directory, exist_ok=True)
77 |
78 | output_dict = self.__dict__
79 | output_path = os.path.join(save_directory, CONFIG_NAME)
80 |
81 | # save it
82 | with open(output_path, "w") as writer:
83 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
84 |
85 | @classmethod
86 | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
87 | r"""
88 | This method loads the configuration of your adapter model from a directory.
89 |
90 | Args:
91 | pretrained_model_name_or_path (`str`):
92 | The directory or the Hub repository id where the configuration is saved.
93 | kwargs (additional keyword arguments, *optional*):
94 | Additional keyword arguments passed along to the child class initialization.
95 | """
96 | path = (
97 | os.path.join(pretrained_model_name_or_path, subfolder)
98 | if subfolder is not None
99 | else pretrained_model_name_or_path
100 | )
101 | if os.path.isfile(os.path.join(path, CONFIG_NAME)):
102 | config_file = os.path.join(path, CONFIG_NAME)
103 | else:
104 | try:
105 | config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder)
106 | except Exception:
107 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
108 |
109 | loaded_attributes = cls.from_json_file(config_file)
110 |
111 | config = cls(**kwargs)
112 |
113 | for key, value in loaded_attributes.items():
114 | if hasattr(config, key):
115 | setattr(config, key, value)
116 |
117 | return config
118 |
119 | @classmethod
120 | def from_json_file(cls, path_json_file, **kwargs):
121 | r"""
122 | Loads a configuration file from a json file.
123 |
124 | Args:
125 | path_json_file (`str`):
126 | The path to the json file.
127 | """
128 | with open(path_json_file, "r") as file:
129 | json_object = json.load(file)
130 |
131 | return json_object
132 |
133 |
134 | @dataclass
135 | class PeftConfig(PeftConfigMixin):
136 | """
137 | This is the base configuration class to store the configuration of a [`PeftModel`].
138 |
139 | Args:
140 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
141 | task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
142 | inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
143 | """
144 |
145 | base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
146 | peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
147 | task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
148 | inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
149 |
150 |
151 | @dataclass
152 | class PromptLearningConfig(PeftConfig):
153 | """
154 | This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or
155 | [`PromptTuning`].
156 |
157 | Args:
158 | num_virtual_tokens (`int`): The number of virtual tokens to use.
159 | token_dim (`int`): The hidden embedding dimension of the base transformer model.
160 | num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
161 | num_attention_heads (`int`): The number of attention heads in the base transformer model.
162 | num_layers (`int`): The number of layers in the base transformer model.
163 | """
164 |
165 | num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
166 | token_dim: int = field(
167 | default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
168 | )
169 | num_transformer_submodules: Optional[int] = field(
170 | default=None, metadata={"help": "Number of transformer submodules"}
171 | )
172 | num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
173 | num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
174 |
--------------------------------------------------------------------------------
/utils/peft/utils/other.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import copy
17 | from functools import reduce
18 |
19 | import torch
20 |
21 |
22 | # needed for prefix-tuning of bloom model
23 | def bloom_model_postprocess_past_key_value(past_key_values):
24 | past_key_values = torch.cat(past_key_values)
25 | total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
26 | keys = past_key_values[: total_layers // 2]
27 | keys = keys.transpose(2, 3).reshape(
28 | total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
29 | )
30 | values = past_key_values[total_layers // 2 :]
31 | values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
32 |
33 | return tuple(zip(keys, values))
34 |
35 |
36 | def prepare_model_for_int8_training(
37 | model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
38 | ):
39 | r"""
40 | This method wraps the entire protocol for preparing a model before running a training. This includes:
41 | 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
42 | head to fp32
43 |
44 | Args:
45 | model, (`transformers.PreTrainedModel`):
46 | The loaded model from `transformers`
47 | """
48 | loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
49 |
50 | for name, param in model.named_parameters():
51 | # freeze base model's layers
52 | param.requires_grad = False
53 |
54 | if loaded_in_8bit:
55 | # cast layer norm in fp32 for stability for 8bit models
56 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
57 | param.data = param.data.to(torch.float32)
58 |
59 | if loaded_in_8bit and use_gradient_checkpointing:
60 | # For backward compatibility
61 | if hasattr(model, "enable_input_require_grads"):
62 | model.enable_input_require_grads()
63 | else:
64 |
65 | def make_inputs_require_grad(module, input, output):
66 | output.requires_grad_(True)
67 |
68 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
69 |
70 | # enable gradient checkpointing for memory efficiency
71 | model.gradient_checkpointing_enable()
72 |
73 | if hasattr(model, output_embedding_layer_name):
74 | output_embedding_layer = getattr(model, output_embedding_layer_name)
75 | input_dtype = output_embedding_layer.weight.dtype
76 |
77 | class CastOutputToFloat(torch.nn.Sequential):
78 | r"""
79 | Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
80 | in fp32
81 |
82 | """
83 |
84 | def forward(self, x):
85 | return super().forward(x.to(input_dtype)).to(torch.float32)
86 |
87 | setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
88 |
89 | return model
90 |
91 |
92 | # copied from transformers.models.bart.modeling_bart
93 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
94 | """
95 | Shift input ids one token to the right.
96 |
97 | Args:
98 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
99 | pad_token_id (`int`): The id of the `padding` token.
100 | decoder_start_token_id (`int`): The id of the `start` token.
101 | """
102 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
103 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
104 | shifted_input_ids[:, 0] = decoder_start_token_id
105 |
106 | if pad_token_id is None:
107 | raise ValueError("self.model.config.pad_token_id has to be defined.")
108 | # replace possible -100 values in labels by `pad_token_id`
109 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
110 |
111 | return shifted_input_ids
112 |
113 |
114 | class ModulesToSaveWrapper(torch.nn.Module):
115 | def __init__(self, module_to_save, adapter_name):
116 | super().__init__()
117 | self.original_module = module_to_save
118 | self.modules_to_save = torch.nn.ModuleDict({})
119 | self.update(adapter_name)
120 | self.active_adapter = adapter_name
121 |
122 | def update(self, adapter_name):
123 | self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
124 |
125 | def forward(self, *args, **kwargs):
126 | if self.active_adapter not in self.modules_to_save:
127 | return self.original_module(*args, **kwargs)
128 | return self.modules_to_save[self.active_adapter](*args, **kwargs)
129 |
130 | def get_module_by_name(module, access_string):
131 | names = access_string.split(sep='.')
132 | return reduce(getattr, names, module)
133 |
134 |
135 | def _get_submodules(model, key):
136 | parent = get_module_by_name(model, ".".join(key.split(".")[:-1]))
137 | target_name = key.split(".")[-1]
138 | target = get_module_by_name(model, key)
139 | return parent, target, target_name
140 |
141 |
142 | def _freeze_adapter(model, adapter_name):
143 | for n, p in model.named_parameters():
144 | if adapter_name in n:
145 | p.requires_grad = False
146 |
147 |
148 | def _set_trainable(model, adapter_name):
149 | key_list = [key for key, _ in model.named_modules()]
150 | for key in key_list:
151 | target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
152 | if target_module_found:
153 | parent, target, target_name = _get_submodules(model, key)
154 | if isinstance(target, ModulesToSaveWrapper):
155 | target.update(adapter_name)
156 | else:
157 | for param in target.parameters():
158 | param.requires_grad = True
159 | setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name))
160 |
161 |
162 | def _set_adapter(model, adapter_name):
163 | for module in model.modules():
164 | if isinstance(module, ModulesToSaveWrapper):
165 | module.active_adapter = adapter_name
166 |
167 |
168 | def fsdp_auto_wrap_policy(model):
169 | import functools
170 | import os
171 |
172 | from accelerate import FullyShardedDataParallelPlugin
173 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
174 |
175 | from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
176 |
177 | def lambda_policy_fn(module):
178 | if (
179 | len(list(module.named_children())) == 0
180 | and getattr(module, "weight", None) is not None
181 | and module.weight.requires_grad
182 | ):
183 | return True
184 | return False
185 |
186 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
187 | transformer_wrap_policy = functools.partial(
188 | transformer_auto_wrap_policy,
189 | transformer_layer_cls=(
190 | PrefixEncoder,
191 | PromptEncoder,
192 | PromptEmbedding,
193 | FullyShardedDataParallelPlugin.get_module_class_from_name(
194 | model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
195 | ),
196 | ),
197 | )
198 |
199 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
200 | return auto_wrap_policy
201 |
202 |
203 | def transpose(weight, fan_in_fan_out):
204 | return weight.T if fan_in_fan_out else weight
205 |
206 |
207 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
208 | "t5": ["q", "v"],
209 | "mt5": ["q", "v"],
210 | "bart": ["q_proj", "v_proj"],
211 | "gpt2": ["c_attn"],
212 | "bloom": ["query_key_value"],
213 | "blip-2": ["q", "v", "q_proj", "v_proj"],
214 | "opt": ["q_proj", "v_proj"],
215 | "gptj": ["q_proj", "v_proj"],
216 | "gpt_neox": ["query_key_value"],
217 | "gpt_neo": ["q_proj", "v_proj"],
218 | "bert": ["query", "value"],
219 | "roberta": ["query", "value"],
220 | "xlm-roberta": ["query", "value"],
221 | "electra": ["query", "value"],
222 | "deberta-v2": ["query_proj", "value_proj"],
223 | "deberta": ["in_proj"],
224 | "layoutlm": ["query", "value"],
225 | "llama": ["q_proj", "v_proj"],
226 | "chatglm": ["query_key_value"],
227 | }
228 |
229 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
230 | "t5": ["q", "k", "v", "o", "wi", "wo"],
231 | "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
232 | "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
233 | # "gpt2": ["c_attn"],
234 | # "bloom": ["query_key_value"],
235 | "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
236 | # "gptj": ["q_proj", "v_proj"],
237 | # "gpt_neox": ["query_key_value"],
238 | # "gpt_neo": ["q_proj", "v_proj"],
239 | # "bert": ["query", "value"],
240 | "roberta": ["query", "key", "value", "dense"],
241 | # "xlm-roberta": ["query", "value"],
242 | # "electra": ["query", "value"],
243 | "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
244 | # "deberta": ["in_proj"],
245 | # "layoutlm": ["query", "value"],
246 | }
247 |
248 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
249 | "bloom": bloom_model_postprocess_past_key_value,
250 | }
251 |
252 | WEIGHTS_NAME = "adapter_model.bin"
253 | CONFIG_NAME = "adapter_config.json"
254 |
--------------------------------------------------------------------------------
/utils/peft/utils/save_and_load.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .config import PeftType, PromptLearningConfig
17 |
18 |
19 | def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"):
20 | """
21 | Get the state dict of the Peft model.
22 |
23 | Args:
24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
25 | the model should be the underlying model/unwrapped model (i.e. model.module).
26 | state_dict (`dict`, *optional*, defaults to `None`):
27 | The state dict of the model. If not provided, the state dict of the model
28 | will be used.
29 | """
30 | config = model.peft_config[adapter_name]
31 | if state_dict is None:
32 | state_dict = model.state_dict()
33 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
34 | # to_return = lora_state_dict(model, bias=model.peft_config.bias)
35 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
36 | # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
37 | bias = config.bias
38 | if bias == "none":
39 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
40 | elif bias == "all":
41 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
42 | elif bias == "lora_only":
43 | to_return = {}
44 | for k in state_dict:
45 | if "lora_" in k:
46 | to_return[k] = state_dict[k]
47 | bias_name = k.split("lora_")[0] + "bias"
48 | if bias_name in state_dict:
49 | to_return[bias_name] = state_dict[bias_name]
50 | else:
51 | raise NotImplementedError
52 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
53 | if config.peft_type == PeftType.ADALORA:
54 | rank_pattern = config.rank_pattern
55 | if rank_pattern is not None:
56 | rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()}
57 | config.rank_pattern = rank_pattern
58 | to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name)
59 | elif isinstance(config, PromptLearningConfig):
60 | to_return = {}
61 | if config.inference_mode:
62 | prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
63 | else:
64 | prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
65 | to_return["prompt_embeddings"] = prompt_embeddings
66 | else:
67 | raise NotImplementedError
68 | if model.modules_to_save is not None:
69 | for key, value in state_dict.items():
70 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
71 | to_return[key.replace("modules_to_save.", "")] = value
72 |
73 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
74 | return to_return
75 |
76 |
77 | def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"):
78 | """
79 | Set the state dict of the Peft model.
80 |
81 | Args:
82 | model ([`PeftModel`]): The Peft model.
83 | peft_model_state_dict (`dict`): The state dict of the Peft model.
84 | """
85 | config = model.peft_config[adapter_name]
86 | state_dict = {}
87 | if model.modules_to_save is not None:
88 | for key, value in peft_model_state_dict.items():
89 | if any(module_name in key for module_name in model.modules_to_save):
90 | for module_name in model.modules_to_save:
91 | if module_name in key:
92 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
93 | break
94 | state_dict[key] = value
95 | else:
96 | state_dict = peft_model_state_dict
97 |
98 | #print("config.peft_type: ".format(config.peft_type))
99 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
100 | peft_model_state_dict = {}
101 | for k, v in state_dict.items():
102 | if "lora_" in k:
103 | suffix = k.split("lora_")[1]
104 | if "." in suffix:
105 | suffix_to_replace = ".".join(suffix.split(".")[1:])
106 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
107 | else:
108 | k = f"{k}.{adapter_name}"
109 | peft_model_state_dict[k] = v
110 | else:
111 | peft_model_state_dict[k] = v
112 | if config.peft_type == PeftType.ADALORA:
113 | rank_pattern = config.rank_pattern
114 | if rank_pattern is not None:
115 | model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
116 | elif isinstance(config, PromptLearningConfig):
117 | peft_model_state_dict = state_dict
118 | else:
119 | raise NotImplementedError
120 |
121 | model.load_state_dict(peft_model_state_dict, strict=False)
122 | #exit()
123 | if isinstance(config, PromptLearningConfig):
124 | model.prompt_encoder[adapter_name].embedding.load_state_dict(
125 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
126 | )
127 |
--------------------------------------------------------------------------------