├── .gitignore ├── LICENSE ├── README.md ├── SVDLLM.py ├── component ├── svd_llama.py ├── svd_llama_kvcache.py ├── svd_mistral.py └── svd_opt.py ├── compress_llama.sh ├── evaluater.py ├── figures ├── framework_v1.jpg ├── framework_v2.jpg └── logo.png ├── gptq ├── gptq.py └── quant.py ├── quant_llama.py ├── requirements.txt ├── svdllm_gptq.sh └── utils ├── LoRA.py ├── Prompter.py ├── data_utils.py ├── model_utils.py └── peft ├── __init__.py ├── import_utils.py ├── mapping.py ├── peft_model.py ├── tuners ├── __init__.py ├── adalora.py ├── lora.py ├── p_tuning.py ├── prefix_tuning.py └── prompt_tuning.py └── utils ├── __init__.py ├── config.py ├── other.py └── save_and_load.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | cache/* 3 | **/__pycache__ 4 | utils/c4-* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 |

4 | 5 |
6 |

SVD-LLM: Singular Value Decomposition for Large Language Model Compression

7 |
8 | 9 | License: Apache 2.0 10 | 11 | 12 | PyTorch>=v1.7.1 13 | 14 | 15 | transformers==v4.35.2 16 | 17 | 18 | LLaMA 19 | 20 | 21 | Llama-2 22 | 23 | 24 | mistral 25 | 26 | 27 | opt 28 | 29 |
30 |
31 | 32 | ## ✨Roadmap 33 | We are working on the following tasks, please stay tuned! 34 | 35 | - [X] Release the code of SVD-LLM. 36 | - [ ] Release the code of SVD-LLM V2. 37 | - [ ] Upgrade the transformers package to the latest version. 38 | - [ ] Update the efficiency evaluation code. 39 | - [ ] Optimize the compression of GQA-based LLMs. 40 | 41 | 42 | ## Introduction 43 | 44 | > **[SVD-LLM: Truncation-aware Singular Value Decomposition for Large Language Model Compression](https://openreview.net/forum?id=LNYIUouhdt&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DICLR.cc%2F2025%2FConference%2FAuthors%23your-submissions))** 45 | > 46 | > *Xin Wang1, Yu Zheng2, Zhongwei Wan1, Mi Zhang1* 47 | > *1The Ohio State University, 2Michigan State University* 48 | > 49 | > International Conference on Learning Representations (ICLR) 2025 50 | 51 | 52 | > **[SVD-LLM V2: Optimizing Singular Value Truncation for Large Language Model Compression](https://arxiv.org/abs/2503.12340)** 53 | > 54 | > *Xin Wang, Samiul Alam, Zhongwei Wan, Hui Shen, Mi Zhang* 55 | > *The Ohio State University* 56 | > 57 | > Annual Conference of the Nations of the Americas Chapter of the Association for Computational Linguistics (NAACL) 2025 58 | 59 | 60 | ## Quick Start 61 | 62 | ### Installation 63 | Please keep the version of the transformers package exactly equal to 4.35.2 since the svd-compressed version of LLM has a slight change of model structure (in the `component/.` folder). 64 | Create and set up a conda environment with python version 3.9 (newer versions break some dependencies) 65 | ``` 66 | conda create -n compress python=3.9 67 | conda activate compress 68 | ``` 69 | Clone and navigate to the repository 70 | ``` 71 | git clone https://github.com/AIoT-MLSys-Lab/SVD-LLM.git 72 | ``` 73 | Install requirements.txt 74 | ``` 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | ### Quick Example 79 | ``` 80 | bash compress_llama.sh 81 | ``` 82 | This script would compress the LLaMA-7B model under 20\% compression ratio and automatically run the evaluation code, including both perplexity and efficiency of the compressed model. 83 | 84 | 85 | ## Step-by-Step Instructions of SVD-LLM 86 | 87 | ### 1. Truncation-Aware Data Whitening + SVD Compression 88 | Under the low compression ratio (recommended ratio <= 0.3), we first run the data whitening of the LLM and saved the weight along with the whitening information. 89 | ``` 90 | python SVDLLM.py \ 91 | --step 1 \ 92 | --ratio COMPRESSION_RATIO \ 93 | --model HUGGINGFACE_MODEL_REPO \ 94 | --whitening_nsamples WHITENING_SAMPLE_NUMBER \ 95 | --dataset WHITENING_DATASET \ 96 | --seed SAMPLING_SEED \ 97 | --model_seq_len MODEL_SEQ_LEN \ 98 | --save_path WHITENING_INFO_SAVING_PATH 99 | ``` 100 | 101 | 102 | 103 | 104 | ### 2. Parameter Update with Sequential Low-rank Approximation 105 | We first update the compressed weight matrix U and then V with LoRA fine-tuning. 106 | ``` 107 | python LoRA.py \ 108 | --prune_model COMPRESSED_MODEL_PATH \ 109 | --data_path yahma/alpaca-cleaned \ 110 | --output_dir LORA_OUTPUT_PATH \ 111 | --lora_r 8 \ 112 | --num_epochs 2 \ 113 | --learning_rate 1e-4 \ 114 | --batch_size 64 115 | ``` 116 | 117 | ### 3. SVD-LLM + GPTQ 118 | SVD-LLM can also be integrated with quantization methods to achieve a better compression. Here is the example of how to integrate SVD-LLM (20% compression ratio) with GPTQ-4bit to compress LLaMA-7B 119 | ``` 120 | bash svdllm_gptq.sh 121 | ``` 122 | 123 | ### 4. Evaluation 124 | - Perplexity Evaluation: 125 | ``` 126 | python SVDLLM.py \ 127 | --step 4 \ 128 | --model_path COMPRESSD_MODEL_SAVING_PATH \ 129 | ``` 130 | We use the same c4 dataset as in [SparseGPT](https://github.com/IST-DASLab/sparsegpt). Since the original dowload link is invalid, please directly download it from this [link](https://drive.google.com/drive/folders/123Id1MkZVsKySGy_sMO4RgiJKrtPcvUp?usp=sharing) and add the two json files under the `utils/.` folder. 131 | - Efficiency Evaluation: 132 | ``` 133 | python SVDLLM.py \ 134 | --step 5 \ 135 | --model_path COMPRESSD_MODEL_SAVING_PATH \ 136 | ``` 137 | ## Citation 138 | If you find this work useful, please cite 139 | ``` 140 | @inproceedings{wang2025svdllm, 141 | title={{SVD}-{LLM}: Truncation-aware Singular Value Decomposition for Large Language Model Compression}, 142 | author={Xin Wang and Yu Zheng and Zhongwei Wan and Mi Zhang}, 143 | booktitle={International Conference on Learning Representations (ICLR)}, 144 | year={2025}, 145 | url={https://openreview.net/forum?id=LNYIUouhdt} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /component/svd_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | 8 | from transformers.activations import ACT2FN 9 | from transformers.utils import logging 10 | from transformers import LlamaConfig 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | _CONFIG_FOR_DOC = "LlamaConfig" 15 | 16 | class LlamaRMSNorm(nn.Module): 17 | def __init__(self, hidden_size, eps=1e-6): 18 | """ 19 | LlamaRMSNorm is equivalent to T5LayerNorm 20 | """ 21 | super().__init__() 22 | self.weight = nn.Parameter(torch.ones(hidden_size)) 23 | self.variance_epsilon = eps 24 | 25 | def forward(self, hidden_states): 26 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 27 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 28 | 29 | # convert into half-precision if necessary 30 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 31 | hidden_states = hidden_states.to(self.weight.dtype) 32 | 33 | return self.weight * hidden_states 34 | 35 | 36 | class LlamaRotaryEmbedding(torch.nn.Module): 37 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 38 | super().__init__() 39 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 40 | self.register_buffer("inv_freq", inv_freq) 41 | 42 | # Build here to make `torch.jit.trace` work. 43 | self.max_seq_len_cached = max_position_embeddings 44 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 45 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 46 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 47 | emb = torch.cat((freqs, freqs), dim=-1) 48 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 49 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 50 | 51 | def forward(self, x, seq_len=None): 52 | # x: [bs, num_attention_heads, seq_len, head_size] 53 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 54 | if seq_len > self.max_seq_len_cached: 55 | self.max_seq_len_cached = seq_len 56 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 57 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 58 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 59 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 60 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 61 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 62 | return ( 63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 65 | ) 66 | 67 | 68 | def rotate_half(x): 69 | """Rotates half the hidden dims of the input.""" 70 | x1 = x[..., : x.shape[-1] // 2] 71 | x2 = x[..., x.shape[-1] // 2 :] 72 | return torch.cat((-x2, x1), dim=-1) 73 | 74 | 75 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 76 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 77 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 78 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 79 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 80 | 81 | q_embed = (q * cos) + (rotate_half(q) * sin) 82 | k_embed = (k * cos) + (rotate_half(k) * sin) 83 | return q_embed, k_embed 84 | 85 | 86 | class SVD_LlamaMLP(nn.Module): 87 | def __init__( 88 | self, 89 | hidden_size: int, 90 | intermediate_size: int, 91 | hidden_act: str, 92 | ratio=1 93 | ): 94 | super().__init__() 95 | self.ratio = ratio 96 | low_rank = int(intermediate_size * hidden_size * self.ratio / (intermediate_size + hidden_size)) 97 | self.gate_u_proj = nn.Linear(low_rank, intermediate_size, bias=False) 98 | self.gate_v_proj = nn.Linear(hidden_size, low_rank, bias=False) 99 | 100 | self.down_u_proj = nn.Linear(low_rank, hidden_size, bias=False) 101 | self.down_v_proj = nn.Linear(intermediate_size, low_rank, bias=False) 102 | 103 | self.up_u_proj = nn.Linear(low_rank, intermediate_size, bias=False) 104 | self.up_v_proj = nn.Linear(hidden_size, low_rank, bias=False) 105 | self.act_fn = ACT2FN[hidden_act] 106 | 107 | def forward(self, x): 108 | up = self.up_u_proj(self.up_v_proj(x)) 109 | gate = self.gate_u_proj(self.gate_v_proj(x)) 110 | return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up)) 111 | 112 | 113 | class SVD_LlamaAttention(nn.Module): 114 | """Multi-headed attention from 'Attention Is All You Need' paper""" 115 | 116 | def __init__(self, config: LlamaConfig, ratio=1): 117 | super().__init__() 118 | self.config = config 119 | self.hidden_size = config.hidden_size 120 | self.num_heads = config.num_attention_heads 121 | self.head_dim = self.hidden_size // self.num_heads 122 | self.max_position_embeddings = config.max_position_embeddings 123 | self.ratio = ratio # 1 means no truncate, just keep normal attn 124 | 125 | if (self.head_dim * self.num_heads) != self.hidden_size: 126 | raise ValueError( 127 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 128 | f" and `num_heads`: {self.num_heads})." 129 | ) 130 | low_rank = int(self.hidden_size * self.ratio/2) 131 | self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 132 | self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 133 | 134 | self.k_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 135 | self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 136 | 137 | self.v_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 138 | self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 139 | 140 | self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False) 141 | self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False) 142 | 143 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) 144 | 145 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 146 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 147 | 148 | def forward( 149 | self, 150 | hidden_states: torch.Tensor, 151 | attention_mask: Optional[torch.Tensor] = None, 152 | position_ids: Optional[torch.LongTensor] = None, 153 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 154 | output_attentions: bool = False, 155 | use_cache: bool = False, 156 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 157 | bsz, q_len, _ = hidden_states.size() 158 | 159 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 160 | 161 | key_states = self.k_u_proj(self.k_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 162 | 163 | value_states = self.v_u_proj(self.v_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 164 | 165 | kv_seq_len = key_states.shape[-2] 166 | if past_key_value is not None: 167 | kv_seq_len += past_key_value[0].shape[-2] 168 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 169 | 170 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 171 | # [bsz, nh, t, hd] 172 | 173 | if past_key_value is not None: 174 | # reuse k, v, self_attention 175 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 176 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 177 | 178 | past_key_value = (key_states, value_states) if use_cache else None 179 | 180 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 181 | 182 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 183 | raise ValueError( 184 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 185 | f" {attn_weights.size()}" 186 | ) 187 | 188 | if attention_mask is not None: 189 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 190 | raise ValueError( 191 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 192 | ) 193 | attn_weights = attn_weights + attention_mask 194 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)) 195 | 196 | # upcast attention to fp32 197 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 198 | attn_output = torch.matmul(attn_weights, value_states) 199 | 200 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 201 | raise ValueError( 202 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 203 | f" {attn_output.size()}" 204 | ) 205 | 206 | attn_output = attn_output.transpose(1, 2) 207 | attn_output = attn_output.reshape(bsz, q_len, -1) 208 | 209 | attn_output = self.o_u_proj(self.o_v_proj(attn_output)) 210 | 211 | if not output_attentions: 212 | attn_weights = None 213 | 214 | return attn_output, attn_weights, past_key_value 215 | -------------------------------------------------------------------------------- /component/svd_llama_kvcache.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | 8 | from transformers.activations import ACT2FN 9 | from transformers.utils import logging 10 | from transformers import LlamaConfig 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | _CONFIG_FOR_DOC = "LlamaConfig" 15 | 16 | class LlamaRMSNorm(nn.Module): 17 | def __init__(self, hidden_size, eps=1e-6): 18 | """ 19 | LlamaRMSNorm is equivalent to T5LayerNorm 20 | """ 21 | super().__init__() 22 | self.weight = nn.Parameter(torch.ones(hidden_size)) 23 | self.variance_epsilon = eps 24 | 25 | def forward(self, hidden_states): 26 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 27 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 28 | 29 | # convert into half-precision if necessary 30 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 31 | hidden_states = hidden_states.to(self.weight.dtype) 32 | 33 | return self.weight * hidden_states 34 | 35 | 36 | class LlamaRotaryEmbedding(torch.nn.Module): 37 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 38 | super().__init__() 39 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 40 | self.register_buffer("inv_freq", inv_freq) 41 | 42 | # Build here to make `torch.jit.trace` work. 43 | self.max_seq_len_cached = max_position_embeddings 44 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 45 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 46 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 47 | emb = torch.cat((freqs, freqs), dim=-1) 48 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 49 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 50 | 51 | def forward(self, x, seq_len=None): 52 | # x: [bs, num_attention_heads, seq_len, head_size] 53 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 54 | if seq_len > self.max_seq_len_cached: 55 | self.max_seq_len_cached = seq_len 56 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 57 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 58 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 59 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 60 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) 61 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) 62 | return ( 63 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 64 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 65 | ) 66 | 67 | 68 | def rotate_half(x): 69 | """Rotates half the hidden dims of the input.""" 70 | x1 = x[..., : x.shape[-1] // 2] 71 | x2 = x[..., x.shape[-1] // 2 :] 72 | return torch.cat((-x2, x1), dim=-1) 73 | 74 | 75 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 76 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 77 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 78 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 79 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 80 | 81 | q_embed = (q * cos) + (rotate_half(q) * sin) 82 | k_embed = (k * cos) + (rotate_half(k) * sin) 83 | return q_embed, k_embed 84 | 85 | 86 | class SVD_LlamaMLP(nn.Module): 87 | def __init__( 88 | self, 89 | hidden_size: int, 90 | intermediate_size: int, 91 | hidden_act: str, 92 | compression_ratio=1, 93 | ): 94 | super().__init__() 95 | self.compression_ratio = compression_ratio 96 | if self.compression_ratio != 1: 97 | low_rank = int(intermediate_size * hidden_size * self.compression_ratio / (intermediate_size + hidden_size)) 98 | self.gate_u_proj = nn.Linear(low_rank, intermediate_size, bias=False) 99 | self.gate_v_proj = nn.Linear(hidden_size, low_rank, bias=False) 100 | else: 101 | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 102 | if self.compression_ratio != 1: 103 | low_rank = int(intermediate_size * hidden_size * self.compression_ratio / (intermediate_size + hidden_size)) 104 | self.down_u_proj = nn.Linear(low_rank, hidden_size, bias=False) 105 | self.down_v_proj = nn.Linear(intermediate_size, low_rank, bias=False) 106 | else: 107 | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) 108 | if self.compression_ratio != 1: 109 | low_rank = int(intermediate_size * hidden_size * self.compression_ratio / (intermediate_size + hidden_size)) 110 | self.up_u_proj = nn.Linear(low_rank, intermediate_size, bias=False) 111 | self.up_v_proj = nn.Linear(hidden_size, low_rank, bias=False) 112 | else: 113 | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) 114 | self.act_fn = ACT2FN[hidden_act] 115 | 116 | def forward(self, x): 117 | if self.compression_ratio != 1: 118 | up = self.up_u_proj(self.up_v_proj(x)) 119 | else: 120 | up = self.up_proj(x) 121 | if self.compression_ratio != 1: 122 | gate = self.gate_u_proj(self.gate_v_proj(x)) 123 | else: 124 | gate = self.gate_proj(x) 125 | if self.compression_ratio != 1: 126 | return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up)) 127 | else: 128 | return self.down_proj(self.act_fn(gate) * up) 129 | 130 | 131 | class SVD_LlamaAttention(nn.Module): 132 | """Multi-headed attention from 'Attention Is All You Need' paper""" 133 | 134 | def __init__(self, config: LlamaConfig, compression_ratio=1): 135 | super().__init__() 136 | self.config = config 137 | self.hidden_size = config.hidden_size 138 | self.num_heads = config.num_attention_heads 139 | self.head_dim = self.hidden_size // self.num_heads 140 | self.max_position_embeddings = config.max_position_embeddings 141 | self.compression_ratio = compression_ratio 142 | if (self.head_dim * self.num_heads) != self.hidden_size: 143 | raise ValueError( 144 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 145 | f" and `num_heads`: {self.num_heads})." 146 | ) 147 | if self.compression_ratio != 1: 148 | low_rank = int(self.hidden_size * self.compression_ratio/2) 149 | self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 150 | self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 151 | else: 152 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 153 | if self.compression_ratio != 1: 154 | low_rank = int(self.hidden_size * self.compression_ratio/2) 155 | self.k_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 156 | self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 157 | else: 158 | self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 159 | if self.compression_ratio != 1: 160 | low_rank = int(self.hidden_size * self.compression_ratio/2) 161 | self.v_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 162 | self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 163 | else: 164 | self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 165 | if self.compression_ratio != 1: 166 | low_rank = int(self.hidden_size * self.compression_ratio/2) 167 | self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False) 168 | self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False) 169 | else: 170 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 171 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) 172 | 173 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 174 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 175 | 176 | def forward( 177 | self, 178 | hidden_states: torch.Tensor, 179 | attention_mask: Optional[torch.Tensor] = None, 180 | position_ids: Optional[torch.LongTensor] = None, 181 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 182 | output_attentions: bool = False, 183 | use_cache: bool = False, 184 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 185 | bsz, q_len, _ = hidden_states.size() 186 | 187 | if self.compression_ratio != 1: 188 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 189 | else: 190 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 191 | if self.compression_ratio != 1: 192 | # TODO 193 | if use_cache: 194 | key_v_x_states = self.k_v_proj(hidden_states) 195 | key_states = self.k_u_proj(key_v_x_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 196 | else: 197 | key_states = self.k_u_proj(self.k_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 198 | else: 199 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 200 | if self.compression_ratio != 1: 201 | # TODO 202 | if use_cache: 203 | value_v_x_states = self.v_v_proj(hidden_states) 204 | value_states = self.v_u_proj(value_v_x_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 205 | else: 206 | value_states = self.v_u_proj(self.v_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 207 | else: 208 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 209 | 210 | # key_states.shape = query_states = value_states[8,32,2048,128] 211 | # cos.shape = sin.shape = [1,1,2048,128] 212 | kv_seq_len = key_states.shape[-2] 213 | if past_key_value is not None: 214 | kv_seq_len += past_key_value[0].shape[-2] 215 | # TODO: pass the shared cos / sin cache from top module (original code only require dev, increasing kv_seq_len, value datatype) 216 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 217 | query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids) 218 | key_states = apply_rotary_pos_emb(key_states, cos, sin, position_ids) 219 | ################################################################################################################################ 220 | # [bsz, nh, t, hd] 221 | # TODO: store the key_v_x_states, value_v_x_states in past_key_value rather than the key,value states 222 | if past_key_value is not None: 223 | # reuse k, v, self_attention 224 | 225 | key_v_x_states = torch.cat([past_key_value[0], key_v_x_states], dim=(1 if self.compression_ratio != 1 else 2)) 226 | value_v_x_states = torch.cat([past_key_value[1], value_v_x_states], dim=(1 if self.compression_ratio != 1 else 2)) 227 | 228 | past_key_value = (key_v_x_states, value_v_x_states) if use_cache else None 229 | ################################################################################## 230 | # TODO: restore the k and v 231 | if self.compression_ratio != 1: 232 | key_states = self.k_u_proj(key_v_x_states).view(bsz, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) 233 | value_states = self.v_u_proj(value_v_x_states).view(bsz, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) 234 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 235 | 236 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 237 | raise ValueError( 238 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 239 | f" {attn_weights.size()}" 240 | ) 241 | 242 | if attention_mask is not None: 243 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 244 | raise ValueError( 245 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 246 | ) 247 | attn_weights = attn_weights + attention_mask 248 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)) 249 | 250 | # upcast attention to fp32 251 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 252 | attn_output = torch.matmul(attn_weights, value_states) 253 | # # TODO: clean 254 | # del key_states, value_states 255 | # key_states = None 256 | # value_states = None 257 | # torch.cuda.empty_cache() 258 | 259 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 260 | raise ValueError( 261 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 262 | f" {attn_output.size()}" 263 | ) 264 | 265 | attn_output = attn_output.transpose(1, 2) 266 | attn_output = attn_output.reshape(bsz, q_len, -1) 267 | 268 | if self.compression_ratio != 1: 269 | attn_output = self.o_u_proj(self.o_v_proj(attn_output)) 270 | else: 271 | attn_output = self.o_proj(attn_output) 272 | 273 | if not output_attentions: 274 | attn_weights = None 275 | 276 | return attn_output, attn_weights, past_key_value -------------------------------------------------------------------------------- /component/svd_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch Mistral model.""" 21 | import inspect 22 | import math 23 | import warnings 24 | from typing import List, Optional, Tuple, Union 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | import torch.utils.checkpoint 29 | from torch import nn 30 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 31 | 32 | from transformers.activations import ACT2FN 33 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask 34 | # from transormers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 35 | from transformers.modeling_utils import PreTrainedModel 36 | from transformers.utils import ( 37 | add_start_docstrings, 38 | add_start_docstrings_to_model_forward, 39 | is_flash_attn_2_available, 40 | logging, 41 | replace_return_docstrings, 42 | ) 43 | from transformers.models.mistral import MistralConfig 44 | 45 | 46 | if is_flash_attn_2_available(): 47 | from flash_attn import flash_attn_func, flash_attn_varlen_func 48 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 49 | 50 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 51 | 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CONFIG_FOR_DOC = "MistralConfig" 56 | 57 | 58 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 59 | def _get_unpad_data(attention_mask): 60 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 61 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 62 | max_seqlen_in_batch = seqlens_in_batch.max().item() 63 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 64 | return ( 65 | indices, 66 | cu_seqlens, 67 | max_seqlen_in_batch, 68 | ) 69 | 70 | 71 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral 72 | class MistralRMSNorm(nn.Module): 73 | def __init__(self, hidden_size, eps=1e-6): 74 | """ 75 | MistralRMSNorm is equivalent to T5LayerNorm 76 | """ 77 | super().__init__() 78 | self.weight = nn.Parameter(torch.ones(hidden_size)) 79 | self.variance_epsilon = eps 80 | 81 | def forward(self, hidden_states): 82 | input_dtype = hidden_states.dtype 83 | hidden_states = hidden_states.to(torch.float32) 84 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 85 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 86 | return self.weight * hidden_states.to(input_dtype) 87 | 88 | 89 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral 90 | class MistralRotaryEmbedding(nn.Module): 91 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 92 | super().__init__() 93 | 94 | self.dim = dim 95 | self.max_position_embeddings = max_position_embeddings 96 | self.base = base 97 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 98 | self.register_buffer("inv_freq", inv_freq, persistent=False) 99 | 100 | # Build here to make `torch.jit.trace` work. 101 | self._set_cos_sin_cache( 102 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 103 | ) 104 | 105 | def _set_cos_sin_cache(self, seq_len, device, dtype): 106 | self.max_seq_len_cached = seq_len 107 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 108 | 109 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 110 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 111 | emb = torch.cat((freqs, freqs), dim=-1) 112 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 113 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 114 | 115 | def forward(self, x, seq_len=None): 116 | # x: [bs, num_attention_heads, seq_len, head_size] 117 | if seq_len > self.max_seq_len_cached: 118 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 119 | 120 | return ( 121 | self.cos_cached[:seq_len].to(dtype=x.dtype), 122 | self.sin_cached[:seq_len].to(dtype=x.dtype), 123 | ) 124 | 125 | 126 | # Copied from transformers.models.llama.modeling_llama.rotate_half 127 | def rotate_half(x): 128 | """Rotates half the hidden dims of the input.""" 129 | x1 = x[..., : x.shape[-1] // 2] 130 | x2 = x[..., x.shape[-1] // 2 :] 131 | return torch.cat((-x2, x1), dim=-1) 132 | 133 | 134 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 135 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 136 | """Applies Rotary Position Embedding to the query and key tensors. 137 | 138 | Args: 139 | q (`torch.Tensor`): The query tensor. 140 | k (`torch.Tensor`): The key tensor. 141 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 142 | sin (`torch.Tensor`): The sine part of the rotary embedding. 143 | position_ids (`torch.Tensor`): 144 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 145 | used to pass offsetted position ids when working with a KV-cache. 146 | unsqueeze_dim (`int`, *optional*, defaults to 1): 147 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 148 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 149 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 150 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 151 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 152 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 153 | Returns: 154 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 155 | """ 156 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 157 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 158 | q_embed = (q * cos) + (rotate_half(q) * sin) 159 | k_embed = (k * cos) + (rotate_half(k) * sin) 160 | return q_embed, k_embed 161 | 162 | 163 | class SVD_MistralMLP(nn.Module): 164 | def __init__(self, config, 165 | ratio=1 # 1 means no truncate, just keep normal MLP 166 | ): 167 | super().__init__() 168 | self.config = config 169 | self.hidden_size = config.hidden_size 170 | self.intermediate_size = config.intermediate_size 171 | self.ratio = ratio 172 | low_rank = int(self.intermediate_size * self.hidden_size * self.ratio / (self.intermediate_size + self.hidden_size)) 173 | self.gate_u_proj = nn.Linear(low_rank, self.intermediate_size, bias=False) 174 | self.gate_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 175 | 176 | self.down_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False) 177 | self.down_v_proj = nn.Linear(self.intermediate_size, low_rank, bias=False) 178 | 179 | self.up_u_proj = nn.Linear(low_rank, self.intermediate_size, bias=False) 180 | self.up_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 181 | self.act_fn = ACT2FN[config.hidden_act] 182 | 183 | def forward(self, x): 184 | up = self.up_u_proj(self.up_v_proj(x)) 185 | gate = self.gate_u_proj(self.gate_v_proj(x)) 186 | return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up)) 187 | 188 | 189 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 190 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 191 | """ 192 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 193 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 194 | """ 195 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 196 | if n_rep == 1: 197 | return hidden_states 198 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 199 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 200 | 201 | 202 | class SVD_MistralAttention(nn.Module): 203 | """ 204 | Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer 205 | and "Generating Long Sequences with Sparse Transformers". 206 | """ 207 | 208 | def __init__(self, config: MistralConfig, 209 | ratio=1): 210 | super().__init__() 211 | self.config = config 212 | self.hidden_size = config.hidden_size 213 | self.num_heads = config.num_attention_heads 214 | self.head_dim = self.hidden_size // self.num_heads 215 | self.num_key_value_heads = config.num_key_value_heads 216 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 217 | self.max_position_embeddings = config.max_position_embeddings 218 | self.rope_theta = config.rope_theta 219 | self.is_causal = True 220 | self.ratio = ratio # 1 means no truncate, just keep normal attn 221 | 222 | if (self.head_dim * self.num_heads) != self.hidden_size: 223 | raise ValueError( 224 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 225 | f" and `num_heads`: {self.num_heads})." 226 | ) 227 | low_rank = int(self.hidden_size * self.ratio/2) 228 | self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) 229 | self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 230 | self.k_u_proj = nn.Linear(low_rank, self.num_key_value_heads * self.head_dim, bias=False) 231 | self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 232 | self.v_u_proj = nn.Linear(low_rank, self.num_key_value_heads * self.head_dim, bias=False) 233 | self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) 234 | self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False) 235 | self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False) 236 | 237 | self.rotary_emb = MistralRotaryEmbedding( 238 | self.head_dim, 239 | max_position_embeddings=self.max_position_embeddings, 240 | base=self.rope_theta, 241 | ) 242 | 243 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 244 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 245 | 246 | def forward( 247 | self, 248 | hidden_states: torch.Tensor, 249 | attention_mask: Optional[torch.Tensor] = None, 250 | position_ids: Optional[torch.LongTensor] = None, 251 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 252 | output_attentions: bool = False, 253 | use_cache: bool = False, 254 | **kwargs, 255 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 256 | if "padding_mask" in kwargs: 257 | warnings.warn( 258 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 259 | ) 260 | bsz, q_len, _ = hidden_states.size() 261 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)) 262 | key_states = self.k_u_proj(self.k_v_proj(hidden_states)) 263 | value_states = self.v_u_proj(self.v_v_proj(hidden_states)) 264 | 265 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 266 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 267 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 268 | 269 | kv_seq_len = key_states.shape[-2] 270 | if past_key_value is not None: 271 | kv_seq_len += past_key_value[0].shape[-2] 272 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 273 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 274 | 275 | if past_key_value is not None: 276 | # reuse k, v, self_attention 277 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 278 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 279 | 280 | past_key_value = (key_states, value_states) if use_cache else None 281 | 282 | # repeat k/v heads if n_kv_heads < n_heads 283 | key_states = repeat_kv(key_states, self.num_key_value_groups) 284 | value_states = repeat_kv(value_states, self.num_key_value_groups) 285 | 286 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 287 | 288 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 289 | raise ValueError( 290 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 291 | f" {attn_weights.size()}" 292 | ) 293 | 294 | if attention_mask is not None: 295 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 296 | raise ValueError( 297 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 298 | ) 299 | 300 | attn_weights = attn_weights + attention_mask 301 | 302 | # upcast attention to fp32 303 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 304 | attn_output = torch.matmul(attn_weights, value_states) 305 | 306 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 307 | raise ValueError( 308 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 309 | f" {attn_output.size()}" 310 | ) 311 | 312 | attn_output = attn_output.transpose(1, 2).contiguous() 313 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 314 | 315 | attn_output = self.o_u_proj(self.o_v_proj(attn_output)) 316 | 317 | if not output_attentions: 318 | attn_weights = None 319 | 320 | return attn_output, attn_weights, past_key_value 321 | -------------------------------------------------------------------------------- /component/svd_opt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch OPT model.""" 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | import torch.utils.checkpoint 21 | from torch import nn 22 | 23 | from transformers.activations import ACT2FN 24 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask 25 | 26 | from transformers.models.opt.configuration_opt import OPTConfig 27 | 28 | 29 | from transformers.utils import logging 30 | 31 | _CHECKPOINT_FOR_DOC = "facebook/opt-350m" 32 | _CONFIG_FOR_DOC = "OPTConfig" 33 | 34 | # Base model docstring 35 | _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] 36 | 37 | # SequenceClassification docstring 38 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" 39 | _SEQ_CLASS_EXPECTED_LOSS = 1.71 40 | _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" 41 | 42 | OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 43 | "facebook/opt-125m", 44 | "facebook/opt-350m", 45 | "facebook/opt-1.3b", 46 | "facebook/opt-2.7b", 47 | "facebook/opt-6.7b", 48 | "facebook/opt-13b", 49 | "facebook/opt-30b", 50 | # See all OPT models at https://huggingface.co/models?filter=opt 51 | ] 52 | 53 | 54 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 55 | def _get_unpad_data(attention_mask): 56 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 57 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 58 | max_seqlen_in_batch = seqlens_in_batch.max().item() 59 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 60 | return ( 61 | indices, 62 | cu_seqlens, 63 | max_seqlen_in_batch, 64 | ) 65 | 66 | 67 | class OPTLearnedPositionalEmbedding(nn.Embedding): 68 | """ 69 | This module learns positional embeddings up to a fixed maximum size. 70 | """ 71 | 72 | def __init__(self, num_embeddings: int, embedding_dim: int): 73 | # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 74 | # and adjust num_embeddings appropriately. Other models don't have this hack 75 | self.offset = 2 76 | super().__init__(num_embeddings + self.offset, embedding_dim) 77 | 78 | def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): 79 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 80 | attention_mask = attention_mask.long() 81 | 82 | # create positions depending on attention_mask 83 | positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 84 | 85 | # cut positions if `past_key_values_length` is > 0 86 | positions = positions[:, past_key_values_length:] 87 | 88 | return super().forward(positions + self.offset) 89 | 90 | 91 | class SVDOPTAttention(nn.Module): 92 | """Multi-headed attention from 'Attention Is All You Need' paper""" 93 | 94 | def __init__( 95 | self, 96 | config: OPTConfig, 97 | is_decoder: bool = False, 98 | ratio=1, 99 | **kwargs, 100 | ): 101 | super().__init__() 102 | self.config = config 103 | 104 | def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs): 105 | """ 106 | If a the deprecated argument `fn_arg_name` is passed, raise a deprecation 107 | warning and return that value, otherwise take the equivalent config.config_arg_name 108 | """ 109 | val = None 110 | if fn_arg_name in kwargs: 111 | logging.warning( 112 | "Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38." 113 | " Please set it in the config instead" 114 | ) 115 | val = kwargs.pop(fn_arg_name) 116 | else: 117 | val = getattr(config, config_arg_name) 118 | return val 119 | 120 | self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs) 121 | self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs) 122 | self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs) 123 | self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs) 124 | 125 | self.head_dim = self.embed_dim // self.num_heads 126 | self.is_causal = True 127 | self.ratio = ratio 128 | 129 | if (self.head_dim * self.num_heads) != self.embed_dim: 130 | raise ValueError( 131 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 132 | f" and `num_heads`: {self.num_heads})." 133 | ) 134 | self.scaling = self.head_dim**-0.5 135 | self.is_decoder = is_decoder 136 | 137 | if self.ratio != 1: 138 | low_rank = int(self.embed_dim * self.ratio/2) 139 | self.q_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias) 140 | self.q_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False) 141 | else: 142 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) 143 | if self.ratio != 1: 144 | low_rank = int(self.embed_dim * self.ratio/2) 145 | self.k_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias) 146 | self.k_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False) 147 | else: 148 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) 149 | if self.ratio != 1: 150 | low_rank = int(self.embed_dim * self.ratio/2) 151 | self.v_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias) 152 | self.v_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False) 153 | else: 154 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) 155 | if self.ratio != 1: 156 | low_rank = int(self.embed_dim * self.ratio/2) 157 | self.out_u_proj = nn.Linear(low_rank, self.embed_dim, bias=self.enable_bias) 158 | self.out_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False) 159 | else: 160 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) 161 | 162 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 163 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 164 | 165 | def forward( 166 | self, 167 | hidden_states: torch.Tensor, 168 | key_value_states: Optional[torch.Tensor] = None, 169 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 170 | attention_mask: Optional[torch.Tensor] = None, 171 | layer_head_mask: Optional[torch.Tensor] = None, 172 | output_attentions: bool = False, 173 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 174 | """Input shape: Batch x Time x Channel""" 175 | 176 | # if key_value_states are provided this layer is used as a cross-attention layer 177 | # for the decoder 178 | is_cross_attention = key_value_states is not None 179 | 180 | bsz, tgt_len, _ = hidden_states.size() 181 | 182 | # get query proj 183 | if self.ratio != 1: 184 | query_states = self.q_u_proj(self.q_v_proj(hidden_states)) * self.scaling 185 | else: 186 | query_states = self.q_proj(hidden_states) * self.scaling 187 | # get key, value proj 188 | if is_cross_attention and past_key_value is not None: 189 | # reuse k,v, cross_attentions 190 | key_states = past_key_value[0] 191 | value_states = past_key_value[1] 192 | elif is_cross_attention: 193 | # cross_attentions 194 | if self.ratio != 1: 195 | key_states = self._shape(self.k_u_proj(self.k_v_proj(key_value_states)), -1, bsz) 196 | else: 197 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 198 | if self.ratio != 1: 199 | value_states = self._shape(self.v_u_proj(self.v_v_proj(key_value_states)), -1, bsz) 200 | else: 201 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 202 | elif past_key_value is not None: 203 | # reuse k, v, self_attention 204 | if self.ratio != 1: 205 | key_states = self._shape(self.k_u_proj(self.k_v_proj(hidden_states)), -1, bsz) 206 | else: 207 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 208 | if self.ratio != 1: 209 | value_states = self._shape(self.v_u_proj(self.v_v_proj(hidden_states)), -1, bsz) 210 | else: 211 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 212 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 213 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 214 | else: 215 | # self_attention 216 | if self.ratio != 1: 217 | key_states = self._shape(self.k_u_proj(self.k_v_proj(hidden_states)), -1, bsz) 218 | else: 219 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 220 | if self.ratio != 1: 221 | value_states = self._shape(self.v_u_proj(self.v_v_proj(hidden_states)), -1, bsz) 222 | else: 223 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 224 | 225 | if self.is_decoder: 226 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 227 | # Further calls to cross_attention layer can then reuse all cross-attention 228 | # key/value_states (first "if" case) 229 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 230 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 231 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 232 | # if encoder bi-directional self-attention `past_key_value` is always `None` 233 | past_key_value = (key_states, value_states) 234 | 235 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 236 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 237 | key_states = key_states.view(*proj_shape) 238 | value_states = value_states.view(*proj_shape) 239 | 240 | src_len = key_states.size(1) 241 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 242 | 243 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 244 | raise ValueError( 245 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 246 | f" {attn_weights.size()}" 247 | ) 248 | 249 | if attention_mask is not None: 250 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 251 | raise ValueError( 252 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 253 | ) 254 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 255 | attn_weights = torch.max( 256 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) 257 | ) 258 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 259 | 260 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 261 | if attn_weights.dtype == torch.float16: 262 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) 263 | else: 264 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 265 | 266 | if layer_head_mask is not None: 267 | if layer_head_mask.size() != (self.num_heads,): 268 | raise ValueError( 269 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 270 | f" {layer_head_mask.size()}" 271 | ) 272 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 273 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 274 | 275 | if output_attentions: 276 | # this operation is a bit awkward, but it's required to 277 | # make sure that attn_weights keeps its gradient. 278 | # In order to do so, attn_weights have to be reshaped 279 | # twice and have to be reused in the following 280 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 281 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 282 | else: 283 | attn_weights_reshaped = None 284 | 285 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 286 | 287 | attn_output = torch.bmm(attn_probs, value_states) 288 | 289 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 290 | raise ValueError( 291 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 292 | f" {attn_output.size()}" 293 | ) 294 | 295 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 296 | attn_output = attn_output.transpose(1, 2) 297 | 298 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 299 | # partitioned aross GPUs when using tensor-parallelism. 300 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 301 | 302 | if self.ratio != 1: 303 | attn_output = self.out_u_proj(self.out_v_proj(attn_output)) 304 | else: 305 | attn_output = self.out_proj(attn_output) 306 | 307 | return attn_output, attn_weights_reshaped, past_key_value 308 | 309 | 310 | 311 | class SVDOPTDecoderLayer(nn.Module): 312 | def __init__(self, config: OPTConfig, ratio = 1): 313 | super().__init__() 314 | self.embed_dim = config.hidden_size 315 | 316 | self.self_attn = SVDOPTAttention(config=config, ratio=ratio, is_decoder=True) 317 | 318 | self.do_layer_norm_before = config.do_layer_norm_before 319 | self.dropout = config.dropout 320 | self.activation_fn = ACT2FN[config.activation_function] 321 | 322 | self.self_attn_layer_norm = nn.LayerNorm( 323 | self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine 324 | ) 325 | self.ratio = ratio 326 | if self.ratio != 1: 327 | low_rank = int(config.ffn_dim * self.embed_dim * self.ratio / (config.ffn_dim + self.embed_dim)) 328 | self.fc1_u_proj = nn.Linear(low_rank, config.ffn_dim, bias=config.enable_bias) 329 | self.fc1_v_proj = nn.Linear(self.embed_dim, low_rank, bias=False) 330 | else: 331 | self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) 332 | if self.ratio != 1: 333 | low_rank = int(config.ffn_dim * self.embed_dim * self.ratio / (config.ffn_dim + self.embed_dim)) 334 | self.fc2_u_proj = nn.Linear(low_rank, self.embed_dim, bias=config.enable_bias) 335 | self.fc2_v_proj = nn.Linear(config.ffn_dim, low_rank, bias=False) 336 | else: 337 | self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) 338 | self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) 339 | 340 | def forward( 341 | self, 342 | hidden_states: torch.Tensor, 343 | attention_mask: Optional[torch.Tensor] = None, 344 | layer_head_mask: Optional[torch.Tensor] = None, 345 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 346 | output_attentions: Optional[bool] = False, 347 | use_cache: Optional[bool] = False, 348 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 349 | """ 350 | Args: 351 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 352 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 353 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 354 | layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size 355 | `(encoder_attention_heads,)`. 356 | output_attentions (`bool`, *optional*): 357 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 358 | returned tensors for more detail. 359 | use_cache (`bool`, *optional*): 360 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 361 | (see `past_key_values`). 362 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 363 | """ 364 | 365 | residual = hidden_states 366 | 367 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention 368 | if self.do_layer_norm_before: 369 | hidden_states = self.self_attn_layer_norm(hidden_states) 370 | 371 | # Self Attention 372 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 373 | hidden_states=hidden_states, 374 | past_key_value=past_key_value, 375 | attention_mask=attention_mask, 376 | layer_head_mask=layer_head_mask, 377 | output_attentions=output_attentions, 378 | ) 379 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 380 | hidden_states = residual + hidden_states 381 | 382 | # 350m applies layer norm AFTER attention 383 | if not self.do_layer_norm_before: 384 | hidden_states = self.self_attn_layer_norm(hidden_states) 385 | 386 | # Fully Connected 387 | hidden_states_shape = hidden_states.shape 388 | hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) 389 | residual = hidden_states 390 | 391 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention 392 | if self.do_layer_norm_before: 393 | hidden_states = self.final_layer_norm(hidden_states) 394 | 395 | if self.ratio != 1: 396 | hidden_states = self.fc1_u_proj(self.fc1_v_proj(hidden_states)) 397 | else: 398 | hidden_states = self.fc1(hidden_states) 399 | hidden_states = self.activation_fn(hidden_states) 400 | 401 | if self.ratio != 1: 402 | hidden_states = self.fc2_u_proj(self.fc2_v_proj(hidden_states)) 403 | else: 404 | hidden_states = self.fc2(hidden_states) 405 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 406 | 407 | hidden_states = (residual + hidden_states).view(hidden_states_shape) 408 | 409 | # 350m applies layer norm AFTER attention 410 | if not self.do_layer_norm_before: 411 | hidden_states = self.final_layer_norm(hidden_states) 412 | 413 | outputs = (hidden_states,) 414 | 415 | if output_attentions: 416 | outputs += (self_attn_weights,) 417 | 418 | if use_cache: 419 | outputs += (present_key_value,) 420 | 421 | return outputs 422 | 423 | -------------------------------------------------------------------------------- /compress_llama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # example of compressing LLaMA-7B with SVDLLM 4 | FINE_TUNE_PATH="." 5 | # run data whitening with 20% compression ratio 6 | python SVDLLM.py --model jeffwan/llama-7b-hf --step 1 --ratio 0.2 --whitening_nsamples 256 --dataset wikitext2 --seed 3 --model_seq_len 2048 --save_path . 7 | ## you can also run the following command for low-resource gpu (ex. llama 7b will only need 15G gpu memory to compress) or to compress large-scale llm (ex. llama 65b) 8 | # python SVDLLM.py --model jeffwan/llama-7b-hf --step 1 --ratio 0.2 --whitening_nsamples 256 --dataset wikitext2 --model_seq_len 2048 --save_path ./ --run_low_resource 9 | python SVDLLM.py --step 4 --model_path jeffwan_llama_7b_hf_whitening_only_0.8.pt 10 | # finetune the compressed model with lora 11 | python utils/LoRA.py --prune_model --data_path yahma/alpaca-cleaned --output_dir $FINE_TUNE_PATH/first_half --lora_target_modules q_u_proj,k_u_proj,v_u_proj,o_u_proj,gate_u_proj,down_u_proj,up_u_proj --lora_r 8 --num_epochs 3 --learning_rate 1e-4 --batch_size 64 12 | python SVDLLM.py --model_path jeffwan_llama_7b_hf_whitening_only_0.8.pt --lora $FINE_TUNE_PATH/first_half /first_half --step 4 13 | python utils/LoRA.py --prune_model $FINE_TUNE_PATH/first_half/merge.pt --data_path yahma/alpaca-cleaned --output_dir $FINE_TUNE_PATH/second_half --lora_target_modules q_v_proj,k_v_proj,v_v_proj,o_v_proj,gate_v_proj,down_v_proj,up_v_proj --lora_r 8 --num_epochs 3 --learning_rate 1e-4 --batch_size 64 14 | python SVDLLM.py --model_path jeffwan_llama_7b_hf_whitening_only_0.8.pt --lora $FINE_TUNE_PATH/first_half /first_half --step 4 15 | python SVDLLM.py --model_path $FINE_TUNE_PATH/first_half/merge.pt --lora $FINE_TUNE_PATH/second_half --step 4 -------------------------------------------------------------------------------- /evaluater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import time 5 | import itertools 6 | from utils.data_utils import get_test_data 7 | import os 8 | import sys 9 | 10 | current_path = os.path.dirname(os.path.abspath(__file__)) 11 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(current_path) 13 | 14 | @torch.no_grad() 15 | def ppl_eval(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], model_seq_len=2048, batch_size=32, device="cuda"): 16 | model.to(device) 17 | model.eval() 18 | ppls = {} 19 | for dataset in datasets: 20 | test_loader = get_test_data(dataset, tokenizer, seq_len=model_seq_len, batch_size = batch_size) 21 | nlls = [] 22 | for batch in tqdm(test_loader): 23 | batch = batch.to(device) 24 | output = model(batch, use_cache=False) 25 | lm_logits = output.logits 26 | if torch.isfinite(lm_logits).all(): 27 | shift_logits = lm_logits[:, :-1, :].contiguous() 28 | shift_labels = batch[:, 1:].contiguous() 29 | 30 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") 31 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1)) 32 | nlls.append(loss) 33 | ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) 34 | ppls[dataset] = ppl 35 | print("PPL after pruning: {}".format(ppls)) 36 | print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024)) 37 | 38 | # only call this function when for 65b or more model 39 | @torch.no_grad() 40 | def ppl_eval_large(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], seq_len=2048, batch_size=32, device="cuda"): 41 | import torch.nn as nn 42 | class LlamaRMSNorm(nn.Module): 43 | def __init__(self, hidden_size=model.config.hidden_size, eps=model.config.rms_norm_eps): 44 | """ 45 | LlamaRMSNorm is equivalent to T5LayerNorm 46 | """ 47 | super().__init__() 48 | self.weight = nn.Parameter(torch.ones(hidden_size)) 49 | self.variance_epsilon = eps 50 | 51 | def forward(self, hidden_states): 52 | input_dtype = hidden_states.dtype 53 | hidden_states = hidden_states.to(torch.float32) 54 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 55 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 56 | return self.weight * hidden_states.to(input_dtype) 57 | norm = LlamaRMSNorm().half().cuda() 58 | lm_head = model.lm_head.cuda() 59 | model.eval() 60 | ppls = {} 61 | layers = model.model.layers 62 | for dataset in datasets: 63 | test_loader = get_test_data(dataset, tokenizer, seq_len=seq_len, batch_size = batch_size) 64 | nlls = [] 65 | for batch in tqdm(test_loader): 66 | model.model.embed_tokens = model.model.embed_tokens.cuda() 67 | model.model.norm = model.model.norm.cuda() 68 | layers[0] = layers[0].cuda() 69 | 70 | dtype = next(iter(model.parameters())).dtype 71 | inps = torch.zeros( 72 | (batch.shape[0], model.seqlen, model.config.hidden_size), dtype=dtype, device="cuda" 73 | ) 74 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 75 | class Catcher(nn.Module): 76 | def __init__(self, module): 77 | super().__init__() 78 | self.module = module 79 | def forward(self, inp, **kwargs): 80 | inps[cache['i']] = inp 81 | cache['i'] += 1 82 | if cache['attention_mask'] is None: 83 | cache['attention_mask'] = kwargs['attention_mask'] 84 | cache['position_ids'] = kwargs['position_ids'] 85 | else: 86 | cache['attention_mask'] = torch.cat((cache['attention_mask'], kwargs['attention_mask']), dim=0) 87 | cache['position_ids'] = torch.cat((cache['position_ids'], kwargs['position_ids']), dim=0) 88 | raise ValueError 89 | layers[0] = Catcher(layers[0]) 90 | for j in range(batch.shape[0]): 91 | try: 92 | model(batch[j].unsqueeze(0).cuda()) 93 | except ValueError: 94 | pass 95 | layers[0] = layers[0].module 96 | layers[0] = layers[0].cpu() 97 | model.model.embed_tokens = model.model.embed_tokens.cpu() 98 | model.model.norm = model.model.norm.cpu() 99 | torch.cuda.empty_cache() 100 | attention_masks = cache['attention_mask'] 101 | position_ids = cache['position_ids'] 102 | for i in range(len(layers)): 103 | layer = layers[i].cuda() 104 | outs = layer(inps, attention_mask=attention_masks, position_ids=position_ids)[0] 105 | layers[i] = layer.cpu() 106 | inps = outs 107 | torch.cuda.empty_cache() 108 | hidden_states = norm(outs) 109 | lm_logits = lm_head(hidden_states) 110 | if torch.isfinite(lm_logits).all(): 111 | shift_logits = lm_logits[:, :-1, :].contiguous() 112 | shift_labels = batch[:, 1:].contiguous().cuda() 113 | 114 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") 115 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1)) 116 | nlls.append(loss) 117 | else: 118 | print("warning: nan or inf in lm_logits") 119 | ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) 120 | ppls[dataset] = ppl 121 | print("PPL after pruning: {}".format(ppls)) 122 | print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024)) 123 | 124 | @torch.no_grad() 125 | def eff_eval(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"): 126 | model.eval() 127 | throughput = 0 128 | token_num = 0 129 | end_memory = 0 130 | num_batches_to_fetch = 10 131 | test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size = batch_size) 132 | weight_memory = torch.cuda.memory_allocated() 133 | for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)): 134 | batch = batch_data.to(device) 135 | token_num += batch.shape[0] * generated_len 136 | torch.cuda.empty_cache() 137 | start_memory = torch.cuda.memory_allocated() 138 | torch.cuda.reset_peak_memory_stats(0) 139 | torch.cuda.synchronize() 140 | start_time = time.time() 141 | generation_output = model.generate( 142 | input_ids=batch, 143 | pad_token_id=tokenizer.eos_token_id, 144 | do_sample=True, 145 | use_cache=True, 146 | top_k=50, 147 | max_length=original_len+generated_len, 148 | top_p=0.95, 149 | temperature=1, 150 | ) 151 | torch.cuda.synchronize() 152 | end_time = time.time() 153 | end_memory = max(torch.cuda.max_memory_allocated(0), end_memory) 154 | if torch.isfinite(generation_output[0]).all(): # check if the generation is successful since fp16 may cause nan 155 | throughput += end_time - start_time 156 | print("time: {}".format(end_time - start_time)) 157 | print("Total Memory: {} GB".format(end_memory/(1024 ** 3))) 158 | print("Weight Memory: {} GB".format(weight_memory/(1024 ** 3))) 159 | print("Activation Memory: {} GB".format((end_memory - start_memory)/(1024 ** 3))) 160 | print("Throughput: {} tokens/sec".format(token_num / throughput)) 161 | 162 | -------------------------------------------------------------------------------- /figures/framework_v1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/SVD-LLM/1c1009a99bea54c22c27a4dd5d9fba7116dc456b/figures/framework_v1.jpg -------------------------------------------------------------------------------- /figures/framework_v2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/SVD-LLM/1c1009a99bea54c22c27a4dd5d9fba7116dc456b/figures/framework_v2.jpg -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/SVD-LLM/1c1009a99bea54c22c27a4dd5d9fba7116dc456b/figures/logo.png -------------------------------------------------------------------------------- /gptq/gptq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | import sys 9 | import os 10 | current_path = os.path.dirname(os.path.abspath(__file__)) 11 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(current_path) 13 | from quant import * 14 | 15 | 16 | DEBUG = False 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = False 19 | torch.backends.cudnn.allow_tf32 = False 20 | 21 | 22 | class GPTQ: 23 | 24 | def __init__(self, layer): 25 | self.layer = layer 26 | self.dev = self.layer.weight.device 27 | W = layer.weight.data.clone() 28 | if isinstance(self.layer, nn.Conv2d): 29 | W = W.flatten(1) 30 | if isinstance(self.layer, transformers.Conv1D): 31 | W = W.t() 32 | self.rows = W.shape[0] 33 | self.columns = W.shape[1] 34 | self.H = torch.zeros((self.columns, self.columns), device=self.dev) 35 | self.nsamples = 0 36 | 37 | def add_batch(self, inp, out): 38 | if DEBUG: 39 | self.inp1 = inp 40 | self.out1 = out 41 | if len(inp.shape) == 2: 42 | inp = inp.unsqueeze(0) 43 | tmp = inp.shape[0] 44 | if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): 45 | if len(inp.shape) == 3: 46 | inp = inp.reshape((-1, inp.shape[-1])) 47 | inp = inp.t() 48 | if isinstance(self.layer, nn.Conv2d): 49 | unfold = nn.Unfold( 50 | self.layer.kernel_size, 51 | dilation=self.layer.dilation, 52 | padding=self.layer.padding, 53 | stride=self.layer.stride 54 | ) 55 | inp = unfold(inp) 56 | inp = inp.permute([1, 0, 2]) 57 | inp = inp.flatten(1) 58 | self.H *= self.nsamples / (self.nsamples + tmp) 59 | self.nsamples += tmp 60 | # inp = inp.float() 61 | inp = math.sqrt(2 / self.nsamples) * inp.float() 62 | # self.H += 2 / self.nsamples * inp.matmul(inp.t()) 63 | self.H += inp.matmul(inp.t()) 64 | 65 | def fasterquant( 66 | self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False 67 | ): 68 | W = self.layer.weight.data.clone() 69 | if isinstance(self.layer, nn.Conv2d): 70 | W = W.flatten(1) 71 | if isinstance(self.layer, transformers.Conv1D): 72 | W = W.t() 73 | W = W.float() 74 | 75 | tick = time.time() 76 | 77 | if not self.quantizer.ready(): 78 | self.quantizer.find_params(W, weight=True) 79 | 80 | H = self.H 81 | del self.H 82 | dead = torch.diag(H) == 0 83 | H[dead, dead] = 1 84 | W[:, dead] = 0 85 | 86 | if static_groups: 87 | import copy 88 | groups = [] 89 | for i in range(0, self.columns, groupsize): 90 | quantizer = copy.deepcopy(self.quantizer) 91 | quantizer.find_params(W[:, i:(i + groupsize)], weight=True) 92 | groups.append(quantizer) 93 | 94 | if actorder: 95 | perm = torch.argsort(torch.diag(H), descending=True) 96 | W = W[:, perm] 97 | H = H[perm][:, perm] 98 | invperm = torch.argsort(perm) 99 | 100 | Losses = torch.zeros_like(W) 101 | Q = torch.zeros_like(W) 102 | 103 | damp = percdamp * torch.mean(torch.diag(H)) 104 | diag = torch.arange(self.columns, device=self.dev) 105 | H[diag, diag] += damp 106 | H = torch.linalg.cholesky(H) 107 | H = torch.cholesky_inverse(H) 108 | H = torch.linalg.cholesky(H, upper=True) 109 | Hinv = H 110 | 111 | for i1 in range(0, self.columns, blocksize): 112 | i2 = min(i1 + blocksize, self.columns) 113 | count = i2 - i1 114 | 115 | W1 = W[:, i1:i2].clone() 116 | Q1 = torch.zeros_like(W1) 117 | Err1 = torch.zeros_like(W1) 118 | Losses1 = torch.zeros_like(W1) 119 | Hinv1 = Hinv[i1:i2, i1:i2] 120 | 121 | for i in range(count): 122 | w = W1[:, i] 123 | d = Hinv1[i, i] 124 | 125 | if groupsize != -1: 126 | if not static_groups: 127 | if (i1 + i) % groupsize == 0: 128 | self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) 129 | else: 130 | idx = i1 + i 131 | if actorder: 132 | idx = perm[idx] 133 | self.quantizer = groups[idx // groupsize] 134 | 135 | q = quantize( 136 | w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq 137 | ).flatten() 138 | Q1[:, i] = q 139 | Losses1[:, i] = (w - q) ** 2 / d ** 2 140 | 141 | err1 = (w - q) / d 142 | W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 143 | Err1[:, i] = err1 144 | 145 | Q[:, i1:i2] = Q1 146 | Losses[:, i1:i2] = Losses1 / 2 147 | 148 | W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) 149 | 150 | if DEBUG: 151 | self.layer.weight.data[:, :i2] = Q[:, :i2] 152 | self.layer.weight.data[:, i2:] = W[:, i2:] 153 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 154 | print(torch.sum(Losses)) 155 | 156 | torch.cuda.synchronize() 157 | print('time %.2f' % (time.time() - tick)) 158 | print('error', torch.sum(Losses).item()) 159 | 160 | if actorder: 161 | Q = Q[:, invperm] 162 | 163 | if isinstance(self.layer, transformers.Conv1D): 164 | Q = Q.t() 165 | self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) 166 | if DEBUG: 167 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) 168 | 169 | def free(self): 170 | if DEBUG: 171 | self.inp1 = None 172 | self.out1 = None 173 | self.H = None 174 | self.Losses = None 175 | self.Trace = None 176 | torch.cuda.empty_cache() 177 | -------------------------------------------------------------------------------- /gptq/quant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import sys 5 | import os 6 | current_path = os.path.dirname(os.path.abspath(__file__)) 7 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(current_path) 9 | 10 | 11 | def quantize(x, scale, zero, maxq): 12 | if maxq < 0: 13 | return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero 14 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) 15 | return scale * (q - zero) 16 | 17 | class Quantizer(nn.Module): 18 | 19 | def __init__(self, shape=1): 20 | super(Quantizer, self).__init__() 21 | self.register_buffer('maxq', torch.tensor(0)) 22 | self.register_buffer('scale', torch.zeros(shape)) 23 | self.register_buffer('zero', torch.zeros(shape)) 24 | 25 | def configure( 26 | self, 27 | bits, perchannel=False, sym=True, 28 | mse=False, norm=2.4, grid=100, maxshrink=.8, 29 | trits=False 30 | ): 31 | self.maxq = torch.tensor(2 ** bits - 1) 32 | self.perchannel = perchannel 33 | self.sym = sym 34 | self.mse = mse 35 | self.norm = norm 36 | self.grid = grid 37 | self.maxshrink = maxshrink 38 | if trits: 39 | self.maxq = torch.tensor(-1) 40 | 41 | def find_params(self, x, weight=False): 42 | dev = x.device 43 | self.maxq = self.maxq.to(dev) 44 | 45 | shape = x.shape 46 | if self.perchannel: 47 | if weight: 48 | x = x.flatten(1) 49 | else: 50 | if len(shape) == 4: 51 | x = x.permute([1, 0, 2, 3]) 52 | x = x.flatten(1) 53 | if len(shape) == 3: 54 | x = x.reshape((-1, shape[-1])).t() 55 | if len(shape) == 2: 56 | x = x.t() 57 | else: 58 | x = x.flatten().unsqueeze(0) 59 | 60 | tmp = torch.zeros(x.shape[0], device=dev) 61 | xmin = torch.minimum(x.min(1)[0], tmp) 62 | xmax = torch.maximum(x.max(1)[0], tmp) 63 | 64 | if self.sym: 65 | xmax = torch.maximum(torch.abs(xmin), xmax) 66 | tmp = xmin < 0 67 | if torch.any(tmp): 68 | xmin[tmp] = -xmax[tmp] 69 | tmp = (xmin == 0) & (xmax == 0) 70 | xmin[tmp] = -1 71 | xmax[tmp] = +1 72 | 73 | if self.maxq < 0: 74 | self.scale = xmax 75 | self.zero = xmin 76 | else: 77 | self.scale = (xmax - xmin) / self.maxq 78 | if self.sym: 79 | self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) 80 | else: 81 | self.zero = torch.round(-xmin / self.scale) 82 | 83 | if self.mse: 84 | best = torch.full([x.shape[0]], float('inf'), device=dev) 85 | for i in range(int(self.maxshrink * self.grid)): 86 | p = 1 - i / self.grid 87 | xmin1 = p * xmin 88 | xmax1 = p * xmax 89 | scale1 = (xmax1 - xmin1) / self.maxq 90 | zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero 91 | q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) 92 | q -= x 93 | q.abs_() 94 | q.pow_(self.norm) 95 | err = torch.sum(q, 1) 96 | tmp = err < best 97 | if torch.any(tmp): 98 | best[tmp] = err[tmp] 99 | self.scale[tmp] = scale1[tmp] 100 | self.zero[tmp] = zero1[tmp] 101 | if not self.perchannel: 102 | if weight: 103 | tmp = shape[0] 104 | else: 105 | tmp = shape[1] if len(shape) != 3 else shape[2] 106 | self.scale = self.scale.repeat(tmp) 107 | self.zero = self.zero.repeat(tmp) 108 | 109 | if weight: 110 | shape = [-1] + [1] * (len(shape) - 1) 111 | self.scale = self.scale.reshape(shape) 112 | self.zero = self.zero.reshape(shape) 113 | return 114 | if len(shape) == 4: 115 | self.scale = self.scale.reshape((1, -1, 1, 1)) 116 | self.zero = self.zero.reshape((1, -1, 1, 1)) 117 | if len(shape) == 3: 118 | self.scale = self.scale.reshape((1, 1, -1)) 119 | self.zero = self.zero.reshape((1, 1, -1)) 120 | if len(shape) == 2: 121 | self.scale = self.scale.unsqueeze(0) 122 | self.zero = self.zero.unsqueeze(0) 123 | 124 | def quantize(self, x): 125 | if self.ready(): 126 | return quantize(x, self.scale, self.zero, self.maxq) 127 | return x 128 | 129 | def enabled(self): 130 | return self.maxq > 0 131 | 132 | def ready(self): 133 | return torch.all(self.scale != 0) 134 | 135 | 136 | try: 137 | import quant_cuda 138 | except: 139 | print('CUDA extension not installed.') 140 | 141 | # Assumes layer is perfectly divisible into 1024 * 1024 blocks 142 | class Quant3Linear(nn.Module): 143 | 144 | def __init__(self, infeatures, outfeatures, faster=False): 145 | super().__init__() 146 | self.register_buffer('zeros', torch.zeros((outfeatures, 1))) 147 | self.register_buffer('scales', torch.zeros((outfeatures, 1))) 148 | self.register_buffer('bias', torch.zeros(outfeatures)) 149 | self.register_buffer( 150 | 'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int) 151 | ) 152 | self.faster = faster 153 | 154 | def pack(self, linear, scales, zeros): 155 | self.zeros = zeros * scales 156 | self.scales = scales.clone() 157 | if linear.bias is not None: 158 | self.bias = linear.bias.clone() 159 | 160 | intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) 161 | intweight = intweight.t().contiguous() 162 | intweight = intweight.numpy().astype(np.uint32) 163 | qweight = np.zeros( 164 | (intweight.shape[0] // 32 * 3, intweight.shape[1]), dtype=np.uint32 165 | ) 166 | i = 0 167 | row = 0 168 | while row < qweight.shape[0]: 169 | for j in range(i, i + 10): 170 | qweight[row] |= intweight[j] << (3 * (j - i)) 171 | i += 10 172 | qweight[row] |= intweight[i] << 30 173 | row += 1 174 | qweight[row] |= (intweight[i] >> 2) & 1 175 | i += 1 176 | for j in range(i, i + 10): 177 | qweight[row] |= intweight[j] << (3 * (j - i) + 1) 178 | i += 10 179 | qweight[row] |= intweight[i] << 31 180 | row += 1 181 | qweight[row] |= (intweight[i] >> 1) & 0x3 182 | i += 1 183 | for j in range(i, i + 10): 184 | qweight[row] |= intweight[j] << (3 * (j - i) + 2) 185 | i += 10 186 | row += 1 187 | 188 | qweight = qweight.astype(np.int32) 189 | self.qweight = torch.from_numpy(qweight) 190 | 191 | def forward(self, x): 192 | if x.shape[-1] == x.numel(): 193 | outshape = list(x.shape) 194 | y = self.bias.clone() 195 | outshape[-1] = self.bias.numel() 196 | dtype = x.dtype 197 | if self.faster: 198 | x = x.half() 199 | quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros) 200 | else: 201 | x = x.float() 202 | quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros) 203 | y = y.to(dtype) 204 | return y.reshape(outshape) 205 | raise ValueError('Only supports a single token currently.') 206 | 207 | def make_quant3(module, names, name='', faster=False): 208 | if isinstance(module, Quant3Linear): 209 | return 210 | for attr in dir(module): 211 | tmp = getattr(module, attr) 212 | name1 = name + '.' + attr if name != '' else attr 213 | if name1 in names: 214 | setattr( 215 | module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster) 216 | ) 217 | for name1, child in module.named_children(): 218 | make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster) 219 | -------------------------------------------------------------------------------- /quant_llama.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from gptq.gptq import * 7 | from utils.model_utils import * 8 | from gptq.quant import * 9 | from evaluater import ppl_eval 10 | 11 | current_path = os.path.dirname(os.path.abspath(__file__)) 12 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | sys.path.append(current_path) 14 | 15 | @torch.no_grad() 16 | def llama_sequential(model, dataloader, dev): 17 | print('Starting ...') 18 | 19 | use_cache = model.config.use_cache 20 | model.config.use_cache = False 21 | layers = model.model.layers 22 | 23 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 24 | model.model.norm = model.model.norm.to(dev) 25 | layers[0] = layers[0].to(dev) 26 | 27 | dtype = next(iter(model.parameters())).dtype 28 | inps = torch.zeros( 29 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 30 | ) 31 | cache = {'i': 0, 'attention_mask': None} 32 | 33 | class Catcher(nn.Module): 34 | def __init__(self, module): 35 | super().__init__() 36 | self.module = module 37 | def forward(self, inp, **kwargs): 38 | inps[cache['i']] = inp 39 | cache['i'] += 1 40 | cache['attention_mask'] = kwargs['attention_mask'] 41 | cache['position_ids'] = kwargs['position_ids'] 42 | raise ValueError 43 | layers[0] = Catcher(layers[0]) 44 | for batch in dataloader: 45 | try: 46 | model(batch[0].to(dev)) 47 | except ValueError: 48 | pass 49 | layers[0] = layers[0].module 50 | 51 | layers[0] = layers[0].cpu() 52 | model.model.embed_tokens = model.model.embed_tokens.cpu() 53 | model.model.norm = model.model.norm.cpu() 54 | torch.cuda.empty_cache() 55 | 56 | outs = torch.zeros_like(inps) 57 | attention_mask = cache['attention_mask'] 58 | position_ids = cache['position_ids'] 59 | 60 | print('Ready.') 61 | 62 | quantizers = {} 63 | for i in range(len(layers)): 64 | layer = layers[i].to(dev) 65 | full = find_layers(layer) 66 | 67 | if args.true_sequential: 68 | sequential = [ 69 | ['self_attn.k_u_proj','self_attn.k_v_proj', 'self_attn.v_u_proj', 'self_attn.v_v_proj', 'self_attn.q_u_proj', 'self_attn.q_v_proj'], 70 | ['self_attn.o_u_proj', 'self_attn.o_v_proj'], 71 | ['mlp.up_u_proj', 'mlp.up_v_proj', 'mlp.gate_u_proj', 'mlp.gate_v_proj'], 72 | ['mlp.down_u_proj', 'mlp.down_v_proj'] 73 | ] 74 | else: 75 | sequential = [list(full.keys())] 76 | 77 | for names in sequential: 78 | subset = {n: full[n] for n in names} 79 | 80 | gptq = {} 81 | for name in subset: 82 | gptq[name] = GPTQ(subset[name]) 83 | gptq[name].quantizer = Quantizer() 84 | gptq[name].quantizer.configure( 85 | args.wbits, perchannel=True, sym=args.sym, mse=False 86 | ) 87 | 88 | def add_batch(name): 89 | def tmp(_, inp, out): 90 | gptq[name].add_batch(inp[0].data, out.data) 91 | return tmp 92 | handles = [] 93 | for name in subset: 94 | handles.append(subset[name].register_forward_hook(add_batch(name))) 95 | for j in range(args.nsamples): 96 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 97 | for h in handles: 98 | h.remove() 99 | 100 | for name in subset: 101 | print(i, name) 102 | print('Quantizing ...') 103 | gptq[name].fasterquant( 104 | percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups 105 | ) 106 | quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer 107 | gptq[name].free() 108 | 109 | for j in range(args.nsamples): 110 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 111 | 112 | layers[i] = layer.cpu() 113 | del layer 114 | del gptq 115 | torch.cuda.empty_cache() 116 | 117 | inps, outs = outs, inps 118 | 119 | model.config.use_cache = use_cache 120 | 121 | return quantizers 122 | 123 | @torch.no_grad() 124 | def llama_eval(model, testenc, dev): 125 | print('Evaluating ...') 126 | 127 | testenc = testenc.input_ids 128 | nsamples = testenc.numel() // model.seqlen 129 | 130 | use_cache = model.config.use_cache 131 | model.config.use_cache = False 132 | layers = model.model.layers 133 | 134 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 135 | layers[0] = layers[0].to(dev) 136 | 137 | dtype = next(iter(model.parameters())).dtype 138 | inps = torch.zeros( 139 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev 140 | ) 141 | cache = {'i': 0, 'attention_mask': None} 142 | 143 | class Catcher(nn.Module): 144 | def __init__(self, module): 145 | super().__init__() 146 | self.module = module 147 | def forward(self, inp, **kwargs): 148 | inps[cache['i']] = inp 149 | cache['i'] += 1 150 | cache['attention_mask'] = kwargs['attention_mask'] 151 | cache['position_ids'] = kwargs['position_ids'] 152 | raise ValueError 153 | layers[0] = Catcher(layers[0]) 154 | for i in range(nsamples): 155 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) 156 | try: 157 | model(batch) 158 | except ValueError: 159 | pass 160 | layers[0] = layers[0].module 161 | 162 | layers[0] = layers[0].cpu() 163 | model.model.embed_tokens = model.model.embed_tokens.cpu() 164 | torch.cuda.empty_cache() 165 | 166 | outs = torch.zeros_like(inps) 167 | attention_mask = cache['attention_mask'] 168 | position_ids = cache['position_ids'] 169 | 170 | for i in range(len(layers)): 171 | print(i) 172 | layer = layers[i].to(dev) 173 | 174 | if args.nearest: 175 | subset = find_layers(layer) 176 | for name in subset: 177 | quantizer = Quantizer() 178 | quantizer.configure( 179 | args.wbits, perchannel=True, sym=False, mse=False 180 | ) 181 | W = subset[name].weight.data 182 | quantizer.find_params(W, weight=True) 183 | subset[name].weight.data = quantize( 184 | W, quantizer.scale, quantizer.zero, quantizer.maxq 185 | ).to(next(iter(layer.parameters())).dtype) 186 | 187 | for j in range(nsamples): 188 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] 189 | layers[i] = layer.cpu() 190 | del layer 191 | torch.cuda.empty_cache() 192 | inps, outs = outs, inps 193 | 194 | if model.model.norm is not None: 195 | model.model.norm = model.model.norm.to(dev) 196 | model.lm_head = model.lm_head.to(dev) 197 | 198 | testenc = testenc.to(dev) 199 | nlls = [] 200 | for i in range(nsamples): 201 | hidden_states = inps[i].unsqueeze(0) 202 | if model.model.norm is not None: 203 | hidden_states = model.model.norm(hidden_states) 204 | lm_logits = model.lm_head(hidden_states) 205 | shift_logits = lm_logits[:, :-1, :].contiguous() 206 | shift_labels = testenc[ 207 | :, (i * model.seqlen):((i + 1) * model.seqlen) 208 | ][:, 1:] 209 | loss_fct = nn.CrossEntropyLoss() 210 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 211 | neg_log_likelihood = loss.float() * model.seqlen 212 | nlls.append(neg_log_likelihood) 213 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) 214 | print(ppl.item()) 215 | 216 | model.config.use_cache = use_cache 217 | 218 | def llama_pack3(model, quantizers): 219 | layers = find_layers(model) 220 | layers = {n: layers[n] for n in quantizers} 221 | make_quant3(model, quantizers) 222 | qlayers = find_layers(model, [Quant3Linear]) 223 | print('Packing ...') 224 | for name in qlayers: 225 | print(name) 226 | quantizers[name] = quantizers[name].cpu() 227 | qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero) 228 | print('Done.') 229 | return model 230 | 231 | 232 | if __name__ == '__main__': 233 | import argparse 234 | from utils.data_utils import * 235 | 236 | parser = argparse.ArgumentParser() 237 | 238 | parser.add_argument( 239 | '--model_path', type=str, 240 | help='path of the compressed model.' 241 | ) 242 | parser.add_argument( 243 | '--dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], 244 | help='Where to extract calibration data from.' 245 | ) 246 | parser.add_argument( 247 | '--seed', 248 | type=int, default=0, help='Seed for sampling the calibration data.' 249 | ) 250 | parser.add_argument( 251 | '--nsamples', type=int, default=128, 252 | help='Number of calibration data samples.' 253 | ) 254 | parser.add_argument( 255 | '--percdamp', type=float, default=.01, 256 | help='Percent of the average Hessian diagonal to use for dampening.' 257 | ) 258 | parser.add_argument( 259 | '--nearest', action='store_true', 260 | help='Whether to run the RTN baseline.' 261 | ) 262 | parser.add_argument( 263 | '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], 264 | help='#bits to use for quantization; use 16 for evaluating base model.' 265 | ) 266 | parser.add_argument( 267 | '--groupsize', type=int, default=-1, 268 | help='Groupsize to use for quantization; default uses full row.' 269 | ) 270 | parser.add_argument( 271 | '--sym', action='store_true', 272 | help='Whether to perform symmetric quantization.' 273 | ) 274 | parser.add_argument( 275 | '--save', type=str, default='', 276 | help='Save quantized checkpoint under this name.' 277 | ) 278 | parser.add_argument( 279 | '--new-eval', action='store_true', 280 | help='Whether to use the new PTB and C4 eval.' 281 | ) 282 | parser.add_argument( 283 | '--act-order', action='store_true', 284 | help='Whether to apply the activation order GPTQ heuristic' 285 | ) 286 | parser.add_argument( 287 | '--true-sequential', action='store_true', 288 | help='Whether to run in true sequential model.' 289 | ) 290 | parser.add_argument( 291 | '--static-groups', action='store_true', 292 | help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.' 293 | ) 294 | parser.add_argument('--DEV', type=str, default="cuda", help='device') 295 | 296 | args = parser.parse_args() 297 | 298 | model, tokenizer = get_model_from_local(args.model_path) 299 | model.eval() 300 | 301 | dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, tokenizer=tokenizer) 302 | 303 | if args.wbits < 16 and not args.nearest: 304 | tick = time.time() 305 | quantizers = llama_sequential(model, dataloader, args.DEV) 306 | print(time.time() - tick) 307 | # if args.save: 308 | # llama_pack3(model, quantizers) 309 | # torch.save(model.state_dict(), args.save) 310 | ppl_eval(model, tokenizer, datasets=['wikitext2'], model_seq_len=2048, batch_size=16, device=args.DEV) 311 | torch.save({ 312 | 'model': model, 313 | 'tokenizer': tokenizer 314 | }, args.save) 315 | 316 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.16.1 2 | numpy==1.26.3 3 | torch>=2.0.1 4 | tqdm==4.65.0 5 | transformers==4.35.2 6 | sentencepiece==0.1.99 7 | matplotlib==3.4.3 8 | evaluate 9 | accelerate -------------------------------------------------------------------------------- /svdllm_gptq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run data whitening with 20% compression ratio 4 | python SVDLLM.py --model jeffwan/llama-7b-hf --step 1 --ratio 0.2 --whitening_nsamples 256 --dataset wikitext2 --seed 3 --model_seq_len 2048 --save_path . 5 | 6 | # further compress the model with GPTQ-4bit 7 | python quant_llama.py --model_path whitening/jeffwan_llama_7b_hf_whitening_0.2.pt --dataset c4 --wbits 4 --true-sequential --act-order --new-eval --save svdllm_gptq_4.pt -------------------------------------------------------------------------------- /utils/LoRA.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Refer to 3 | https://github.com/tloen/alpaca-lora/blob/main/finetune.py 4 | ''' 5 | 6 | import os 7 | import sys 8 | import argparse 9 | from typing import List 10 | 11 | import torch 12 | import transformers 13 | from datasets import load_dataset 14 | 15 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | sys.path.append(parent_path) 17 | from peft import ( 18 | LoraConfig, 19 | get_peft_model, 20 | get_peft_model_state_dict, 21 | prepare_model_for_int8_training, 22 | set_peft_model_state_dict, 23 | ) 24 | from Prompter import Prompter, ZeroPrompter 25 | 26 | device = "cuda" if torch.cuda.is_available() else "cpu" 27 | 28 | def wikitext2(): 29 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 30 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 31 | return traindata, testdata 32 | 33 | def ptb(): 34 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 35 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') 36 | return traindata, valdata 37 | 38 | def apply_lora(model, tokenizer, batch_size=64, micro_batch_size=4, cutoff_len=256, add_eos_token=False, 39 | lora_r=2, lora_alpha=16, lora_target_modules="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj", 40 | lora_dropout=0.05, val_set_size=2000, data_path="yahma/alpaca-cleaned",num_epochs=2, learning_rate=1e-4, 41 | output_dir="Checkpoints/tune", group_by_length=False, extra_val_dataset=None): 42 | 43 | gradient_accumulation_steps = batch_size // micro_batch_size 44 | prompter = ZeroPrompter() 45 | 46 | if device == 'cuda': 47 | model.half() 48 | 49 | tokenizer.pad_token_id = 0 50 | tokenizer.padding_side = "left" 51 | 52 | def tokenize(prompt, add_eos_token=True): 53 | result = tokenizer( 54 | prompt, 55 | truncation=True, 56 | max_length=cutoff_len, 57 | padding=False, 58 | return_tensors=None, 59 | ) 60 | if ( 61 | result["input_ids"][-1] != tokenizer.eos_token_id 62 | and len(result["input_ids"]) < cutoff_len 63 | and add_eos_token 64 | ): 65 | result["input_ids"].append(tokenizer.eos_token_id) 66 | result["attention_mask"].append(1) 67 | 68 | result["labels"] = result["input_ids"].copy() 69 | 70 | return result 71 | 72 | def generate_and_tokenize_prompt(data_point): 73 | full_prompt = prompter.generate_prompt( 74 | data_point["instruction"], 75 | data_point["input"], 76 | data_point["output"], 77 | ) 78 | tokenized_full_prompt = tokenize(full_prompt) 79 | user_prompt = prompter.generate_prompt( 80 | data_point["instruction"], data_point["input"] 81 | ) 82 | tokenized_user_prompt = tokenize( 83 | user_prompt, add_eos_token=add_eos_token 84 | ) 85 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 86 | 87 | if add_eos_token: 88 | user_prompt_len -= 1 89 | 90 | tokenized_full_prompt["labels"] = [ 91 | -100 92 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 93 | user_prompt_len: 94 | ] # could be sped up, probably 95 | return tokenized_full_prompt 96 | 97 | def split_and_tokenizer(test_data, tokenizer, seq_len, field_name): 98 | test_ids = tokenizer("\n\n".join(test_data[field_name]), return_tensors='pt').input_ids[0] 99 | test_ids_batch = [] 100 | nsamples = test_ids.numel() // seq_len 101 | 102 | test_set = [] 103 | for i in range(nsamples): 104 | batch = test_ids[(i * seq_len):((i + 1) * seq_len)] 105 | test_set.append({ 106 | 'input_ids': batch, 107 | 'labels': batch 108 | }) 109 | return test_set 110 | 111 | # Prepare For LoRA 112 | model = prepare_model_for_int8_training(model) 113 | config = LoraConfig( 114 | r=lora_r, 115 | lora_alpha=lora_alpha, 116 | target_modules=lora_target_modules.split(","), 117 | lora_dropout=lora_dropout, 118 | bias="none", 119 | task_type="CAUSAL_LM", 120 | ) 121 | model = get_peft_model(model, config) 122 | model.print_trainable_parameters() 123 | 124 | # Load Train Dataset 125 | data = load_dataset(data_path) 126 | train_val = data["train"].train_test_split( 127 | test_size=val_set_size, shuffle=True, seed=42 128 | ) 129 | train_data = ( 130 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 131 | ) 132 | val_data = { 133 | data_path: train_val["test"].shuffle().map(generate_and_tokenize_prompt), 134 | } 135 | 136 | # Load Extra Validation Dataset 137 | if extra_val_dataset: 138 | seq_len = 128 139 | for extra_dataset in extra_val_dataset.split(','): 140 | if 'wikitext2' in extra_dataset: 141 | _, test_data = wikitext2() 142 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='text') 143 | if 'ptb' in extra_dataset: 144 | _, test_data = ptb() 145 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='sentence') 146 | val_data[extra_dataset] = test_data 147 | 148 | trainer = transformers.Trainer( 149 | model=model, 150 | train_dataset=train_data, 151 | eval_dataset=val_data, 152 | args=transformers.TrainingArguments( 153 | per_device_train_batch_size=micro_batch_size, 154 | gradient_accumulation_steps=gradient_accumulation_steps, 155 | warmup_steps=100, 156 | num_train_epochs=num_epochs, 157 | learning_rate=learning_rate, 158 | fp16=True, 159 | logging_steps=10, 160 | logging_first_step=True, 161 | optim="adamw_torch", 162 | evaluation_strategy="steps", 163 | save_strategy="steps", 164 | eval_steps=100, 165 | save_steps=200, 166 | output_dir=output_dir, 167 | save_total_limit=30, 168 | load_best_model_at_end=True, 169 | ddp_find_unused_parameters=None, 170 | group_by_length=group_by_length, 171 | report_to="none", 172 | run_name="none", 173 | metric_for_best_model="{}_loss".format(data_path), 174 | ), 175 | data_collator=transformers.DataCollatorForSeq2Seq( 176 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 177 | ), 178 | ) 179 | model.config.use_cache = False 180 | old_state_dict = model.state_dict 181 | model.state_dict = ( 182 | lambda self, *_, **__: get_peft_model_state_dict( 183 | self, old_state_dict() 184 | ) 185 | ).__get__(model, type(model)) 186 | 187 | trainer.train() 188 | model.state_dict = old_state_dict 189 | return model 190 | 191 | def main(args): 192 | # Set WanDB 193 | os.environ["WANDB_PROJECT"] = args.wandb_project 194 | 195 | # Load Pruned Model 196 | pruned_dict = torch.load(args.prune_model, map_location='cpu') 197 | tokenizer, model = pruned_dict['tokenizer'], pruned_dict['model'] 198 | gradient_accumulation_steps = args.batch_size // args.micro_batch_size 199 | if not args.no_instruction: 200 | prompter = Prompter(args.prompt_template_name) 201 | else: 202 | prompter = ZeroPrompter() 203 | 204 | if device == 'cuda': 205 | model.half() 206 | 207 | tokenizer.pad_token_id = 0 208 | tokenizer.padding_side = "left" 209 | 210 | def tokenize(prompt, add_eos_token=True): 211 | result = tokenizer( 212 | prompt, 213 | truncation=True, 214 | max_length=args.cutoff_len, 215 | padding=False, 216 | return_tensors=None, 217 | ) 218 | if ( 219 | result["input_ids"][-1] != tokenizer.eos_token_id 220 | and len(result["input_ids"]) < args.cutoff_len 221 | and add_eos_token 222 | ): 223 | result["input_ids"].append(tokenizer.eos_token_id) 224 | result["attention_mask"].append(1) 225 | 226 | result["labels"] = result["input_ids"].copy() 227 | 228 | return result 229 | 230 | def generate_and_tokenize_prompt(data_point): 231 | full_prompt = prompter.generate_prompt( 232 | data_point["instruction"], 233 | data_point["input"], 234 | data_point["output"], 235 | ) 236 | tokenized_full_prompt = tokenize(full_prompt) 237 | if not args.train_on_inputs: 238 | user_prompt = prompter.generate_prompt( 239 | data_point["instruction"], data_point["input"] 240 | ) 241 | tokenized_user_prompt = tokenize( 242 | user_prompt, add_eos_token=args.add_eos_token 243 | ) 244 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 245 | 246 | if args.add_eos_token: 247 | user_prompt_len -= 1 248 | 249 | tokenized_full_prompt["labels"] = [ 250 | -100 251 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 252 | user_prompt_len: 253 | ] # could be sped up, probably 254 | return tokenized_full_prompt 255 | 256 | def split_and_tokenizer(test_data, tokenizer, seq_len, field_name): 257 | test_ids = tokenizer("\n\n".join(test_data[field_name]), return_tensors='pt').input_ids[0] 258 | test_ids_batch = [] 259 | nsamples = test_ids.numel() // seq_len 260 | 261 | test_set = [] 262 | for i in range(nsamples): 263 | batch = test_ids[(i * seq_len):((i + 1) * seq_len)] 264 | test_set.append({ 265 | 'input_ids': batch, 266 | 'labels': batch 267 | }) 268 | return test_set 269 | 270 | # Prepare For LoRA 271 | model = prepare_model_for_int8_training(model) 272 | config = LoraConfig( 273 | r=args.lora_r, 274 | lora_alpha=args.lora_alpha, 275 | target_modules=args.lora_target_modules.split(","), 276 | lora_dropout=args.lora_dropout, 277 | bias="none", 278 | task_type="CAUSAL_LM", 279 | ) 280 | model = get_peft_model(model, config) 281 | model.print_trainable_parameters() 282 | 283 | # Load Train Dataset 284 | data = load_dataset(args.data_path) 285 | train_val = data["train"].train_test_split( 286 | test_size=args.val_set_size, shuffle=True, seed=42 287 | ) 288 | train_data = ( 289 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 290 | ) 291 | val_data = { 292 | args.data_path: train_val["test"].shuffle().map(generate_and_tokenize_prompt), 293 | } 294 | 295 | # Load Extra Validation Dataset 296 | if args.extra_val_dataset: 297 | seq_len = 128 298 | for extra_dataset in args.extra_val_dataset.split(','): 299 | if 'wikitext2' in extra_dataset: 300 | _, test_data = wikitext2() 301 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='text') 302 | if 'ptb' in extra_dataset: 303 | _, test_data = ptb() 304 | test_data = split_and_tokenizer(test_data, tokenizer, seq_len, field_name='sentence') 305 | val_data[extra_dataset] = test_data 306 | 307 | trainer = transformers.Trainer( 308 | model=model, 309 | train_dataset=train_data, 310 | eval_dataset=val_data, 311 | args=transformers.TrainingArguments( 312 | per_device_train_batch_size=args.micro_batch_size, 313 | gradient_accumulation_steps=gradient_accumulation_steps, 314 | warmup_steps=100, 315 | num_train_epochs=args.num_epochs, 316 | learning_rate=args.learning_rate, 317 | fp16=True, 318 | logging_steps=10, 319 | logging_first_step=True, 320 | optim="adamw_torch", 321 | evaluation_strategy="steps", 322 | save_strategy="steps", 323 | save_safetensors=False, 324 | eval_steps=100, 325 | save_steps=200, 326 | output_dir=args.output_dir, 327 | save_total_limit=20, 328 | load_best_model_at_end=True, 329 | ddp_find_unused_parameters=None, 330 | group_by_length=args.group_by_length, 331 | report_to="none", 332 | run_name="none", 333 | metric_for_best_model="{}_loss".format(args.data_path), 334 | ), 335 | data_collator=transformers.DataCollatorForSeq2Seq( 336 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 337 | ), 338 | ) 339 | model.config.use_cache = False 340 | old_state_dict = model.state_dict 341 | model.state_dict = ( 342 | lambda self, *_, **__: get_peft_model_state_dict( 343 | self, old_state_dict() 344 | ) 345 | ).__get__(model, type(model)) 346 | 347 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 348 | 349 | model.state_dict = old_state_dict 350 | model.save_pretrained(args.output_dir, safe_serialization=False) 351 | 352 | 353 | if __name__ == "__main__": 354 | parser = argparse.ArgumentParser(description='Tuning Pruned LLM') 355 | 356 | # Model Type&Path 357 | parser.add_argument('--base_model', type=str, default="decapoda-research/llama-7b-hf", help='base model name') 358 | parser.add_argument('--prune_model', type=str, help='prune model name') 359 | parser.add_argument('--data_path', type=str, default="yahma/alpaca-cleaned", help='data path') 360 | # parser.add_argument('--extra_val_dataset', type=str, default='wikitext2,ptb', help='validation datasets. Split with ","') 361 | parser.add_argument('--extra_val_dataset', type=str, default=None, help='validation datasets. Split with ","') 362 | parser.add_argument('--output_dir', type=str, default="./lora-alpaca", help='output directory') 363 | 364 | # Training Hyperparameters 365 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 366 | parser.add_argument('--micro_batch_size', type=int, default=4, help='micro batch size') 367 | parser.add_argument('--num_epochs', type=int, default=5, help='number of epochs') 368 | parser.add_argument('--learning_rate', type=float, default=3e-4, help='learning rate') 369 | parser.add_argument('--cutoff_len', type=int, default=256, help='cutoff length') 370 | parser.add_argument('--val_set_size', type=int, default=2000, help='validation set size') 371 | parser.add_argument('--prompt_template_name', type=str, default="alpaca", help="The prompt template to use, will default to alpaca.") 372 | parser.add_argument('--no_instruction', action='store_true', default=False, help="Whether to use the instruction template or not.") 373 | 374 | # Lora Configuration 375 | parser.add_argument('--lora_r', type=int, default=8, help='lora r') 376 | parser.add_argument('--lora_alpha', type=int, default=16, help='lora alpha') 377 | parser.add_argument('--lora_dropout', type=float, default=0.05, help='lora dropout') 378 | parser.add_argument('--lora_target_modules', type=str, default="q_v_proj,q_u_proj,k_v_proj,k_u_proj,v_u_proj,v_v_proj,o_u_proj,o_v_proj,gate_u_proj,gate_v_proj,down_u_proj,down_v_proj,up_u_proj,up_v_proj", help='lora target modules') 379 | 380 | # llm hyperparameters 381 | parser.add_argument('--train_on_inputs', default=False, action="store_true", help='Train on inputs. If False, masks out inputs in loss') 382 | parser.add_argument('--add_eos_token', default=False, action="store_true") 383 | parser.add_argument('--group_by_length', default=False, action="store_true", help="faster, but produces an odd training loss curve") 384 | 385 | # wandb params 386 | parser.add_argument('--wandb_project', type=str, default="") 387 | parser.add_argument('--resume_from_checkpoint', type=str, help="either training checkpoint or final adapter") 388 | 389 | args = parser.parse_args() 390 | torch_version = int(torch.__version__.split('.')[1]) 391 | args.torch_version = torch_version 392 | 393 | main(args) -------------------------------------------------------------------------------- /utils/Prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | alpaca_template = { 10 | "description": "Template used by Alpaca-LoRA.", 11 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 12 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 13 | "response_split": "### Response:" 14 | } 15 | 16 | class Prompter(object): 17 | __slots__ = ("template", "_verbose") 18 | 19 | def __init__(self, template_name: str = "", verbose: bool = False): 20 | self._verbose = verbose 21 | if not template_name or template_name == 'alpaca': 22 | self.template = alpaca_template 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | ) -> str: 34 | # returns the full prompt from instruction and optional input 35 | # if a label (=response, =output) is provided, it's also appended. 36 | if input: 37 | res = self.template["prompt_input"].format( 38 | instruction=instruction, input=input 39 | ) 40 | else: 41 | res = self.template["prompt_no_input"].format( 42 | instruction=instruction 43 | ) 44 | if label: 45 | res = f"{res}{label}" 46 | if self._verbose: 47 | print(res) 48 | return res 49 | 50 | def get_response(self, output: str) -> str: 51 | return output.split(self.template["response_split"])[1].strip() 52 | 53 | 54 | class ZeroPrompter(object): 55 | __slots__ = ("_verbose") 56 | 57 | def __init__(self, verbose: bool = False): 58 | self._verbose = verbose 59 | 60 | if self._verbose: 61 | print( 62 | f"Without using prompt template!" 63 | ) 64 | 65 | def generate_prompt( 66 | self, 67 | instruction: str, 68 | input: Union[None, str] = None, 69 | label: Union[None, str] = None, 70 | ) -> str: 71 | # returns the full prompt from instruction and optional input 72 | # if a label (=response, =output) is provided, it's also appended. 73 | if instruction[-1] == '.': 74 | instruction = instruction[:-1] + ':' 75 | if instruction[-1] not in ['.', ':', '?', '!']: 76 | instruction = instruction + ':' 77 | instruction += ' ' 78 | 79 | if input: 80 | if input[-1] not in ['.', ':', '?', '!']: 81 | input = input + '.' 82 | res = instruction + input 83 | else: 84 | res = instruction 85 | if label: 86 | res = f"{res} {label}" 87 | if self._verbose: 88 | print(res) 89 | return res 90 | 91 | def get_response(self, output: str) -> str: 92 | return output.strip() 93 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import sys 5 | from datasets import load_dataset 6 | from torch.utils.data.dataset import Dataset 7 | 8 | current_path = os.path.dirname(os.path.abspath(__file__)) 9 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(current_path) 11 | 12 | def get_calib_train_data(name, tokenizer, nsamples, seqlen=2048, seed=3, batch_size=1, dataset_cache_dir=None): 13 | import random 14 | random.seed(seed) 15 | cache_file = ( 16 | f"cache/{name}_{nsamples}_{seqlen}_{seed}_{batch_size}.pt" 17 | ) 18 | nsamples += 1 ############################# 19 | if not os.path.exists("cache"): 20 | os.makedirs("cache") 21 | if os.path.exists(cache_file): 22 | traindataset = torch.load(cache_file) 23 | return traindataset 24 | if name == "c4": 25 | traindata = load_dataset("json", data_files="utils/c4-train.json")['train'] 26 | tot_text = "\n\n".join(traindata["text"]) 27 | elif name == "ptb": 28 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir) 29 | tot_text = "\n\n".join(traindata["sentence"]) 30 | elif name == "wikitext2": 31 | traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", cache_dir=dataset_cache_dir) 32 | tot_text = "\n\n".join(traindata["text"]) 33 | else: 34 | raise NotImplementedError 35 | traindataset = [] 36 | for s in range(nsamples): 37 | i = random.randint(0, len(tot_text) - seqlen - 1) 38 | j = i + seqlen * 10 39 | trainenc = tokenizer(tot_text[i:j], return_tensors="pt") 40 | if trainenc.input_ids.shape[1] < seqlen: 41 | s = s - 1 42 | continue 43 | if s % batch_size == 0: 44 | if s != 0: 45 | attention_mask = torch.ones_like(inp) 46 | traindataset.append({"input_ids": inp, "attention_mask": attention_mask}) 47 | inp = trainenc.input_ids[:, :seqlen] 48 | else: 49 | inp = torch.cat((inp, trainenc.input_ids[:, :seqlen]), dim=0) 50 | torch.save(traindataset, cache_file) 51 | return traindataset 52 | 53 | 54 | 55 | def get_wikitext2(nsamples, seed, seqlen, tokenizer, dataset_cache_dir=None): 56 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', cache_dir=dataset_cache_dir) 57 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', cache_dir=dataset_cache_dir) 58 | 59 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 60 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 61 | 62 | import random 63 | random.seed(seed) 64 | trainloader = [] 65 | for _ in range(nsamples): 66 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 67 | j = i + seqlen 68 | inp = trainenc.input_ids[:, i:j] 69 | tar = inp.clone() 70 | tar[:, :-1] = -100 71 | trainloader.append((inp, tar)) 72 | return trainloader, testenc 73 | 74 | def get_ptb(nsamples, seed, seqlen, tokenizer, dataset_cache_dir=None): 75 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir) 76 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation', cache_dir=dataset_cache_dir) 77 | 78 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 79 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 80 | 81 | import random 82 | random.seed(seed) 83 | trainloader = [] 84 | for _ in range(nsamples): 85 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 86 | j = i + seqlen 87 | inp = trainenc.input_ids[:, i:j] 88 | tar = inp.clone() 89 | tar[:, :-1] = -100 90 | trainloader.append((inp, tar)) 91 | return trainloader, testenc 92 | 93 | def get_c4(nsamples, seed, seqlen, tokenizer): 94 | traindata = load_dataset("json", data_files="utils/c4-train.json")['train'] 95 | valdata = load_dataset("json", data_files="utils/c4-validation.json")['train'] 96 | 97 | import random 98 | random.seed(seed) 99 | trainloader = [] 100 | for _ in range(nsamples): 101 | while True: 102 | i = random.randint(0, len(traindata) - 1) 103 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 104 | if trainenc.input_ids.shape[1] >= seqlen: 105 | break 106 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 107 | j = i + seqlen 108 | inp = trainenc.input_ids[:, i:j] 109 | tar = inp.clone() 110 | tar[:, :-1] = -100 111 | trainloader.append((inp, tar)) 112 | 113 | import random 114 | random.seed(0) 115 | valenc = [] 116 | for _ in range(256): 117 | while True: 118 | i = random.randint(0, len(valdata) - 1) 119 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 120 | if tmp.input_ids.shape[1] >= seqlen: 121 | break 122 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 123 | j = i + seqlen 124 | valenc.append(tmp.input_ids[:, i:j]) 125 | valenc = torch.hstack(valenc) 126 | class TokenizerWrapper: 127 | def __init__(self, input_ids): 128 | self.input_ids = input_ids 129 | valenc = TokenizerWrapper(valenc) 130 | 131 | return trainloader, valenc 132 | 133 | 134 | 135 | def get_ptb_new(nsamples, seed, seqlen, tokenizer, dataset_cache_dir=None): 136 | from datasets import load_dataset 137 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir) 138 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test', cache_dir=dataset_cache_dir) 139 | 140 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 141 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 142 | 143 | import random 144 | random.seed(seed) 145 | trainloader = [] 146 | for _ in range(nsamples): 147 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 148 | j = i + seqlen 149 | inp = trainenc.input_ids[:, i:j] 150 | tar = inp.clone() 151 | tar[:, :-1] = -100 152 | trainloader.append((inp, tar)) 153 | return trainloader, testenc 154 | 155 | def get_c4_new(nsamples, seed, seqlen, tokenizer): 156 | traindata = load_dataset("json", data_files="utils/c4-train.json")['train'] 157 | valdata = load_dataset("json", data_files="utils/c4-validation.json")['train'] 158 | 159 | import random 160 | random.seed(seed) 161 | trainloader = [] 162 | for _ in range(nsamples): 163 | while True: 164 | i = random.randint(0, len(traindata) - 1) 165 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 166 | if trainenc.input_ids.shape[1] >= seqlen: 167 | break 168 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 169 | j = i + seqlen 170 | inp = trainenc.input_ids[:, i:j] 171 | tar = inp.clone() 172 | tar[:, :-1] = -100 173 | trainloader.append((inp, tar)) 174 | 175 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 176 | valenc = valenc.input_ids[:, :(256 * seqlen)] 177 | 178 | class TokenizerWrapper: 179 | def __init__(self, input_ids): 180 | self.input_ids = input_ids 181 | valenc = TokenizerWrapper(valenc) 182 | 183 | return trainloader, valenc 184 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): 185 | if 'wikitext2' in name: 186 | return get_wikitext2(nsamples, seed, seqlen, tokenizer) 187 | if 'ptb' in name: 188 | if 'new' in name: 189 | return get_ptb_new(nsamples, seed, seqlen, tokenizer) 190 | return get_ptb(nsamples, seed, seqlen, tokenizer) 191 | if 'c4' in name: 192 | if 'new' in name: 193 | return get_c4_new(nsamples, seed, seqlen, tokenizer) 194 | return get_c4(nsamples, seed, seqlen, tokenizer) 195 | 196 | 197 | 198 | def get_test_data(name, tokenizer, seq_len=2048, batch_size = 4): 199 | class IndexDataset(Dataset): 200 | def __init__(self, tensors): 201 | self.tensors = tensors 202 | 203 | def __getitem__(self, index): 204 | return self.tensors[index] 205 | 206 | def __len__(self): 207 | return len(self.tensors) 208 | #### 209 | def process_data(samples, tokenizer, seq_len, field_name): 210 | test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0] 211 | test_ids_batch = [] 212 | nsamples = test_ids.numel() // seq_len 213 | 214 | for i in range(nsamples): 215 | batch = test_ids[(i * seq_len):((i + 1) * seq_len)] 216 | test_ids_batch.append(batch) 217 | test_ids_batch = torch.stack(test_ids_batch) 218 | return IndexDataset(tensors=test_ids_batch) 219 | #### 220 | if 'wikitext2' in name: 221 | test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 222 | test_dataset = process_data(test_data, tokenizer, seq_len, 'text') 223 | if 'ptb' in name: 224 | test_data = load_dataset('ptb_text_only', 'penn_treebank', split='test') 225 | test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence') 226 | elif 'c4' in name: 227 | test_data = load_dataset("json", data_files="utils/c4-validation.json")['train'] 228 | test_dataset = process_data(test_data[0:2000], tokenizer, seq_len, 'text') 229 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 230 | return test_loader -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import os 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | 7 | current_path = os.path.dirname(os.path.abspath(__file__)) 8 | parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(current_path) 10 | 11 | # bandaid fix 12 | dev = torch.device("cuda") 13 | 14 | def get_model_from_huggingface(model_id): 15 | from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer, LlamaForCausalLM 16 | if "opt" in model_id or "mistral" in model_id: 17 | tokenizer = AutoTokenizer.from_pretrained(model_id, device_map="cpu", trust_remote_code=True) 18 | else: 19 | tokenizer = LlamaTokenizer.from_pretrained(model_id, device_map="cpu", trust_remote_code=True) 20 | model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", torch_dtype=torch.float16, trust_remote_code=True, cache_dir=None) 21 | model.seqlen = 2048 22 | return model, tokenizer 23 | 24 | def get_model_from_local(model_id): 25 | pruned_dict = torch.load(model_id, weights_only=False, map_location='cpu') 26 | tokenizer, model = pruned_dict['tokenizer'], pruned_dict['model'] 27 | return model, tokenizer 28 | 29 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): 30 | if type(module) in layers: 31 | return {name: module} 32 | res = {} 33 | for name1, child in module.named_children(): 34 | res.update(find_layers( 35 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 36 | )) 37 | return res 38 | -------------------------------------------------------------------------------- /utils/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.3.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | ) 30 | from .tuners import ( 31 | LoraConfig, 32 | LoraModel, 33 | AdaLoraConfig, 34 | AdaLoraModel, 35 | PrefixEncoder, 36 | PrefixTuningConfig, 37 | PromptEmbedding, 38 | PromptEncoder, 39 | PromptEncoderConfig, 40 | PromptEncoderReparameterizationType, 41 | PromptTuningConfig, 42 | PromptTuningInit, 43 | ) 44 | from .utils import ( 45 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 46 | PeftConfig, 47 | PeftType, 48 | PromptLearningConfig, 49 | TaskType, 50 | bloom_model_postprocess_past_key_value, 51 | get_peft_model_state_dict, 52 | prepare_model_for_int8_training, 53 | set_peft_model_state_dict, 54 | shift_tokens_right, 55 | ) 56 | -------------------------------------------------------------------------------- /utils/peft/import_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import importlib 16 | 17 | 18 | def is_bnb_available(): 19 | return importlib.util.find_spec("bitsandbytes") is not None 20 | -------------------------------------------------------------------------------- /utils/peft/mapping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .peft_model import ( 17 | PeftModel, 18 | PeftModelForCausalLM, 19 | PeftModelForSeq2SeqLM, 20 | PeftModelForSequenceClassification, 21 | PeftModelForTokenClassification, 22 | ) 23 | from .tuners import AdaLoraConfig, LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig 24 | from .utils import PromptLearningConfig 25 | 26 | 27 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { 28 | "SEQ_CLS": PeftModelForSequenceClassification, 29 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, 30 | "CAUSAL_LM": PeftModelForCausalLM, 31 | "TOKEN_CLS": PeftModelForTokenClassification, 32 | } 33 | 34 | PEFT_TYPE_TO_CONFIG_MAPPING = { 35 | "PROMPT_TUNING": PromptTuningConfig, 36 | "PREFIX_TUNING": PrefixTuningConfig, 37 | "P_TUNING": PromptEncoderConfig, 38 | "LORA": LoraConfig, 39 | "ADALORA": AdaLoraConfig, 40 | } 41 | 42 | 43 | def get_peft_config(config_dict): 44 | """ 45 | Returns a Peft config object from a dictionary. 46 | 47 | Args: 48 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. 49 | """ 50 | 51 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) 52 | 53 | 54 | def _prepare_prompt_learning_config(peft_config, model_config): 55 | if peft_config.num_layers is None: 56 | if "num_hidden_layers" in model_config: 57 | num_layers = model_config["num_hidden_layers"] 58 | elif "num_layers" in model_config: 59 | num_layers = model_config["num_layers"] 60 | elif "n_layer" in model_config: 61 | num_layers = model_config["n_layer"] 62 | else: 63 | raise ValueError("Please specify `num_layers` in `peft_config`") 64 | peft_config.num_layers = num_layers 65 | 66 | if peft_config.token_dim is None: 67 | if "hidden_size" in model_config: 68 | token_dim = model_config["hidden_size"] 69 | elif "n_embd" in model_config: 70 | token_dim = model_config["n_embd"] 71 | elif "d_model" in model_config: 72 | token_dim = model_config["d_model"] 73 | else: 74 | raise ValueError("Please specify `token_dim` in `peft_config`") 75 | peft_config.token_dim = token_dim 76 | 77 | if peft_config.num_attention_heads is None: 78 | if "num_attention_heads" in model_config: 79 | num_attention_heads = model_config["num_attention_heads"] 80 | elif "n_head" in model_config: 81 | num_attention_heads = model_config["n_head"] 82 | elif "num_heads" in model_config: 83 | num_attention_heads = model_config["num_heads"] 84 | elif "encoder_attention_heads" in model_config: 85 | num_attention_heads = model_config["encoder_attention_heads"] 86 | else: 87 | raise ValueError("Please specify `num_attention_heads` in `peft_config`") 88 | peft_config.num_attention_heads = num_attention_heads 89 | 90 | if getattr(peft_config, "encoder_hidden_size", None) is None: 91 | setattr(peft_config, "encoder_hidden_size", token_dim) 92 | 93 | return peft_config 94 | 95 | 96 | def get_peft_model(model, peft_config): 97 | """ 98 | Returns a Peft model object from a model and a config. 99 | 100 | Args: 101 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. 102 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. 103 | """ 104 | model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config 105 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 106 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( 107 | peft_config, PromptLearningConfig 108 | ): 109 | return PeftModel(model, peft_config) 110 | if isinstance(peft_config, PromptLearningConfig): 111 | peft_config = _prepare_prompt_learning_config(peft_config, model_config) 112 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config) 113 | -------------------------------------------------------------------------------- /utils/peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .lora import LoraConfig, LoraModel 21 | from .adalora import AdaLoraConfig, AdaLoraModel 22 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 23 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 24 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 25 | -------------------------------------------------------------------------------- /utils/peft/tuners/lora.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import math 16 | import re 17 | import warnings 18 | from dataclasses import asdict, dataclass, field 19 | from enum import Enum 20 | from typing import List, Optional, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | from transformers.pytorch_utils import Conv1D 26 | 27 | from ..import_utils import is_bnb_available 28 | from ..utils import ( 29 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 30 | ModulesToSaveWrapper, 31 | PeftConfig, 32 | PeftType, 33 | _freeze_adapter, 34 | _get_submodules, 35 | transpose, 36 | ) 37 | 38 | 39 | if is_bnb_available(): 40 | import bitsandbytes as bnb 41 | 42 | 43 | @dataclass 44 | class LoraConfig(PeftConfig): 45 | """ 46 | This is the configuration class to store the configuration of a [`LoraModel`]. 47 | 48 | Args: 49 | r (`int`): Lora attention dimension. 50 | target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. 51 | lora_alpha (`float`): The alpha parameter for Lora scaling. 52 | lora_dropout (`float`): The dropout probability for Lora layers. 53 | fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). 54 | For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: 55 | bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' 56 | modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable 57 | and saved in the final checkpoint. 58 | """ 59 | 60 | r: int = field(default=8, metadata={"help": "Lora attention dimension"}) 61 | target_modules: Optional[Union[List[str], str]] = field( 62 | default=None, 63 | metadata={ 64 | "help": "List of module names or regex expression of the module names to replace with Lora." 65 | "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " 66 | }, 67 | ) 68 | lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"}) 69 | lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"}) 70 | fan_in_fan_out: bool = field( 71 | default=False, 72 | metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, 73 | ) 74 | bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) 75 | modules_to_save: Optional[List[str]] = field( 76 | default=None, 77 | metadata={ 78 | "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 79 | "For example, in Sequence Classification or Token Classification tasks, " 80 | "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." 81 | }, 82 | ) 83 | init_lora_weights: bool = field( 84 | default=True, 85 | metadata={"help": "Whether to initialize the weights of the Lora layers."}, 86 | ) 87 | 88 | def __post_init__(self): 89 | self.peft_type = PeftType.LORA 90 | 91 | 92 | class LoraModel(torch.nn.Module): 93 | """ 94 | Creates Low Rank Adapter (Lora) model from a pretrained transformers model. 95 | 96 | Args: 97 | model ([`~transformers.PreTrainedModel`]): The model to be adapted. 98 | config ([`LoraConfig`]): The configuration of the Lora model. 99 | 100 | Returns: 101 | `torch.nn.Module`: The Lora model. 102 | 103 | Example: 104 | 105 | ```py 106 | >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig 107 | >>> from peft import LoraModel, LoraConfig 108 | 109 | >>> config = LoraConfig( 110 | ... peft_type="LORA", 111 | ... task_type="SEQ_2_SEQ_LM", 112 | ... r=8, 113 | ... lora_alpha=32, 114 | ... target_modules=["q", "v"], 115 | ... lora_dropout=0.01, 116 | ... ) 117 | 118 | >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") 119 | >>> lora_model = LoraModel(config, model) 120 | ``` 121 | 122 | **Attributes**: 123 | - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. 124 | - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. 125 | """ 126 | 127 | def __init__(self, model, config, adapter_name): 128 | super().__init__() 129 | self.model = model 130 | self.forward = self.model.forward 131 | self.peft_config = config 132 | self.add_adapter(adapter_name, self.peft_config[adapter_name]) 133 | 134 | def add_adapter(self, adapter_name, config=None): 135 | if config is not None: 136 | model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config 137 | config = self._prepare_lora_config(config, model_config) 138 | self.peft_config[adapter_name] = config 139 | self._find_and_replace(adapter_name) 140 | if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": 141 | raise ValueError( 142 | "LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." 143 | ) 144 | mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) 145 | if self.peft_config[adapter_name].inference_mode: 146 | _freeze_adapter(self.model, adapter_name) 147 | 148 | def _find_and_replace(self, adapter_name): 149 | lora_config = self.peft_config[adapter_name] 150 | loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) 151 | if loaded_in_8bit and not is_bnb_available(): 152 | raise ImportError( 153 | "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " 154 | "You can install it with `pip install bitsandbytes`." 155 | ) 156 | is_target_modules_in_base_model = False 157 | kwargs = { 158 | "r": lora_config.r, 159 | "lora_alpha": lora_config.lora_alpha, 160 | "lora_dropout": lora_config.lora_dropout, 161 | "fan_in_fan_out": lora_config.fan_in_fan_out, 162 | "init_lora_weights": lora_config.init_lora_weights, 163 | } 164 | key_list = [key for key, _ in self.model.named_modules()] 165 | for key in key_list: 166 | if isinstance(lora_config.target_modules, str): 167 | target_module_found = re.fullmatch(lora_config.target_modules, key) 168 | else: 169 | target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) 170 | if target_module_found: 171 | if not is_target_modules_in_base_model: 172 | is_target_modules_in_base_model = True 173 | parent, target, target_name = _get_submodules(self.model, key) 174 | bias = target.bias is not None 175 | if isinstance(target, LoraLayer): 176 | target.update_layer( 177 | adapter_name, 178 | lora_config.r, 179 | lora_config.lora_alpha, 180 | lora_config.lora_dropout, 181 | lora_config.init_lora_weights, 182 | ) 183 | else: 184 | if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): 185 | eightbit_kwargs = kwargs.copy() 186 | eightbit_kwargs.update( 187 | { 188 | "has_fp16_weights": target.state.has_fp16_weights, 189 | "memory_efficient_backward": target.state.memory_efficient_backward, 190 | "threshold": target.state.threshold, 191 | "index": target.index, 192 | } 193 | ) 194 | new_module = Linear8bitLt( 195 | adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs 196 | ) 197 | else: 198 | if isinstance(target, torch.nn.Linear): 199 | in_features, out_features = target.in_features, target.out_features 200 | if kwargs["fan_in_fan_out"]: 201 | warnings.warn( 202 | "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " 203 | "Setting fan_in_fan_out to False." 204 | ) 205 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False 206 | elif isinstance(target, Conv1D): 207 | in_features, out_features = ( 208 | target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape 209 | ) 210 | if not kwargs["fan_in_fan_out"]: 211 | warnings.warn( 212 | "fan_in_fan_out is set to False but the target module is `Conv1D`. " 213 | "Setting fan_in_fan_out to True." 214 | ) 215 | kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True 216 | else: 217 | raise ValueError( 218 | f"Target module {target} is not supported. " 219 | f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." 220 | ) 221 | new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs) 222 | 223 | self._replace_module(parent, target_name, new_module, target) 224 | if not is_target_modules_in_base_model: 225 | raise ValueError( 226 | f"Target modules {lora_config.target_modules} not found in the base model. " 227 | f"Please check the target modules and try again." 228 | ) 229 | 230 | def _replace_module(self, parent_module, child_name, new_module, old_module): 231 | setattr(parent_module, child_name, new_module) 232 | new_module.weight = old_module.weight 233 | if old_module.bias is not None: 234 | new_module.bias = old_module.bias 235 | if getattr(old_module, "state", None) is not None: 236 | new_module.state = old_module.state 237 | new_module.to(old_module.weight.device) 238 | 239 | # dispatch to correct device 240 | for name, module in new_module.named_modules(): 241 | if "lora_" in name: 242 | module.to(old_module.weight.device) 243 | 244 | def __getattr__(self, name: str): 245 | """Forward missing attributes to the wrapped module.""" 246 | try: 247 | return super().__getattr__(name) # defer to nn.Module's logic 248 | except AttributeError: 249 | return getattr(self.model, name) 250 | 251 | def get_peft_config_as_dict(self, inference: bool = False): 252 | config_dict = {} 253 | for key, value in self.peft_config.items(): 254 | config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} 255 | if inference: 256 | config["inference_mode"] = True 257 | config_dict[key] = config 258 | return config 259 | 260 | def _set_adapter_layers(self, enabled=True): 261 | for module in self.model.modules(): 262 | if isinstance(module, LoraLayer): 263 | module.disable_adapters = False if enabled else True 264 | 265 | def enable_adapter_layers(self): 266 | self._set_adapter_layers(enabled=True) 267 | 268 | def disable_adapter_layers(self): 269 | self._set_adapter_layers(enabled=False) 270 | 271 | def set_adapter(self, adapter_name): 272 | for module in self.model.modules(): 273 | if isinstance(module, LoraLayer): 274 | if module.merged: 275 | warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") 276 | module.unmerge() 277 | module.active_adapter = adapter_name 278 | 279 | def merge_adapter(self): 280 | for module in self.model.modules(): 281 | if isinstance(module, LoraLayer): 282 | module.merge() 283 | 284 | def unmerge_adapter(self): 285 | for module in self.model.modules(): 286 | if isinstance(module, LoraLayer): 287 | module.unmerge() 288 | 289 | @staticmethod 290 | def _prepare_lora_config(peft_config, model_config): 291 | if peft_config.target_modules is None: 292 | if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: 293 | raise ValueError("Please specify `target_modules` in `peft_config`") 294 | peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] 295 | if peft_config.inference_mode: 296 | peft_config.merge_weights = True 297 | return peft_config 298 | 299 | def merge_and_unload(self): 300 | r""" 301 | This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model 302 | as a standalone model. 303 | """ 304 | if getattr(self.config, "model_type", None) == "gpt2": 305 | raise ValueError("GPT2 models are not supported for merging LORA layers") 306 | 307 | if getattr(self.model, "is_loaded_in_8bit", False): 308 | raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode") 309 | 310 | key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] 311 | for key in key_list: 312 | try: 313 | parent, target, target_name = _get_submodules(self.model, key) 314 | except AttributeError: 315 | continue 316 | if isinstance(target, LoraLayer): 317 | bias = target.bias is not None 318 | new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) 319 | target.merge() 320 | self._replace_module(parent, target_name, new_module, target) 321 | 322 | # save any additional trainable modules part of `modules_to_save` 323 | if isinstance(target, ModulesToSaveWrapper): 324 | setattr(parent, target_name, target.modules_to_save[target.active_adapter]) 325 | 326 | return self.model 327 | 328 | def add_weighted_adapter(self, adapters, weights, adapter_name): 329 | if len({self.peft_config[adapter].r for adapter in adapters}) != 1: 330 | raise ValueError("All adapters must have the same r value") 331 | self.peft_config[adapter_name] = self.peft_config[adapters[0]] 332 | self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r 333 | self._find_and_replace(adapter_name) 334 | mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) 335 | _freeze_adapter(self.model, adapter_name) 336 | key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] 337 | for key in key_list: 338 | _, target, _ = _get_submodules(self.model, key) 339 | if isinstance(target, LoraLayer): 340 | target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0 341 | target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0 342 | for adapter, weight in zip(adapters, weights): 343 | if adapter not in target.lora_A: 344 | continue 345 | target.lora_A[adapter_name].weight.data += ( 346 | target.lora_A[adapter].weight.data * weight * target.scaling[adapter] 347 | ) 348 | target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight 349 | 350 | 351 | # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 352 | # and modified to work with PyTorch FSDP 353 | 354 | 355 | # ------------------------------------------------------------------------------------------ 356 | # Copyright (c) Microsoft Corporation. All rights reserved. 357 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 358 | # ------------------------------------------------------------------------------------------ 359 | 360 | 361 | # had to adapt it for `lora_only` to work 362 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: 363 | for n, p in model.named_parameters(): 364 | if "lora_" not in n: 365 | p.requires_grad = False 366 | if bias == "none": 367 | return 368 | elif bias == "all": 369 | for n, p in model.named_parameters(): 370 | if "bias" in n: 371 | p.requires_grad = True 372 | elif bias == "lora_only": 373 | for m in model.modules(): 374 | if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: 375 | m.bias.requires_grad = True 376 | else: 377 | raise NotImplementedError 378 | 379 | 380 | class LoraLayer: 381 | def __init__( 382 | self, 383 | in_features: int, 384 | out_features: int, 385 | ): 386 | self.r = {} 387 | self.lora_alpha = {} 388 | self.scaling = {} 389 | self.lora_dropout = nn.ModuleDict({}) 390 | self.lora_A = nn.ModuleDict({}) 391 | self.lora_B = nn.ModuleDict({}) 392 | # Mark the weight as unmerged 393 | self.merged = False 394 | self.disable_adapters = False 395 | self.in_features = in_features 396 | self.out_features = out_features 397 | 398 | def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): 399 | self.r[adapter_name] = r 400 | self.lora_alpha[adapter_name] = lora_alpha 401 | if lora_dropout > 0.0: 402 | lora_dropout_layer = nn.Dropout(p=lora_dropout) 403 | else: 404 | lora_dropout_layer = nn.Identity() 405 | 406 | self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) 407 | # Actual trainable parameters 408 | if r > 0: 409 | self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)})) 410 | self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)})) 411 | self.scaling[adapter_name] = lora_alpha / r 412 | if init_lora_weights: 413 | self.reset_lora_parameters(adapter_name) 414 | self.to(self.weight.device) 415 | 416 | def reset_lora_parameters(self, adapter_name): 417 | if adapter_name in self.lora_A.keys(): 418 | # initialize A the same way as the default for nn.Linear and B to zero 419 | nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) 420 | nn.init.zeros_(self.lora_B[adapter_name].weight) 421 | 422 | 423 | class Linear(nn.Linear, LoraLayer): 424 | # Lora implemented in a dense layer 425 | def __init__( 426 | self, 427 | adapter_name: str, 428 | in_features: int, 429 | out_features: int, 430 | r: int = 0, 431 | lora_alpha: int = 1, 432 | lora_dropout: float = 0.0, 433 | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 434 | **kwargs, 435 | ): 436 | init_lora_weights = kwargs.pop("init_lora_weights", True) 437 | 438 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 439 | LoraLayer.__init__(self, in_features=in_features, out_features=out_features) 440 | # Freezing the pre-trained weight matrix 441 | self.weight.requires_grad = False 442 | 443 | self.fan_in_fan_out = fan_in_fan_out 444 | if fan_in_fan_out: 445 | self.weight.data = self.weight.data.T 446 | 447 | nn.Linear.reset_parameters(self) 448 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) 449 | self.active_adapter = adapter_name 450 | 451 | def merge(self): 452 | if self.active_adapter not in self.lora_A.keys(): 453 | return 454 | if self.merged: 455 | warnings.warn("Already merged. Nothing to do.") 456 | return 457 | if self.r[self.active_adapter] > 0: 458 | self.weight.data += ( 459 | transpose( 460 | self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight, 461 | self.fan_in_fan_out, 462 | ) 463 | * self.scaling[self.active_adapter] 464 | ) 465 | self.merged = True 466 | 467 | def unmerge(self): 468 | if self.active_adapter not in self.lora_A.keys(): 469 | return 470 | if not self.merged: 471 | warnings.warn("Already unmerged. Nothing to do.") 472 | return 473 | if self.r[self.active_adapter] > 0: 474 | self.weight.data -= ( 475 | transpose( 476 | self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight, 477 | self.fan_in_fan_out, 478 | ) 479 | * self.scaling[self.active_adapter] 480 | ) 481 | self.merged = False 482 | 483 | def forward(self, x: torch.Tensor): 484 | previous_dtype = x.dtype 485 | 486 | if self.active_adapter not in self.lora_A.keys(): 487 | return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) 488 | if self.disable_adapters: 489 | if self.r[self.active_adapter] > 0 and self.merged: 490 | self.unmerge() 491 | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) 492 | elif self.r[self.active_adapter] > 0 and not self.merged: 493 | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) 494 | 495 | x = x.to(self.lora_A[self.active_adapter].weight.dtype) 496 | 497 | result += ( 498 | self.lora_B[self.active_adapter]( 499 | self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) 500 | ) 501 | * self.scaling[self.active_adapter] 502 | ) 503 | else: 504 | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) 505 | 506 | result = result.to(previous_dtype) 507 | 508 | return result 509 | 510 | 511 | if is_bnb_available(): 512 | 513 | class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): 514 | # Lora implemented in a dense layer 515 | def __init__( 516 | self, 517 | adapter_name, 518 | in_features, 519 | out_features, 520 | r: int = 0, 521 | lora_alpha: int = 1, 522 | lora_dropout: float = 0.0, 523 | **kwargs, 524 | ): 525 | bnb.nn.Linear8bitLt.__init__( 526 | self, 527 | in_features, 528 | out_features, 529 | bias=kwargs.get("bias", True), 530 | has_fp16_weights=kwargs.get("has_fp16_weights", True), 531 | memory_efficient_backward=kwargs.get("memory_efficient_backward", False), 532 | threshold=kwargs.get("threshold", 0.0), 533 | index=kwargs.get("index", None), 534 | ) 535 | LoraLayer.__init__(self, in_features=in_features, out_features=out_features) 536 | 537 | # Freezing the pre-trained weight matrix 538 | self.weight.requires_grad = False 539 | 540 | init_lora_weights = kwargs.pop("init_lora_weights", True) 541 | self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) 542 | self.active_adapter = adapter_name 543 | 544 | def forward(self, x: torch.Tensor): 545 | result = super().forward(x) 546 | 547 | if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): 548 | return result 549 | elif self.r[self.active_adapter] > 0: 550 | if not torch.is_autocast_enabled(): 551 | expected_dtype = result.dtype 552 | 553 | if x.dtype != torch.float32: 554 | x = x.float() 555 | output = ( 556 | self.lora_B[self.active_adapter]( 557 | self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) 558 | ).to(expected_dtype) 559 | * self.scaling[self.active_adapter] 560 | ) 561 | else: 562 | output = ( 563 | self.lora_B[self.active_adapter]( 564 | self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) 565 | ) 566 | * self.scaling[self.active_adapter] 567 | ) 568 | result += output 569 | return result 570 | -------------------------------------------------------------------------------- /utils/peft/tuners/p_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import warnings 18 | from dataclasses import dataclass, field 19 | from typing import Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptEncoderReparameterizationType(str, enum.Enum): 27 | MLP = "MLP" 28 | LSTM = "LSTM" 29 | 30 | 31 | @dataclass 32 | class PromptEncoderConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEncoder`]. 35 | 36 | Args: 37 | encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]): 38 | The type of reparameterization to use. 39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 40 | encoder_num_layers (`int`): The number of layers of the prompt encoder. 41 | encoder_dropout (`float`): The dropout probability of the prompt encoder. 42 | """ 43 | 44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field( 45 | default=PromptEncoderReparameterizationType.MLP, 46 | metadata={"help": "How to reparameterize the prompt encoder"}, 47 | ) 48 | encoder_hidden_size: int = field( 49 | default=None, 50 | metadata={"help": "The hidden size of the prompt encoder"}, 51 | ) 52 | encoder_num_layers: int = field( 53 | default=2, 54 | metadata={"help": "The number of layers of the prompt encoder"}, 55 | ) 56 | encoder_dropout: float = field( 57 | default=0.0, 58 | metadata={"help": "The dropout of the prompt encoder"}, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.P_TUNING 63 | 64 | 65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py 66 | # with some refactor 67 | class PromptEncoder(torch.nn.Module): 68 | """ 69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. 70 | 71 | Args: 72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder. 73 | 74 | Example: 75 | 76 | ```py 77 | >>> from peft import PromptEncoder, PromptEncoderConfig 78 | 79 | >>> config = PromptEncoderConfig( 80 | ... peft_type="P_TUNING", 81 | ... task_type="SEQ_2_SEQ_LM", 82 | ... num_virtual_tokens=20, 83 | ... token_dim=768, 84 | ... num_transformer_submodules=1, 85 | ... num_attention_heads=12, 86 | ... num_layers=12, 87 | ... encoder_reparameterization_type="MLP", 88 | ... encoder_hidden_size=768, 89 | ... ) 90 | 91 | >>> prompt_encoder = PromptEncoder(config) 92 | ``` 93 | 94 | **Attributes**: 95 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder. 96 | - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`. 97 | - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and 98 | `encoder_reparameterization_type="LSTM"`. 99 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model. 100 | - **input_size** (`int`) -- The input size of the prompt encoder. 101 | - **output_size** (`int`) -- The output size of the prompt encoder. 102 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder. 103 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the 104 | prompt encoder. 105 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt 106 | encoder. 107 | 108 | 109 | Input shape: (`batch_size`, `total_virtual_tokens`) 110 | 111 | Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 112 | """ 113 | 114 | def __init__(self, config): 115 | super().__init__() 116 | self.token_dim = config.token_dim 117 | self.input_size = self.token_dim 118 | self.output_size = self.token_dim 119 | self.hidden_size = config.encoder_hidden_size 120 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 121 | self.encoder_type = config.encoder_reparameterization_type 122 | 123 | # embedding 124 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) 125 | if not config.inference_mode: 126 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 127 | lstm_dropout = config.encoder_dropout 128 | num_layers = config.encoder_num_layers 129 | # LSTM 130 | self.lstm_head = torch.nn.LSTM( 131 | input_size=self.input_size, 132 | hidden_size=self.hidden_size, 133 | num_layers=num_layers, 134 | dropout=lstm_dropout, 135 | bidirectional=True, 136 | batch_first=True, 137 | ) 138 | 139 | self.mlp_head = torch.nn.Sequential( 140 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2), 141 | torch.nn.ReLU(), 142 | torch.nn.Linear(self.hidden_size * 2, self.output_size), 143 | ) 144 | 145 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 146 | warnings.warn( 147 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." 148 | ) 149 | layers = [ 150 | torch.nn.Linear(self.input_size, self.hidden_size), 151 | torch.nn.ReLU(), 152 | torch.nn.Linear(self.hidden_size, self.hidden_size), 153 | torch.nn.ReLU(), 154 | torch.nn.Linear(self.hidden_size, self.output_size), 155 | ] 156 | self.mlp_head = torch.nn.Sequential(*layers) 157 | 158 | else: 159 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 160 | 161 | def forward(self, indices): 162 | input_embeds = self.embedding(indices) 163 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 164 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) 165 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 166 | output_embeds = self.mlp_head(input_embeds) 167 | else: 168 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 169 | 170 | return output_embeds 171 | -------------------------------------------------------------------------------- /utils/peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The `torch.nn` model to encode the prefix. 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example: 57 | 58 | ```py 59 | >>> from peft import PrefixEncoder, PrefixTuningConfig 60 | 61 | >>> config = PrefixTuningConfig( 62 | ... peft_type="PREFIX_TUNING", 63 | ... task_type="SEQ_2_SEQ_LM", 64 | ... num_virtual_tokens=20, 65 | ... token_dim=768, 66 | ... num_transformer_submodules=1, 67 | ... num_attention_heads=12, 68 | ... num_layers=12, 69 | ... encoder_hidden_size=768, 70 | ... ) 71 | >>> prefix_encoder = PrefixEncoder(config) 72 | ``` 73 | 74 | **Attributes**: 75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder. 76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if 77 | `prefix_projection` is `True`. 78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 79 | 80 | Input shape: (`batch_size`, `num_virtual_tokens`) 81 | 82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`) 83 | """ 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.prefix_projection = config.prefix_projection 88 | token_dim = config.token_dim 89 | num_layers = config.num_layers 90 | encoder_hidden_size = config.encoder_hidden_size 91 | num_virtual_tokens = config.num_virtual_tokens 92 | if self.prefix_projection and not config.inference_mode: 93 | # Use a two-layer MLP to encode the prefix 94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 95 | self.transform = torch.nn.Sequential( 96 | torch.nn.Linear(token_dim, encoder_hidden_size), 97 | torch.nn.Tanh(), 98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 99 | ) 100 | else: 101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 102 | 103 | def forward(self, prefix: torch.Tensor): 104 | if self.prefix_projection: 105 | prefix_tokens = self.embedding(prefix) 106 | past_key_values = self.transform(prefix_tokens) 107 | else: 108 | past_key_values = self.embedding(prefix) 109 | return past_key_values 110 | -------------------------------------------------------------------------------- /utils/peft/tuners/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import math 18 | from dataclasses import dataclass, field 19 | from typing import Optional, Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptTuningInit(str, enum.Enum): 27 | TEXT = "TEXT" 28 | RANDOM = "RANDOM" 29 | 30 | 31 | @dataclass 32 | class PromptTuningConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEmbedding`]. 35 | 36 | Args: 37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. 38 | prompt_tuning_init_text (`str`, *optional*): 39 | The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. 40 | tokenizer_name_or_path (`str`, *optional*): 41 | The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. 42 | """ 43 | 44 | prompt_tuning_init: Union[PromptTuningInit, str] = field( 45 | default=PromptTuningInit.RANDOM, 46 | metadata={"help": "How to initialize the prompt tuning parameters"}, 47 | ) 48 | prompt_tuning_init_text: Optional[str] = field( 49 | default=None, 50 | metadata={ 51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 52 | }, 53 | ) 54 | tokenizer_name_or_path: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 58 | }, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.PROMPT_TUNING 63 | 64 | 65 | class PromptEmbedding(torch.nn.Module): 66 | """ 67 | The model to encode virtual tokens into prompt embeddings. 68 | 69 | Args: 70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding. 71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. 72 | 73 | **Attributes**: 74 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. 75 | 76 | Example: 77 | 78 | ```py 79 | >>> from peft import PromptEmbedding, PromptTuningConfig 80 | 81 | >>> config = PromptTuningConfig( 82 | ... peft_type="PROMPT_TUNING", 83 | ... task_type="SEQ_2_SEQ_LM", 84 | ... num_virtual_tokens=20, 85 | ... token_dim=768, 86 | ... num_transformer_submodules=1, 87 | ... num_attention_heads=12, 88 | ... num_layers=12, 89 | ... prompt_tuning_init="TEXT", 90 | ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", 91 | ... tokenizer_name_or_path="t5-base", 92 | ... ) 93 | 94 | >>> # t5_model.shared is the word embeddings of the base model 95 | >>> prompt_embedding = PromptEmbedding(config, t5_model.shared) 96 | ``` 97 | 98 | Input Shape: (`batch_size`, `total_virtual_tokens`) 99 | 100 | Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 101 | """ 102 | 103 | def __init__(self, config, word_embeddings): 104 | super().__init__() 105 | 106 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 107 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) 108 | if config.prompt_tuning_init == PromptTuningInit.TEXT: 109 | from transformers import AutoTokenizer 110 | 111 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) 112 | init_text = config.prompt_tuning_init_text 113 | init_token_ids = tokenizer(init_text)["input_ids"] 114 | # Trim or iterate until num_text_tokens matches total_virtual_tokens 115 | num_text_tokens = len(init_token_ids) 116 | if num_text_tokens > total_virtual_tokens: 117 | init_token_ids = init_token_ids[:total_virtual_tokens] 118 | elif num_text_tokens < total_virtual_tokens: 119 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens) 120 | init_token_ids = init_token_ids * num_reps 121 | init_token_ids = init_token_ids[:total_virtual_tokens] 122 | 123 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone() 124 | word_embedding_weights = word_embedding_weights.to(torch.float32) 125 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights) 126 | 127 | def forward(self, indices): 128 | # Just get embeddings 129 | prompt_embeddings = self.embedding(indices) 130 | return prompt_embeddings 131 | -------------------------------------------------------------------------------- /utils/peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 21 | from .other import ( 22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, 25 | CONFIG_NAME, 26 | WEIGHTS_NAME, 27 | _set_trainable, 28 | bloom_model_postprocess_past_key_value, 29 | prepare_model_for_int8_training, 30 | shift_tokens_right, 31 | transpose, 32 | _get_submodules, 33 | _set_adapter, 34 | _freeze_adapter, 35 | ModulesToSaveWrapper, 36 | ) 37 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 38 | -------------------------------------------------------------------------------- /utils/peft/utils/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import enum 16 | import json 17 | import os 18 | from dataclasses import asdict, dataclass, field 19 | from typing import Optional, Union 20 | 21 | from huggingface_hub import hf_hub_download 22 | from transformers.utils import PushToHubMixin 23 | 24 | from .other import CONFIG_NAME 25 | 26 | 27 | class PeftType(str, enum.Enum): 28 | PROMPT_TUNING = "PROMPT_TUNING" 29 | P_TUNING = "P_TUNING" 30 | PREFIX_TUNING = "PREFIX_TUNING" 31 | LORA = "LORA" 32 | ADALORA = "ADALORA" 33 | 34 | 35 | class TaskType(str, enum.Enum): 36 | SEQ_CLS = "SEQ_CLS" 37 | SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" 38 | CAUSAL_LM = "CAUSAL_LM" 39 | TOKEN_CLS = "TOKEN_CLS" 40 | 41 | 42 | @dataclass 43 | class PeftConfigMixin(PushToHubMixin): 44 | r""" 45 | This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all 46 | PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to 47 | push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a 48 | directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. 49 | 50 | Args: 51 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 52 | """ 53 | peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) 54 | 55 | @property 56 | def __dict__(self): 57 | return asdict(self) 58 | 59 | def to_dict(self): 60 | return self.__dict__ 61 | 62 | def save_pretrained(self, save_directory, **kwargs): 63 | r""" 64 | This method saves the configuration of your adapter model in a directory. 65 | 66 | Args: 67 | save_directory (`str`): 68 | The directory where the configuration will be saved. 69 | kwargs (additional keyword arguments, *optional*): 70 | Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`] 71 | method. 72 | """ 73 | if os.path.isfile(save_directory): 74 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") 75 | 76 | os.makedirs(save_directory, exist_ok=True) 77 | 78 | output_dict = self.__dict__ 79 | output_path = os.path.join(save_directory, CONFIG_NAME) 80 | 81 | # save it 82 | with open(output_path, "w") as writer: 83 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) 84 | 85 | @classmethod 86 | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): 87 | r""" 88 | This method loads the configuration of your adapter model from a directory. 89 | 90 | Args: 91 | pretrained_model_name_or_path (`str`): 92 | The directory or the Hub repository id where the configuration is saved. 93 | kwargs (additional keyword arguments, *optional*): 94 | Additional keyword arguments passed along to the child class initialization. 95 | """ 96 | path = ( 97 | os.path.join(pretrained_model_name_or_path, subfolder) 98 | if subfolder is not None 99 | else pretrained_model_name_or_path 100 | ) 101 | if os.path.isfile(os.path.join(path, CONFIG_NAME)): 102 | config_file = os.path.join(path, CONFIG_NAME) 103 | else: 104 | try: 105 | config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder) 106 | except Exception: 107 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") 108 | 109 | loaded_attributes = cls.from_json_file(config_file) 110 | 111 | config = cls(**kwargs) 112 | 113 | for key, value in loaded_attributes.items(): 114 | if hasattr(config, key): 115 | setattr(config, key, value) 116 | 117 | return config 118 | 119 | @classmethod 120 | def from_json_file(cls, path_json_file, **kwargs): 121 | r""" 122 | Loads a configuration file from a json file. 123 | 124 | Args: 125 | path_json_file (`str`): 126 | The path to the json file. 127 | """ 128 | with open(path_json_file, "r") as file: 129 | json_object = json.load(file) 130 | 131 | return json_object 132 | 133 | 134 | @dataclass 135 | class PeftConfig(PeftConfigMixin): 136 | """ 137 | This is the base configuration class to store the configuration of a [`PeftModel`]. 138 | 139 | Args: 140 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 141 | task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. 142 | inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. 143 | """ 144 | 145 | base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."}) 146 | peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"}) 147 | task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"}) 148 | inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) 149 | 150 | 151 | @dataclass 152 | class PromptLearningConfig(PeftConfig): 153 | """ 154 | This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or 155 | [`PromptTuning`]. 156 | 157 | Args: 158 | num_virtual_tokens (`int`): The number of virtual tokens to use. 159 | token_dim (`int`): The hidden embedding dimension of the base transformer model. 160 | num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model. 161 | num_attention_heads (`int`): The number of attention heads in the base transformer model. 162 | num_layers (`int`): The number of layers in the base transformer model. 163 | """ 164 | 165 | num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"}) 166 | token_dim: int = field( 167 | default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"} 168 | ) 169 | num_transformer_submodules: Optional[int] = field( 170 | default=None, metadata={"help": "Number of transformer submodules"} 171 | ) 172 | num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) 173 | num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) 174 | -------------------------------------------------------------------------------- /utils/peft/utils/other.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | from functools import reduce 18 | 19 | import torch 20 | 21 | 22 | # needed for prefix-tuning of bloom model 23 | def bloom_model_postprocess_past_key_value(past_key_values): 24 | past_key_values = torch.cat(past_key_values) 25 | total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape 26 | keys = past_key_values[: total_layers // 2] 27 | keys = keys.transpose(2, 3).reshape( 28 | total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens 29 | ) 30 | values = past_key_values[total_layers // 2 :] 31 | values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) 32 | 33 | return tuple(zip(keys, values)) 34 | 35 | 36 | def prepare_model_for_int8_training( 37 | model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"] 38 | ): 39 | r""" 40 | This method wraps the entire protocol for preparing a model before running a training. This includes: 41 | 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm 42 | head to fp32 43 | 44 | Args: 45 | model, (`transformers.PreTrainedModel`): 46 | The loaded model from `transformers` 47 | """ 48 | loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False) 49 | 50 | for name, param in model.named_parameters(): 51 | # freeze base model's layers 52 | param.requires_grad = False 53 | 54 | if loaded_in_8bit: 55 | # cast layer norm in fp32 for stability for 8bit models 56 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 57 | param.data = param.data.to(torch.float32) 58 | 59 | if loaded_in_8bit and use_gradient_checkpointing: 60 | # For backward compatibility 61 | if hasattr(model, "enable_input_require_grads"): 62 | model.enable_input_require_grads() 63 | else: 64 | 65 | def make_inputs_require_grad(module, input, output): 66 | output.requires_grad_(True) 67 | 68 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 69 | 70 | # enable gradient checkpointing for memory efficiency 71 | model.gradient_checkpointing_enable() 72 | 73 | if hasattr(model, output_embedding_layer_name): 74 | output_embedding_layer = getattr(model, output_embedding_layer_name) 75 | input_dtype = output_embedding_layer.weight.dtype 76 | 77 | class CastOutputToFloat(torch.nn.Sequential): 78 | r""" 79 | Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted 80 | in fp32 81 | 82 | """ 83 | 84 | def forward(self, x): 85 | return super().forward(x.to(input_dtype)).to(torch.float32) 86 | 87 | setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) 88 | 89 | return model 90 | 91 | 92 | # copied from transformers.models.bart.modeling_bart 93 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 94 | """ 95 | Shift input ids one token to the right. 96 | 97 | Args: 98 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids 99 | pad_token_id (`int`): The id of the `padding` token. 100 | decoder_start_token_id (`int`): The id of the `start` token. 101 | """ 102 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 103 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 104 | shifted_input_ids[:, 0] = decoder_start_token_id 105 | 106 | if pad_token_id is None: 107 | raise ValueError("self.model.config.pad_token_id has to be defined.") 108 | # replace possible -100 values in labels by `pad_token_id` 109 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 110 | 111 | return shifted_input_ids 112 | 113 | 114 | class ModulesToSaveWrapper(torch.nn.Module): 115 | def __init__(self, module_to_save, adapter_name): 116 | super().__init__() 117 | self.original_module = module_to_save 118 | self.modules_to_save = torch.nn.ModuleDict({}) 119 | self.update(adapter_name) 120 | self.active_adapter = adapter_name 121 | 122 | def update(self, adapter_name): 123 | self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) 124 | 125 | def forward(self, *args, **kwargs): 126 | if self.active_adapter not in self.modules_to_save: 127 | return self.original_module(*args, **kwargs) 128 | return self.modules_to_save[self.active_adapter](*args, **kwargs) 129 | 130 | def get_module_by_name(module, access_string): 131 | names = access_string.split(sep='.') 132 | return reduce(getattr, names, module) 133 | 134 | 135 | def _get_submodules(model, key): 136 | parent = get_module_by_name(model, ".".join(key.split(".")[:-1])) 137 | target_name = key.split(".")[-1] 138 | target = get_module_by_name(model, key) 139 | return parent, target, target_name 140 | 141 | 142 | def _freeze_adapter(model, adapter_name): 143 | for n, p in model.named_parameters(): 144 | if adapter_name in n: 145 | p.requires_grad = False 146 | 147 | 148 | def _set_trainable(model, adapter_name): 149 | key_list = [key for key, _ in model.named_modules()] 150 | for key in key_list: 151 | target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) 152 | if target_module_found: 153 | parent, target, target_name = _get_submodules(model, key) 154 | if isinstance(target, ModulesToSaveWrapper): 155 | target.update(adapter_name) 156 | else: 157 | for param in target.parameters(): 158 | param.requires_grad = True 159 | setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) 160 | 161 | 162 | def _set_adapter(model, adapter_name): 163 | for module in model.modules(): 164 | if isinstance(module, ModulesToSaveWrapper): 165 | module.active_adapter = adapter_name 166 | 167 | 168 | def fsdp_auto_wrap_policy(model): 169 | import functools 170 | import os 171 | 172 | from accelerate import FullyShardedDataParallelPlugin 173 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy 174 | 175 | from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder 176 | 177 | def lambda_policy_fn(module): 178 | if ( 179 | len(list(module.named_children())) == 0 180 | and getattr(module, "weight", None) is not None 181 | and module.weight.requires_grad 182 | ): 183 | return True 184 | return False 185 | 186 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 187 | transformer_wrap_policy = functools.partial( 188 | transformer_auto_wrap_policy, 189 | transformer_layer_cls=( 190 | PrefixEncoder, 191 | PromptEncoder, 192 | PromptEmbedding, 193 | FullyShardedDataParallelPlugin.get_module_class_from_name( 194 | model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") 195 | ), 196 | ), 197 | ) 198 | 199 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) 200 | return auto_wrap_policy 201 | 202 | 203 | def transpose(weight, fan_in_fan_out): 204 | return weight.T if fan_in_fan_out else weight 205 | 206 | 207 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { 208 | "t5": ["q", "v"], 209 | "mt5": ["q", "v"], 210 | "bart": ["q_proj", "v_proj"], 211 | "gpt2": ["c_attn"], 212 | "bloom": ["query_key_value"], 213 | "blip-2": ["q", "v", "q_proj", "v_proj"], 214 | "opt": ["q_proj", "v_proj"], 215 | "gptj": ["q_proj", "v_proj"], 216 | "gpt_neox": ["query_key_value"], 217 | "gpt_neo": ["q_proj", "v_proj"], 218 | "bert": ["query", "value"], 219 | "roberta": ["query", "value"], 220 | "xlm-roberta": ["query", "value"], 221 | "electra": ["query", "value"], 222 | "deberta-v2": ["query_proj", "value_proj"], 223 | "deberta": ["in_proj"], 224 | "layoutlm": ["query", "value"], 225 | "llama": ["q_proj", "v_proj"], 226 | "chatglm": ["query_key_value"], 227 | } 228 | 229 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { 230 | "t5": ["q", "k", "v", "o", "wi", "wo"], 231 | "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], 232 | "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], 233 | # "gpt2": ["c_attn"], 234 | # "bloom": ["query_key_value"], 235 | "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], 236 | # "gptj": ["q_proj", "v_proj"], 237 | # "gpt_neox": ["query_key_value"], 238 | # "gpt_neo": ["q_proj", "v_proj"], 239 | # "bert": ["query", "value"], 240 | "roberta": ["query", "key", "value", "dense"], 241 | # "xlm-roberta": ["query", "value"], 242 | # "electra": ["query", "value"], 243 | "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], 244 | # "deberta": ["in_proj"], 245 | # "layoutlm": ["query", "value"], 246 | } 247 | 248 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { 249 | "bloom": bloom_model_postprocess_past_key_value, 250 | } 251 | 252 | WEIGHTS_NAME = "adapter_model.bin" 253 | CONFIG_NAME = "adapter_config.json" 254 | -------------------------------------------------------------------------------- /utils/peft/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import PeftType, PromptLearningConfig 17 | 18 | 19 | def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): 20 | """ 21 | Get the state dict of the Peft model. 22 | 23 | Args: 24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, 25 | the model should be the underlying model/unwrapped model (i.e. model.module). 26 | state_dict (`dict`, *optional*, defaults to `None`): 27 | The state dict of the model. If not provided, the state dict of the model 28 | will be used. 29 | """ 30 | config = model.peft_config[adapter_name] 31 | if state_dict is None: 32 | state_dict = model.state_dict() 33 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 34 | # to_return = lora_state_dict(model, bias=model.peft_config.bias) 35 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` 36 | # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP 37 | bias = config.bias 38 | if bias == "none": 39 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 40 | elif bias == "all": 41 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 42 | elif bias == "lora_only": 43 | to_return = {} 44 | for k in state_dict: 45 | if "lora_" in k: 46 | to_return[k] = state_dict[k] 47 | bias_name = k.split("lora_")[0] + "bias" 48 | if bias_name in state_dict: 49 | to_return[bias_name] = state_dict[bias_name] 50 | else: 51 | raise NotImplementedError 52 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} 53 | if config.peft_type == PeftType.ADALORA: 54 | rank_pattern = config.rank_pattern 55 | if rank_pattern is not None: 56 | rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} 57 | config.rank_pattern = rank_pattern 58 | to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) 59 | elif isinstance(config, PromptLearningConfig): 60 | to_return = {} 61 | if config.inference_mode: 62 | prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight 63 | else: 64 | prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) 65 | to_return["prompt_embeddings"] = prompt_embeddings 66 | else: 67 | raise NotImplementedError 68 | if model.modules_to_save is not None: 69 | for key, value in state_dict.items(): 70 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): 71 | to_return[key.replace("modules_to_save.", "")] = value 72 | 73 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} 74 | return to_return 75 | 76 | 77 | def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"): 78 | """ 79 | Set the state dict of the Peft model. 80 | 81 | Args: 82 | model ([`PeftModel`]): The Peft model. 83 | peft_model_state_dict (`dict`): The state dict of the Peft model. 84 | """ 85 | config = model.peft_config[adapter_name] 86 | state_dict = {} 87 | if model.modules_to_save is not None: 88 | for key, value in peft_model_state_dict.items(): 89 | if any(module_name in key for module_name in model.modules_to_save): 90 | for module_name in model.modules_to_save: 91 | if module_name in key: 92 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") 93 | break 94 | state_dict[key] = value 95 | else: 96 | state_dict = peft_model_state_dict 97 | 98 | #print("config.peft_type: ".format(config.peft_type)) 99 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 100 | peft_model_state_dict = {} 101 | for k, v in state_dict.items(): 102 | if "lora_" in k: 103 | suffix = k.split("lora_")[1] 104 | if "." in suffix: 105 | suffix_to_replace = ".".join(suffix.split(".")[1:]) 106 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") 107 | else: 108 | k = f"{k}.{adapter_name}" 109 | peft_model_state_dict[k] = v 110 | else: 111 | peft_model_state_dict[k] = v 112 | if config.peft_type == PeftType.ADALORA: 113 | rank_pattern = config.rank_pattern 114 | if rank_pattern is not None: 115 | model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) 116 | elif isinstance(config, PromptLearningConfig): 117 | peft_model_state_dict = state_dict 118 | else: 119 | raise NotImplementedError 120 | 121 | model.load_state_dict(peft_model_state_dict, strict=False) 122 | #exit() 123 | if isinstance(config, PromptLearningConfig): 124 | model.prompt_encoder[adapter_name].embedding.load_state_dict( 125 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True 126 | ) 127 | --------------------------------------------------------------------------------