├── .gitignore
├── LICENSE
├── README.md
├── configs
├── llama_adamole_csr_test.config
└── llama_adamole_csr_train.config
├── data.py
├── images
└── adamole.png
├── poster.pdf
├── requirements.txt
├── src
├── __init__.py
├── adamole
│ ├── __init__.py
│ ├── config.py
│ ├── layer.py
│ └── model.py
├── config.py
├── lora
│ ├── __init__.py
│ ├── config.py
│ ├── layer.py
│ └── model.py
├── mapping.py
├── mole
│ ├── __init__.py
│ ├── config.py
│ ├── layer.py
│ └── model.py
├── peft_model.py
├── trainer.py
└── utils
│ ├── peft_types.py
│ └── save_and_load.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # Project files
163 | outputs/
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Zefang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # \[COLM 2024\] AdaMoLE: Adaptive Mixture of Low-Rank Adaptation Experts
2 |
3 | ## Introduction
4 |
5 | Explore AdaMoLE, a novel approach that integrates Low-Rank Adaptation (LoRA) with a dynamic Mixture of Experts (MoE) to enhance fine-tuning of Large Language Models (LLMs). AdaMoLE advances beyond static top-k expert activation by employing a dynamic thresholding mechanism, which adapts to the complexities of varied tasks to optimize model performance. This method efficiently selects and activates the most suitable experts based on input context, demonstrating superior performance in commonsense reasoning and natural language processing tasks.
6 |
7 | For more details regarding AdaMoLE, you are welcome to refer to our [paper](https://arxiv.org/abs/2405.00361) and [poster](poster.pdf).
8 |
9 |
10 |

11 |
12 |
13 | ## Features
14 |
15 | - **Dynamic Expert Activation:** Improves expert activation with a dynamic threshold network, adapting to task complexities for optimal expert engagement.
16 | - **Integration of LoRA and MoE:** Seamlessly combines Low-Rank Adaptation with the Mixture of Experts framework, enhancing the fine-tuning process for LLMs.
17 | - **Hugging Face Compatibility:** Designed to be compatible with the Hugging Face's [Transformers](https://github.com/huggingface/transformers) and [Parameter-Efficient Fine-Tuning (PEFT)](https://github.com/huggingface/peft) library, ensuring ease of use and integration into existing workflows.
18 |
19 | ## Installation
20 |
21 | ```bash
22 | # Navigate to the AdaMoLE directory
23 | cd AdaMoLE
24 |
25 | # Install required dependencies
26 | pip install -r requirements.txt
27 | ```
28 |
29 | ## Usage
30 |
31 | ```bash
32 | # Train the model
33 | python train.py @configs/llama_adamole_csr_train.config
34 |
35 | # Test the model
36 | python test.py @configs/llama_adamole_csr_test.config
37 | ```
38 |
39 | ## Citation
40 |
41 | If you find AdaMoLE useful in your projects, please consider citing our paper:
42 |
43 | Liu, Z., & Luo, J. (2024). AdaMoLE: Fine-Tuning Large Language Models with Adaptive Mixture of Low-Rank Adaptation Experts. arXiv preprint *arXiv:2405.00361*.
44 |
45 | ```bibtex
46 | @article{liu2024adamole,
47 | title={AdaMoLE: Fine-Tuning Large Language Models with Adaptive Mixture of Low-Rank Adaptation Experts},
48 | author={Liu, Zefang and Luo, Jiahua},
49 | journal={arXiv preprint arXiv:2405.00361},
50 | year={2024}
51 | }
52 | ```
53 |
54 | ## License
55 |
56 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
57 |
--------------------------------------------------------------------------------
/configs/llama_adamole_csr_test.config:
--------------------------------------------------------------------------------
1 | --model_path=outputs/llama-2-7b-hf-adamole-the8-commonsense-qa
2 | --data_path=tau/commonsense_qa
3 | --max_new_tokens=16
4 | --batch_size=16
5 | --logits
--------------------------------------------------------------------------------
/configs/llama_adamole_csr_train.config:
--------------------------------------------------------------------------------
1 | --model_path=meta-llama/Llama-2-7b-hf
2 | --data_path=tau/commonsense_qa
3 | --peft_type=adamole
4 | --lora_rank=32
5 | --target_modules
6 | q_proj
7 | k_proj
8 | v_proj
9 | o_proj
10 | --num_experts=8
11 | --threshold=0.125
12 | --max_length=256
13 | --batch_size=4
14 | --gradient_accumulation_steps=4
15 | --num_train_epochs=2
16 | --learning_rate=1e-4
17 | --lr_scheduler_type=constant_with_warmup
18 | --warmup_steps=200
19 | --weight_decay=0.0
20 | --aux_loss_coeff=1e-3
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Loading and Preprocessing Datasets
3 | """
4 | import os
5 |
6 | from datasets import load_dataset, concatenate_datasets, DatasetDict
7 |
8 |
9 | def format_text(example, data_name: str, prompt_only: bool = True):
10 | """
11 | Format an example into one text
12 | """
13 | if data_name == 'boolq':
14 | """
15 | Passage: Windows Movie Maker -- Windows Movie Maker (formerly known as Windows Live Movie Maker in Windows 7)
16 | is a discontinued video editing software by Microsoft. It is a part of Windows Essentials software suite and
17 | offers the ability to create and edit videos as well as to publish them on OneDrive, Facebook, Vimeo, YouTube,
18 | and Flickr.
19 | Question: is windows movie maker part of windows essentials
20 | Choices:
21 | A. No
22 | B. Yes
23 | Answer: B
24 | """
25 | text = f"Passage: {example['passage']}\nQuestion: {example['question']}\nChoices:\n"
26 | text += "A. No\nB. Yes\n"
27 | text += "Answer: "
28 | example['answer'] = ['A', 'B'][example['label']]
29 | example['num_choices'] = 2
30 |
31 | elif data_name == 'cb':
32 | """
33 | Text: It was a complex language. Not written down but handed down. One might say it was peeled down.
34 | Hypothesis: the language was peeled down
35 | Question: Does the text entail the hypothesis, contradict it, or is it neutral?
36 | Choices:
37 | A. Entailment
38 | B. Contradiction
39 | C. Neutral
40 | Answer: A
41 | """
42 | text = f"Text: {example['premise']}\nHypothesis: {example['hypothesis']}\n" \
43 | f"Question: Does the text entail the hypothesis, contradict it, or is it neutral?\nChoices:\n"
44 | text += "A. Entailment\nB. Contradiction\nC. Neutral\n"
45 | text += "Answer: "
46 | example['answer'] = ['A', 'B', 'C'][example['label']]
47 | example['num_choices'] = 3
48 |
49 | elif data_name == 'copa':
50 | """
51 | Premise: My body cast a shadow over the grass.
52 | Question: What’s the cause for this?
53 | Choices:
54 | A. The sun was rising.
55 | B. The grass was cut.
56 | Answer: A
57 | """
58 | text = f"Premise: {example['premise']}\nQuestion: What’s the {example['question']} for this?\nChoices:\n"
59 | text += f"A. {example['choice1']}\nB. {example['choice2']}\n"
60 | text += "Answer: "
61 | example['answer'] = ['A', 'B'][example['label']]
62 | example['num_choices'] = 2
63 |
64 | elif data_name == 'multirc':
65 | """
66 | Paragraph: While this process moved along, diplomacy continued its rounds. Direct pressure on the Taliban had
67 | proved unsuccessful. As one NSC staff note put it, "Under the Taliban, Afghanistan is not so much a state
68 | sponsor of terrorism as it is a state sponsored by terrorists." ...
69 | Question: What did the high-level effort to persuade Pakistan include?
70 | Candidate Answer: Children, Gerd, or Dorian Popa
71 | Choices:
72 | A. False
73 | B. True
74 | Answer: A
75 | """
76 | text = f"Paragraph: {example['paragraph']}\nQuestion: {example['question']}\n" \
77 | f"Candidate Answer: {example['answer']}\nChoices:\n"
78 | text += f"A. False\nB. True\n"
79 | text += "Answer: "
80 | example['answer'] = ['A', 'B'][example['label']]
81 | example['num_choices'] = 2
82 |
83 | elif data_name == 'record':
84 | raise NotImplementedError
85 |
86 | elif data_name == 'rte':
87 | """
88 | Text: No Weapons of Mass Destruction Found in Iraq Yet.
89 | Hypothesis: Weapons of Mass Destruction Found in Iraq.
90 | Question: Does the text entail the hypothesis or not?
91 | Choices:
92 | A. Entailment
93 | B. Not entailment
94 | Answer: B
95 | """
96 | text = f"Text: {example['premise']}\nHypothesis: {example['hypothesis']}\n" \
97 | f"Question: Does the text entail the hypothesis or not?\nChoices:\n"
98 | text += "A. Entailment\nB. Not entailment\n"
99 | text += "Answer: "
100 | example['answer'] = ['A', 'B'][example['label']]
101 | example['num_choices'] = 2
102 |
103 | elif data_name == 'wic':
104 | """
105 | Context 1: Do you want to come over to my later?
106 | Context 2: A political system with no for the less prominent groups.
107 | Question: Is the word in brackets used with the same meaning in both contexts?
108 | Choices:
109 | A. False
110 | B. True
111 | Answer: A
112 | """
113 | sentence1 = example['sentence1']
114 | sentence2 = example['sentence2']
115 | marked_sentence1 = sentence1[:example['start1']] + '<' + sentence1[example['start1']:example['end1']] \
116 | + '>' + sentence1[example['end1']:]
117 | marked_sentence2 = sentence2[:example['start2']] + '<' + sentence2[example['start2']:example['end2']] \
118 | + '>' + sentence2[example['end2']:]
119 | text = f"Context 1: {marked_sentence1}\nContext 2: {marked_sentence2}\n" \
120 | f"Question: Is the word in brackets used with the same meaning in both contexts?\nChoices:\n"
121 | text += "A. False\nB. True\n"
122 | text += "Answer: "
123 | example['answer'] = ['A', 'B'][example['label']]
124 | example['num_choices'] = 2
125 |
126 | elif data_name == 'wsc.fixed':
127 | """
128 | Text: told Pete many lies about himself, which Pete included in his book. should have been more
129 | skeptical.
130 | Question: Is the pronoun in brackets referring to the correct entity as intended in the context?
131 | Choices:
132 | A. False
133 | B. True
134 | Answer: A
135 | """
136 | tokens = example['text'].split()
137 | span1_start = example['span1_index']
138 | span1_end = example['span1_index'] + len(example['span1_text'].split())
139 | span2_start = example['span2_index']
140 | span2_end = example['span2_index'] + len(example['span2_text'].split())
141 | marked_tokens = tokens[:span1_start] + ['<' + example['span1_text'] + '>'] + tokens[span1_end:span2_start] \
142 | + ['<' + example['span2_text'] + '>'] + tokens[span2_end:]
143 | marked_text = ' '.join(marked_tokens)
144 | text = f"Text: {marked_text}\n" \
145 | f"Question: Is the pronoun in brackets referring to the correct entity as intended in the context?\n" \
146 | f"Choices:\n"
147 | text += "A. False\nB. True\n"
148 | text += "Answer: "
149 | example['answer'] = ['A', 'B'][example['label']]
150 | example['num_choices'] = 2
151 |
152 | elif data_name == 'commonsense_qa':
153 | """
154 | Question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the
155 | school had made to change?
156 | Choices:
157 | A. ignore
158 | B. enforce
159 | C. authoritarian
160 | D. yell at
161 | E. avoid
162 | Answer: A
163 | """
164 | text = f"Question: {example['question']}\nChoices:\n"
165 | choices = example['choices']
166 | for label, choice in zip(choices['label'], choices['text']):
167 | text += f"{label}. {choice}\n"
168 | text += "Answer: "
169 | example['answer'] = example['answerKey']
170 | example['num_choices'] = 5
171 |
172 | elif data_name == 'cosmos_qa':
173 | """
174 | Context: Good Old War and person L : I saw both of these bands Wednesday night , and they both blew me away .
175 | seriously . Good Old War is acoustic and makes me smile . I really can not help but be happy when I listen to
176 | them ; I think it 's the fact that they seemed so happy themselves when they played .
177 | Question: In the future , will this person go to see other bands play ?
178 | Choices:
179 | A. None of the above choices .
180 | B. This person likes music and likes to see the show , they will see other bands play .
181 | C. This person only likes Good Old War and Person L , no other bands .
182 | D. Other Bands is not on tour and this person can not see them .
183 | Answer: B
184 | """
185 | text = f"Context: {example['context']}\nQuestion: {example['question']}\nChoices:\n"
186 | text += f"A. {example['answer0']}\n"
187 | text += f"B. {example['answer1']}\n"
188 | text += f"C. {example['answer2']}\n"
189 | text += f"D. {example['answer3']}\n"
190 | text += "Answer: "
191 | example['answer'] = chr(ord('A') + example['label'])
192 | example['num_choices'] = 4
193 |
194 | elif data_name == 'social_i_qa':
195 | """
196 | Context: Cameron decided to have a barbecue and gathered her friends together.
197 | Question: How would Others feel as a result?
198 | Choices:
199 | A. like attending
200 | B. like staying home
201 | C. a good friend to have
202 | Answer: A
203 | """
204 | text = f"Context: {example['context']}\nQuestion: {example['question']}\nChoices:\n"
205 | text += f"A. {example['answerA']}\n"
206 | text += f"B. {example['answerB']}\n"
207 | text += f"C. {example['answerC']}\n"
208 | text += "Answer: "
209 | example['answer'] = chr(ord('A') + int(example['label']) - 1)
210 | example['num_choices'] = 3
211 |
212 | elif data_name == 'piqa':
213 | """
214 | Question: When boiling butter, when it's ready, you can
215 | Choices:
216 | A. Pour it onto a plate
217 | B. Pour it into a jar
218 | Answer: B
219 | """
220 | text = f"Question: {example['goal']}\nChoices:\n"
221 | text += f"A. {example['sol1']}\n"
222 | text += f"B. {example['sol2']}\n"
223 | text += "Answer: "
224 | example['answer'] = chr(ord('A') + example['label'])
225 | example['num_choices'] = 2
226 |
227 | elif data_name == 'openbookqa':
228 | """
229 | Fact: the sun is the source of energy for physical cycles on Earth
230 | Question: The sun is responsible for
231 | Choices:
232 | A. puppies learning new tricks
233 | B. children growing up and getting old
234 | C. flowers wilting in a vase
235 | D. plants sprouting, blooming and wilting
236 | Answer: D
237 | """
238 | text = f"Fact: {example['fact1']}\nQuestion: {example['question_stem']}\nChoices:\n"
239 | choices = example['choices']
240 | for label, choice in zip(choices['label'], choices['text']):
241 | text += f"{label}. {choice}\n"
242 | text += "Answer: "
243 | example['answer'] = example['answerKey']
244 | example['num_choices'] = 4
245 |
246 | elif data_name == 'ai2_arc':
247 | """
248 | Question: George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most
249 | heat?
250 | Choices:
251 | A. dry palms
252 | B. wet palms
253 | C. palms covered with oil
254 | D. palms covered with lotion
255 | Answer: A
256 | """
257 | text = f"Question: {example['question']}\nChoices:\n"
258 | choices = example['choices']
259 | for label, choice in zip(choices['label'], choices['text']):
260 | text += f"{label}. {choice}\n"
261 | text += "Answer: "
262 | example['answer'] = example['answerKey']
263 | example['num_choices'] = 4
264 |
265 | elif data_name == 'scienceqa':
266 | """
267 | Question: Which tense does the sentence use?
268 | Mona will print her name with care.
269 | Choices:
270 | A. present tense
271 | B. future tense
272 | C. past tense
273 | Answer: B
274 | """
275 | text = f"Question: {example['question']}\nChoices:\n"
276 | choices = example['choices']
277 | for index, choice in enumerate(choices):
278 | text += f"{chr(ord('A') + index)}. {choice}\n"
279 | text += "Answer: "
280 | example['answer'] = chr(ord('A') + example['answer'])
281 | example['num_choices'] = 5 # undefined
282 |
283 | else:
284 | raise NotImplementedError
285 |
286 | if not prompt_only:
287 | text += f"{example['answer']}"
288 | example['data_name'] = data_name
289 | example['text'] = text
290 | return example
291 |
292 |
293 | def get_formatted_datasets(data_path: str, prompt_only: bool):
294 | """
295 | Get formatted datasets
296 | """
297 | data_name = os.path.basename(data_path).lower()
298 |
299 | # Load and format datasets
300 | if data_name == 'super_glue':
301 | data_names = ['boolq', 'cb', 'copa', 'rte', 'wic']
302 | splits = ['train', 'validation', 'test']
303 | formatted_datasets = {split: [] for split in splits}
304 |
305 | # Load and format datasets
306 | for _data_name in data_names:
307 | _datasets = load_dataset(path='super_glue', name=_data_name)
308 | print(f'Datasets: {_datasets}')
309 | _formatted_datasets = _datasets.map(
310 | lambda example: format_text(example, _data_name, prompt_only=prompt_only),
311 | batched=False, load_from_cache_file=False)
312 | for split in splits:
313 | formatted_datasets[split].append(
314 | _formatted_datasets[split].select_columns(['data_name', 'text', 'num_choices', 'answer']))
315 |
316 | # Concatenate datasets
317 | for split in splits:
318 | formatted_datasets[split] = concatenate_datasets(formatted_datasets[split])
319 | formatted_datasets = DatasetDict(formatted_datasets)
320 | print(f'Formatted datasets: {formatted_datasets}')
321 | print(f"Text example:\n{formatted_datasets['train']['text'][0]}")
322 | else:
323 | # Load datasets
324 | if data_name in [
325 | 'axb', 'axg', 'boolq', 'cb', 'copa', 'multirc',
326 | 'record', 'rte', 'wic', 'wsc', 'wsc.fixed',
327 | ]:
328 | datasets = load_dataset(path='super_glue', name=data_name)
329 | elif data_name == 'openbookqa':
330 | datasets = load_dataset(path=data_path, name='additional')
331 | elif data_name == 'ai2_arc':
332 | datasets = load_dataset(path=data_path, name='ARC-Challenge')
333 | elif data_name == 'scienceqa':
334 | datasets = load_dataset(path=data_path)
335 | datasets = datasets.filter(lambda example: example["image"] is None)
336 | else:
337 | datasets = load_dataset(path=data_path)
338 | print(f'Datasets: {datasets}')
339 | print(f"Example: {datasets['train'][0]}")
340 |
341 | # Format datasets
342 | formatted_datasets = datasets.map(
343 | lambda example: format_text(example, data_name, prompt_only=prompt_only),
344 | batched=False, load_from_cache_file=False)
345 | print(f'Formatted datasets: {formatted_datasets}')
346 | print(f"Formatted example: {formatted_datasets['train'][0]}")
347 | print(f"Text example:\n{formatted_datasets['train']['text'][0]}")
348 |
349 | return formatted_datasets
350 |
351 |
352 | if __name__ == '__main__':
353 | data_path = 'cb'
354 | _ = get_formatted_datasets(data_path=data_path, prompt_only=False)
355 |
--------------------------------------------------------------------------------
/images/adamole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zefang-liu/AdaMoLE/5e8599a7ad6fafde03aaad8466f583f5f28c22ac/images/adamole.png
--------------------------------------------------------------------------------
/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zefang-liu/AdaMoLE/5e8599a7ad6fafde03aaad8466f583f5f28c22ac/poster.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==2.18.0
2 | huggingface_hub==0.21.4
3 | numpy==1.25.2
4 | pandas==1.3.5
5 | peft==0.9.0
6 | torch==2.0.1
7 | tqdm==4.66.1
8 | transformers==4.38.2
9 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Package Initialization
3 | """
4 | from .adamole import AdaMoleConfig, AdaMoleModel
5 | from .config import PeftConfig
6 | from .lora import LoraConfig, LoraModel
7 | from .mole import MoleConfig, MoleModel
8 | from .peft_model import PeftModel, PeftModelForCausalLM
9 | from .trainer import PeftTrainer
10 | from .utils.peft_types import PeftType, TaskType
11 |
--------------------------------------------------------------------------------
/src/adamole/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | AdaMoLE Initialization
3 | """
4 | from .config import AdaMoleConfig
5 | from .layer import AdaMoleLayer, LinearAdaMoleLayer
6 | from .model import AdaMoleModel
7 |
8 | __all__ = ["AdaMoleConfig", "AdaMoleLayer", "LinearAdaMoleLayer", "AdaMoleModel"]
9 |
10 |
11 | def __getattr__(name):
12 | raise AttributeError(f"Module {__name__} has no attribute {name}.")
13 |
--------------------------------------------------------------------------------
/src/adamole/config.py:
--------------------------------------------------------------------------------
1 | """
2 | AdaMoLE Configuration
3 | """
4 | from dataclasses import dataclass, field
5 |
6 | from ..lora import LoraConfig
7 | from ..utils.peft_types import PeftType
8 |
9 |
10 | @dataclass
11 | class AdaMoleConfig(LoraConfig):
12 | """
13 | AdaMoLE Configuration
14 | """
15 | num_experts: int = field(default=4, metadata={"help": "The number of experts in MoE."})
16 | max_threshold: float = field(default=None, metadata={
17 | "help": "The maximum threshold for selecting experts in the threshold function. "
18 | "The default value will be 1 / number of experts"})
19 |
20 | def __post_init__(self):
21 | self.peft_type = PeftType.ADAMOLE
22 | self.target_modules = (
23 | set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
24 | )
25 |
--------------------------------------------------------------------------------
/src/adamole/layer.py:
--------------------------------------------------------------------------------
1 | """
2 | AdaMoLE Layer
3 | """
4 | import math
5 | from abc import ABC
6 | from typing import Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from ..lora import LoraLayer
13 | from ..mole.layer import LoraExpert
14 |
15 |
16 | class AdaMoeLayer(nn.Module):
17 | """
18 | Adaptive Mixture of Experts (MoE) Layer
19 | """
20 |
21 | def __init__(self, experts: nn.ModuleList, gate: nn.Module, threshold_fn: nn.Module, max_threshold: float):
22 | super().__init__()
23 | self.experts = experts
24 | self.gate = gate
25 | self.threshold_fn = threshold_fn
26 | self.max_threshold = max_threshold
27 | self.layer_loss = None
28 |
29 | def get_layer_loss(self, gate_logits: torch.Tensor, selected_experts: torch.Tensor) -> torch.Tensor:
30 | """
31 | Get the load balancing loss by following the Switch Transformer
32 | """
33 | num_inputs = gate_logits.shape[0]
34 | num_experts = len(self.experts)
35 | expert_counts = torch.sum(selected_experts, dim=0)
36 | expert_fractions = expert_counts / num_inputs
37 | expert_probs = torch.sum(gate_logits, dim=0) / num_inputs
38 | layer_loss = num_experts * torch.sum(expert_fractions * expert_probs)
39 | return layer_loss
40 |
41 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
42 | """
43 | Forward propagation
44 | """
45 | flattened_inputs = inputs.view((-1, inputs.shape[-1]))
46 | gate_logits = F.softmax(self.gate(flattened_inputs), dim=-1)
47 | thresholds = F.sigmoid(self.threshold_fn(flattened_inputs)) * self.max_threshold
48 | adapted_gate_logits = gate_logits - thresholds
49 | selected_experts = torch.ge(adapted_gate_logits, 0).to(torch.float)
50 | weights = adapted_gate_logits * selected_experts
51 | weight_sums = torch.sum(weights, dim=-1, keepdim=True, dtype=inputs.dtype)
52 | weight_sums = torch.where(weight_sums == 0, torch.ones_like(weight_sums), weight_sums)
53 | weights = weights / weight_sums
54 | results = torch.zeros_like(self.experts[0](flattened_inputs))
55 |
56 | for i, expert in enumerate(self.experts):
57 | batch_idx = torch.where(selected_experts[:, i])[0]
58 | if len(batch_idx) > 0:
59 | results[batch_idx] += weights[batch_idx, i, None] * expert(flattened_inputs[batch_idx])
60 |
61 | results = results.view((*inputs.shape[:-1], results.shape[-1]))
62 | if inputs.requires_grad:
63 | self.layer_loss = self.get_layer_loss(gate_logits=adapted_gate_logits, selected_experts=selected_experts)
64 | return results
65 |
66 |
67 | class AdaMoleLayer(LoraLayer, ABC):
68 | """
69 | AdaMoLE Layer
70 | """
71 |
72 | def __init__(self, base_layer: nn.Module, **kwargs):
73 | super().__init__(base_layer, **kwargs)
74 | self.lora_gating = nn.ModuleDict({})
75 | self.lora_threshold = nn.ModuleDict({})
76 | self.moe_layer = nn.ModuleDict({})
77 |
78 | def update_layer(
79 | self, adapter_name: str, lora_rank: int, lora_alpha: int, lora_dropout: float, init_lora_weights: bool,
80 | num_experts: int, max_threshold: float,
81 | ) -> None:
82 | """
83 | Update the layer
84 | """
85 | if lora_rank <= 0:
86 | raise ValueError(f"The rank `r` should be a positive integer value but the value passed is {lora_rank}.")
87 |
88 | if max_threshold is None:
89 | max_threshold = 1 / num_experts
90 |
91 | self.lora_rank[adapter_name] = lora_rank
92 | self.lora_alpha[adapter_name] = lora_alpha
93 |
94 | if lora_dropout > 0.0:
95 | lora_dropout_layer = nn.ModuleList([nn.Dropout(p=lora_dropout) for _ in range(num_experts)])
96 | else:
97 | lora_dropout_layer = nn.ModuleList([nn.Identity(p=lora_dropout) for _ in range(num_experts)])
98 |
99 | self.lora_dropout[adapter_name] = lora_dropout_layer
100 | self.lora_A[adapter_name] = nn.ModuleList(
101 | [nn.Linear(self.in_features, lora_rank, bias=False) for _ in range(num_experts)])
102 | self.lora_B[adapter_name] = nn.ModuleList(
103 | [nn.Linear(lora_rank, self.out_features, bias=False) for _ in range(num_experts)])
104 | self.scaling[adapter_name] = lora_alpha / lora_rank
105 | self.lora_gating[adapter_name] = nn.Linear(self.in_features, num_experts, bias=False)
106 | self.lora_threshold[adapter_name] = nn.Linear(self.in_features, 1)
107 |
108 | experts = nn.ModuleList([LoraExpert(
109 | self.lora_A[adapter_name][i],
110 | self.lora_B[adapter_name][i],
111 | self.lora_dropout[adapter_name][i],
112 | self.scaling[adapter_name],
113 | ) for i in range(num_experts)])
114 | self.moe_layer[adapter_name] = AdaMoeLayer(
115 | experts=experts, gate=self.lora_gating[adapter_name],
116 | threshold_fn=self.lora_threshold[adapter_name], max_threshold=max_threshold)
117 |
118 | self.reset_parameters(adapter_name, init_lora_weights)
119 | self.set_adapter(self.active_adapters)
120 |
121 | def reset_parameters(self, adapter_name: str, init_lora_weights: bool) -> None:
122 | """
123 | Reset the parameters
124 | """
125 | if init_lora_weights is False:
126 | return
127 | elif adapter_name in self.lora_A.keys():
128 | for i in range(len(self.lora_A[adapter_name])):
129 | nn.init.kaiming_uniform_(self.lora_A[adapter_name][i].weight, a=math.sqrt(5))
130 | nn.init.zeros_(self.lora_B[adapter_name][i].weight)
131 |
132 |
133 | class LinearAdaMoleLayer(nn.Module, AdaMoleLayer):
134 | """
135 | AdaMoLE Implementation in a Linear Layer
136 | """
137 |
138 | def __init__(
139 | self,
140 | base_layer: nn.Module,
141 | adapter_name: str,
142 | lora_rank: int = 0,
143 | lora_alpha: int = 1,
144 | lora_dropout: float = 0.0,
145 | init_lora_weights: bool = True,
146 | num_experts: int = 4,
147 | max_threshold: float = None,
148 | **kwargs,
149 | ) -> None:
150 | super().__init__()
151 | AdaMoleLayer.__init__(self, base_layer=base_layer, **kwargs)
152 | self._active_adapter = adapter_name
153 | self.update_layer(
154 | adapter_name, lora_rank, lora_alpha, lora_dropout, init_lora_weights, num_experts, max_threshold)
155 |
156 | def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
157 | """
158 | Merge the active adapter weights inside the base weights
159 | """
160 | pass
161 |
162 | def unmerge(self) -> None:
163 | """
164 | Unmerge all merged adapter layers from the base weights
165 | """
166 | pass
167 |
168 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
169 | """
170 | Forward propagation
171 | """
172 | previous_dtype = x.dtype
173 | result = self.base_layer(x, *args, **kwargs)
174 |
175 | for active_adapter in self.active_adapters:
176 | if active_adapter not in self.lora_A.keys():
177 | continue
178 |
179 | moe_layer = self.moe_layer[active_adapter]
180 | x = x.to(moe_layer.experts[0].lora_A.weight.dtype)
181 | result += moe_layer(x)
182 |
183 | result = result.to(previous_dtype)
184 | return result
185 |
--------------------------------------------------------------------------------
/src/adamole/model.py:
--------------------------------------------------------------------------------
1 | """
2 | AdaMoLE Model
3 | """
4 | from typing import Any
5 |
6 | import torch
7 | from peft.tuners.tuners_utils import BaseTunerLayer
8 | from torch import nn
9 |
10 | from .config import AdaMoleConfig
11 | from .layer import AdaMoleLayer, LinearAdaMoleLayer
12 | from ..lora import LoraModel
13 |
14 |
15 | class AdaMoleModel(LoraModel):
16 | """
17 | AdaMoLE (Adaptive Mixture of LoRA Experts) Model
18 | """
19 | prefix: str = "lora_"
20 |
21 | def __init__(self, model: nn.Module, config: AdaMoleConfig, adapter_name: str = "default") -> None:
22 | super().__init__(model, config, adapter_name)
23 |
24 | def _create_and_replace(
25 | self, adamole_config: AdaMoleConfig, adapter_name: str,
26 | target: nn.Module, target_name: str, parent: nn.Module, **kwargs: Any,
27 | ) -> None:
28 | """
29 | Inplace replacement of the target module with the adapter layer
30 | """
31 | kwargs = {
32 | "lora_rank": adamole_config.lora_rank,
33 | "lora_alpha": adamole_config.lora_alpha,
34 | "lora_dropout": adamole_config.lora_dropout,
35 | "init_lora_weights": adamole_config.init_lora_weights,
36 | "num_experts": adamole_config.num_experts,
37 | "max_threshold": adamole_config.max_threshold,
38 | }
39 |
40 | if isinstance(target, AdaMoleLayer):
41 | target.update_layer(adapter_name, **kwargs)
42 | else:
43 | new_module = self._create_new_module(adapter_name, target, **kwargs)
44 | self._replace_module(parent, target_name, new_module, target)
45 |
46 | @staticmethod
47 | def _create_new_module(adapter_name: str, target: nn.Module, **kwargs: Any) -> nn.Module:
48 | """
49 | Create the new LoRA module for the target module
50 | """
51 | if isinstance(target, BaseTunerLayer):
52 | target_base_layer = target.get_base_layer()
53 | else:
54 | target_base_layer = target
55 |
56 | if isinstance(target_base_layer, torch.nn.Linear):
57 | new_module = LinearAdaMoleLayer(base_layer=target, adapter_name=adapter_name, **kwargs)
58 | else:
59 | raise ValueError(
60 | f"The target module `{target}` is not supported. "
61 | f"Currently, only the following modules are supported: `torch.nn.Linear`.")
62 |
63 | return new_module
64 |
65 | def get_aux_loss(self, adapter_name="default") -> torch.Tensor:
66 | """
67 | Get the load balancing loss for the whole model
68 | """
69 | model_loss = torch.tensor(0, dtype=torch.float).to(self.model.device)
70 |
71 | for name, module in self.model.named_modules():
72 | if name.endswith('moe_layer'):
73 | layer_loss = module[adapter_name].layer_loss
74 | model_loss += layer_loss
75 |
76 | return model_loss
77 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | """
2 | PEFT Configuration
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | import inspect
9 | import json
10 | import os
11 | from dataclasses import asdict, dataclass, field
12 | from typing import Dict, Optional, Union
13 |
14 | from huggingface_hub import hf_hub_download
15 | from peft.utils import CONFIG_NAME
16 | from transformers.utils import PushToHubMixin
17 |
18 | from .utils.peft_types import PeftType, TaskType
19 |
20 |
21 | @dataclass
22 | class PeftConfigMixin(PushToHubMixin):
23 | """
24 | Base Configuration Class for PEFT Models
25 | """
26 | peft_type: Optional[PeftType] = field(
27 | default=None, metadata={"help": "The type of Peft method to use."})
28 | auto_mapping: Optional[dict] = field(
29 | default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
30 | )
31 |
32 | def to_dict(self) -> Dict:
33 | """
34 | Return the configuration for the adapter model as a dictionary
35 | """
36 | return asdict(self)
37 |
38 | def save_pretrained(self, save_directory: str, **kwargs) -> None:
39 | """
40 | Save the configuration of the adapter model in a directory
41 |
42 | :param save_directory: the directory where the configuration will be saved
43 | :param kwargs: additional keyword arguments
44 | """
45 | if os.path.isfile(save_directory):
46 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
47 |
48 | os.makedirs(save_directory, exist_ok=True)
49 |
50 | # Converting set type to list
51 | output_dict = asdict(self)
52 | for key, value in output_dict.items():
53 | if isinstance(value, set):
54 | output_dict[key] = list(value)
55 |
56 | output_path = os.path.join(save_directory, CONFIG_NAME)
57 |
58 | # Add auto mapping details for custom models.
59 | auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)
60 | if auto_mapping_dict is not None:
61 | output_dict["auto_mapping"] = auto_mapping_dict
62 |
63 | # Save the configuration
64 | with open(output_path, "w") as writer:
65 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
66 |
67 | @classmethod
68 | def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
69 | """
70 | Loads the configuration of the adapter model from a directory
71 |
72 | :param pretrained_model_name_or_path: the directory or the Hub repository id where the configuration is saved
73 | :param subfolder: subfolder for the directory
74 | :param kwargs: additional keyword arguments passed along to the child class initialization
75 | :return:
76 | """
77 | from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING
78 |
79 | path = (
80 | os.path.join(pretrained_model_name_or_path, subfolder)
81 | if subfolder is not None
82 | else pretrained_model_name_or_path
83 | )
84 |
85 | hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
86 |
87 | if os.path.isfile(os.path.join(path, CONFIG_NAME)):
88 | config_file = os.path.join(path, CONFIG_NAME)
89 | else:
90 | try:
91 | config_file = hf_hub_download(
92 | pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
93 | )
94 | except Exception:
95 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
96 |
97 | loaded_attributes = cls.from_json_file(config_file)
98 | if "peft_type" in loaded_attributes:
99 | peft_type = loaded_attributes["peft_type"]
100 | config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
101 | else:
102 | config_cls = cls
103 |
104 | kwargs = {**class_kwargs, **loaded_attributes}
105 | config = config_cls(**kwargs)
106 | return config
107 |
108 | @classmethod
109 | def from_json_file(cls, path_json_file: str):
110 | """
111 | Load a configuration file from a JSON file
112 |
113 | :param path_json_file: the path to the JSON file
114 | :return: a JSON object
115 | """
116 | with open(path_json_file, "r") as file:
117 | json_object = json.load(file)
118 | return json_object
119 |
120 | @classmethod
121 | def _split_kwargs(cls, kwargs):
122 | hf_hub_download_kwargs = {}
123 | class_kwargs = {}
124 | other_kwargs = {}
125 |
126 | for key, value in kwargs.items():
127 | if key in inspect.signature(hf_hub_download).parameters:
128 | hf_hub_download_kwargs[key] = value
129 | elif key in list(cls.__annotations__):
130 | class_kwargs[key] = value
131 | else:
132 | other_kwargs[key] = value
133 |
134 | return hf_hub_download_kwargs, class_kwargs, other_kwargs
135 |
136 | @classmethod
137 | def _get_peft_type(cls, model_id: str, **hf_hub_download_kwargs):
138 | subfolder = hf_hub_download_kwargs.get("subfolder", None)
139 | path = os.path.join(model_id, subfolder) if subfolder is not None else model_id
140 |
141 | if os.path.isfile(os.path.join(path, CONFIG_NAME)):
142 | config_file = os.path.join(path, CONFIG_NAME)
143 | else:
144 | try:
145 | config_file = hf_hub_download(
146 | model_id,
147 | CONFIG_NAME,
148 | **hf_hub_download_kwargs,
149 | )
150 | except Exception:
151 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'")
152 |
153 | loaded_attributes = cls.from_json_file(config_file)
154 | return loaded_attributes["peft_type"]
155 |
156 |
157 | @dataclass
158 | class PeftConfig(PeftConfigMixin):
159 | """
160 | Base configuration class to store the configuration of a PEFT model
161 | """
162 | base_model_name_or_path: Optional[str] = field(
163 | default=None, metadata={"help": "The name of the base model to use."})
164 | revision: Optional[str] = field(
165 | default=None, metadata={"help": "The specific model version to use."})
166 | peft_type: Optional[Union[str, PeftType]] = field(
167 | default=None, metadata={"help": "The type of PEFT method to use."})
168 | task_type: Optional[Union[str, TaskType]] = field(
169 | default=None, metadata={"help": "The type of task to perform."})
170 | inference_mode: bool = field(
171 | default=False, metadata={"help": "Whether to use the PEFT model in inference mode."})
172 |
--------------------------------------------------------------------------------
/src/lora/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA Initialization
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | from .config import LoraConfig
9 | from .layer import LoraLayer, LinearLoraLayer
10 | from .model import LoraModel
11 |
12 | __all__ = ["LoraConfig", "LoraLayer", "LinearLoraLayer", "LoraModel"]
13 |
14 |
15 | def __getattr__(name):
16 | raise AttributeError(f"Module {__name__} has no attribute {name}.")
17 |
--------------------------------------------------------------------------------
/src/lora/config.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA Configuration
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | from dataclasses import dataclass, field
9 | from typing import List, Literal, Optional, Union
10 |
11 | from ..config import PeftConfig
12 | from ..utils.peft_types import PeftType
13 |
14 |
15 | @dataclass
16 | class LoraConfig(PeftConfig):
17 | """
18 | LoRA Configuration
19 | """
20 | lora_rank: int = field(default=8, metadata={"help": "The Lora rank for the attention dimension."})
21 | lora_alpha: int = field(default=8, metadata={"help": "The alpha parameter for Lora scaling."})
22 | lora_dropout: float = field(default=0.0, metadata={"help": "The dropout probability for Lora layers."})
23 | bias: Literal["none", "all", "lora_only"] = field(
24 | default="none", metadata={"help": "The bias type for Lora layers and can be 'none', 'all' or 'lora_only'."})
25 | target_modules: Optional[Union[List[str], str]] = field(
26 | default=None, metadata={"help": "The names of the modules to apply the adapter to."})
27 | init_lora_weights: bool = field(
28 | default=True, metadata={"help": "Whether to initialize the weights of the adapter layers."})
29 |
30 | def __post_init__(self):
31 | self.peft_type = PeftType.LORA
32 | self.target_modules = (
33 | set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
34 | )
35 |
--------------------------------------------------------------------------------
/src/lora/layer.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA Layer
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | import math
9 | from abc import ABC
10 | from typing import Optional
11 |
12 | import torch
13 | import torch.nn as nn
14 | from peft.tuners.tuners_utils import BaseTunerLayer
15 |
16 |
17 | class LoraLayer(BaseTunerLayer, ABC):
18 | """
19 | LoRA Layer
20 | """
21 |
22 | def __init__(self, base_layer: nn.Module, **kwargs) -> None:
23 | self.base_layer = base_layer
24 | self.lora_rank = {}
25 | self.lora_alpha = {}
26 | self.scaling = {}
27 |
28 | self.lora_dropout = nn.ModuleDict({})
29 | self.lora_A = nn.ModuleDict({})
30 | self.lora_B = nn.ModuleDict({})
31 | self.kwargs = kwargs
32 |
33 | if isinstance(base_layer, nn.Linear):
34 | in_features, out_features = base_layer.in_features, base_layer.out_features
35 | else:
36 | raise ValueError(f"Unsupported layer type {type(base_layer)}")
37 |
38 | self.in_features = in_features
39 | self.out_features = out_features
40 |
41 | def update_layer(
42 | self, adapter_name: str, lora_rank: int, lora_alpha: int, lora_dropout: float, init_lora_weights: bool,
43 | ) -> None:
44 | """
45 | Update the layer
46 | """
47 | if lora_rank <= 0:
48 | raise ValueError(f"The rank `r` should be a positive integer value but the value passed is {lora_rank}.")
49 |
50 | self.lora_rank[adapter_name] = lora_rank
51 | self.lora_alpha[adapter_name] = lora_alpha
52 |
53 | if lora_dropout > 0.0:
54 | lora_dropout_layer = nn.Dropout(p=lora_dropout)
55 | else:
56 | lora_dropout_layer = nn.Identity()
57 |
58 | self.lora_dropout[adapter_name] = lora_dropout_layer
59 | self.lora_A[adapter_name] = nn.Linear(self.in_features, lora_rank, bias=False)
60 | self.lora_B[adapter_name] = nn.Linear(lora_rank, self.out_features, bias=False)
61 | self.scaling[adapter_name] = lora_alpha / lora_rank
62 |
63 | self.reset_parameters(adapter_name, init_lora_weights)
64 | self.set_adapter(self.active_adapters)
65 |
66 | def reset_parameters(self, adapter_name: str, init_lora_weights: bool) -> None:
67 | """
68 | Reset the parameters
69 | """
70 | if init_lora_weights is False:
71 | return
72 | elif adapter_name in self.lora_A.keys():
73 | nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
74 | nn.init.zeros_(self.lora_B[adapter_name].weight)
75 |
76 |
77 | class LinearLoraLayer(nn.Module, LoraLayer):
78 | """
79 | LoRA Implementation in a Linear Layer
80 | """
81 |
82 | def __init__(
83 | self,
84 | base_layer: nn.Module,
85 | adapter_name: str,
86 | lora_rank: int = 0,
87 | lora_alpha: int = 1,
88 | lora_dropout: float = 0.0,
89 | init_lora_weights: bool = True,
90 | **kwargs,
91 | ) -> None:
92 | super().__init__()
93 | LoraLayer.__init__(self, base_layer, **kwargs)
94 | self._active_adapter = adapter_name
95 | self.update_layer(adapter_name, lora_rank, lora_alpha, lora_dropout, init_lora_weights)
96 |
97 | def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
98 | """
99 | Merge the active adapter weights inside the base weights
100 | """
101 | raise NotImplementedError
102 |
103 | def unmerge(self) -> None:
104 | """
105 | Unmerge all merged adapter layers from the base weights
106 | """
107 | raise NotImplementedError
108 |
109 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
110 | """
111 | Forward propagation
112 | """
113 | previous_dtype = x.dtype
114 | result = self.base_layer(x, *args, **kwargs)
115 |
116 | for active_adapter in self.active_adapters:
117 | if active_adapter not in self.lora_A.keys():
118 | continue
119 |
120 | lora_A = self.lora_A[active_adapter]
121 | lora_B = self.lora_B[active_adapter]
122 | dropout = self.lora_dropout[active_adapter]
123 | scaling = self.scaling[active_adapter]
124 |
125 | x = x.to(lora_A.weight.dtype)
126 | result += lora_B(lora_A(dropout(x))) * scaling
127 |
128 | result = result.to(previous_dtype)
129 | return result
130 |
--------------------------------------------------------------------------------
/src/lora/model.py:
--------------------------------------------------------------------------------
1 | """
2 | LoRA Model
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | from typing import Any
9 |
10 | import torch
11 | from peft import PeftConfig
12 | from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
13 | from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
14 | from torch import nn
15 | from transformers import PretrainedConfig
16 |
17 | from .config import LoraConfig
18 | from .layer import LoraLayer, LinearLoraLayer
19 |
20 |
21 | class LoraModel(BaseTuner):
22 | """
23 | Low Rank Adapter (LoRA) Model
24 | """
25 | prefix: str = "lora_"
26 |
27 | def __init__(self, model: nn.Module, config: LoraConfig, adapter_name: str = "default") -> None:
28 | """
29 | Initialize LoraModel
30 |
31 | :param model: model to be adapted
32 | :param config: configuration of the LoRA model
33 | :param adapter_name: name of the adapter
34 | """
35 | super().__init__(model, config, adapter_name)
36 |
37 | def __getattr__(self, name: str) -> Any:
38 | """
39 | Forward missing attributes to the wrapped module
40 | """
41 | try:
42 | return super().__getattr__(name)
43 | except AttributeError:
44 | return getattr(self.model, name)
45 |
46 | def _check_new_adapter_config(self, config: LoraConfig) -> None:
47 | """
48 | Check the config when a new adapter is being added
49 | """
50 | if (len(self.peft_config) > 1) and (config.bias != "none"):
51 | raise ValueError(
52 | f"{self.__class__.__name__} supports only 1 adapter with bias. "
53 | f"When using multiple adapters, set bias to 'none' for all adapters.")
54 |
55 | @staticmethod
56 | def _check_target_module_exists(lora_config: LoraConfig, key: str) -> bool:
57 | """
58 | Check if the passed module's key name matches any of the target modules in the config target module list
59 | """
60 | return check_target_module_exists(lora_config, key)
61 |
62 | @staticmethod
63 | def _prepare_adapter_config(peft_config: LoraConfig, model_config: PretrainedConfig) -> PeftConfig:
64 | """
65 | Prepare the adapter config, such as automatically inferring target modules if it is none
66 | """
67 | if peft_config.target_modules is None:
68 | if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
69 | raise ValueError("Please specify `target_modules` in `peft_config`.")
70 | peft_config.target_modules = set(
71 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]],
72 | )
73 | return peft_config
74 |
75 | def _create_and_replace(
76 | self, lora_config: LoraConfig, adapter_name: str,
77 | target: nn.Module, target_name: str, parent: nn.Module, **kwargs: Any,
78 | ) -> None:
79 | """
80 | Inplace replacement of the target module with the adapter layer
81 | """
82 | kwargs = {
83 | "lora_rank": lora_config.lora_rank,
84 | "lora_alpha": lora_config.lora_alpha,
85 | "lora_dropout": lora_config.lora_dropout,
86 | "init_lora_weights": lora_config.init_lora_weights,
87 | }
88 |
89 | if isinstance(target, LoraLayer):
90 | target.update_layer(adapter_name, **kwargs)
91 | else:
92 | new_module = self._create_new_module(adapter_name, target, **kwargs)
93 | self._replace_module(parent, target_name, new_module, target)
94 |
95 | @staticmethod
96 | def _create_new_module(adapter_name: str, target: nn.Module, **kwargs: Any) -> nn.Module:
97 | """
98 | Create the new LoRA module for the target module
99 | """
100 | if isinstance(target, BaseTunerLayer):
101 | target_base_layer = target.get_base_layer()
102 | else:
103 | target_base_layer = target
104 |
105 | if isinstance(target_base_layer, torch.nn.Linear):
106 | new_module = LinearLoraLayer(base_layer=target, adapter_name=adapter_name, **kwargs)
107 | else:
108 | raise ValueError(
109 | f"The target module `{target}` is not supported. "
110 | f"Currently, only the following modules are supported: `torch.nn.Linear`.")
111 |
112 | return new_module
113 |
114 | def _replace_module(self, parent: nn.Module, child_name: str, new_module: nn.Module, child: nn.Module) -> None:
115 | """
116 | Replace the module
117 | """
118 | setattr(parent, child_name, new_module)
119 |
120 | if hasattr(child, "base_layer"):
121 | child = child.base_layer
122 |
123 | if not hasattr(new_module, "base_layer"):
124 | new_module.weight = child.weight
125 | if hasattr(child, "bias"):
126 | new_module.bias = child.bias
127 |
128 | if getattr(child, "state", None) is not None:
129 | if hasattr(new_module, "base_layer"):
130 | new_module.base_layer.state = child.state
131 | else:
132 | new_module.state = child.state
133 | new_module.to(child.weight.device)
134 |
135 | for name, module in new_module.named_modules():
136 | if self.prefix in name:
137 | module.to(child.weight.device)
138 |
139 | def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
140 | """
141 | Make only adapters as trainable
142 | """
143 | for name, param in model.named_parameters():
144 | if self.prefix not in name:
145 | param.requires_grad = False
146 |
147 | for active_adapter in self.active_adapters:
148 | bias = self.peft_config[active_adapter].bias
149 | if bias == "none":
150 | continue
151 | elif bias == "all":
152 | for name, param in model.named_parameters():
153 | if "bias" in name:
154 | param.requires_grad = True
155 | elif bias == "lora_only":
156 | for module in model.modules():
157 | if isinstance(module, LoraLayer) and hasattr(module, "bias") and module.bias is not None:
158 | module.bias.requires_grad = True
159 | else:
160 | raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
161 |
--------------------------------------------------------------------------------
/src/mapping.py:
--------------------------------------------------------------------------------
1 | """
2 | Configure and Model Mappings
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 |
9 | from .adamole import AdaMoleConfig, AdaMoleModel
10 | from .lora import LoraConfig, LoraModel
11 | from .mole import MoleConfig, MoleModel
12 | from .utils.peft_types import PeftType
13 |
14 | PEFT_TYPE_TO_CONFIG_MAPPING = {
15 | PeftType.LORA: LoraConfig,
16 | PeftType.MOLE: MoleConfig,
17 | PeftType.ADAMOLE: AdaMoleConfig,
18 | }
19 | PEFT_TYPE_TO_MODEL_MAPPING = {
20 | PeftType.LORA: LoraModel,
21 | PeftType.MOLE: MoleModel,
22 | PeftType.ADAMOLE: AdaMoleModel,
23 | }
24 |
--------------------------------------------------------------------------------
/src/mole/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MoLE Initialization
3 | """
4 | from .config import MoleConfig
5 | from .layer import MoleLayer, LinearMoleLayer
6 | from .model import MoleModel
7 |
8 | __all__ = ["MoleConfig", "MoleLayer", "LinearMoleLayer", "MoleModel"]
9 |
10 |
11 | def __getattr__(name):
12 | raise AttributeError(f"Module {__name__} has no attribute {name}.")
13 |
--------------------------------------------------------------------------------
/src/mole/config.py:
--------------------------------------------------------------------------------
1 | """
2 | MoLE Configuration
3 | """
4 | from dataclasses import dataclass, field
5 |
6 | from ..lora import LoraConfig
7 | from ..utils.peft_types import PeftType
8 |
9 |
10 | @dataclass
11 | class MoleConfig(LoraConfig):
12 | """
13 | MoLE Configuration
14 | """
15 | num_experts: int = field(default=4, metadata={"help": "The number of experts in MoE."})
16 | top_k: int = field(default=None, metadata={
17 | "help": "The k in top-k gating if the expert threshold is None."})
18 | threshold: float = field(default=None, metadata={
19 | "help": "The threshold for selecting experts if the top-k is None. "
20 | "The maximum threshold should be 1 / number of experts"})
21 |
22 | def __post_init__(self):
23 | self.peft_type = PeftType.MOLE
24 | self.target_modules = (
25 | set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
26 | )
27 |
--------------------------------------------------------------------------------
/src/mole/layer.py:
--------------------------------------------------------------------------------
1 | """
2 | MoLE Layer
3 | """
4 | import math
5 | from abc import ABC
6 | from typing import Optional
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from ..lora import LoraLayer
13 |
14 |
15 | class TopKMoeLayer(nn.Module):
16 | """
17 | Mixture of Experts (MoE) Layer with the Top-k
18 |
19 | Adapted from https://github.com/mistralai/mistral-src
20 | """
21 |
22 | def __init__(self, experts: nn.ModuleList, gate: nn.Module, top_k: int):
23 | super().__init__()
24 | self.experts = experts
25 | self.gate = gate
26 | self.top_k = top_k
27 | self.layer_loss = None
28 |
29 | def get_layer_loss(self, gate_logits: torch.Tensor, selected_experts: torch.Tensor) -> torch.Tensor:
30 | """
31 | Get the load balancing loss by following the Switch Transformer
32 | """
33 | num_inputs = gate_logits.shape[0]
34 | num_experts = len(self.experts)
35 | expert_counts = torch.bincount(selected_experts.reshape(-1), minlength=num_experts)
36 | expert_fractions = expert_counts / num_inputs
37 | expert_probs = torch.sum(gate_logits, dim=0) / num_inputs
38 | layer_loss = num_experts * torch.sum(expert_fractions * expert_probs)
39 | return layer_loss
40 |
41 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
42 | """
43 | Forward propagation
44 | """
45 | flattened_inputs = inputs.view((-1, inputs.shape[-1]))
46 | gate_logits = F.softmax(self.gate(flattened_inputs), dim=-1)
47 | weights, selected_experts = torch.topk(input=gate_logits, k=self.top_k, dim=-1)
48 | weights = weights / torch.sum(weights, dim=-1, keepdim=True, dtype=inputs.dtype)
49 | results = torch.zeros_like(self.experts[0](flattened_inputs))
50 |
51 | for i, expert in enumerate(self.experts):
52 | batch_idx, nth_expert = torch.where(selected_experts == i)
53 | results[batch_idx] += \
54 | weights[batch_idx, nth_expert, None] * expert(flattened_inputs[batch_idx])
55 |
56 | results = results.view((*inputs.shape[:-1], results.shape[-1]))
57 | if inputs.requires_grad:
58 | self.layer_loss = self.get_layer_loss(gate_logits=gate_logits, selected_experts=selected_experts)
59 | return results
60 |
61 |
62 | class ThresholdMoeLayer(nn.Module):
63 | """
64 | Mixture of Experts (MoE) Layer with the Threshold
65 | """
66 |
67 | def __init__(self, experts: nn.ModuleList, gate: nn.Module, threshold: float):
68 | super().__init__()
69 | self.experts = experts
70 | self.gate = gate
71 | self.threshold = threshold
72 | self.layer_loss = None
73 |
74 | def get_layer_loss(self, gate_logits: torch.Tensor, selected_experts: torch.Tensor) -> torch.Tensor:
75 | """
76 | Get the load balancing loss by following the Switch Transformer
77 | """
78 | num_inputs = gate_logits.shape[0]
79 | num_experts = len(self.experts)
80 | expert_counts = torch.sum(selected_experts, dim=0)
81 | expert_fractions = expert_counts / num_inputs
82 | expert_probs = torch.sum(gate_logits, dim=0) / num_inputs
83 | layer_loss = num_experts * torch.sum(expert_fractions * expert_probs)
84 | return layer_loss
85 |
86 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
87 | """
88 | Forward propagation
89 | """
90 | flattened_inputs = inputs.view((-1, inputs.shape[-1]))
91 | gate_logits = F.softmax(self.gate(flattened_inputs), dim=-1)
92 | selected_experts = torch.ge(gate_logits, self.threshold).to(torch.float)
93 | weights = gate_logits * selected_experts
94 | weight_sums = torch.sum(weights, dim=-1, keepdim=True, dtype=inputs.dtype)
95 | weight_sums = torch.where(weight_sums == 0, torch.ones_like(weight_sums), weight_sums)
96 | weights = weights / weight_sums
97 | results = torch.zeros_like(self.experts[0](flattened_inputs))
98 |
99 | for i, expert in enumerate(self.experts):
100 | batch_idx = torch.where(selected_experts[:, i])[0]
101 | if len(batch_idx) > 0:
102 | results[batch_idx] += weights[batch_idx, i, None] * expert(flattened_inputs[batch_idx])
103 |
104 | results = results.view((*inputs.shape[:-1], results.shape[-1]))
105 | if inputs.requires_grad:
106 | self.layer_loss = self.get_layer_loss(gate_logits=gate_logits, selected_experts=selected_experts)
107 | return results
108 |
109 |
110 | class LoraExpert(nn.Module):
111 | """
112 | LoRA Expert
113 | """
114 |
115 | def __init__(self, lora_A: nn.Module, lora_B: nn.Module, lora_dropout: nn.Module, scaling: float):
116 | super().__init__()
117 | self.lora_A = lora_A
118 | self.lora_B = lora_B
119 | self.lora_dropout = lora_dropout
120 | self.scaling = scaling
121 |
122 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
123 | """
124 | Forward propagation
125 | """
126 | outputs = self.lora_B(self.lora_A(self.lora_dropout(inputs))) * self.scaling
127 | return outputs
128 |
129 |
130 | class MoleLayer(LoraLayer, ABC):
131 | """
132 | MoLE Layer
133 | """
134 |
135 | def __init__(self, base_layer: nn.Module, **kwargs):
136 | super().__init__(base_layer, **kwargs)
137 | self.lora_gating = nn.ModuleDict({})
138 | self.moe_layer = nn.ModuleDict({})
139 |
140 | def update_layer(
141 | self, adapter_name: str, lora_rank: int, lora_alpha: int, lora_dropout: float, init_lora_weights: bool,
142 | num_experts: int, top_k: int, threshold: float,
143 | ) -> None:
144 | """
145 | Update the layer
146 | """
147 | if lora_rank <= 0:
148 | raise ValueError(f"The rank `r` should be a positive integer value but the value passed is {lora_rank}.")
149 |
150 | if (top_k is not None) and (threshold is not None):
151 | raise ValueError(f"Only one of the top-k {top_k} and the threshold {threshold} can be used.")
152 | elif (top_k is None) and (threshold is None):
153 | raise ValueError(f"At least one of the top-k {top_k} and the threshold {threshold} should be used.")
154 |
155 | self.lora_rank[adapter_name] = lora_rank
156 | self.lora_alpha[adapter_name] = lora_alpha
157 |
158 | if lora_dropout > 0.0:
159 | lora_dropout_layer = nn.ModuleList(nn.Dropout(p=lora_dropout) for _ in range(num_experts))
160 | else:
161 | lora_dropout_layer = nn.ModuleList(nn.Identity(p=lora_dropout) for _ in range(num_experts))
162 |
163 | self.lora_dropout[adapter_name] = lora_dropout_layer
164 | self.lora_A[adapter_name] = nn.ModuleList(
165 | nn.Linear(self.in_features, lora_rank, bias=False) for _ in range(num_experts))
166 | self.lora_B[adapter_name] = nn.ModuleList(
167 | nn.Linear(lora_rank, self.out_features, bias=False) for _ in range(num_experts))
168 | self.scaling[adapter_name] = lora_alpha / lora_rank
169 | self.lora_gating[adapter_name] = nn.Linear(self.in_features, num_experts, bias=False)
170 |
171 | experts = nn.ModuleList(LoraExpert(
172 | self.lora_A[adapter_name][i],
173 | self.lora_B[adapter_name][i],
174 | self.lora_dropout[adapter_name][i],
175 | self.scaling[adapter_name],
176 | ) for i in range(num_experts))
177 |
178 | if top_k is not None:
179 | self.moe_layer[adapter_name] = TopKMoeLayer(
180 | experts=experts, gate=self.lora_gating[adapter_name], top_k=top_k)
181 | elif threshold is not None:
182 | self.moe_layer[adapter_name] = ThresholdMoeLayer(
183 | experts=experts, gate=self.lora_gating[adapter_name], threshold=threshold)
184 |
185 | self.reset_parameters(adapter_name, init_lora_weights)
186 | self.set_adapter(self.active_adapters)
187 |
188 | def reset_parameters(self, adapter_name: str, init_lora_weights: bool) -> None:
189 | """
190 | Reset the parameters
191 | """
192 | if init_lora_weights is False:
193 | return
194 | elif adapter_name in self.lora_A.keys():
195 | for i in range(len(self.lora_A[adapter_name])):
196 | nn.init.kaiming_uniform_(self.lora_A[adapter_name][i].weight, a=math.sqrt(5))
197 | nn.init.zeros_(self.lora_B[adapter_name][i].weight)
198 |
199 |
200 | class LinearMoleLayer(nn.Module, MoleLayer):
201 | """
202 | MoLE Implementation in a Linear Layer
203 | """
204 |
205 | def __init__(
206 | self,
207 | base_layer: nn.Module,
208 | adapter_name: str,
209 | lora_rank: int = 0,
210 | lora_alpha: int = 1,
211 | lora_dropout: float = 0.0,
212 | init_lora_weights: bool = True,
213 | num_experts: int = 4,
214 | top_k: int = None,
215 | threshold: float = None,
216 | **kwargs,
217 | ) -> None:
218 | super().__init__()
219 | MoleLayer.__init__(self, base_layer=base_layer, **kwargs)
220 | self._active_adapter = adapter_name
221 | self.update_layer(
222 | adapter_name, lora_rank, lora_alpha, lora_dropout, init_lora_weights, num_experts, top_k, threshold)
223 |
224 | def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
225 | """
226 | Merge the active adapter weights inside the base weights
227 | """
228 | pass
229 |
230 | def unmerge(self) -> None:
231 | """
232 | Unmerge all merged adapter layers from the base weights
233 | """
234 | pass
235 |
236 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
237 | """
238 | Forward propagation
239 | """
240 | previous_dtype = x.dtype
241 | result = self.base_layer(x, *args, **kwargs)
242 |
243 | for active_adapter in self.active_adapters:
244 | if active_adapter not in self.lora_A.keys():
245 | continue
246 |
247 | moe_layer = self.moe_layer[active_adapter]
248 | x = x.to(moe_layer.experts[0].lora_A.weight.dtype)
249 | result += moe_layer(x)
250 |
251 | result = result.to(previous_dtype)
252 | return result
253 |
--------------------------------------------------------------------------------
/src/mole/model.py:
--------------------------------------------------------------------------------
1 | """
2 | MoLE Model
3 | """
4 | from typing import Any
5 |
6 | import torch
7 | from peft.tuners.tuners_utils import BaseTunerLayer
8 | from torch import nn
9 |
10 | from .config import MoleConfig
11 | from .layer import MoleLayer, LinearMoleLayer
12 | from ..lora import LoraModel
13 |
14 |
15 | class MoleModel(LoraModel):
16 | """
17 | MoLE (Mixture of LoRA Experts) Model
18 | """
19 | prefix: str = "lora_"
20 |
21 | def __init__(self, model, config, adapter_name="default") -> None:
22 | super().__init__(model, config, adapter_name)
23 |
24 | def _create_and_replace(
25 | self, mole_config: MoleConfig, adapter_name: str,
26 | target: nn.Module, target_name: str, parent: nn.Module, **kwargs: Any,
27 | ) -> None:
28 | """
29 | Inplace replacement of the target module with the adapter layer
30 | """
31 | kwargs = {
32 | "lora_rank": mole_config.lora_rank,
33 | "lora_alpha": mole_config.lora_alpha,
34 | "lora_dropout": mole_config.lora_dropout,
35 | "init_lora_weights": mole_config.init_lora_weights,
36 | "num_experts": mole_config.num_experts,
37 | "top_k": mole_config.top_k,
38 | "threshold": mole_config.threshold,
39 | }
40 |
41 | if isinstance(target, MoleLayer):
42 | target.update_layer(adapter_name, **kwargs)
43 | else:
44 | new_module = self._create_new_module(adapter_name, target, **kwargs)
45 | self._replace_module(parent, target_name, new_module, target)
46 |
47 | @staticmethod
48 | def _create_new_module(adapter_name: str, target: nn.Module, **kwargs: Any) -> nn.Module:
49 | """
50 | Create the new LoRA module for the target module
51 | """
52 | if isinstance(target, BaseTunerLayer):
53 | target_base_layer = target.get_base_layer()
54 | else:
55 | target_base_layer = target
56 |
57 | if isinstance(target_base_layer, torch.nn.Linear):
58 | new_module = LinearMoleLayer(base_layer=target, adapter_name=adapter_name, **kwargs)
59 | else:
60 | raise ValueError(
61 | f"The target module `{target}` is not supported. "
62 | f"Currently, only the following modules are supported: `torch.nn.Linear`.")
63 |
64 | return new_module
65 |
66 | def get_aux_loss(self, adapter_name="default") -> torch.Tensor:
67 | """
68 | Get the load balancing loss for the whole model
69 | """
70 | model_loss = torch.tensor(0, dtype=torch.float).to(self.model.device)
71 |
72 | for name, module in self.model.named_modules():
73 | if name.endswith('moe_layer'):
74 | layer_loss = module[adapter_name].layer_loss
75 | model_loss += layer_loss
76 |
77 | return model_loss
78 |
--------------------------------------------------------------------------------
/src/peft_model.py:
--------------------------------------------------------------------------------
1 | """
2 | PEFT Model
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | from __future__ import annotations
9 |
10 | import inspect
11 | import os
12 | from typing import Any, Dict, List, Optional, Union
13 |
14 | import torch
15 | from huggingface_hub import hf_hub_download
16 | from peft.utils import (
17 | WEIGHTS_NAME,
18 | _set_adapter,
19 | _set_trainable,
20 | infer_device,
21 | load_peft_weights,
22 | )
23 | from transformers import PreTrainedModel
24 | from transformers.utils import PushToHubMixin
25 |
26 | from .config import PeftConfig
27 | from .mapping import (
28 | PEFT_TYPE_TO_CONFIG_MAPPING,
29 | PEFT_TYPE_TO_MODEL_MAPPING,
30 | )
31 | from .utils.peft_types import TaskType
32 | from .utils.save_and_load import (
33 | get_peft_model_state_dict,
34 | set_peft_model_state_dict
35 | )
36 |
37 |
38 | class PeftModel(PushToHubMixin, torch.nn.Module):
39 | """
40 | Parameter-Efficient Fine-Tuning (PEFT) Model
41 |
42 | :ivar base_model: base transformer model used for PEFT
43 | :ivar peft_config: configuration of the PEFT model
44 | :ivar modules_to_save: list of submodule names to save when saving the model
45 | """
46 | base_model: [torch.nn.Module]
47 | peft_config: [PeftConfig]
48 | modules_to_save: [str]
49 |
50 | def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None:
51 | """
52 | Initialize PeftModel
53 |
54 | :param model: base transformer model used for PEFT
55 | :param peft_config: configuration of the PEFT model
56 | :param adapter_name: name of the adapter
57 | """
58 | super().__init__()
59 | self.modules_to_save = None
60 | self.active_adapter = adapter_name
61 | self.peft_type = peft_config.peft_type
62 | peft_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]
63 | self.base_model = peft_model(model, {adapter_name: peft_config}, adapter_name)
64 | self.set_additional_trainable_modules(peft_config, adapter_name)
65 |
66 | if getattr(model, "is_gradient_checkpointing", True):
67 | _ = self._prepare_model_for_gradient_checkpointing(model)
68 |
69 | def __getattr__(self, name: str):
70 | """
71 | Forward missing attributes to the wrapped module
72 | """
73 | try:
74 | return super().__getattr__(name)
75 | except AttributeError:
76 | return getattr(self.base_model, name)
77 |
78 | @property
79 | def peft_config(self) -> Dict[str, PeftConfig]:
80 | """
81 | Get the PEFT configuration
82 | """
83 | return self.base_model.peft_config
84 |
85 | @peft_config.setter
86 | def peft_config(self, value: Dict[str, PeftConfig]):
87 | """
88 | Set the PEFT configuration
89 | """
90 | self.base_model.peft_config = value
91 |
92 | @property
93 | def active_adapters(self) -> list[str]:
94 | """
95 | Active adapters
96 | """
97 | try:
98 | adapters = self.base_model.active_adapters
99 | except AttributeError:
100 | adapters = self.active_adapter
101 | if isinstance(adapters, str):
102 | adapters = [adapters]
103 | return adapters
104 |
105 | def _get_base_model_class(self):
106 | """
107 | Return the base model class
108 | """
109 | return self.base_model.model.__class__
110 |
111 | def get_base_model(self) -> torch.nn.Module:
112 | """
113 | Return the base model
114 | """
115 | return self.base_model.model
116 |
117 | def set_additional_trainable_modules(self, peft_config: PeftConfig, adapter_name: str) -> None:
118 | """
119 | Set additional trainable modules
120 | """
121 | if getattr(peft_config, "modules_to_save", None) is not None:
122 | if self.modules_to_save is None:
123 | self.modules_to_save = set(peft_config.modules_to_save)
124 | else:
125 | self.modules_to_save.update(peft_config.modules_to_save)
126 | _set_trainable(self, adapter_name)
127 |
128 | def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel):
129 | """
130 | Prepares the model for gradient checkpointing if necessary
131 | """
132 | if hasattr(model, "enable_input_require_grads"):
133 | model.enable_input_require_grads()
134 | elif hasattr(model, "get_input_embeddings"):
135 | def make_inputs_require_grad(module, input, output):
136 | output.requires_grad_(True)
137 |
138 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
139 | return model
140 |
141 | def get_nb_trainable_parameters(self) -> tuple[int, int]:
142 | """
143 | Return the number of trainable parameters and the number of all parameters in the model
144 | """
145 | trainable_params = 0
146 | all_param = 0
147 | for _, param in self.named_parameters():
148 | num_params = param.numel()
149 | all_param += num_params
150 | if param.requires_grad:
151 | trainable_params += num_params
152 | return trainable_params, all_param
153 |
154 | def print_trainable_parameters(self) -> None:
155 | """
156 | Prints the number of trainable parameters in the model
157 | """
158 | trainable_params, all_param = self.get_nb_trainable_parameters()
159 | print(
160 | f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || "
161 | f"trainable: {trainable_params / all_param:.2%}")
162 |
163 | def save_pretrained(
164 | self,
165 | save_directory: str,
166 | selected_adapters: Optional[List[str]] = None,
167 | save_embedding_layers: Union[str, bool] = "auto",
168 | is_main_process: bool = True,
169 | **kwargs: Any,
170 | ) -> None:
171 | """
172 | Save the adapter model and the adapter configuration files to a directory, so that it can be reloaded
173 |
174 | :param save_directory: a directory where the adapter model and configuration files will be saved
175 | :param selected_adapters: a list of adapters to be saved (default to all adapters)
176 | :param save_embedding_layers: if `True`, save the embedding layers in addition to adapter weights;
177 | if `auto`, checks the common embedding layers in config's `target_modules` when available
178 | :param is_main_process: whether the process calling this is the main process or not
179 | :param kwargs: additional keyword arguments
180 | """
181 | if os.path.isfile(save_directory):
182 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
183 |
184 | if selected_adapters is None:
185 | selected_adapters = list(self.peft_config.keys())
186 | else:
187 | if any(
188 | selected_adapter_name not in list(self.peft_config.keys())
189 | for selected_adapter_name in selected_adapters
190 | ):
191 | raise ValueError(
192 | f"You passed an invalid `selected_adapters` arguments, current supported adapter names are"
193 | f" {list(self.peft_config.keys())} - got {selected_adapters}."
194 | )
195 |
196 | if is_main_process:
197 | os.makedirs(save_directory, exist_ok=True)
198 |
199 | for adapter_name in selected_adapters:
200 | peft_config = self.peft_config[adapter_name]
201 |
202 | # Save only the trainable weights
203 | output_state_dict = get_peft_model_state_dict(
204 | self,
205 | state_dict=kwargs.get("state_dict", None),
206 | adapter_name=adapter_name,
207 | save_embedding_layers=save_embedding_layers,
208 | )
209 | output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
210 | os.makedirs(output_dir, exist_ok=True)
211 |
212 | if is_main_process:
213 | torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
214 |
215 | # Save the config and change the inference mode to `True`
216 | if peft_config.base_model_name_or_path is None:
217 | peft_config.base_model_name_or_path = (
218 | self.base_model.model.__dict__.get("name_or_path", None)
219 | )
220 |
221 | inference_mode = peft_config.inference_mode
222 | peft_config.inference_mode = True
223 | if is_main_process:
224 | peft_config.save_pretrained(output_dir)
225 | peft_config.inference_mode = inference_mode
226 |
227 | @classmethod
228 | def from_pretrained(
229 | cls,
230 | model: torch.nn.Module,
231 | model_id: Union[str, os.PathLike],
232 | adapter_name: str = "default",
233 | is_trainable: bool = False,
234 | config: Optional[PeftConfig] = None,
235 | **kwargs: Any,
236 | ) -> PeftModel:
237 | """
238 | Instantiate a PEFT model from a pretrained model and loaded PEFT weights (Note that the passed `model`
239 | may be modified inplace.)
240 |
241 | :param model: the transformer model to be adapted
242 | :param model_id: the name of the PEFT configuration to use
243 | :param adapter_name: the name of the adapter to be loaded
244 | :param is_trainable: whether the adapter should be trainable or not
245 | :param config: the configuration object to use instead of an automatically loaded configuration
246 | :param kwargs: additional keyword arguments passed along to the specific PEFT configuration class
247 | :return: the PEFT model
248 | """
249 | # Load the config
250 | if config is None:
251 | config = PEFT_TYPE_TO_CONFIG_MAPPING[
252 | PeftConfig._get_peft_type(
253 | model_id,
254 | subfolder=kwargs.get("subfolder", None),
255 | revision=kwargs.get("revision", None),
256 | cache_dir=kwargs.get("cache_dir", None),
257 | use_auth_token=kwargs.get("use_auth_token", None),
258 | token=kwargs.get("token", None),
259 | )
260 | ].from_pretrained(model_id, **kwargs)
261 | elif isinstance(config, PeftConfig):
262 | config.inference_mode = not is_trainable
263 | else:
264 | raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")
265 |
266 | config.inference_mode = not is_trainable
267 | model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name)
268 | model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
269 | return model
270 |
271 | def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any):
272 | """
273 | Load a trained adapter into the model (The new adapter is not automatically set as the active adapter.
274 | Use `PeftModel.set_adapter` to set the active adapter.)
275 |
276 | :param model_id: the name of the adapter to be added
277 | :param adapter_name: the configuration of the adapter to be added
278 | :param is_trainable: whether the adapter should be trainable or not
279 | :param kwargs: additional arguments to modify the way the adapter is loaded
280 | :return:
281 | """
282 | hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)
283 | torch_device = infer_device()
284 |
285 | if adapter_name not in self.peft_config:
286 | # Load the config
287 | peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[
288 | PeftConfig._get_peft_type(
289 | model_id,
290 | **hf_hub_download_kwargs,
291 | )
292 | ].from_pretrained(
293 | model_id,
294 | **hf_hub_download_kwargs,
295 | )
296 | peft_config.inference_mode = not is_trainable
297 | self.add_adapter(adapter_name, peft_config)
298 |
299 | # Load the weights into the model
300 | adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)
301 | load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
302 |
303 | # Set model in evaluation mode to deactivate dropout modules by default
304 | if not is_trainable:
305 | self.eval()
306 |
307 | return load_result
308 |
309 | @classmethod
310 | def _split_kwargs(cls, kwargs: Dict[str, Any]):
311 | """
312 | Split keyword arguments
313 | """
314 | _kwargs_not_in_hf_hub_download_signature = ("use_auth_token",)
315 | hf_hub_download_kwargs = {}
316 | other_kwargs = {}
317 |
318 | for key, value in kwargs.items():
319 | if (
320 | key in inspect.signature(hf_hub_download).parameters
321 | or key in _kwargs_not_in_hf_hub_download_signature
322 | ):
323 | hf_hub_download_kwargs[key] = value
324 | else:
325 | other_kwargs[key] = value
326 |
327 | return hf_hub_download_kwargs, other_kwargs
328 |
329 | def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
330 | """
331 | Add an adapter to the model based on the passed configuration (The new adapter is not automatically set as
332 | the active adapter. Use `PeftModel.set_adapter` to set the active adapter.)
333 |
334 | :param adapter_name: the name of the adapter to be added
335 | :param peft_config: the configuration of the adapter to be added
336 | """
337 | if peft_config.peft_type != self.peft_type:
338 | raise ValueError(
339 | f"Cannot combine adapters with different peft types. "
340 | f"Found {self.peft_type} and {peft_config.peft_type}.")
341 |
342 | try:
343 | self.peft_config[adapter_name] = peft_config
344 | self.base_model.inject_adapter(self.base_model.model, adapter_name)
345 | except Exception:
346 | if adapter_name in self.peft_config:
347 | del self.peft_config[adapter_name]
348 | raise
349 |
350 | self.set_additional_trainable_modules(peft_config, adapter_name)
351 |
352 | def set_adapter(self, adapter_name: str) -> None:
353 | """
354 | Sets the active adapter (Only one adapter can be active at a time.)
355 |
356 | :param adapter_name: the name of the adapter to be set as active
357 | """
358 | if adapter_name not in self.peft_config:
359 | raise ValueError(f"Adapter {adapter_name} not found.")
360 | self.active_adapter = adapter_name
361 | if not self.peft_config[adapter_name].is_prompt_learning:
362 | self.base_model.set_adapter(adapter_name)
363 | _set_adapter(self, adapter_name)
364 |
365 | def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
366 | """
367 | Forward pass of the model
368 | """
369 | return self.get_base_model()(*args, **kwargs)
370 |
371 |
372 | class PeftModelForCausalLM(PeftModel):
373 | """
374 | PEFT Model for Causal Language Modeling
375 | """
376 |
377 | def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None:
378 | """
379 | Initialize PeftModelForCausalLM
380 |
381 | :param model: base transformer model
382 | :param peft_config: PEFT configuration
383 | :param adapter_name: adapter name
384 | """
385 | super().__init__(model, peft_config, adapter_name)
386 | self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
387 |
388 | def forward(
389 | self,
390 | input_ids=None,
391 | attention_mask=None,
392 | inputs_embeds=None,
393 | labels=None,
394 | output_attentions=None,
395 | output_hidden_states=None,
396 | return_dict=None,
397 | task_ids=None,
398 | **kwargs,
399 | ) -> torch.Tensor:
400 | """
401 | Forward function
402 | """
403 | return self.base_model(
404 | input_ids=input_ids,
405 | attention_mask=attention_mask,
406 | inputs_embeds=inputs_embeds,
407 | labels=labels,
408 | output_attentions=output_attentions,
409 | output_hidden_states=output_hidden_states,
410 | return_dict=return_dict,
411 | **kwargs,
412 | )
413 |
414 | def generate(self, **kwargs):
415 | """
416 | Generate the text
417 | """
418 | self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
419 | if hasattr(self.base_model, "model"):
420 | self.base_model.model.generation_config = self.generation_config
421 | else:
422 | self.base_model.generation_config = self.generation_config
423 | try:
424 | outputs = self.base_model.generate(**kwargs)
425 | except:
426 | self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
427 | raise
428 | else:
429 | self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
430 | return outputs
431 |
432 | def prepare_inputs_for_generation(self, *args, **kwargs):
433 | """
434 | Prepare inputs for text generation
435 | """
436 | model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
437 | return model_kwargs
438 |
439 |
440 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
441 | TaskType.CAUSAL_LM: PeftModelForCausalLM,
442 | }
443 |
444 |
445 | def get_peft_model(
446 | model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default",
447 | ) -> PeftModel:
448 | """
449 | Return a PEFT model object from a pre-trained model and a PEFT config
450 |
451 | :param model: model to be wrapped
452 | :param peft_config: configuration containing the parameters of the PEFT model
453 | :param adapter_name: name of the adapter to be injected
454 | :return:
455 | """
456 |
457 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
458 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
459 | return PeftModel(model, peft_config, adapter_name=adapter_name)
460 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
461 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Trainer
3 | """
4 | import os
5 | from typing import Optional
6 |
7 | import torch
8 | from torch import nn
9 | from transformers import Trainer
10 | from transformers.trainer import TRAINING_ARGS_NAME, logger
11 |
12 | from .peft_model import PeftModel
13 |
14 |
15 | class PeftTrainer(Trainer):
16 | """
17 | Trainer for the PEFT Model
18 | """
19 |
20 | def __init__(self, aux_loss_coeff=1e-2, **kwargs):
21 | """
22 | Initialize PeftTrainer
23 |
24 | :param aux_loss_coeff: a coefficient for the load balancing loss in Mixture-of-Experts (MoE) models
25 | :param kwargs: additional keyword arguments
26 | """
27 | super().__init__(**kwargs)
28 | self.loss_fn = nn.CrossEntropyLoss()
29 | self.aux_loss_coeff = aux_loss_coeff
30 |
31 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
32 | """
33 | Save the model and tokenizer
34 | """
35 | output_dir = output_dir if output_dir is not None else self.args.output_dir
36 | os.makedirs(output_dir, exist_ok=True)
37 | logger.info(f"Saving model checkpoint to {output_dir}")
38 | self.model.save_pretrained(output_dir, state_dict=state_dict)
39 |
40 | if self.tokenizer is not None:
41 | self.tokenizer.save_pretrained(output_dir)
42 |
43 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
44 |
45 | def compute_loss(self, model: PeftModel, inputs, return_outputs=False):
46 | """
47 | Compute the loss by the trainer
48 | """
49 | outputs = model(**inputs)
50 | if "loss" in outputs:
51 | loss = outputs.get("loss")
52 | else:
53 | logits = outputs.get("logits")
54 | labels = inputs.get("labels")
55 | loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
56 | if hasattr(model, 'get_aux_loss'):
57 | aux_loss = model.get_aux_loss()
58 | loss += self.aux_loss_coeff * aux_loss
59 | return (loss, outputs) if return_outputs else loss
60 |
--------------------------------------------------------------------------------
/src/utils/peft_types.py:
--------------------------------------------------------------------------------
1 | """
2 | PEFT and Task Types
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | import enum
9 |
10 |
11 | class PeftType(str, enum.Enum):
12 | """
13 | PEFT Adapter Types
14 | """
15 | LORA = "LORA"
16 | MOLE = "MOLE"
17 | ADAMOLE = "ADAMOLE"
18 |
19 |
20 | class TaskType(str, enum.Enum):
21 | """
22 | PEFT Task Type
23 | """
24 | CAUSAL_LM = "CAUSAL_LM"
25 |
--------------------------------------------------------------------------------
/src/utils/save_and_load.py:
--------------------------------------------------------------------------------
1 | """
2 | Saving and Loading Models
3 |
4 | Portions of this file are modifications based on work created and
5 | shared by the HuggingFace Inc. team and used according to terms
6 | described in the Apache License 2.0.
7 | """
8 | import warnings
9 |
10 | from peft.utils.other import EMBEDDING_LAYER_NAMES
11 | from peft.utils.save_and_load import has_valid_embedding_base_layer, get_embedding_layer_name
12 |
13 | from ..utils.peft_types import PeftType
14 |
15 |
16 | def get_peft_model_state_dict(
17 | model, state_dict: dict = None, adapter_name: str = "default", unwrap_compiled: bool = False,
18 | save_embedding_layers: str | bool = "auto"
19 | ):
20 | """
21 | Get the state dict of the PEFT model
22 |
23 | :param model: the PEFT model
24 | :param state_dict: the state dict of the model (If not provided, the state dict of the passed model will be used.)
25 | :param adapter_name: the name of the adapter whose state dict should be returned
26 | :param unwrap_compiled: whether to unwrap the model if `torch.compile` was used
27 | :param save_embedding_layers: if `True`, save the embedding layers in addition to adapter weights;
28 | if `auto`, checks the common embedding layers in config's `target_modules` when available
29 | :return:
30 | """
31 | if unwrap_compiled:
32 | model = getattr(model, "_orig_mod", model)
33 |
34 | config = model.peft_config[adapter_name]
35 |
36 | if state_dict is None:
37 | state_dict = model.state_dict()
38 |
39 | if config.peft_type in (PeftType.LORA, PeftType.MOLE, PeftType.ADAMOLE):
40 | bias = config.bias
41 | if bias == "none":
42 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
43 | elif bias == "all":
44 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
45 | elif bias == "lora_only":
46 | to_return = {}
47 | for k in state_dict:
48 | if "lora_" in k:
49 | to_return[k] = state_dict[k]
50 | bias_name = k.split("lora_")[0] + "bias"
51 | if bias_name in state_dict:
52 | to_return[bias_name] = state_dict[bias_name]
53 | else:
54 | raise NotImplementedError
55 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
56 | else:
57 | raise NotImplementedError
58 |
59 | if getattr(model, "modules_to_save", None) is not None:
60 | for key, value in state_dict.items():
61 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
62 | to_return[key.replace("modules_to_save.", "")] = value
63 |
64 | # Check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
65 | if (
66 | save_embedding_layers == "auto"
67 | and hasattr(config, "target_modules")
68 | and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
69 | ):
70 | warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
71 | save_embedding_layers = True
72 | elif save_embedding_layers == "auto":
73 | save_embedding_layers = False
74 |
75 | if save_embedding_layers and hasattr(model, "get_input_embeddings"):
76 | for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
77 | if config.is_prompt_learning or has_valid_embedding_base_layer(layer):
78 | embedding_module_name = get_embedding_layer_name(model, layer, config.is_prompt_learning)
79 | if embedding_module_name:
80 | to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
81 | elif save_embedding_layers:
82 | warnings.warn("Could not identify embedding layer(s) because the model is not a model in transformers.")
83 |
84 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
85 | return to_return
86 |
87 |
88 | def set_peft_model_state_dict(model, peft_model_state_dict: dict, adapter_name="default"):
89 | """
90 | Set the state dict of the PEFT model
91 |
92 | :param model: the PEFT model.
93 | :param peft_model_state_dict: the state dict of the PEFT model
94 | :param adapter_name: the adapter name
95 | :return:
96 | """
97 | config = model.peft_config[adapter_name]
98 | state_dict = {}
99 |
100 | if getattr(model, "modules_to_save", None) is not None:
101 | for key, value in peft_model_state_dict.items():
102 | if any(module_name in key for module_name in model.modules_to_save):
103 | for module_name in model.modules_to_save:
104 | if module_name in key:
105 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
106 | break
107 | state_dict[key] = value
108 | else:
109 | state_dict = peft_model_state_dict
110 |
111 | if config.peft_type in (PeftType.LORA, PeftType.MOLE, PeftType.ADAMOLE):
112 | peft_model_state_dict = {}
113 | parameter_prefix = {
114 | PeftType.LORA: "lora_",
115 | PeftType.MOLE: "lora_",
116 | PeftType.ADAMOLE: "lora_",
117 | }[config.peft_type]
118 | for k, v in state_dict.items():
119 | if parameter_prefix in k:
120 | suffix = k.split(parameter_prefix)[1]
121 | if "." in suffix:
122 | suffix_to_replace = ".".join(suffix.split(".")[1:])
123 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
124 | else:
125 | k = f"{k}.{adapter_name}"
126 | peft_model_state_dict[k] = v
127 | else:
128 | peft_model_state_dict[k] = v
129 | else:
130 | raise NotImplementedError
131 |
132 | load_result = model.load_state_dict(peft_model_state_dict, strict=False)
133 | return load_result
134 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Testing LLMs on Benchmarks
3 | """
4 | import argparse
5 | import json
6 | import os
7 | import re
8 |
9 | import pandas as pd
10 | import torch
11 | import transformers
12 | from tqdm import tqdm
13 | from transformers import (
14 | AutoTokenizer,
15 | AutoModelForCausalLM,
16 | TextGenerationPipeline,
17 | GenerationConfig,
18 | )
19 | from transformers.pipelines.pt_utils import KeyDataset
20 |
21 | from data import get_formatted_datasets
22 | from src import PeftConfig, PeftModelForCausalLM
23 |
24 | transformers.set_seed(0)
25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26 |
27 |
28 | def predict_choices(examples):
29 | """
30 | Predict choices
31 | """
32 | prompts = examples['text']
33 | inputs = tokenizer(prompts, return_tensors="pt", padding=True)
34 | inputs = {key: value.to(device) for key, value in inputs.items()}
35 |
36 | with torch.no_grad():
37 | outputs = model(**inputs)
38 |
39 | logits = outputs.logits[:, -1, :]
40 | choices = [chr(ord('A') + i) for i in range(max(examples['num_choices']))]
41 | choice_ids = [tokenizer.encode(choice, add_special_tokens=False)[-1] for choice in choices]
42 |
43 | predicted_ids = torch.argmax(logits[:, choice_ids], dim=-1)
44 | predictions = [choices[predicted_id] for predicted_id in predicted_ids.cpu().numpy()]
45 | examples['prediction'] = predictions
46 |
47 | return examples
48 |
49 |
50 | if __name__ == '__main__':
51 | # Add arguments
52 | parser = argparse.ArgumentParser(
53 | description='Fine-tuning LLMs on training data.',
54 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
55 | fromfile_prefix_chars='@')
56 | parser.add_argument(
57 | '--model_path', type=str, default='outputs/llama-2-7b-hf-adamole-the8-commonsense-qa',
58 | help='huggingface model id or local model path')
59 | parser.add_argument(
60 | '--data_path', type=str, default='tau/commonsense_qa',
61 | help='huggingface data id or local data path')
62 | parser.add_argument(
63 | '--max_new_tokens', type=int, default=16,
64 | help='maximum number of new tokens')
65 | parser.add_argument(
66 | '--batch_size', type=int, default=16,
67 | help='batch size in the pipeline')
68 | parser.add_argument(
69 | '--logits', default=False, action='store_true',
70 | help='checking choice logits instead of generated texts')
71 |
72 | # Parse arguments
73 | args = parser.parse_args()
74 | model_path = args.model_path
75 | data_path = args.data_path
76 | model_name = os.path.basename(model_path).lower()
77 | data_name = os.path.basename(data_path).lower()
78 | max_new_tokens = args.max_new_tokens
79 | batch_size = args.batch_size
80 | if data_name in ['openbookqa', 'ai2_arc']:
81 | split = 'test'
82 | else:
83 | split = 'validation'
84 |
85 | # Load and format datasets
86 | formatted_datasets = get_formatted_datasets(data_path=data_path, prompt_only=True)
87 |
88 | # Load the configuration and model
89 | peft_config = PeftConfig.from_pretrained(model_path)
90 | base_model = AutoModelForCausalLM.from_pretrained(
91 | peft_config.base_model_name_or_path,
92 | # torch_dtype=torch.bfloat16,
93 | )
94 | tokenizer = AutoTokenizer.from_pretrained(
95 | peft_config.base_model_name_or_path,
96 | padding_side="left",
97 | )
98 | tokenizer.pad_token = tokenizer.eos_token
99 | model = PeftModelForCausalLM.from_pretrained(model=base_model, model_id=model_path)
100 | model.to(device)
101 | print(f'Model loaded from {model_path}')
102 | print(f'Model: {model}')
103 |
104 | if not args.logits:
105 | # Build the pipeline
106 | generation_config = GenerationConfig(
107 | max_new_tokens=max_new_tokens,
108 | do_sample=False,
109 | )
110 | pipeline = TextGenerationPipeline(
111 | model=model,
112 | tokenizer=tokenizer,
113 | device=device,
114 | )
115 |
116 | # Get the model responses
117 | responses = []
118 | for response in tqdm(
119 | pipeline(
120 | KeyDataset(formatted_datasets[split], 'text'),
121 | generation_config=generation_config,
122 | return_full_text=False,
123 | batch_size=batch_size,
124 | ),
125 | total=len(formatted_datasets[split]),
126 | ):
127 | responses.append(response[0]['generated_text'])
128 |
129 | # Print one response
130 | print(f'Response example:\n{responses[0]}')
131 |
132 | # Get the results
133 | df = formatted_datasets[split].to_pandas()
134 | df['response'] = responses
135 | df['prediction'] = df['response'].str.extract(pat=r'\b([A-Z])\b')[0]
136 |
137 | else:
138 | # Get predictions
139 | dataset_with_predictions = formatted_datasets[split].map(
140 | predict_choices, batched=True, batch_size=batch_size)
141 | df = dataset_with_predictions.to_pandas()
142 |
143 | # Save the results
144 | result_path = os.path.join(model_path, f'{split}_results.csv')
145 | df.to_csv(result_path, index=False)
146 | print(f'Results saved to {result_path}')
147 |
148 | # Compute evaluation metrics
149 | metrics = {}
150 | for _data_name in df['data_name'].unique():
151 | df_set = df[df['data_name'] == _data_name]
152 | accuracy = pd.Series(df_set['answer'] == df_set['prediction']).mean()
153 | print(f'Accuracy of {_data_name}: {accuracy:.2%}')
154 | metrics['accuracy_' + re.sub(r'\W', '_', _data_name)] = accuracy
155 |
156 | # Save evaluation metrics
157 | metric_path = os.path.join(model_path, f'{split}_metrics.json')
158 | with open(metric_path, 'w') as file:
159 | json.dump(metrics, file)
160 | print(f'Metrics saved to {metric_path}')
161 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Fine-Tuning LLMs on Tasks
3 | """
4 | import argparse
5 | import os
6 | import re
7 |
8 | import torch
9 | import transformers
10 | from transformers import (
11 | AutoTokenizer,
12 | AutoModelForCausalLM,
13 | TrainingArguments,
14 | DataCollatorForLanguageModeling,
15 | )
16 |
17 | from data import get_formatted_datasets
18 | from src import (
19 | TaskType,
20 | LoraConfig,
21 | MoleConfig,
22 | AdaMoleConfig,
23 | PeftTrainer,
24 | PeftModelForCausalLM,
25 | )
26 |
27 | transformers.set_seed(0)
28 |
29 | if __name__ == '__main__':
30 | # Add arguments
31 | parser = argparse.ArgumentParser(
32 | description='Fine-tuning LLMs on training data.',
33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
34 | fromfile_prefix_chars='@')
35 | parser.add_argument(
36 | '--model_path', type=str, default='meta-llama/Llama-2-7b-hf',
37 | help='huggingface model id or local model path')
38 | parser.add_argument(
39 | '--data_path', type=str, default='tau/commonsense_qa',
40 | help='huggingface data id or local data path')
41 | parser.add_argument(
42 | '--peft_type', type=str, default='lora', choices=['lora', 'mole', 'adamole'],
43 | help='peft model type to be fine-tuned')
44 | parser.add_argument(
45 | '--lora_rank', type=int, default=32,
46 | help='lora rank if the peft type is lora or total lora rank if moe')
47 | parser.add_argument(
48 | '--target_modules', type=str, default=['q_proj', 'v_proj'], nargs='+',
49 | help='target modules in lora layers')
50 | parser.add_argument(
51 | '--num_experts', type=int, default=1,
52 | help='number of experts in each moe layer')
53 | parser.add_argument(
54 | '--top_k', type=int, default=None,
55 | help='top-k experts in moe (only one of top_k or threshold can be used)')
56 | parser.add_argument(
57 | '--threshold', type=float, default=None,
58 | help='threshold for expert gating in moe (only one of top_k or threshold can be used)')
59 | parser.add_argument(
60 | '--max_length', type=int, default=256,
61 | help='maximum number of tokens')
62 | parser.add_argument(
63 | '--batch_size', type=int, default=16,
64 | help='batch size in the trainer')
65 | parser.add_argument(
66 | '--gradient_accumulation_steps', type=int, default=1,
67 | help='gradient accumulation steps')
68 | parser.add_argument(
69 | '--num_train_epochs', type=int, default=1,
70 | help='number of training epochs')
71 | parser.add_argument(
72 | '--learning_rate', type=float, default=1e-4,
73 | help='learning rate for training')
74 | parser.add_argument(
75 | '--lr_scheduler_type', type=str, default="constant_with_warmup",
76 | help='learning rate scheduler type')
77 | parser.add_argument(
78 | '--warmup_steps', type=int, default=200,
79 | help='number of warmup steps for training')
80 | parser.add_argument(
81 | '--weight_decay', type=float, default=0.0,
82 | help='weight decay')
83 | parser.add_argument(
84 | '--aux_loss_coeff', type=float, default=None,
85 | help='auxiliary loss coefficient for moe')
86 |
87 | # Parse arguments
88 | args = parser.parse_args()
89 | print(f'Arguments: {args}')
90 | model_path = args.model_path
91 | data_path = args.data_path
92 | model_name = os.path.basename(model_path).lower()
93 | data_name = os.path.basename(data_path).lower()
94 | peft_type = args.peft_type
95 | num_experts = args.num_experts
96 | max_length = args.max_length
97 | lora_rank = args.lora_rank if peft_type == 'lora' else args.lora_rank // num_experts
98 | lora_alpha = 16
99 | lora_dropout = 0.05
100 | peft_type_name = peft_type
101 | if args.top_k is not None:
102 | peft_type_name += f'-top{args.top_k}'
103 | if args.threshold is not None:
104 | threshold_name = int(1 / args.threshold)
105 | peft_type_name += f'-the{threshold_name}'
106 | output_dir = os.path.join('outputs', re.sub(r'[^0-9a-zA-Z]', '-', f'{model_name}-{peft_type_name}-{data_name}'))
107 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108 |
109 | # Load and format datasets
110 | formatted_datasets = get_formatted_datasets(data_path=data_path, prompt_only=False)
111 |
112 | # Load the tokenizer
113 | tokenizer = AutoTokenizer.from_pretrained(
114 | model_path,
115 | padding_side="left",
116 | # add_bos_token=True,
117 | add_eos_token=True,
118 | )
119 | tokenizer.pad_token = tokenizer.eos_token
120 |
121 | # Tokenize datasets
122 | tokenize_text = lambda examples: tokenizer(
123 | examples["text"],
124 | truncation=True,
125 | max_length=max_length,
126 | # padding=True,
127 | # return_tensors="pt",
128 | )
129 | tokenized_datasets = formatted_datasets.map(
130 | tokenize_text,
131 | batched=True,
132 | remove_columns=formatted_datasets["train"].column_names,
133 | )
134 | print(f'Tokenized datasets: {tokenized_datasets}')
135 |
136 | # Set the data collator
137 | data_collator = DataCollatorForLanguageModeling(
138 | tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="pt")
139 |
140 | # Load the base model
141 | base_model = AutoModelForCausalLM.from_pretrained(
142 | model_path,
143 | # torch_dtype=torch.bfloat16,
144 | # device_map="auto",
145 | )
146 | print(f'Base model loaded from {model_path}')
147 | print(f'Base model: {base_model}')
148 |
149 | # Get the PEFT model
150 | if peft_type == 'lora':
151 | peft_config = LoraConfig(
152 | lora_rank=lora_rank,
153 | lora_alpha=lora_alpha,
154 | lora_dropout=lora_dropout,
155 | target_modules=args.target_modules,
156 | task_type=TaskType.CAUSAL_LM,
157 | bias="none",
158 | )
159 | elif peft_type == 'mole':
160 | peft_config = MoleConfig(
161 | lora_rank=lora_rank,
162 | lora_alpha=lora_alpha,
163 | lora_dropout=lora_dropout,
164 | target_modules=args.target_modules,
165 | task_type=TaskType.CAUSAL_LM,
166 | bias="none",
167 | num_experts=num_experts,
168 | top_k=args.top_k,
169 | threshold=args.threshold,
170 | )
171 | elif peft_type == 'adamole':
172 | peft_config = AdaMoleConfig(
173 | lora_rank=lora_rank,
174 | lora_alpha=lora_alpha,
175 | lora_dropout=lora_dropout,
176 | target_modules=args.target_modules,
177 | task_type=TaskType.CAUSAL_LM,
178 | bias="none",
179 | num_experts=num_experts,
180 | max_threshold=args.threshold,
181 | )
182 | else:
183 | raise KeyError(f'Unsupported PEFT type: {peft_type}')
184 |
185 | model = PeftModelForCausalLM(base_model, peft_config)
186 | model.enable_input_require_grads()
187 | model.print_trainable_parameters()
188 | print(f'Model: {model}')
189 |
190 | # Set the trainer
191 | training_args = TrainingArguments(
192 | output_dir=output_dir,
193 | overwrite_output_dir=True,
194 | group_by_length=True,
195 | remove_unused_columns=False,
196 | logging_strategy="steps",
197 | logging_steps=10,
198 | evaluation_strategy="steps",
199 | eval_steps=200,
200 | save_strategy="epoch",
201 | # save_steps=1000,
202 | optim="adamw_torch",
203 | per_device_train_batch_size=args.batch_size,
204 | per_device_eval_batch_size=args.batch_size,
205 | gradient_accumulation_steps=args.gradient_accumulation_steps,
206 | gradient_checkpointing=False,
207 | num_train_epochs=args.num_train_epochs,
208 | learning_rate=args.learning_rate,
209 | lr_scheduler_type=args.lr_scheduler_type,
210 | warmup_steps=args.warmup_steps,
211 | weight_decay=args.weight_decay,
212 | # fp16=True,
213 | seed=0,
214 | data_seed=0,
215 | report_to=["tensorboard"],
216 | )
217 | trainer = PeftTrainer(
218 | model=model,
219 | tokenizer=tokenizer,
220 | args=training_args,
221 | data_collator=data_collator,
222 | train_dataset=tokenized_datasets["train"],
223 | eval_dataset=tokenized_datasets["validation"],
224 | aux_loss_coeff=args.aux_loss_coeff,
225 | )
226 | with open(os.path.join(output_dir, 'training_args.json'), 'w') as output_file:
227 | output_file.write(training_args.to_json_string())
228 |
229 | # Train the model
230 | model.config.use_cache = False
231 | trainer.train()
232 | model.config.use_cache = True
233 |
234 | # Save the model
235 | trainer.save_model()
236 | trainer.save_state()
237 |
--------------------------------------------------------------------------------