├── .gitignore ├── LICENSE ├── README.md ├── configs ├── llama_adamole_csr_test.config └── llama_adamole_csr_train.config ├── data.py ├── images └── adamole.png ├── poster.pdf ├── requirements.txt ├── src ├── __init__.py ├── adamole │ ├── __init__.py │ ├── config.py │ ├── layer.py │ └── model.py ├── config.py ├── lora │ ├── __init__.py │ ├── config.py │ ├── layer.py │ └── model.py ├── mapping.py ├── mole │ ├── __init__.py │ ├── config.py │ ├── layer.py │ └── model.py ├── peft_model.py ├── trainer.py └── utils │ ├── peft_types.py │ └── save_and_load.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Project files 163 | outputs/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zefang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # \[COLM 2024\] AdaMoLE: Adaptive Mixture of Low-Rank Adaptation Experts 2 | 3 | ## Introduction 4 | 5 | Explore AdaMoLE, a novel approach that integrates Low-Rank Adaptation (LoRA) with a dynamic Mixture of Experts (MoE) to enhance fine-tuning of Large Language Models (LLMs). AdaMoLE advances beyond static top-k expert activation by employing a dynamic thresholding mechanism, which adapts to the complexities of varied tasks to optimize model performance. This method efficiently selects and activates the most suitable experts based on input context, demonstrating superior performance in commonsense reasoning and natural language processing tasks. 6 | 7 | For more details regarding AdaMoLE, you are welcome to refer to our [paper](https://arxiv.org/abs/2405.00361) and [poster](poster.pdf). 8 | 9 |
10 | AdaMoLE Framework 11 |
12 | 13 | ## Features 14 | 15 | - **Dynamic Expert Activation:** Improves expert activation with a dynamic threshold network, adapting to task complexities for optimal expert engagement. 16 | - **Integration of LoRA and MoE:** Seamlessly combines Low-Rank Adaptation with the Mixture of Experts framework, enhancing the fine-tuning process for LLMs. 17 | - **Hugging Face Compatibility:** Designed to be compatible with the Hugging Face's [Transformers](https://github.com/huggingface/transformers) and [Parameter-Efficient Fine-Tuning (PEFT)](https://github.com/huggingface/peft) library, ensuring ease of use and integration into existing workflows. 18 | 19 | ## Installation 20 | 21 | ```bash 22 | # Navigate to the AdaMoLE directory 23 | cd AdaMoLE 24 | 25 | # Install required dependencies 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Usage 30 | 31 | ```bash 32 | # Train the model 33 | python train.py @configs/llama_adamole_csr_train.config 34 | 35 | # Test the model 36 | python test.py @configs/llama_adamole_csr_test.config 37 | ``` 38 | 39 | ## Citation 40 | 41 | If you find AdaMoLE useful in your projects, please consider citing our paper: 42 | 43 | Liu, Z., & Luo, J. (2024). AdaMoLE: Fine-Tuning Large Language Models with Adaptive Mixture of Low-Rank Adaptation Experts. arXiv preprint *arXiv:2405.00361*. 44 | 45 | ```bibtex 46 | @article{liu2024adamole, 47 | title={AdaMoLE: Fine-Tuning Large Language Models with Adaptive Mixture of Low-Rank Adaptation Experts}, 48 | author={Liu, Zefang and Luo, Jiahua}, 49 | journal={arXiv preprint arXiv:2405.00361}, 50 | year={2024} 51 | } 52 | ``` 53 | 54 | ## License 55 | 56 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 57 | -------------------------------------------------------------------------------- /configs/llama_adamole_csr_test.config: -------------------------------------------------------------------------------- 1 | --model_path=outputs/llama-2-7b-hf-adamole-the8-commonsense-qa 2 | --data_path=tau/commonsense_qa 3 | --max_new_tokens=16 4 | --batch_size=16 5 | --logits -------------------------------------------------------------------------------- /configs/llama_adamole_csr_train.config: -------------------------------------------------------------------------------- 1 | --model_path=meta-llama/Llama-2-7b-hf 2 | --data_path=tau/commonsense_qa 3 | --peft_type=adamole 4 | --lora_rank=32 5 | --target_modules 6 | q_proj 7 | k_proj 8 | v_proj 9 | o_proj 10 | --num_experts=8 11 | --threshold=0.125 12 | --max_length=256 13 | --batch_size=4 14 | --gradient_accumulation_steps=4 15 | --num_train_epochs=2 16 | --learning_rate=1e-4 17 | --lr_scheduler_type=constant_with_warmup 18 | --warmup_steps=200 19 | --weight_decay=0.0 20 | --aux_loss_coeff=1e-3 -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loading and Preprocessing Datasets 3 | """ 4 | import os 5 | 6 | from datasets import load_dataset, concatenate_datasets, DatasetDict 7 | 8 | 9 | def format_text(example, data_name: str, prompt_only: bool = True): 10 | """ 11 | Format an example into one text 12 | """ 13 | if data_name == 'boolq': 14 | """ 15 | Passage: Windows Movie Maker -- Windows Movie Maker (formerly known as Windows Live Movie Maker in Windows 7) 16 | is a discontinued video editing software by Microsoft. It is a part of Windows Essentials software suite and 17 | offers the ability to create and edit videos as well as to publish them on OneDrive, Facebook, Vimeo, YouTube, 18 | and Flickr. 19 | Question: is windows movie maker part of windows essentials 20 | Choices: 21 | A. No 22 | B. Yes 23 | Answer: B 24 | """ 25 | text = f"Passage: {example['passage']}\nQuestion: {example['question']}\nChoices:\n" 26 | text += "A. No\nB. Yes\n" 27 | text += "Answer: " 28 | example['answer'] = ['A', 'B'][example['label']] 29 | example['num_choices'] = 2 30 | 31 | elif data_name == 'cb': 32 | """ 33 | Text: It was a complex language. Not written down but handed down. One might say it was peeled down. 34 | Hypothesis: the language was peeled down 35 | Question: Does the text entail the hypothesis, contradict it, or is it neutral? 36 | Choices: 37 | A. Entailment 38 | B. Contradiction 39 | C. Neutral 40 | Answer: A 41 | """ 42 | text = f"Text: {example['premise']}\nHypothesis: {example['hypothesis']}\n" \ 43 | f"Question: Does the text entail the hypothesis, contradict it, or is it neutral?\nChoices:\n" 44 | text += "A. Entailment\nB. Contradiction\nC. Neutral\n" 45 | text += "Answer: " 46 | example['answer'] = ['A', 'B', 'C'][example['label']] 47 | example['num_choices'] = 3 48 | 49 | elif data_name == 'copa': 50 | """ 51 | Premise: My body cast a shadow over the grass. 52 | Question: What’s the cause for this? 53 | Choices: 54 | A. The sun was rising. 55 | B. The grass was cut. 56 | Answer: A 57 | """ 58 | text = f"Premise: {example['premise']}\nQuestion: What’s the {example['question']} for this?\nChoices:\n" 59 | text += f"A. {example['choice1']}\nB. {example['choice2']}\n" 60 | text += "Answer: " 61 | example['answer'] = ['A', 'B'][example['label']] 62 | example['num_choices'] = 2 63 | 64 | elif data_name == 'multirc': 65 | """ 66 | Paragraph: While this process moved along, diplomacy continued its rounds. Direct pressure on the Taliban had 67 | proved unsuccessful. As one NSC staff note put it, "Under the Taliban, Afghanistan is not so much a state 68 | sponsor of terrorism as it is a state sponsored by terrorists." ... 69 | Question: What did the high-level effort to persuade Pakistan include? 70 | Candidate Answer: Children, Gerd, or Dorian Popa 71 | Choices: 72 | A. False 73 | B. True 74 | Answer: A 75 | """ 76 | text = f"Paragraph: {example['paragraph']}\nQuestion: {example['question']}\n" \ 77 | f"Candidate Answer: {example['answer']}\nChoices:\n" 78 | text += f"A. False\nB. True\n" 79 | text += "Answer: " 80 | example['answer'] = ['A', 'B'][example['label']] 81 | example['num_choices'] = 2 82 | 83 | elif data_name == 'record': 84 | raise NotImplementedError 85 | 86 | elif data_name == 'rte': 87 | """ 88 | Text: No Weapons of Mass Destruction Found in Iraq Yet. 89 | Hypothesis: Weapons of Mass Destruction Found in Iraq. 90 | Question: Does the text entail the hypothesis or not? 91 | Choices: 92 | A. Entailment 93 | B. Not entailment 94 | Answer: B 95 | """ 96 | text = f"Text: {example['premise']}\nHypothesis: {example['hypothesis']}\n" \ 97 | f"Question: Does the text entail the hypothesis or not?\nChoices:\n" 98 | text += "A. Entailment\nB. Not entailment\n" 99 | text += "Answer: " 100 | example['answer'] = ['A', 'B'][example['label']] 101 | example['num_choices'] = 2 102 | 103 | elif data_name == 'wic': 104 | """ 105 | Context 1: Do you want to come over to my later? 106 | Context 2: A political system with no for the less prominent groups. 107 | Question: Is the word in brackets used with the same meaning in both contexts? 108 | Choices: 109 | A. False 110 | B. True 111 | Answer: A 112 | """ 113 | sentence1 = example['sentence1'] 114 | sentence2 = example['sentence2'] 115 | marked_sentence1 = sentence1[:example['start1']] + '<' + sentence1[example['start1']:example['end1']] \ 116 | + '>' + sentence1[example['end1']:] 117 | marked_sentence2 = sentence2[:example['start2']] + '<' + sentence2[example['start2']:example['end2']] \ 118 | + '>' + sentence2[example['end2']:] 119 | text = f"Context 1: {marked_sentence1}\nContext 2: {marked_sentence2}\n" \ 120 | f"Question: Is the word in brackets used with the same meaning in both contexts?\nChoices:\n" 121 | text += "A. False\nB. True\n" 122 | text += "Answer: " 123 | example['answer'] = ['A', 'B'][example['label']] 124 | example['num_choices'] = 2 125 | 126 | elif data_name == 'wsc.fixed': 127 | """ 128 | Text: told Pete many lies about himself, which Pete included in his book. should have been more 129 | skeptical. 130 | Question: Is the pronoun in brackets referring to the correct entity as intended in the context? 131 | Choices: 132 | A. False 133 | B. True 134 | Answer: A 135 | """ 136 | tokens = example['text'].split() 137 | span1_start = example['span1_index'] 138 | span1_end = example['span1_index'] + len(example['span1_text'].split()) 139 | span2_start = example['span2_index'] 140 | span2_end = example['span2_index'] + len(example['span2_text'].split()) 141 | marked_tokens = tokens[:span1_start] + ['<' + example['span1_text'] + '>'] + tokens[span1_end:span2_start] \ 142 | + ['<' + example['span2_text'] + '>'] + tokens[span2_end:] 143 | marked_text = ' '.join(marked_tokens) 144 | text = f"Text: {marked_text}\n" \ 145 | f"Question: Is the pronoun in brackets referring to the correct entity as intended in the context?\n" \ 146 | f"Choices:\n" 147 | text += "A. False\nB. True\n" 148 | text += "Answer: " 149 | example['answer'] = ['A', 'B'][example['label']] 150 | example['num_choices'] = 2 151 | 152 | elif data_name == 'commonsense_qa': 153 | """ 154 | Question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the 155 | school had made to change? 156 | Choices: 157 | A. ignore 158 | B. enforce 159 | C. authoritarian 160 | D. yell at 161 | E. avoid 162 | Answer: A 163 | """ 164 | text = f"Question: {example['question']}\nChoices:\n" 165 | choices = example['choices'] 166 | for label, choice in zip(choices['label'], choices['text']): 167 | text += f"{label}. {choice}\n" 168 | text += "Answer: " 169 | example['answer'] = example['answerKey'] 170 | example['num_choices'] = 5 171 | 172 | elif data_name == 'cosmos_qa': 173 | """ 174 | Context: Good Old War and person L : I saw both of these bands Wednesday night , and they both blew me away . 175 | seriously . Good Old War is acoustic and makes me smile . I really can not help but be happy when I listen to 176 | them ; I think it 's the fact that they seemed so happy themselves when they played . 177 | Question: In the future , will this person go to see other bands play ? 178 | Choices: 179 | A. None of the above choices . 180 | B. This person likes music and likes to see the show , they will see other bands play . 181 | C. This person only likes Good Old War and Person L , no other bands . 182 | D. Other Bands is not on tour and this person can not see them . 183 | Answer: B 184 | """ 185 | text = f"Context: {example['context']}\nQuestion: {example['question']}\nChoices:\n" 186 | text += f"A. {example['answer0']}\n" 187 | text += f"B. {example['answer1']}\n" 188 | text += f"C. {example['answer2']}\n" 189 | text += f"D. {example['answer3']}\n" 190 | text += "Answer: " 191 | example['answer'] = chr(ord('A') + example['label']) 192 | example['num_choices'] = 4 193 | 194 | elif data_name == 'social_i_qa': 195 | """ 196 | Context: Cameron decided to have a barbecue and gathered her friends together. 197 | Question: How would Others feel as a result? 198 | Choices: 199 | A. like attending 200 | B. like staying home 201 | C. a good friend to have 202 | Answer: A 203 | """ 204 | text = f"Context: {example['context']}\nQuestion: {example['question']}\nChoices:\n" 205 | text += f"A. {example['answerA']}\n" 206 | text += f"B. {example['answerB']}\n" 207 | text += f"C. {example['answerC']}\n" 208 | text += "Answer: " 209 | example['answer'] = chr(ord('A') + int(example['label']) - 1) 210 | example['num_choices'] = 3 211 | 212 | elif data_name == 'piqa': 213 | """ 214 | Question: When boiling butter, when it's ready, you can 215 | Choices: 216 | A. Pour it onto a plate 217 | B. Pour it into a jar 218 | Answer: B 219 | """ 220 | text = f"Question: {example['goal']}\nChoices:\n" 221 | text += f"A. {example['sol1']}\n" 222 | text += f"B. {example['sol2']}\n" 223 | text += "Answer: " 224 | example['answer'] = chr(ord('A') + example['label']) 225 | example['num_choices'] = 2 226 | 227 | elif data_name == 'openbookqa': 228 | """ 229 | Fact: the sun is the source of energy for physical cycles on Earth 230 | Question: The sun is responsible for 231 | Choices: 232 | A. puppies learning new tricks 233 | B. children growing up and getting old 234 | C. flowers wilting in a vase 235 | D. plants sprouting, blooming and wilting 236 | Answer: D 237 | """ 238 | text = f"Fact: {example['fact1']}\nQuestion: {example['question_stem']}\nChoices:\n" 239 | choices = example['choices'] 240 | for label, choice in zip(choices['label'], choices['text']): 241 | text += f"{label}. {choice}\n" 242 | text += "Answer: " 243 | example['answer'] = example['answerKey'] 244 | example['num_choices'] = 4 245 | 246 | elif data_name == 'ai2_arc': 247 | """ 248 | Question: George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most 249 | heat? 250 | Choices: 251 | A. dry palms 252 | B. wet palms 253 | C. palms covered with oil 254 | D. palms covered with lotion 255 | Answer: A 256 | """ 257 | text = f"Question: {example['question']}\nChoices:\n" 258 | choices = example['choices'] 259 | for label, choice in zip(choices['label'], choices['text']): 260 | text += f"{label}. {choice}\n" 261 | text += "Answer: " 262 | example['answer'] = example['answerKey'] 263 | example['num_choices'] = 4 264 | 265 | elif data_name == 'scienceqa': 266 | """ 267 | Question: Which tense does the sentence use? 268 | Mona will print her name with care. 269 | Choices: 270 | A. present tense 271 | B. future tense 272 | C. past tense 273 | Answer: B 274 | """ 275 | text = f"Question: {example['question']}\nChoices:\n" 276 | choices = example['choices'] 277 | for index, choice in enumerate(choices): 278 | text += f"{chr(ord('A') + index)}. {choice}\n" 279 | text += "Answer: " 280 | example['answer'] = chr(ord('A') + example['answer']) 281 | example['num_choices'] = 5 # undefined 282 | 283 | else: 284 | raise NotImplementedError 285 | 286 | if not prompt_only: 287 | text += f"{example['answer']}" 288 | example['data_name'] = data_name 289 | example['text'] = text 290 | return example 291 | 292 | 293 | def get_formatted_datasets(data_path: str, prompt_only: bool): 294 | """ 295 | Get formatted datasets 296 | """ 297 | data_name = os.path.basename(data_path).lower() 298 | 299 | # Load and format datasets 300 | if data_name == 'super_glue': 301 | data_names = ['boolq', 'cb', 'copa', 'rte', 'wic'] 302 | splits = ['train', 'validation', 'test'] 303 | formatted_datasets = {split: [] for split in splits} 304 | 305 | # Load and format datasets 306 | for _data_name in data_names: 307 | _datasets = load_dataset(path='super_glue', name=_data_name) 308 | print(f'Datasets: {_datasets}') 309 | _formatted_datasets = _datasets.map( 310 | lambda example: format_text(example, _data_name, prompt_only=prompt_only), 311 | batched=False, load_from_cache_file=False) 312 | for split in splits: 313 | formatted_datasets[split].append( 314 | _formatted_datasets[split].select_columns(['data_name', 'text', 'num_choices', 'answer'])) 315 | 316 | # Concatenate datasets 317 | for split in splits: 318 | formatted_datasets[split] = concatenate_datasets(formatted_datasets[split]) 319 | formatted_datasets = DatasetDict(formatted_datasets) 320 | print(f'Formatted datasets: {formatted_datasets}') 321 | print(f"Text example:\n{formatted_datasets['train']['text'][0]}") 322 | else: 323 | # Load datasets 324 | if data_name in [ 325 | 'axb', 'axg', 'boolq', 'cb', 'copa', 'multirc', 326 | 'record', 'rte', 'wic', 'wsc', 'wsc.fixed', 327 | ]: 328 | datasets = load_dataset(path='super_glue', name=data_name) 329 | elif data_name == 'openbookqa': 330 | datasets = load_dataset(path=data_path, name='additional') 331 | elif data_name == 'ai2_arc': 332 | datasets = load_dataset(path=data_path, name='ARC-Challenge') 333 | elif data_name == 'scienceqa': 334 | datasets = load_dataset(path=data_path) 335 | datasets = datasets.filter(lambda example: example["image"] is None) 336 | else: 337 | datasets = load_dataset(path=data_path) 338 | print(f'Datasets: {datasets}') 339 | print(f"Example: {datasets['train'][0]}") 340 | 341 | # Format datasets 342 | formatted_datasets = datasets.map( 343 | lambda example: format_text(example, data_name, prompt_only=prompt_only), 344 | batched=False, load_from_cache_file=False) 345 | print(f'Formatted datasets: {formatted_datasets}') 346 | print(f"Formatted example: {formatted_datasets['train'][0]}") 347 | print(f"Text example:\n{formatted_datasets['train']['text'][0]}") 348 | 349 | return formatted_datasets 350 | 351 | 352 | if __name__ == '__main__': 353 | data_path = 'cb' 354 | _ = get_formatted_datasets(data_path=data_path, prompt_only=False) 355 | -------------------------------------------------------------------------------- /images/adamole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zefang-liu/AdaMoLE/5e8599a7ad6fafde03aaad8466f583f5f28c22ac/images/adamole.png -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zefang-liu/AdaMoLE/5e8599a7ad6fafde03aaad8466f583f5f28c22ac/poster.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.18.0 2 | huggingface_hub==0.21.4 3 | numpy==1.25.2 4 | pandas==1.3.5 5 | peft==0.9.0 6 | torch==2.0.1 7 | tqdm==4.66.1 8 | transformers==4.38.2 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Package Initialization 3 | """ 4 | from .adamole import AdaMoleConfig, AdaMoleModel 5 | from .config import PeftConfig 6 | from .lora import LoraConfig, LoraModel 7 | from .mole import MoleConfig, MoleModel 8 | from .peft_model import PeftModel, PeftModelForCausalLM 9 | from .trainer import PeftTrainer 10 | from .utils.peft_types import PeftType, TaskType 11 | -------------------------------------------------------------------------------- /src/adamole/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdaMoLE Initialization 3 | """ 4 | from .config import AdaMoleConfig 5 | from .layer import AdaMoleLayer, LinearAdaMoleLayer 6 | from .model import AdaMoleModel 7 | 8 | __all__ = ["AdaMoleConfig", "AdaMoleLayer", "LinearAdaMoleLayer", "AdaMoleModel"] 9 | 10 | 11 | def __getattr__(name): 12 | raise AttributeError(f"Module {__name__} has no attribute {name}.") 13 | -------------------------------------------------------------------------------- /src/adamole/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdaMoLE Configuration 3 | """ 4 | from dataclasses import dataclass, field 5 | 6 | from ..lora import LoraConfig 7 | from ..utils.peft_types import PeftType 8 | 9 | 10 | @dataclass 11 | class AdaMoleConfig(LoraConfig): 12 | """ 13 | AdaMoLE Configuration 14 | """ 15 | num_experts: int = field(default=4, metadata={"help": "The number of experts in MoE."}) 16 | max_threshold: float = field(default=None, metadata={ 17 | "help": "The maximum threshold for selecting experts in the threshold function. " 18 | "The default value will be 1 / number of experts"}) 19 | 20 | def __post_init__(self): 21 | self.peft_type = PeftType.ADAMOLE 22 | self.target_modules = ( 23 | set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules 24 | ) 25 | -------------------------------------------------------------------------------- /src/adamole/layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdaMoLE Layer 3 | """ 4 | import math 5 | from abc import ABC 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from ..lora import LoraLayer 13 | from ..mole.layer import LoraExpert 14 | 15 | 16 | class AdaMoeLayer(nn.Module): 17 | """ 18 | Adaptive Mixture of Experts (MoE) Layer 19 | """ 20 | 21 | def __init__(self, experts: nn.ModuleList, gate: nn.Module, threshold_fn: nn.Module, max_threshold: float): 22 | super().__init__() 23 | self.experts = experts 24 | self.gate = gate 25 | self.threshold_fn = threshold_fn 26 | self.max_threshold = max_threshold 27 | self.layer_loss = None 28 | 29 | def get_layer_loss(self, gate_logits: torch.Tensor, selected_experts: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Get the load balancing loss by following the Switch Transformer 32 | """ 33 | num_inputs = gate_logits.shape[0] 34 | num_experts = len(self.experts) 35 | expert_counts = torch.sum(selected_experts, dim=0) 36 | expert_fractions = expert_counts / num_inputs 37 | expert_probs = torch.sum(gate_logits, dim=0) / num_inputs 38 | layer_loss = num_experts * torch.sum(expert_fractions * expert_probs) 39 | return layer_loss 40 | 41 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Forward propagation 44 | """ 45 | flattened_inputs = inputs.view((-1, inputs.shape[-1])) 46 | gate_logits = F.softmax(self.gate(flattened_inputs), dim=-1) 47 | thresholds = F.sigmoid(self.threshold_fn(flattened_inputs)) * self.max_threshold 48 | adapted_gate_logits = gate_logits - thresholds 49 | selected_experts = torch.ge(adapted_gate_logits, 0).to(torch.float) 50 | weights = adapted_gate_logits * selected_experts 51 | weight_sums = torch.sum(weights, dim=-1, keepdim=True, dtype=inputs.dtype) 52 | weight_sums = torch.where(weight_sums == 0, torch.ones_like(weight_sums), weight_sums) 53 | weights = weights / weight_sums 54 | results = torch.zeros_like(self.experts[0](flattened_inputs)) 55 | 56 | for i, expert in enumerate(self.experts): 57 | batch_idx = torch.where(selected_experts[:, i])[0] 58 | if len(batch_idx) > 0: 59 | results[batch_idx] += weights[batch_idx, i, None] * expert(flattened_inputs[batch_idx]) 60 | 61 | results = results.view((*inputs.shape[:-1], results.shape[-1])) 62 | if inputs.requires_grad: 63 | self.layer_loss = self.get_layer_loss(gate_logits=adapted_gate_logits, selected_experts=selected_experts) 64 | return results 65 | 66 | 67 | class AdaMoleLayer(LoraLayer, ABC): 68 | """ 69 | AdaMoLE Layer 70 | """ 71 | 72 | def __init__(self, base_layer: nn.Module, **kwargs): 73 | super().__init__(base_layer, **kwargs) 74 | self.lora_gating = nn.ModuleDict({}) 75 | self.lora_threshold = nn.ModuleDict({}) 76 | self.moe_layer = nn.ModuleDict({}) 77 | 78 | def update_layer( 79 | self, adapter_name: str, lora_rank: int, lora_alpha: int, lora_dropout: float, init_lora_weights: bool, 80 | num_experts: int, max_threshold: float, 81 | ) -> None: 82 | """ 83 | Update the layer 84 | """ 85 | if lora_rank <= 0: 86 | raise ValueError(f"The rank `r` should be a positive integer value but the value passed is {lora_rank}.") 87 | 88 | if max_threshold is None: 89 | max_threshold = 1 / num_experts 90 | 91 | self.lora_rank[adapter_name] = lora_rank 92 | self.lora_alpha[adapter_name] = lora_alpha 93 | 94 | if lora_dropout > 0.0: 95 | lora_dropout_layer = nn.ModuleList([nn.Dropout(p=lora_dropout) for _ in range(num_experts)]) 96 | else: 97 | lora_dropout_layer = nn.ModuleList([nn.Identity(p=lora_dropout) for _ in range(num_experts)]) 98 | 99 | self.lora_dropout[adapter_name] = lora_dropout_layer 100 | self.lora_A[adapter_name] = nn.ModuleList( 101 | [nn.Linear(self.in_features, lora_rank, bias=False) for _ in range(num_experts)]) 102 | self.lora_B[adapter_name] = nn.ModuleList( 103 | [nn.Linear(lora_rank, self.out_features, bias=False) for _ in range(num_experts)]) 104 | self.scaling[adapter_name] = lora_alpha / lora_rank 105 | self.lora_gating[adapter_name] = nn.Linear(self.in_features, num_experts, bias=False) 106 | self.lora_threshold[adapter_name] = nn.Linear(self.in_features, 1) 107 | 108 | experts = nn.ModuleList([LoraExpert( 109 | self.lora_A[adapter_name][i], 110 | self.lora_B[adapter_name][i], 111 | self.lora_dropout[adapter_name][i], 112 | self.scaling[adapter_name], 113 | ) for i in range(num_experts)]) 114 | self.moe_layer[adapter_name] = AdaMoeLayer( 115 | experts=experts, gate=self.lora_gating[adapter_name], 116 | threshold_fn=self.lora_threshold[adapter_name], max_threshold=max_threshold) 117 | 118 | self.reset_parameters(adapter_name, init_lora_weights) 119 | self.set_adapter(self.active_adapters) 120 | 121 | def reset_parameters(self, adapter_name: str, init_lora_weights: bool) -> None: 122 | """ 123 | Reset the parameters 124 | """ 125 | if init_lora_weights is False: 126 | return 127 | elif adapter_name in self.lora_A.keys(): 128 | for i in range(len(self.lora_A[adapter_name])): 129 | nn.init.kaiming_uniform_(self.lora_A[adapter_name][i].weight, a=math.sqrt(5)) 130 | nn.init.zeros_(self.lora_B[adapter_name][i].weight) 131 | 132 | 133 | class LinearAdaMoleLayer(nn.Module, AdaMoleLayer): 134 | """ 135 | AdaMoLE Implementation in a Linear Layer 136 | """ 137 | 138 | def __init__( 139 | self, 140 | base_layer: nn.Module, 141 | adapter_name: str, 142 | lora_rank: int = 0, 143 | lora_alpha: int = 1, 144 | lora_dropout: float = 0.0, 145 | init_lora_weights: bool = True, 146 | num_experts: int = 4, 147 | max_threshold: float = None, 148 | **kwargs, 149 | ) -> None: 150 | super().__init__() 151 | AdaMoleLayer.__init__(self, base_layer=base_layer, **kwargs) 152 | self._active_adapter = adapter_name 153 | self.update_layer( 154 | adapter_name, lora_rank, lora_alpha, lora_dropout, init_lora_weights, num_experts, max_threshold) 155 | 156 | def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: 157 | """ 158 | Merge the active adapter weights inside the base weights 159 | """ 160 | pass 161 | 162 | def unmerge(self) -> None: 163 | """ 164 | Unmerge all merged adapter layers from the base weights 165 | """ 166 | pass 167 | 168 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 169 | """ 170 | Forward propagation 171 | """ 172 | previous_dtype = x.dtype 173 | result = self.base_layer(x, *args, **kwargs) 174 | 175 | for active_adapter in self.active_adapters: 176 | if active_adapter not in self.lora_A.keys(): 177 | continue 178 | 179 | moe_layer = self.moe_layer[active_adapter] 180 | x = x.to(moe_layer.experts[0].lora_A.weight.dtype) 181 | result += moe_layer(x) 182 | 183 | result = result.to(previous_dtype) 184 | return result 185 | -------------------------------------------------------------------------------- /src/adamole/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdaMoLE Model 3 | """ 4 | from typing import Any 5 | 6 | import torch 7 | from peft.tuners.tuners_utils import BaseTunerLayer 8 | from torch import nn 9 | 10 | from .config import AdaMoleConfig 11 | from .layer import AdaMoleLayer, LinearAdaMoleLayer 12 | from ..lora import LoraModel 13 | 14 | 15 | class AdaMoleModel(LoraModel): 16 | """ 17 | AdaMoLE (Adaptive Mixture of LoRA Experts) Model 18 | """ 19 | prefix: str = "lora_" 20 | 21 | def __init__(self, model: nn.Module, config: AdaMoleConfig, adapter_name: str = "default") -> None: 22 | super().__init__(model, config, adapter_name) 23 | 24 | def _create_and_replace( 25 | self, adamole_config: AdaMoleConfig, adapter_name: str, 26 | target: nn.Module, target_name: str, parent: nn.Module, **kwargs: Any, 27 | ) -> None: 28 | """ 29 | Inplace replacement of the target module with the adapter layer 30 | """ 31 | kwargs = { 32 | "lora_rank": adamole_config.lora_rank, 33 | "lora_alpha": adamole_config.lora_alpha, 34 | "lora_dropout": adamole_config.lora_dropout, 35 | "init_lora_weights": adamole_config.init_lora_weights, 36 | "num_experts": adamole_config.num_experts, 37 | "max_threshold": adamole_config.max_threshold, 38 | } 39 | 40 | if isinstance(target, AdaMoleLayer): 41 | target.update_layer(adapter_name, **kwargs) 42 | else: 43 | new_module = self._create_new_module(adapter_name, target, **kwargs) 44 | self._replace_module(parent, target_name, new_module, target) 45 | 46 | @staticmethod 47 | def _create_new_module(adapter_name: str, target: nn.Module, **kwargs: Any) -> nn.Module: 48 | """ 49 | Create the new LoRA module for the target module 50 | """ 51 | if isinstance(target, BaseTunerLayer): 52 | target_base_layer = target.get_base_layer() 53 | else: 54 | target_base_layer = target 55 | 56 | if isinstance(target_base_layer, torch.nn.Linear): 57 | new_module = LinearAdaMoleLayer(base_layer=target, adapter_name=adapter_name, **kwargs) 58 | else: 59 | raise ValueError( 60 | f"The target module `{target}` is not supported. " 61 | f"Currently, only the following modules are supported: `torch.nn.Linear`.") 62 | 63 | return new_module 64 | 65 | def get_aux_loss(self, adapter_name="default") -> torch.Tensor: 66 | """ 67 | Get the load balancing loss for the whole model 68 | """ 69 | model_loss = torch.tensor(0, dtype=torch.float).to(self.model.device) 70 | 71 | for name, module in self.model.named_modules(): 72 | if name.endswith('moe_layer'): 73 | layer_loss = module[adapter_name].layer_loss 74 | model_loss += layer_loss 75 | 76 | return model_loss 77 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | PEFT Configuration 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | import inspect 9 | import json 10 | import os 11 | from dataclasses import asdict, dataclass, field 12 | from typing import Dict, Optional, Union 13 | 14 | from huggingface_hub import hf_hub_download 15 | from peft.utils import CONFIG_NAME 16 | from transformers.utils import PushToHubMixin 17 | 18 | from .utils.peft_types import PeftType, TaskType 19 | 20 | 21 | @dataclass 22 | class PeftConfigMixin(PushToHubMixin): 23 | """ 24 | Base Configuration Class for PEFT Models 25 | """ 26 | peft_type: Optional[PeftType] = field( 27 | default=None, metadata={"help": "The type of Peft method to use."}) 28 | auto_mapping: Optional[dict] = field( 29 | default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."} 30 | ) 31 | 32 | def to_dict(self) -> Dict: 33 | """ 34 | Return the configuration for the adapter model as a dictionary 35 | """ 36 | return asdict(self) 37 | 38 | def save_pretrained(self, save_directory: str, **kwargs) -> None: 39 | """ 40 | Save the configuration of the adapter model in a directory 41 | 42 | :param save_directory: the directory where the configuration will be saved 43 | :param kwargs: additional keyword arguments 44 | """ 45 | if os.path.isfile(save_directory): 46 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") 47 | 48 | os.makedirs(save_directory, exist_ok=True) 49 | 50 | # Converting set type to list 51 | output_dict = asdict(self) 52 | for key, value in output_dict.items(): 53 | if isinstance(value, set): 54 | output_dict[key] = list(value) 55 | 56 | output_path = os.path.join(save_directory, CONFIG_NAME) 57 | 58 | # Add auto mapping details for custom models. 59 | auto_mapping_dict = kwargs.pop("auto_mapping_dict", None) 60 | if auto_mapping_dict is not None: 61 | output_dict["auto_mapping"] = auto_mapping_dict 62 | 63 | # Save the configuration 64 | with open(output_path, "w") as writer: 65 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) 66 | 67 | @classmethod 68 | def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs): 69 | """ 70 | Loads the configuration of the adapter model from a directory 71 | 72 | :param pretrained_model_name_or_path: the directory or the Hub repository id where the configuration is saved 73 | :param subfolder: subfolder for the directory 74 | :param kwargs: additional keyword arguments passed along to the child class initialization 75 | :return: 76 | """ 77 | from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING 78 | 79 | path = ( 80 | os.path.join(pretrained_model_name_or_path, subfolder) 81 | if subfolder is not None 82 | else pretrained_model_name_or_path 83 | ) 84 | 85 | hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs) 86 | 87 | if os.path.isfile(os.path.join(path, CONFIG_NAME)): 88 | config_file = os.path.join(path, CONFIG_NAME) 89 | else: 90 | try: 91 | config_file = hf_hub_download( 92 | pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs 93 | ) 94 | except Exception: 95 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") 96 | 97 | loaded_attributes = cls.from_json_file(config_file) 98 | if "peft_type" in loaded_attributes: 99 | peft_type = loaded_attributes["peft_type"] 100 | config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] 101 | else: 102 | config_cls = cls 103 | 104 | kwargs = {**class_kwargs, **loaded_attributes} 105 | config = config_cls(**kwargs) 106 | return config 107 | 108 | @classmethod 109 | def from_json_file(cls, path_json_file: str): 110 | """ 111 | Load a configuration file from a JSON file 112 | 113 | :param path_json_file: the path to the JSON file 114 | :return: a JSON object 115 | """ 116 | with open(path_json_file, "r") as file: 117 | json_object = json.load(file) 118 | return json_object 119 | 120 | @classmethod 121 | def _split_kwargs(cls, kwargs): 122 | hf_hub_download_kwargs = {} 123 | class_kwargs = {} 124 | other_kwargs = {} 125 | 126 | for key, value in kwargs.items(): 127 | if key in inspect.signature(hf_hub_download).parameters: 128 | hf_hub_download_kwargs[key] = value 129 | elif key in list(cls.__annotations__): 130 | class_kwargs[key] = value 131 | else: 132 | other_kwargs[key] = value 133 | 134 | return hf_hub_download_kwargs, class_kwargs, other_kwargs 135 | 136 | @classmethod 137 | def _get_peft_type(cls, model_id: str, **hf_hub_download_kwargs): 138 | subfolder = hf_hub_download_kwargs.get("subfolder", None) 139 | path = os.path.join(model_id, subfolder) if subfolder is not None else model_id 140 | 141 | if os.path.isfile(os.path.join(path, CONFIG_NAME)): 142 | config_file = os.path.join(path, CONFIG_NAME) 143 | else: 144 | try: 145 | config_file = hf_hub_download( 146 | model_id, 147 | CONFIG_NAME, 148 | **hf_hub_download_kwargs, 149 | ) 150 | except Exception: 151 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'") 152 | 153 | loaded_attributes = cls.from_json_file(config_file) 154 | return loaded_attributes["peft_type"] 155 | 156 | 157 | @dataclass 158 | class PeftConfig(PeftConfigMixin): 159 | """ 160 | Base configuration class to store the configuration of a PEFT model 161 | """ 162 | base_model_name_or_path: Optional[str] = field( 163 | default=None, metadata={"help": "The name of the base model to use."}) 164 | revision: Optional[str] = field( 165 | default=None, metadata={"help": "The specific model version to use."}) 166 | peft_type: Optional[Union[str, PeftType]] = field( 167 | default=None, metadata={"help": "The type of PEFT method to use."}) 168 | task_type: Optional[Union[str, TaskType]] = field( 169 | default=None, metadata={"help": "The type of task to perform."}) 170 | inference_mode: bool = field( 171 | default=False, metadata={"help": "Whether to use the PEFT model in inference mode."}) 172 | -------------------------------------------------------------------------------- /src/lora/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | LoRA Initialization 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | from .config import LoraConfig 9 | from .layer import LoraLayer, LinearLoraLayer 10 | from .model import LoraModel 11 | 12 | __all__ = ["LoraConfig", "LoraLayer", "LinearLoraLayer", "LoraModel"] 13 | 14 | 15 | def __getattr__(name): 16 | raise AttributeError(f"Module {__name__} has no attribute {name}.") 17 | -------------------------------------------------------------------------------- /src/lora/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | LoRA Configuration 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | from dataclasses import dataclass, field 9 | from typing import List, Literal, Optional, Union 10 | 11 | from ..config import PeftConfig 12 | from ..utils.peft_types import PeftType 13 | 14 | 15 | @dataclass 16 | class LoraConfig(PeftConfig): 17 | """ 18 | LoRA Configuration 19 | """ 20 | lora_rank: int = field(default=8, metadata={"help": "The Lora rank for the attention dimension."}) 21 | lora_alpha: int = field(default=8, metadata={"help": "The alpha parameter for Lora scaling."}) 22 | lora_dropout: float = field(default=0.0, metadata={"help": "The dropout probability for Lora layers."}) 23 | bias: Literal["none", "all", "lora_only"] = field( 24 | default="none", metadata={"help": "The bias type for Lora layers and can be 'none', 'all' or 'lora_only'."}) 25 | target_modules: Optional[Union[List[str], str]] = field( 26 | default=None, metadata={"help": "The names of the modules to apply the adapter to."}) 27 | init_lora_weights: bool = field( 28 | default=True, metadata={"help": "Whether to initialize the weights of the adapter layers."}) 29 | 30 | def __post_init__(self): 31 | self.peft_type = PeftType.LORA 32 | self.target_modules = ( 33 | set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules 34 | ) 35 | -------------------------------------------------------------------------------- /src/lora/layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | LoRA Layer 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | import math 9 | from abc import ABC 10 | from typing import Optional 11 | 12 | import torch 13 | import torch.nn as nn 14 | from peft.tuners.tuners_utils import BaseTunerLayer 15 | 16 | 17 | class LoraLayer(BaseTunerLayer, ABC): 18 | """ 19 | LoRA Layer 20 | """ 21 | 22 | def __init__(self, base_layer: nn.Module, **kwargs) -> None: 23 | self.base_layer = base_layer 24 | self.lora_rank = {} 25 | self.lora_alpha = {} 26 | self.scaling = {} 27 | 28 | self.lora_dropout = nn.ModuleDict({}) 29 | self.lora_A = nn.ModuleDict({}) 30 | self.lora_B = nn.ModuleDict({}) 31 | self.kwargs = kwargs 32 | 33 | if isinstance(base_layer, nn.Linear): 34 | in_features, out_features = base_layer.in_features, base_layer.out_features 35 | else: 36 | raise ValueError(f"Unsupported layer type {type(base_layer)}") 37 | 38 | self.in_features = in_features 39 | self.out_features = out_features 40 | 41 | def update_layer( 42 | self, adapter_name: str, lora_rank: int, lora_alpha: int, lora_dropout: float, init_lora_weights: bool, 43 | ) -> None: 44 | """ 45 | Update the layer 46 | """ 47 | if lora_rank <= 0: 48 | raise ValueError(f"The rank `r` should be a positive integer value but the value passed is {lora_rank}.") 49 | 50 | self.lora_rank[adapter_name] = lora_rank 51 | self.lora_alpha[adapter_name] = lora_alpha 52 | 53 | if lora_dropout > 0.0: 54 | lora_dropout_layer = nn.Dropout(p=lora_dropout) 55 | else: 56 | lora_dropout_layer = nn.Identity() 57 | 58 | self.lora_dropout[adapter_name] = lora_dropout_layer 59 | self.lora_A[adapter_name] = nn.Linear(self.in_features, lora_rank, bias=False) 60 | self.lora_B[adapter_name] = nn.Linear(lora_rank, self.out_features, bias=False) 61 | self.scaling[adapter_name] = lora_alpha / lora_rank 62 | 63 | self.reset_parameters(adapter_name, init_lora_weights) 64 | self.set_adapter(self.active_adapters) 65 | 66 | def reset_parameters(self, adapter_name: str, init_lora_weights: bool) -> None: 67 | """ 68 | Reset the parameters 69 | """ 70 | if init_lora_weights is False: 71 | return 72 | elif adapter_name in self.lora_A.keys(): 73 | nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) 74 | nn.init.zeros_(self.lora_B[adapter_name].weight) 75 | 76 | 77 | class LinearLoraLayer(nn.Module, LoraLayer): 78 | """ 79 | LoRA Implementation in a Linear Layer 80 | """ 81 | 82 | def __init__( 83 | self, 84 | base_layer: nn.Module, 85 | adapter_name: str, 86 | lora_rank: int = 0, 87 | lora_alpha: int = 1, 88 | lora_dropout: float = 0.0, 89 | init_lora_weights: bool = True, 90 | **kwargs, 91 | ) -> None: 92 | super().__init__() 93 | LoraLayer.__init__(self, base_layer, **kwargs) 94 | self._active_adapter = adapter_name 95 | self.update_layer(adapter_name, lora_rank, lora_alpha, lora_dropout, init_lora_weights) 96 | 97 | def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: 98 | """ 99 | Merge the active adapter weights inside the base weights 100 | """ 101 | raise NotImplementedError 102 | 103 | def unmerge(self) -> None: 104 | """ 105 | Unmerge all merged adapter layers from the base weights 106 | """ 107 | raise NotImplementedError 108 | 109 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 110 | """ 111 | Forward propagation 112 | """ 113 | previous_dtype = x.dtype 114 | result = self.base_layer(x, *args, **kwargs) 115 | 116 | for active_adapter in self.active_adapters: 117 | if active_adapter not in self.lora_A.keys(): 118 | continue 119 | 120 | lora_A = self.lora_A[active_adapter] 121 | lora_B = self.lora_B[active_adapter] 122 | dropout = self.lora_dropout[active_adapter] 123 | scaling = self.scaling[active_adapter] 124 | 125 | x = x.to(lora_A.weight.dtype) 126 | result += lora_B(lora_A(dropout(x))) * scaling 127 | 128 | result = result.to(previous_dtype) 129 | return result 130 | -------------------------------------------------------------------------------- /src/lora/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | LoRA Model 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | from typing import Any 9 | 10 | import torch 11 | from peft import PeftConfig 12 | from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists 13 | from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING 14 | from torch import nn 15 | from transformers import PretrainedConfig 16 | 17 | from .config import LoraConfig 18 | from .layer import LoraLayer, LinearLoraLayer 19 | 20 | 21 | class LoraModel(BaseTuner): 22 | """ 23 | Low Rank Adapter (LoRA) Model 24 | """ 25 | prefix: str = "lora_" 26 | 27 | def __init__(self, model: nn.Module, config: LoraConfig, adapter_name: str = "default") -> None: 28 | """ 29 | Initialize LoraModel 30 | 31 | :param model: model to be adapted 32 | :param config: configuration of the LoRA model 33 | :param adapter_name: name of the adapter 34 | """ 35 | super().__init__(model, config, adapter_name) 36 | 37 | def __getattr__(self, name: str) -> Any: 38 | """ 39 | Forward missing attributes to the wrapped module 40 | """ 41 | try: 42 | return super().__getattr__(name) 43 | except AttributeError: 44 | return getattr(self.model, name) 45 | 46 | def _check_new_adapter_config(self, config: LoraConfig) -> None: 47 | """ 48 | Check the config when a new adapter is being added 49 | """ 50 | if (len(self.peft_config) > 1) and (config.bias != "none"): 51 | raise ValueError( 52 | f"{self.__class__.__name__} supports only 1 adapter with bias. " 53 | f"When using multiple adapters, set bias to 'none' for all adapters.") 54 | 55 | @staticmethod 56 | def _check_target_module_exists(lora_config: LoraConfig, key: str) -> bool: 57 | """ 58 | Check if the passed module's key name matches any of the target modules in the config target module list 59 | """ 60 | return check_target_module_exists(lora_config, key) 61 | 62 | @staticmethod 63 | def _prepare_adapter_config(peft_config: LoraConfig, model_config: PretrainedConfig) -> PeftConfig: 64 | """ 65 | Prepare the adapter config, such as automatically inferring target modules if it is none 66 | """ 67 | if peft_config.target_modules is None: 68 | if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: 69 | raise ValueError("Please specify `target_modules` in `peft_config`.") 70 | peft_config.target_modules = set( 71 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]], 72 | ) 73 | return peft_config 74 | 75 | def _create_and_replace( 76 | self, lora_config: LoraConfig, adapter_name: str, 77 | target: nn.Module, target_name: str, parent: nn.Module, **kwargs: Any, 78 | ) -> None: 79 | """ 80 | Inplace replacement of the target module with the adapter layer 81 | """ 82 | kwargs = { 83 | "lora_rank": lora_config.lora_rank, 84 | "lora_alpha": lora_config.lora_alpha, 85 | "lora_dropout": lora_config.lora_dropout, 86 | "init_lora_weights": lora_config.init_lora_weights, 87 | } 88 | 89 | if isinstance(target, LoraLayer): 90 | target.update_layer(adapter_name, **kwargs) 91 | else: 92 | new_module = self._create_new_module(adapter_name, target, **kwargs) 93 | self._replace_module(parent, target_name, new_module, target) 94 | 95 | @staticmethod 96 | def _create_new_module(adapter_name: str, target: nn.Module, **kwargs: Any) -> nn.Module: 97 | """ 98 | Create the new LoRA module for the target module 99 | """ 100 | if isinstance(target, BaseTunerLayer): 101 | target_base_layer = target.get_base_layer() 102 | else: 103 | target_base_layer = target 104 | 105 | if isinstance(target_base_layer, torch.nn.Linear): 106 | new_module = LinearLoraLayer(base_layer=target, adapter_name=adapter_name, **kwargs) 107 | else: 108 | raise ValueError( 109 | f"The target module `{target}` is not supported. " 110 | f"Currently, only the following modules are supported: `torch.nn.Linear`.") 111 | 112 | return new_module 113 | 114 | def _replace_module(self, parent: nn.Module, child_name: str, new_module: nn.Module, child: nn.Module) -> None: 115 | """ 116 | Replace the module 117 | """ 118 | setattr(parent, child_name, new_module) 119 | 120 | if hasattr(child, "base_layer"): 121 | child = child.base_layer 122 | 123 | if not hasattr(new_module, "base_layer"): 124 | new_module.weight = child.weight 125 | if hasattr(child, "bias"): 126 | new_module.bias = child.bias 127 | 128 | if getattr(child, "state", None) is not None: 129 | if hasattr(new_module, "base_layer"): 130 | new_module.base_layer.state = child.state 131 | else: 132 | new_module.state = child.state 133 | new_module.to(child.weight.device) 134 | 135 | for name, module in new_module.named_modules(): 136 | if self.prefix in name: 137 | module.to(child.weight.device) 138 | 139 | def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: 140 | """ 141 | Make only adapters as trainable 142 | """ 143 | for name, param in model.named_parameters(): 144 | if self.prefix not in name: 145 | param.requires_grad = False 146 | 147 | for active_adapter in self.active_adapters: 148 | bias = self.peft_config[active_adapter].bias 149 | if bias == "none": 150 | continue 151 | elif bias == "all": 152 | for name, param in model.named_parameters(): 153 | if "bias" in name: 154 | param.requires_grad = True 155 | elif bias == "lora_only": 156 | for module in model.modules(): 157 | if isinstance(module, LoraLayer) and hasattr(module, "bias") and module.bias is not None: 158 | module.bias.requires_grad = True 159 | else: 160 | raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") 161 | -------------------------------------------------------------------------------- /src/mapping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configure and Model Mappings 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | 9 | from .adamole import AdaMoleConfig, AdaMoleModel 10 | from .lora import LoraConfig, LoraModel 11 | from .mole import MoleConfig, MoleModel 12 | from .utils.peft_types import PeftType 13 | 14 | PEFT_TYPE_TO_CONFIG_MAPPING = { 15 | PeftType.LORA: LoraConfig, 16 | PeftType.MOLE: MoleConfig, 17 | PeftType.ADAMOLE: AdaMoleConfig, 18 | } 19 | PEFT_TYPE_TO_MODEL_MAPPING = { 20 | PeftType.LORA: LoraModel, 21 | PeftType.MOLE: MoleModel, 22 | PeftType.ADAMOLE: AdaMoleModel, 23 | } 24 | -------------------------------------------------------------------------------- /src/mole/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MoLE Initialization 3 | """ 4 | from .config import MoleConfig 5 | from .layer import MoleLayer, LinearMoleLayer 6 | from .model import MoleModel 7 | 8 | __all__ = ["MoleConfig", "MoleLayer", "LinearMoleLayer", "MoleModel"] 9 | 10 | 11 | def __getattr__(name): 12 | raise AttributeError(f"Module {__name__} has no attribute {name}.") 13 | -------------------------------------------------------------------------------- /src/mole/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | MoLE Configuration 3 | """ 4 | from dataclasses import dataclass, field 5 | 6 | from ..lora import LoraConfig 7 | from ..utils.peft_types import PeftType 8 | 9 | 10 | @dataclass 11 | class MoleConfig(LoraConfig): 12 | """ 13 | MoLE Configuration 14 | """ 15 | num_experts: int = field(default=4, metadata={"help": "The number of experts in MoE."}) 16 | top_k: int = field(default=None, metadata={ 17 | "help": "The k in top-k gating if the expert threshold is None."}) 18 | threshold: float = field(default=None, metadata={ 19 | "help": "The threshold for selecting experts if the top-k is None. " 20 | "The maximum threshold should be 1 / number of experts"}) 21 | 22 | def __post_init__(self): 23 | self.peft_type = PeftType.MOLE 24 | self.target_modules = ( 25 | set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules 26 | ) 27 | -------------------------------------------------------------------------------- /src/mole/layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | MoLE Layer 3 | """ 4 | import math 5 | from abc import ABC 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from ..lora import LoraLayer 13 | 14 | 15 | class TopKMoeLayer(nn.Module): 16 | """ 17 | Mixture of Experts (MoE) Layer with the Top-k 18 | 19 | Adapted from https://github.com/mistralai/mistral-src 20 | """ 21 | 22 | def __init__(self, experts: nn.ModuleList, gate: nn.Module, top_k: int): 23 | super().__init__() 24 | self.experts = experts 25 | self.gate = gate 26 | self.top_k = top_k 27 | self.layer_loss = None 28 | 29 | def get_layer_loss(self, gate_logits: torch.Tensor, selected_experts: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Get the load balancing loss by following the Switch Transformer 32 | """ 33 | num_inputs = gate_logits.shape[0] 34 | num_experts = len(self.experts) 35 | expert_counts = torch.bincount(selected_experts.reshape(-1), minlength=num_experts) 36 | expert_fractions = expert_counts / num_inputs 37 | expert_probs = torch.sum(gate_logits, dim=0) / num_inputs 38 | layer_loss = num_experts * torch.sum(expert_fractions * expert_probs) 39 | return layer_loss 40 | 41 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Forward propagation 44 | """ 45 | flattened_inputs = inputs.view((-1, inputs.shape[-1])) 46 | gate_logits = F.softmax(self.gate(flattened_inputs), dim=-1) 47 | weights, selected_experts = torch.topk(input=gate_logits, k=self.top_k, dim=-1) 48 | weights = weights / torch.sum(weights, dim=-1, keepdim=True, dtype=inputs.dtype) 49 | results = torch.zeros_like(self.experts[0](flattened_inputs)) 50 | 51 | for i, expert in enumerate(self.experts): 52 | batch_idx, nth_expert = torch.where(selected_experts == i) 53 | results[batch_idx] += \ 54 | weights[batch_idx, nth_expert, None] * expert(flattened_inputs[batch_idx]) 55 | 56 | results = results.view((*inputs.shape[:-1], results.shape[-1])) 57 | if inputs.requires_grad: 58 | self.layer_loss = self.get_layer_loss(gate_logits=gate_logits, selected_experts=selected_experts) 59 | return results 60 | 61 | 62 | class ThresholdMoeLayer(nn.Module): 63 | """ 64 | Mixture of Experts (MoE) Layer with the Threshold 65 | """ 66 | 67 | def __init__(self, experts: nn.ModuleList, gate: nn.Module, threshold: float): 68 | super().__init__() 69 | self.experts = experts 70 | self.gate = gate 71 | self.threshold = threshold 72 | self.layer_loss = None 73 | 74 | def get_layer_loss(self, gate_logits: torch.Tensor, selected_experts: torch.Tensor) -> torch.Tensor: 75 | """ 76 | Get the load balancing loss by following the Switch Transformer 77 | """ 78 | num_inputs = gate_logits.shape[0] 79 | num_experts = len(self.experts) 80 | expert_counts = torch.sum(selected_experts, dim=0) 81 | expert_fractions = expert_counts / num_inputs 82 | expert_probs = torch.sum(gate_logits, dim=0) / num_inputs 83 | layer_loss = num_experts * torch.sum(expert_fractions * expert_probs) 84 | return layer_loss 85 | 86 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 87 | """ 88 | Forward propagation 89 | """ 90 | flattened_inputs = inputs.view((-1, inputs.shape[-1])) 91 | gate_logits = F.softmax(self.gate(flattened_inputs), dim=-1) 92 | selected_experts = torch.ge(gate_logits, self.threshold).to(torch.float) 93 | weights = gate_logits * selected_experts 94 | weight_sums = torch.sum(weights, dim=-1, keepdim=True, dtype=inputs.dtype) 95 | weight_sums = torch.where(weight_sums == 0, torch.ones_like(weight_sums), weight_sums) 96 | weights = weights / weight_sums 97 | results = torch.zeros_like(self.experts[0](flattened_inputs)) 98 | 99 | for i, expert in enumerate(self.experts): 100 | batch_idx = torch.where(selected_experts[:, i])[0] 101 | if len(batch_idx) > 0: 102 | results[batch_idx] += weights[batch_idx, i, None] * expert(flattened_inputs[batch_idx]) 103 | 104 | results = results.view((*inputs.shape[:-1], results.shape[-1])) 105 | if inputs.requires_grad: 106 | self.layer_loss = self.get_layer_loss(gate_logits=gate_logits, selected_experts=selected_experts) 107 | return results 108 | 109 | 110 | class LoraExpert(nn.Module): 111 | """ 112 | LoRA Expert 113 | """ 114 | 115 | def __init__(self, lora_A: nn.Module, lora_B: nn.Module, lora_dropout: nn.Module, scaling: float): 116 | super().__init__() 117 | self.lora_A = lora_A 118 | self.lora_B = lora_B 119 | self.lora_dropout = lora_dropout 120 | self.scaling = scaling 121 | 122 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 123 | """ 124 | Forward propagation 125 | """ 126 | outputs = self.lora_B(self.lora_A(self.lora_dropout(inputs))) * self.scaling 127 | return outputs 128 | 129 | 130 | class MoleLayer(LoraLayer, ABC): 131 | """ 132 | MoLE Layer 133 | """ 134 | 135 | def __init__(self, base_layer: nn.Module, **kwargs): 136 | super().__init__(base_layer, **kwargs) 137 | self.lora_gating = nn.ModuleDict({}) 138 | self.moe_layer = nn.ModuleDict({}) 139 | 140 | def update_layer( 141 | self, adapter_name: str, lora_rank: int, lora_alpha: int, lora_dropout: float, init_lora_weights: bool, 142 | num_experts: int, top_k: int, threshold: float, 143 | ) -> None: 144 | """ 145 | Update the layer 146 | """ 147 | if lora_rank <= 0: 148 | raise ValueError(f"The rank `r` should be a positive integer value but the value passed is {lora_rank}.") 149 | 150 | if (top_k is not None) and (threshold is not None): 151 | raise ValueError(f"Only one of the top-k {top_k} and the threshold {threshold} can be used.") 152 | elif (top_k is None) and (threshold is None): 153 | raise ValueError(f"At least one of the top-k {top_k} and the threshold {threshold} should be used.") 154 | 155 | self.lora_rank[adapter_name] = lora_rank 156 | self.lora_alpha[adapter_name] = lora_alpha 157 | 158 | if lora_dropout > 0.0: 159 | lora_dropout_layer = nn.ModuleList(nn.Dropout(p=lora_dropout) for _ in range(num_experts)) 160 | else: 161 | lora_dropout_layer = nn.ModuleList(nn.Identity(p=lora_dropout) for _ in range(num_experts)) 162 | 163 | self.lora_dropout[adapter_name] = lora_dropout_layer 164 | self.lora_A[adapter_name] = nn.ModuleList( 165 | nn.Linear(self.in_features, lora_rank, bias=False) for _ in range(num_experts)) 166 | self.lora_B[adapter_name] = nn.ModuleList( 167 | nn.Linear(lora_rank, self.out_features, bias=False) for _ in range(num_experts)) 168 | self.scaling[adapter_name] = lora_alpha / lora_rank 169 | self.lora_gating[adapter_name] = nn.Linear(self.in_features, num_experts, bias=False) 170 | 171 | experts = nn.ModuleList(LoraExpert( 172 | self.lora_A[adapter_name][i], 173 | self.lora_B[adapter_name][i], 174 | self.lora_dropout[adapter_name][i], 175 | self.scaling[adapter_name], 176 | ) for i in range(num_experts)) 177 | 178 | if top_k is not None: 179 | self.moe_layer[adapter_name] = TopKMoeLayer( 180 | experts=experts, gate=self.lora_gating[adapter_name], top_k=top_k) 181 | elif threshold is not None: 182 | self.moe_layer[adapter_name] = ThresholdMoeLayer( 183 | experts=experts, gate=self.lora_gating[adapter_name], threshold=threshold) 184 | 185 | self.reset_parameters(adapter_name, init_lora_weights) 186 | self.set_adapter(self.active_adapters) 187 | 188 | def reset_parameters(self, adapter_name: str, init_lora_weights: bool) -> None: 189 | """ 190 | Reset the parameters 191 | """ 192 | if init_lora_weights is False: 193 | return 194 | elif adapter_name in self.lora_A.keys(): 195 | for i in range(len(self.lora_A[adapter_name])): 196 | nn.init.kaiming_uniform_(self.lora_A[adapter_name][i].weight, a=math.sqrt(5)) 197 | nn.init.zeros_(self.lora_B[adapter_name][i].weight) 198 | 199 | 200 | class LinearMoleLayer(nn.Module, MoleLayer): 201 | """ 202 | MoLE Implementation in a Linear Layer 203 | """ 204 | 205 | def __init__( 206 | self, 207 | base_layer: nn.Module, 208 | adapter_name: str, 209 | lora_rank: int = 0, 210 | lora_alpha: int = 1, 211 | lora_dropout: float = 0.0, 212 | init_lora_weights: bool = True, 213 | num_experts: int = 4, 214 | top_k: int = None, 215 | threshold: float = None, 216 | **kwargs, 217 | ) -> None: 218 | super().__init__() 219 | MoleLayer.__init__(self, base_layer=base_layer, **kwargs) 220 | self._active_adapter = adapter_name 221 | self.update_layer( 222 | adapter_name, lora_rank, lora_alpha, lora_dropout, init_lora_weights, num_experts, top_k, threshold) 223 | 224 | def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: 225 | """ 226 | Merge the active adapter weights inside the base weights 227 | """ 228 | pass 229 | 230 | def unmerge(self) -> None: 231 | """ 232 | Unmerge all merged adapter layers from the base weights 233 | """ 234 | pass 235 | 236 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 237 | """ 238 | Forward propagation 239 | """ 240 | previous_dtype = x.dtype 241 | result = self.base_layer(x, *args, **kwargs) 242 | 243 | for active_adapter in self.active_adapters: 244 | if active_adapter not in self.lora_A.keys(): 245 | continue 246 | 247 | moe_layer = self.moe_layer[active_adapter] 248 | x = x.to(moe_layer.experts[0].lora_A.weight.dtype) 249 | result += moe_layer(x) 250 | 251 | result = result.to(previous_dtype) 252 | return result 253 | -------------------------------------------------------------------------------- /src/mole/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MoLE Model 3 | """ 4 | from typing import Any 5 | 6 | import torch 7 | from peft.tuners.tuners_utils import BaseTunerLayer 8 | from torch import nn 9 | 10 | from .config import MoleConfig 11 | from .layer import MoleLayer, LinearMoleLayer 12 | from ..lora import LoraModel 13 | 14 | 15 | class MoleModel(LoraModel): 16 | """ 17 | MoLE (Mixture of LoRA Experts) Model 18 | """ 19 | prefix: str = "lora_" 20 | 21 | def __init__(self, model, config, adapter_name="default") -> None: 22 | super().__init__(model, config, adapter_name) 23 | 24 | def _create_and_replace( 25 | self, mole_config: MoleConfig, adapter_name: str, 26 | target: nn.Module, target_name: str, parent: nn.Module, **kwargs: Any, 27 | ) -> None: 28 | """ 29 | Inplace replacement of the target module with the adapter layer 30 | """ 31 | kwargs = { 32 | "lora_rank": mole_config.lora_rank, 33 | "lora_alpha": mole_config.lora_alpha, 34 | "lora_dropout": mole_config.lora_dropout, 35 | "init_lora_weights": mole_config.init_lora_weights, 36 | "num_experts": mole_config.num_experts, 37 | "top_k": mole_config.top_k, 38 | "threshold": mole_config.threshold, 39 | } 40 | 41 | if isinstance(target, MoleLayer): 42 | target.update_layer(adapter_name, **kwargs) 43 | else: 44 | new_module = self._create_new_module(adapter_name, target, **kwargs) 45 | self._replace_module(parent, target_name, new_module, target) 46 | 47 | @staticmethod 48 | def _create_new_module(adapter_name: str, target: nn.Module, **kwargs: Any) -> nn.Module: 49 | """ 50 | Create the new LoRA module for the target module 51 | """ 52 | if isinstance(target, BaseTunerLayer): 53 | target_base_layer = target.get_base_layer() 54 | else: 55 | target_base_layer = target 56 | 57 | if isinstance(target_base_layer, torch.nn.Linear): 58 | new_module = LinearMoleLayer(base_layer=target, adapter_name=adapter_name, **kwargs) 59 | else: 60 | raise ValueError( 61 | f"The target module `{target}` is not supported. " 62 | f"Currently, only the following modules are supported: `torch.nn.Linear`.") 63 | 64 | return new_module 65 | 66 | def get_aux_loss(self, adapter_name="default") -> torch.Tensor: 67 | """ 68 | Get the load balancing loss for the whole model 69 | """ 70 | model_loss = torch.tensor(0, dtype=torch.float).to(self.model.device) 71 | 72 | for name, module in self.model.named_modules(): 73 | if name.endswith('moe_layer'): 74 | layer_loss = module[adapter_name].layer_loss 75 | model_loss += layer_loss 76 | 77 | return model_loss 78 | -------------------------------------------------------------------------------- /src/peft_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | PEFT Model 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | from __future__ import annotations 9 | 10 | import inspect 11 | import os 12 | from typing import Any, Dict, List, Optional, Union 13 | 14 | import torch 15 | from huggingface_hub import hf_hub_download 16 | from peft.utils import ( 17 | WEIGHTS_NAME, 18 | _set_adapter, 19 | _set_trainable, 20 | infer_device, 21 | load_peft_weights, 22 | ) 23 | from transformers import PreTrainedModel 24 | from transformers.utils import PushToHubMixin 25 | 26 | from .config import PeftConfig 27 | from .mapping import ( 28 | PEFT_TYPE_TO_CONFIG_MAPPING, 29 | PEFT_TYPE_TO_MODEL_MAPPING, 30 | ) 31 | from .utils.peft_types import TaskType 32 | from .utils.save_and_load import ( 33 | get_peft_model_state_dict, 34 | set_peft_model_state_dict 35 | ) 36 | 37 | 38 | class PeftModel(PushToHubMixin, torch.nn.Module): 39 | """ 40 | Parameter-Efficient Fine-Tuning (PEFT) Model 41 | 42 | :ivar base_model: base transformer model used for PEFT 43 | :ivar peft_config: configuration of the PEFT model 44 | :ivar modules_to_save: list of submodule names to save when saving the model 45 | """ 46 | base_model: [torch.nn.Module] 47 | peft_config: [PeftConfig] 48 | modules_to_save: [str] 49 | 50 | def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None: 51 | """ 52 | Initialize PeftModel 53 | 54 | :param model: base transformer model used for PEFT 55 | :param peft_config: configuration of the PEFT model 56 | :param adapter_name: name of the adapter 57 | """ 58 | super().__init__() 59 | self.modules_to_save = None 60 | self.active_adapter = adapter_name 61 | self.peft_type = peft_config.peft_type 62 | peft_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] 63 | self.base_model = peft_model(model, {adapter_name: peft_config}, adapter_name) 64 | self.set_additional_trainable_modules(peft_config, adapter_name) 65 | 66 | if getattr(model, "is_gradient_checkpointing", True): 67 | _ = self._prepare_model_for_gradient_checkpointing(model) 68 | 69 | def __getattr__(self, name: str): 70 | """ 71 | Forward missing attributes to the wrapped module 72 | """ 73 | try: 74 | return super().__getattr__(name) 75 | except AttributeError: 76 | return getattr(self.base_model, name) 77 | 78 | @property 79 | def peft_config(self) -> Dict[str, PeftConfig]: 80 | """ 81 | Get the PEFT configuration 82 | """ 83 | return self.base_model.peft_config 84 | 85 | @peft_config.setter 86 | def peft_config(self, value: Dict[str, PeftConfig]): 87 | """ 88 | Set the PEFT configuration 89 | """ 90 | self.base_model.peft_config = value 91 | 92 | @property 93 | def active_adapters(self) -> list[str]: 94 | """ 95 | Active adapters 96 | """ 97 | try: 98 | adapters = self.base_model.active_adapters 99 | except AttributeError: 100 | adapters = self.active_adapter 101 | if isinstance(adapters, str): 102 | adapters = [adapters] 103 | return adapters 104 | 105 | def _get_base_model_class(self): 106 | """ 107 | Return the base model class 108 | """ 109 | return self.base_model.model.__class__ 110 | 111 | def get_base_model(self) -> torch.nn.Module: 112 | """ 113 | Return the base model 114 | """ 115 | return self.base_model.model 116 | 117 | def set_additional_trainable_modules(self, peft_config: PeftConfig, adapter_name: str) -> None: 118 | """ 119 | Set additional trainable modules 120 | """ 121 | if getattr(peft_config, "modules_to_save", None) is not None: 122 | if self.modules_to_save is None: 123 | self.modules_to_save = set(peft_config.modules_to_save) 124 | else: 125 | self.modules_to_save.update(peft_config.modules_to_save) 126 | _set_trainable(self, adapter_name) 127 | 128 | def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel): 129 | """ 130 | Prepares the model for gradient checkpointing if necessary 131 | """ 132 | if hasattr(model, "enable_input_require_grads"): 133 | model.enable_input_require_grads() 134 | elif hasattr(model, "get_input_embeddings"): 135 | def make_inputs_require_grad(module, input, output): 136 | output.requires_grad_(True) 137 | 138 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 139 | return model 140 | 141 | def get_nb_trainable_parameters(self) -> tuple[int, int]: 142 | """ 143 | Return the number of trainable parameters and the number of all parameters in the model 144 | """ 145 | trainable_params = 0 146 | all_param = 0 147 | for _, param in self.named_parameters(): 148 | num_params = param.numel() 149 | all_param += num_params 150 | if param.requires_grad: 151 | trainable_params += num_params 152 | return trainable_params, all_param 153 | 154 | def print_trainable_parameters(self) -> None: 155 | """ 156 | Prints the number of trainable parameters in the model 157 | """ 158 | trainable_params, all_param = self.get_nb_trainable_parameters() 159 | print( 160 | f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || " 161 | f"trainable: {trainable_params / all_param:.2%}") 162 | 163 | def save_pretrained( 164 | self, 165 | save_directory: str, 166 | selected_adapters: Optional[List[str]] = None, 167 | save_embedding_layers: Union[str, bool] = "auto", 168 | is_main_process: bool = True, 169 | **kwargs: Any, 170 | ) -> None: 171 | """ 172 | Save the adapter model and the adapter configuration files to a directory, so that it can be reloaded 173 | 174 | :param save_directory: a directory where the adapter model and configuration files will be saved 175 | :param selected_adapters: a list of adapters to be saved (default to all adapters) 176 | :param save_embedding_layers: if `True`, save the embedding layers in addition to adapter weights; 177 | if `auto`, checks the common embedding layers in config's `target_modules` when available 178 | :param is_main_process: whether the process calling this is the main process or not 179 | :param kwargs: additional keyword arguments 180 | """ 181 | if os.path.isfile(save_directory): 182 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") 183 | 184 | if selected_adapters is None: 185 | selected_adapters = list(self.peft_config.keys()) 186 | else: 187 | if any( 188 | selected_adapter_name not in list(self.peft_config.keys()) 189 | for selected_adapter_name in selected_adapters 190 | ): 191 | raise ValueError( 192 | f"You passed an invalid `selected_adapters` arguments, current supported adapter names are" 193 | f" {list(self.peft_config.keys())} - got {selected_adapters}." 194 | ) 195 | 196 | if is_main_process: 197 | os.makedirs(save_directory, exist_ok=True) 198 | 199 | for adapter_name in selected_adapters: 200 | peft_config = self.peft_config[adapter_name] 201 | 202 | # Save only the trainable weights 203 | output_state_dict = get_peft_model_state_dict( 204 | self, 205 | state_dict=kwargs.get("state_dict", None), 206 | adapter_name=adapter_name, 207 | save_embedding_layers=save_embedding_layers, 208 | ) 209 | output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory 210 | os.makedirs(output_dir, exist_ok=True) 211 | 212 | if is_main_process: 213 | torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 214 | 215 | # Save the config and change the inference mode to `True` 216 | if peft_config.base_model_name_or_path is None: 217 | peft_config.base_model_name_or_path = ( 218 | self.base_model.model.__dict__.get("name_or_path", None) 219 | ) 220 | 221 | inference_mode = peft_config.inference_mode 222 | peft_config.inference_mode = True 223 | if is_main_process: 224 | peft_config.save_pretrained(output_dir) 225 | peft_config.inference_mode = inference_mode 226 | 227 | @classmethod 228 | def from_pretrained( 229 | cls, 230 | model: torch.nn.Module, 231 | model_id: Union[str, os.PathLike], 232 | adapter_name: str = "default", 233 | is_trainable: bool = False, 234 | config: Optional[PeftConfig] = None, 235 | **kwargs: Any, 236 | ) -> PeftModel: 237 | """ 238 | Instantiate a PEFT model from a pretrained model and loaded PEFT weights (Note that the passed `model` 239 | may be modified inplace.) 240 | 241 | :param model: the transformer model to be adapted 242 | :param model_id: the name of the PEFT configuration to use 243 | :param adapter_name: the name of the adapter to be loaded 244 | :param is_trainable: whether the adapter should be trainable or not 245 | :param config: the configuration object to use instead of an automatically loaded configuration 246 | :param kwargs: additional keyword arguments passed along to the specific PEFT configuration class 247 | :return: the PEFT model 248 | """ 249 | # Load the config 250 | if config is None: 251 | config = PEFT_TYPE_TO_CONFIG_MAPPING[ 252 | PeftConfig._get_peft_type( 253 | model_id, 254 | subfolder=kwargs.get("subfolder", None), 255 | revision=kwargs.get("revision", None), 256 | cache_dir=kwargs.get("cache_dir", None), 257 | use_auth_token=kwargs.get("use_auth_token", None), 258 | token=kwargs.get("token", None), 259 | ) 260 | ].from_pretrained(model_id, **kwargs) 261 | elif isinstance(config, PeftConfig): 262 | config.inference_mode = not is_trainable 263 | else: 264 | raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") 265 | 266 | config.inference_mode = not is_trainable 267 | model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name) 268 | model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) 269 | return model 270 | 271 | def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any): 272 | """ 273 | Load a trained adapter into the model (The new adapter is not automatically set as the active adapter. 274 | Use `PeftModel.set_adapter` to set the active adapter.) 275 | 276 | :param model_id: the name of the adapter to be added 277 | :param adapter_name: the configuration of the adapter to be added 278 | :param is_trainable: whether the adapter should be trainable or not 279 | :param kwargs: additional arguments to modify the way the adapter is loaded 280 | :return: 281 | """ 282 | hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs) 283 | torch_device = infer_device() 284 | 285 | if adapter_name not in self.peft_config: 286 | # Load the config 287 | peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[ 288 | PeftConfig._get_peft_type( 289 | model_id, 290 | **hf_hub_download_kwargs, 291 | ) 292 | ].from_pretrained( 293 | model_id, 294 | **hf_hub_download_kwargs, 295 | ) 296 | peft_config.inference_mode = not is_trainable 297 | self.add_adapter(adapter_name, peft_config) 298 | 299 | # Load the weights into the model 300 | adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs) 301 | load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name) 302 | 303 | # Set model in evaluation mode to deactivate dropout modules by default 304 | if not is_trainable: 305 | self.eval() 306 | 307 | return load_result 308 | 309 | @classmethod 310 | def _split_kwargs(cls, kwargs: Dict[str, Any]): 311 | """ 312 | Split keyword arguments 313 | """ 314 | _kwargs_not_in_hf_hub_download_signature = ("use_auth_token",) 315 | hf_hub_download_kwargs = {} 316 | other_kwargs = {} 317 | 318 | for key, value in kwargs.items(): 319 | if ( 320 | key in inspect.signature(hf_hub_download).parameters 321 | or key in _kwargs_not_in_hf_hub_download_signature 322 | ): 323 | hf_hub_download_kwargs[key] = value 324 | else: 325 | other_kwargs[key] = value 326 | 327 | return hf_hub_download_kwargs, other_kwargs 328 | 329 | def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: 330 | """ 331 | Add an adapter to the model based on the passed configuration (The new adapter is not automatically set as 332 | the active adapter. Use `PeftModel.set_adapter` to set the active adapter.) 333 | 334 | :param adapter_name: the name of the adapter to be added 335 | :param peft_config: the configuration of the adapter to be added 336 | """ 337 | if peft_config.peft_type != self.peft_type: 338 | raise ValueError( 339 | f"Cannot combine adapters with different peft types. " 340 | f"Found {self.peft_type} and {peft_config.peft_type}.") 341 | 342 | try: 343 | self.peft_config[adapter_name] = peft_config 344 | self.base_model.inject_adapter(self.base_model.model, adapter_name) 345 | except Exception: 346 | if adapter_name in self.peft_config: 347 | del self.peft_config[adapter_name] 348 | raise 349 | 350 | self.set_additional_trainable_modules(peft_config, adapter_name) 351 | 352 | def set_adapter(self, adapter_name: str) -> None: 353 | """ 354 | Sets the active adapter (Only one adapter can be active at a time.) 355 | 356 | :param adapter_name: the name of the adapter to be set as active 357 | """ 358 | if adapter_name not in self.peft_config: 359 | raise ValueError(f"Adapter {adapter_name} not found.") 360 | self.active_adapter = adapter_name 361 | if not self.peft_config[adapter_name].is_prompt_learning: 362 | self.base_model.set_adapter(adapter_name) 363 | _set_adapter(self, adapter_name) 364 | 365 | def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: 366 | """ 367 | Forward pass of the model 368 | """ 369 | return self.get_base_model()(*args, **kwargs) 370 | 371 | 372 | class PeftModelForCausalLM(PeftModel): 373 | """ 374 | PEFT Model for Causal Language Modeling 375 | """ 376 | 377 | def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None: 378 | """ 379 | Initialize PeftModelForCausalLM 380 | 381 | :param model: base transformer model 382 | :param peft_config: PEFT configuration 383 | :param adapter_name: adapter name 384 | """ 385 | super().__init__(model, peft_config, adapter_name) 386 | self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation 387 | 388 | def forward( 389 | self, 390 | input_ids=None, 391 | attention_mask=None, 392 | inputs_embeds=None, 393 | labels=None, 394 | output_attentions=None, 395 | output_hidden_states=None, 396 | return_dict=None, 397 | task_ids=None, 398 | **kwargs, 399 | ) -> torch.Tensor: 400 | """ 401 | Forward function 402 | """ 403 | return self.base_model( 404 | input_ids=input_ids, 405 | attention_mask=attention_mask, 406 | inputs_embeds=inputs_embeds, 407 | labels=labels, 408 | output_attentions=output_attentions, 409 | output_hidden_states=output_hidden_states, 410 | return_dict=return_dict, 411 | **kwargs, 412 | ) 413 | 414 | def generate(self, **kwargs): 415 | """ 416 | Generate the text 417 | """ 418 | self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation 419 | if hasattr(self.base_model, "model"): 420 | self.base_model.model.generation_config = self.generation_config 421 | else: 422 | self.base_model.generation_config = self.generation_config 423 | try: 424 | outputs = self.base_model.generate(**kwargs) 425 | except: 426 | self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation 427 | raise 428 | else: 429 | self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation 430 | return outputs 431 | 432 | def prepare_inputs_for_generation(self, *args, **kwargs): 433 | """ 434 | Prepare inputs for text generation 435 | """ 436 | model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) 437 | return model_kwargs 438 | 439 | 440 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { 441 | TaskType.CAUSAL_LM: PeftModelForCausalLM, 442 | } 443 | 444 | 445 | def get_peft_model( 446 | model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", 447 | ) -> PeftModel: 448 | """ 449 | Return a PEFT model object from a pre-trained model and a PEFT config 450 | 451 | :param model: model to be wrapped 452 | :param peft_config: configuration containing the parameters of the PEFT model 453 | :param adapter_name: name of the adapter to be injected 454 | :return: 455 | """ 456 | 457 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 458 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): 459 | return PeftModel(model, peft_config, adapter_name=adapter_name) 460 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) 461 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer 3 | """ 4 | import os 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | from transformers import Trainer 10 | from transformers.trainer import TRAINING_ARGS_NAME, logger 11 | 12 | from .peft_model import PeftModel 13 | 14 | 15 | class PeftTrainer(Trainer): 16 | """ 17 | Trainer for the PEFT Model 18 | """ 19 | 20 | def __init__(self, aux_loss_coeff=1e-2, **kwargs): 21 | """ 22 | Initialize PeftTrainer 23 | 24 | :param aux_loss_coeff: a coefficient for the load balancing loss in Mixture-of-Experts (MoE) models 25 | :param kwargs: additional keyword arguments 26 | """ 27 | super().__init__(**kwargs) 28 | self.loss_fn = nn.CrossEntropyLoss() 29 | self.aux_loss_coeff = aux_loss_coeff 30 | 31 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 32 | """ 33 | Save the model and tokenizer 34 | """ 35 | output_dir = output_dir if output_dir is not None else self.args.output_dir 36 | os.makedirs(output_dir, exist_ok=True) 37 | logger.info(f"Saving model checkpoint to {output_dir}") 38 | self.model.save_pretrained(output_dir, state_dict=state_dict) 39 | 40 | if self.tokenizer is not None: 41 | self.tokenizer.save_pretrained(output_dir) 42 | 43 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 44 | 45 | def compute_loss(self, model: PeftModel, inputs, return_outputs=False): 46 | """ 47 | Compute the loss by the trainer 48 | """ 49 | outputs = model(**inputs) 50 | if "loss" in outputs: 51 | loss = outputs.get("loss") 52 | else: 53 | logits = outputs.get("logits") 54 | labels = inputs.get("labels") 55 | loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) 56 | if hasattr(model, 'get_aux_loss'): 57 | aux_loss = model.get_aux_loss() 58 | loss += self.aux_loss_coeff * aux_loss 59 | return (loss, outputs) if return_outputs else loss 60 | -------------------------------------------------------------------------------- /src/utils/peft_types.py: -------------------------------------------------------------------------------- 1 | """ 2 | PEFT and Task Types 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | import enum 9 | 10 | 11 | class PeftType(str, enum.Enum): 12 | """ 13 | PEFT Adapter Types 14 | """ 15 | LORA = "LORA" 16 | MOLE = "MOLE" 17 | ADAMOLE = "ADAMOLE" 18 | 19 | 20 | class TaskType(str, enum.Enum): 21 | """ 22 | PEFT Task Type 23 | """ 24 | CAUSAL_LM = "CAUSAL_LM" 25 | -------------------------------------------------------------------------------- /src/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | """ 2 | Saving and Loading Models 3 | 4 | Portions of this file are modifications based on work created and 5 | shared by the HuggingFace Inc. team and used according to terms 6 | described in the Apache License 2.0. 7 | """ 8 | import warnings 9 | 10 | from peft.utils.other import EMBEDDING_LAYER_NAMES 11 | from peft.utils.save_and_load import has_valid_embedding_base_layer, get_embedding_layer_name 12 | 13 | from ..utils.peft_types import PeftType 14 | 15 | 16 | def get_peft_model_state_dict( 17 | model, state_dict: dict = None, adapter_name: str = "default", unwrap_compiled: bool = False, 18 | save_embedding_layers: str | bool = "auto" 19 | ): 20 | """ 21 | Get the state dict of the PEFT model 22 | 23 | :param model: the PEFT model 24 | :param state_dict: the state dict of the model (If not provided, the state dict of the passed model will be used.) 25 | :param adapter_name: the name of the adapter whose state dict should be returned 26 | :param unwrap_compiled: whether to unwrap the model if `torch.compile` was used 27 | :param save_embedding_layers: if `True`, save the embedding layers in addition to adapter weights; 28 | if `auto`, checks the common embedding layers in config's `target_modules` when available 29 | :return: 30 | """ 31 | if unwrap_compiled: 32 | model = getattr(model, "_orig_mod", model) 33 | 34 | config = model.peft_config[adapter_name] 35 | 36 | if state_dict is None: 37 | state_dict = model.state_dict() 38 | 39 | if config.peft_type in (PeftType.LORA, PeftType.MOLE, PeftType.ADAMOLE): 40 | bias = config.bias 41 | if bias == "none": 42 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 43 | elif bias == "all": 44 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 45 | elif bias == "lora_only": 46 | to_return = {} 47 | for k in state_dict: 48 | if "lora_" in k: 49 | to_return[k] = state_dict[k] 50 | bias_name = k.split("lora_")[0] + "bias" 51 | if bias_name in state_dict: 52 | to_return[bias_name] = state_dict[bias_name] 53 | else: 54 | raise NotImplementedError 55 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} 56 | else: 57 | raise NotImplementedError 58 | 59 | if getattr(model, "modules_to_save", None) is not None: 60 | for key, value in state_dict.items(): 61 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): 62 | to_return[key.replace("modules_to_save.", "")] = value 63 | 64 | # Check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary 65 | if ( 66 | save_embedding_layers == "auto" 67 | and hasattr(config, "target_modules") 68 | and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES) 69 | ): 70 | warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.") 71 | save_embedding_layers = True 72 | elif save_embedding_layers == "auto": 73 | save_embedding_layers = False 74 | 75 | if save_embedding_layers and hasattr(model, "get_input_embeddings"): 76 | for layer in [model.get_input_embeddings(), model.get_output_embeddings()]: 77 | if config.is_prompt_learning or has_valid_embedding_base_layer(layer): 78 | embedding_module_name = get_embedding_layer_name(model, layer, config.is_prompt_learning) 79 | if embedding_module_name: 80 | to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k}) 81 | elif save_embedding_layers: 82 | warnings.warn("Could not identify embedding layer(s) because the model is not a model in transformers.") 83 | 84 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} 85 | return to_return 86 | 87 | 88 | def set_peft_model_state_dict(model, peft_model_state_dict: dict, adapter_name="default"): 89 | """ 90 | Set the state dict of the PEFT model 91 | 92 | :param model: the PEFT model. 93 | :param peft_model_state_dict: the state dict of the PEFT model 94 | :param adapter_name: the adapter name 95 | :return: 96 | """ 97 | config = model.peft_config[adapter_name] 98 | state_dict = {} 99 | 100 | if getattr(model, "modules_to_save", None) is not None: 101 | for key, value in peft_model_state_dict.items(): 102 | if any(module_name in key for module_name in model.modules_to_save): 103 | for module_name in model.modules_to_save: 104 | if module_name in key: 105 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") 106 | break 107 | state_dict[key] = value 108 | else: 109 | state_dict = peft_model_state_dict 110 | 111 | if config.peft_type in (PeftType.LORA, PeftType.MOLE, PeftType.ADAMOLE): 112 | peft_model_state_dict = {} 113 | parameter_prefix = { 114 | PeftType.LORA: "lora_", 115 | PeftType.MOLE: "lora_", 116 | PeftType.ADAMOLE: "lora_", 117 | }[config.peft_type] 118 | for k, v in state_dict.items(): 119 | if parameter_prefix in k: 120 | suffix = k.split(parameter_prefix)[1] 121 | if "." in suffix: 122 | suffix_to_replace = ".".join(suffix.split(".")[1:]) 123 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") 124 | else: 125 | k = f"{k}.{adapter_name}" 126 | peft_model_state_dict[k] = v 127 | else: 128 | peft_model_state_dict[k] = v 129 | else: 130 | raise NotImplementedError 131 | 132 | load_result = model.load_state_dict(peft_model_state_dict, strict=False) 133 | return load_result 134 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing LLMs on Benchmarks 3 | """ 4 | import argparse 5 | import json 6 | import os 7 | import re 8 | 9 | import pandas as pd 10 | import torch 11 | import transformers 12 | from tqdm import tqdm 13 | from transformers import ( 14 | AutoTokenizer, 15 | AutoModelForCausalLM, 16 | TextGenerationPipeline, 17 | GenerationConfig, 18 | ) 19 | from transformers.pipelines.pt_utils import KeyDataset 20 | 21 | from data import get_formatted_datasets 22 | from src import PeftConfig, PeftModelForCausalLM 23 | 24 | transformers.set_seed(0) 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | 27 | 28 | def predict_choices(examples): 29 | """ 30 | Predict choices 31 | """ 32 | prompts = examples['text'] 33 | inputs = tokenizer(prompts, return_tensors="pt", padding=True) 34 | inputs = {key: value.to(device) for key, value in inputs.items()} 35 | 36 | with torch.no_grad(): 37 | outputs = model(**inputs) 38 | 39 | logits = outputs.logits[:, -1, :] 40 | choices = [chr(ord('A') + i) for i in range(max(examples['num_choices']))] 41 | choice_ids = [tokenizer.encode(choice, add_special_tokens=False)[-1] for choice in choices] 42 | 43 | predicted_ids = torch.argmax(logits[:, choice_ids], dim=-1) 44 | predictions = [choices[predicted_id] for predicted_id in predicted_ids.cpu().numpy()] 45 | examples['prediction'] = predictions 46 | 47 | return examples 48 | 49 | 50 | if __name__ == '__main__': 51 | # Add arguments 52 | parser = argparse.ArgumentParser( 53 | description='Fine-tuning LLMs on training data.', 54 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 55 | fromfile_prefix_chars='@') 56 | parser.add_argument( 57 | '--model_path', type=str, default='outputs/llama-2-7b-hf-adamole-the8-commonsense-qa', 58 | help='huggingface model id or local model path') 59 | parser.add_argument( 60 | '--data_path', type=str, default='tau/commonsense_qa', 61 | help='huggingface data id or local data path') 62 | parser.add_argument( 63 | '--max_new_tokens', type=int, default=16, 64 | help='maximum number of new tokens') 65 | parser.add_argument( 66 | '--batch_size', type=int, default=16, 67 | help='batch size in the pipeline') 68 | parser.add_argument( 69 | '--logits', default=False, action='store_true', 70 | help='checking choice logits instead of generated texts') 71 | 72 | # Parse arguments 73 | args = parser.parse_args() 74 | model_path = args.model_path 75 | data_path = args.data_path 76 | model_name = os.path.basename(model_path).lower() 77 | data_name = os.path.basename(data_path).lower() 78 | max_new_tokens = args.max_new_tokens 79 | batch_size = args.batch_size 80 | if data_name in ['openbookqa', 'ai2_arc']: 81 | split = 'test' 82 | else: 83 | split = 'validation' 84 | 85 | # Load and format datasets 86 | formatted_datasets = get_formatted_datasets(data_path=data_path, prompt_only=True) 87 | 88 | # Load the configuration and model 89 | peft_config = PeftConfig.from_pretrained(model_path) 90 | base_model = AutoModelForCausalLM.from_pretrained( 91 | peft_config.base_model_name_or_path, 92 | # torch_dtype=torch.bfloat16, 93 | ) 94 | tokenizer = AutoTokenizer.from_pretrained( 95 | peft_config.base_model_name_or_path, 96 | padding_side="left", 97 | ) 98 | tokenizer.pad_token = tokenizer.eos_token 99 | model = PeftModelForCausalLM.from_pretrained(model=base_model, model_id=model_path) 100 | model.to(device) 101 | print(f'Model loaded from {model_path}') 102 | print(f'Model: {model}') 103 | 104 | if not args.logits: 105 | # Build the pipeline 106 | generation_config = GenerationConfig( 107 | max_new_tokens=max_new_tokens, 108 | do_sample=False, 109 | ) 110 | pipeline = TextGenerationPipeline( 111 | model=model, 112 | tokenizer=tokenizer, 113 | device=device, 114 | ) 115 | 116 | # Get the model responses 117 | responses = [] 118 | for response in tqdm( 119 | pipeline( 120 | KeyDataset(formatted_datasets[split], 'text'), 121 | generation_config=generation_config, 122 | return_full_text=False, 123 | batch_size=batch_size, 124 | ), 125 | total=len(formatted_datasets[split]), 126 | ): 127 | responses.append(response[0]['generated_text']) 128 | 129 | # Print one response 130 | print(f'Response example:\n{responses[0]}') 131 | 132 | # Get the results 133 | df = formatted_datasets[split].to_pandas() 134 | df['response'] = responses 135 | df['prediction'] = df['response'].str.extract(pat=r'\b([A-Z])\b')[0] 136 | 137 | else: 138 | # Get predictions 139 | dataset_with_predictions = formatted_datasets[split].map( 140 | predict_choices, batched=True, batch_size=batch_size) 141 | df = dataset_with_predictions.to_pandas() 142 | 143 | # Save the results 144 | result_path = os.path.join(model_path, f'{split}_results.csv') 145 | df.to_csv(result_path, index=False) 146 | print(f'Results saved to {result_path}') 147 | 148 | # Compute evaluation metrics 149 | metrics = {} 150 | for _data_name in df['data_name'].unique(): 151 | df_set = df[df['data_name'] == _data_name] 152 | accuracy = pd.Series(df_set['answer'] == df_set['prediction']).mean() 153 | print(f'Accuracy of {_data_name}: {accuracy:.2%}') 154 | metrics['accuracy_' + re.sub(r'\W', '_', _data_name)] = accuracy 155 | 156 | # Save evaluation metrics 157 | metric_path = os.path.join(model_path, f'{split}_metrics.json') 158 | with open(metric_path, 'w') as file: 159 | json.dump(metrics, file) 160 | print(f'Metrics saved to {metric_path}') 161 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fine-Tuning LLMs on Tasks 3 | """ 4 | import argparse 5 | import os 6 | import re 7 | 8 | import torch 9 | import transformers 10 | from transformers import ( 11 | AutoTokenizer, 12 | AutoModelForCausalLM, 13 | TrainingArguments, 14 | DataCollatorForLanguageModeling, 15 | ) 16 | 17 | from data import get_formatted_datasets 18 | from src import ( 19 | TaskType, 20 | LoraConfig, 21 | MoleConfig, 22 | AdaMoleConfig, 23 | PeftTrainer, 24 | PeftModelForCausalLM, 25 | ) 26 | 27 | transformers.set_seed(0) 28 | 29 | if __name__ == '__main__': 30 | # Add arguments 31 | parser = argparse.ArgumentParser( 32 | description='Fine-tuning LLMs on training data.', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 34 | fromfile_prefix_chars='@') 35 | parser.add_argument( 36 | '--model_path', type=str, default='meta-llama/Llama-2-7b-hf', 37 | help='huggingface model id or local model path') 38 | parser.add_argument( 39 | '--data_path', type=str, default='tau/commonsense_qa', 40 | help='huggingface data id or local data path') 41 | parser.add_argument( 42 | '--peft_type', type=str, default='lora', choices=['lora', 'mole', 'adamole'], 43 | help='peft model type to be fine-tuned') 44 | parser.add_argument( 45 | '--lora_rank', type=int, default=32, 46 | help='lora rank if the peft type is lora or total lora rank if moe') 47 | parser.add_argument( 48 | '--target_modules', type=str, default=['q_proj', 'v_proj'], nargs='+', 49 | help='target modules in lora layers') 50 | parser.add_argument( 51 | '--num_experts', type=int, default=1, 52 | help='number of experts in each moe layer') 53 | parser.add_argument( 54 | '--top_k', type=int, default=None, 55 | help='top-k experts in moe (only one of top_k or threshold can be used)') 56 | parser.add_argument( 57 | '--threshold', type=float, default=None, 58 | help='threshold for expert gating in moe (only one of top_k or threshold can be used)') 59 | parser.add_argument( 60 | '--max_length', type=int, default=256, 61 | help='maximum number of tokens') 62 | parser.add_argument( 63 | '--batch_size', type=int, default=16, 64 | help='batch size in the trainer') 65 | parser.add_argument( 66 | '--gradient_accumulation_steps', type=int, default=1, 67 | help='gradient accumulation steps') 68 | parser.add_argument( 69 | '--num_train_epochs', type=int, default=1, 70 | help='number of training epochs') 71 | parser.add_argument( 72 | '--learning_rate', type=float, default=1e-4, 73 | help='learning rate for training') 74 | parser.add_argument( 75 | '--lr_scheduler_type', type=str, default="constant_with_warmup", 76 | help='learning rate scheduler type') 77 | parser.add_argument( 78 | '--warmup_steps', type=int, default=200, 79 | help='number of warmup steps for training') 80 | parser.add_argument( 81 | '--weight_decay', type=float, default=0.0, 82 | help='weight decay') 83 | parser.add_argument( 84 | '--aux_loss_coeff', type=float, default=None, 85 | help='auxiliary loss coefficient for moe') 86 | 87 | # Parse arguments 88 | args = parser.parse_args() 89 | print(f'Arguments: {args}') 90 | model_path = args.model_path 91 | data_path = args.data_path 92 | model_name = os.path.basename(model_path).lower() 93 | data_name = os.path.basename(data_path).lower() 94 | peft_type = args.peft_type 95 | num_experts = args.num_experts 96 | max_length = args.max_length 97 | lora_rank = args.lora_rank if peft_type == 'lora' else args.lora_rank // num_experts 98 | lora_alpha = 16 99 | lora_dropout = 0.05 100 | peft_type_name = peft_type 101 | if args.top_k is not None: 102 | peft_type_name += f'-top{args.top_k}' 103 | if args.threshold is not None: 104 | threshold_name = int(1 / args.threshold) 105 | peft_type_name += f'-the{threshold_name}' 106 | output_dir = os.path.join('outputs', re.sub(r'[^0-9a-zA-Z]', '-', f'{model_name}-{peft_type_name}-{data_name}')) 107 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 108 | 109 | # Load and format datasets 110 | formatted_datasets = get_formatted_datasets(data_path=data_path, prompt_only=False) 111 | 112 | # Load the tokenizer 113 | tokenizer = AutoTokenizer.from_pretrained( 114 | model_path, 115 | padding_side="left", 116 | # add_bos_token=True, 117 | add_eos_token=True, 118 | ) 119 | tokenizer.pad_token = tokenizer.eos_token 120 | 121 | # Tokenize datasets 122 | tokenize_text = lambda examples: tokenizer( 123 | examples["text"], 124 | truncation=True, 125 | max_length=max_length, 126 | # padding=True, 127 | # return_tensors="pt", 128 | ) 129 | tokenized_datasets = formatted_datasets.map( 130 | tokenize_text, 131 | batched=True, 132 | remove_columns=formatted_datasets["train"].column_names, 133 | ) 134 | print(f'Tokenized datasets: {tokenized_datasets}') 135 | 136 | # Set the data collator 137 | data_collator = DataCollatorForLanguageModeling( 138 | tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="pt") 139 | 140 | # Load the base model 141 | base_model = AutoModelForCausalLM.from_pretrained( 142 | model_path, 143 | # torch_dtype=torch.bfloat16, 144 | # device_map="auto", 145 | ) 146 | print(f'Base model loaded from {model_path}') 147 | print(f'Base model: {base_model}') 148 | 149 | # Get the PEFT model 150 | if peft_type == 'lora': 151 | peft_config = LoraConfig( 152 | lora_rank=lora_rank, 153 | lora_alpha=lora_alpha, 154 | lora_dropout=lora_dropout, 155 | target_modules=args.target_modules, 156 | task_type=TaskType.CAUSAL_LM, 157 | bias="none", 158 | ) 159 | elif peft_type == 'mole': 160 | peft_config = MoleConfig( 161 | lora_rank=lora_rank, 162 | lora_alpha=lora_alpha, 163 | lora_dropout=lora_dropout, 164 | target_modules=args.target_modules, 165 | task_type=TaskType.CAUSAL_LM, 166 | bias="none", 167 | num_experts=num_experts, 168 | top_k=args.top_k, 169 | threshold=args.threshold, 170 | ) 171 | elif peft_type == 'adamole': 172 | peft_config = AdaMoleConfig( 173 | lora_rank=lora_rank, 174 | lora_alpha=lora_alpha, 175 | lora_dropout=lora_dropout, 176 | target_modules=args.target_modules, 177 | task_type=TaskType.CAUSAL_LM, 178 | bias="none", 179 | num_experts=num_experts, 180 | max_threshold=args.threshold, 181 | ) 182 | else: 183 | raise KeyError(f'Unsupported PEFT type: {peft_type}') 184 | 185 | model = PeftModelForCausalLM(base_model, peft_config) 186 | model.enable_input_require_grads() 187 | model.print_trainable_parameters() 188 | print(f'Model: {model}') 189 | 190 | # Set the trainer 191 | training_args = TrainingArguments( 192 | output_dir=output_dir, 193 | overwrite_output_dir=True, 194 | group_by_length=True, 195 | remove_unused_columns=False, 196 | logging_strategy="steps", 197 | logging_steps=10, 198 | evaluation_strategy="steps", 199 | eval_steps=200, 200 | save_strategy="epoch", 201 | # save_steps=1000, 202 | optim="adamw_torch", 203 | per_device_train_batch_size=args.batch_size, 204 | per_device_eval_batch_size=args.batch_size, 205 | gradient_accumulation_steps=args.gradient_accumulation_steps, 206 | gradient_checkpointing=False, 207 | num_train_epochs=args.num_train_epochs, 208 | learning_rate=args.learning_rate, 209 | lr_scheduler_type=args.lr_scheduler_type, 210 | warmup_steps=args.warmup_steps, 211 | weight_decay=args.weight_decay, 212 | # fp16=True, 213 | seed=0, 214 | data_seed=0, 215 | report_to=["tensorboard"], 216 | ) 217 | trainer = PeftTrainer( 218 | model=model, 219 | tokenizer=tokenizer, 220 | args=training_args, 221 | data_collator=data_collator, 222 | train_dataset=tokenized_datasets["train"], 223 | eval_dataset=tokenized_datasets["validation"], 224 | aux_loss_coeff=args.aux_loss_coeff, 225 | ) 226 | with open(os.path.join(output_dir, 'training_args.json'), 'w') as output_file: 227 | output_file.write(training_args.to_json_string()) 228 | 229 | # Train the model 230 | model.config.use_cache = False 231 | trainer.train() 232 | model.config.use_cache = True 233 | 234 | # Save the model 235 | trainer.save_model() 236 | trainer.save_state() 237 | --------------------------------------------------------------------------------