├── .gitignore ├── LICENSE ├── Performance.png ├── README.md ├── discriminative ├── config │ ├── ada_merge.yml │ ├── average_merge.yml │ ├── dare_mask.yml │ ├── dare_mask2.yml │ ├── dare_merge.yml │ ├── dare_merge2.yml │ ├── fisher_merge.yml │ ├── glue.py │ ├── regmean_merge.yml │ ├── task_arithmetic.yml │ ├── task_arithmetic_plus.yml │ ├── task_arithmetic_search.yml │ ├── ties_merge.yml │ └── twin_merge.yml ├── data │ ├── test.json │ └── test_router.json ├── eval.py ├── merge.py ├── param.py ├── router.py ├── run.sh ├── run_merge.py ├── scripts.sh ├── sparsify.py ├── twin_merge.py └── utils.py ├── generative ├── config │ ├── ada_merge.yml │ ├── average_merge.yml │ ├── dare_mask.yml │ ├── dare_mask2.yml │ ├── dare_merge.yml │ ├── dare_merge2.yml │ ├── fisher_merge.yml │ ├── regmean_merge.yml │ ├── task_arithmetic.yml │ ├── task_arithmetic_plus.yml │ ├── task_arithmetic_search.yml │ ├── ties_merge.yml │ └── twin_merge.yml ├── data │ ├── test_data.json │ └── test_router.json ├── eval.sh ├── eval_merge.py ├── eval_scripts.sh ├── eval_twin.py ├── gen_eval_data.py ├── helm.sh ├── helm_utils │ ├── debias.py │ ├── helm_type.py │ ├── lora_utils.py │ ├── peft_save_merge.py │ └── prompter.py ├── merge.py ├── model.py ├── param.py ├── qwen_lora.json ├── qwen_task.py ├── router.py ├── run_merge.py ├── scripts.sh ├── sparsify.py └── utils.py ├── method.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.npy 3 | *.npz 4 | *.log 5 | outs/ 6 | *.pdf 7 | *.csv 8 | *.xlsx 9 | *.md 10 | roberta 11 | qwen 12 | *.npz 13 | *.safetensors 14 | HELM-Extended-Local -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Zhenyi Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LZY-the-boys/Twin-Merging/f481c60826cdf54c70f75f879f73ec68d22429df/Performance.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Twin-Merging 2 | 3 | 🚩 **Our paper Twin-Merging was accepted by NeurIPS2024. [Paper](https://arxiv.org/pdf/2406.15479v2)** 4 | 5 | This repository provides a PyTorch implementation and checkpoint for our Twin-Merging method, introduced in our paper [Twin-Merging](https://arxiv.org/pdf/2406.15479). 6 | Twin-Merging consists of two stages: modularizing knowledge into shared and exclusive components with compression to reduce redundancy, and dynamically merging shared and task-specific knowledge based on input. 7 | 8 | ![Twin-Merging Method](method.png) 9 | 10 | This approach significantly narrows the performance gap between merged and fine-tuned models, improving adaptability to heterogeneous data. It shows an average improvement of 28.34% in absolute normalized score for discriminative tasks and even surpasses the fine-tuned upper bound on generative tasks. 11 | 12 | ![Twin-Merging Performance](Performance.png) 13 | 14 | This repository contains: 15 | 16 | * 🪐 A simple PyTorch implementation of Twin-Merging on 12 tasks, both discriminative and generative. 17 | * ⚡️ Fine-tuned experts for RoBerta/Qwen, and Router checkpoints on [Huggingface Hub](https://huggingface.co/lu-vae/twin-merging-router). 18 | * 💥 A lightweight and easy-to-run merging framework supporting typical merging algorithms, with scripts for: 19 | - [Weight Average](https://arxiv.org/abs/2203.05482) 20 | - [Task-Arithmetic](https://arxiv.org/abs/2212.04089) 21 | - [Ties-Merging](https://arxiv.org/abs/2306.01708) 22 | - [Task-Arithmetic with Dare Merging](https://arxiv.org/abs/2311.03099) 23 | - [Ties-Merging with Dare Merging](https://arxiv.org/abs/2311.03099) 24 | - [Twin-Merging (Ours)](https://arxiv.org/pdf/2406.15479) 25 | 26 | ## Setup 27 | --- 28 | 29 | First, download and set up the repo: 30 | 31 | ```bash 32 | git clone https://github.com/LZY-the-boys/Twin-Merging 33 | cd Twin-Merging 34 | ``` 35 | 36 | We provide a requirements file to create a Conda environment. The Conda environment name `merging` is used in `generative/eval_scripts.sh`. 37 | If you change the name, update it in `generative/eval_scripts.sh` as well. 38 | 39 | ``` 40 | conda create -n merging python=3.9 41 | conda activate merging 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ## Merging for Discriminative Models: 46 | --- 47 | 48 | We offer examples for merging RoBERTa models tuned on the GLUE tasks. 49 | 50 | ### Merge 51 | 52 | For convenience, you can download a single expert for each dataset directly: 53 | ``` 54 | huggingface-cli download lu-vae/roberta-glue --local-dir roberta 55 | ``` 56 | 57 | You can find the detailed run command in [`discriminative/scripts.sh`](discriminative/scripts.sh). 58 | To run other algorithms (e.g., Ties-Merging), simply use: 59 | ``` 60 | source scripts.sh 61 | run_tie 62 | ``` 63 | 64 | For generative tasks, to run our twin-merging, you should use [`generative/eval.sh`](generative/eval.sh). 65 | 66 | ### Eval 67 | 68 | The merged model is automatically evaluated using the official Hugging Face [`evaluate`](https://huggingface.co/docs/evaluate/en/index). 69 | The full pipeline is in `discriminative/run.sh`. To get the results of our Twin-Merging approach, run: 70 | ``` 71 | cd discriminative 72 | bash run.sh 73 | ``` 74 | 75 | Performance is calculated by the normalized score as shown in Equation (4) of our paper. Using `{seed=0;gpu=A100-sxm-80g}`, the results are as follows (note: results may vary slightly with different devices or seeds): 76 | 77 | | Merging Algorithm | cola | mnli | mrpc | qnli | qqp | rte | sst2 | stsb | 78 | |:-----|-------:|-------:|-------:|-------:|------:|------:|-------:|-------:| 79 | | Ties-Merging | 9.46 | 59.34 | 74.71 | 65.93 | 41.29 | 47.29 | 72.13 | 9.21 | 80 | | Task-Arithmetic | 6.68 | 66.23 | 78.46 | 78.62 | 72.69 | 53.43 | 83.49 | 27.1 | 81 | | Twin-Merging | 101.06 | 94.35 | 97.51 | 98.78 | 98.06 | 94.56 | 99.64 | 82.67 | 82 | 83 | The reference absolute accuracy of the finetuned experts: 84 | 85 | | Expert | cola | mnli | mrpc | qnli | qqp | rte | sst2 | stsb | 86 | |:-----|-------:|-------:|-------:|-------:|------:|------:|-------:|-------:| 87 | | cola | 56.52 | 34.17 | 74.8 | 47.15 | 33.42 | 47.29 | 51.38 | 5.41 | 88 | | sst2 | 9.29 | 37.56 | 51.97 | 47.95 | 44.8 | 51.62 | 94.72 | 4.37 | 89 | | mrpc | 11.1 | 35.46 | 87.99 | 60.84 | 62.01 | 47.29 | 50.46 | 57.61 | 90 | | stsb | 0 | 32.37 | 75.17 | 58.54 | 33.62 | 47.29 | 50.92 | 86.36 | 91 | | qqp | 0.28 | 43.06 | 77.21 | 62.64 | 89.71 | 46.57 | 50.92 | 52.32 | 92 | | mnli | 1.72 | 87.01 | 53.29 | 49.55 | 44.91 | 28.88 | 51.26 | -24.9 | 93 | | qnli | 18.04 | 38.76 | 74.96 | 91.71 | 34.33 | 47.29 | 55.16 | -24.07 | 94 | | rte | -2.12 | 34.27 | 74.8 | 56.04 | 50.77 | 66.43 | 51.38 | 49.3 | 95 | 96 | 97 | ## Merging for Generative Models: 98 | --- 99 | We offer examples for merging Qwen-14B for four generative tasks: MMLU, TruthfulQA, BBQ, CNN-DailyMail. 100 | 101 | ### Merge 102 | 103 | Firstly, you should download the task-specific checkpoints: 104 | ``` 105 | huggingface-cli download lu-vae/qwen-cnn-merged --local-dir qwen/qwen-cnn 106 | huggingface-cli download lu-vae/qwen-dolly --local-dir qwen/qwen-mmlu 107 | huggingface-cli download lu-vae/qwen-truthfulqa-merged --local-dir qwen/qwen-truthfulqa 108 | huggingface-cli download lu-vae/qwen-bbq-merged --local-dir qwen/qwen-bbq 109 | ``` 110 | 111 | Alternatively, you can fine-tune them using the LoRA method with the [`axolotl`](https://github.com/LZY-the-boys/axolotl/) framework. The configuration file is available [here](https://github.com/LZY-the-boys/axolotl/blob/main/examples/qwen/qlora.yml). 112 | Their finetune dataset is uploaded in [here](https://huggingface.co/datasets/lu-vae/natural-dataset) 113 | 114 | Then, you can run specific merging algorithm via 115 | ``` 116 | cd generative 117 | source scripts.sh 118 | run_task_arith 119 | ``` 120 | If using LoRA, update the `--lora` flag with your configuration JSON, as shown in [`generative/qwen_lora.json`](generative/qwen_lora.json). 121 | 122 | ### Eval 123 | 124 | We evaluate the merged model using [`HELM` framework](https://github.com/stanford-crfm/helm), which is one of the biggest LLM benchmark, similar to huggingface Openllmleaderboard. 125 | However, its environment is a bit complex to install and have problems in showing the results , we recommend use [our enhanced version](https://github.com/LZY-the-boys/HELM-Extended-Local) for a smoother experience: 126 | ``` 127 | cd generative 128 | git clone --single-branch --branch dev HELM-Extended-Local 129 | conda create -n crfm-helm python=3.8 130 | conda activate crfm-helm 131 | pip install -r HELM-Extended-Local/requirements.txt 132 | pip install summ-eval jieba bert-score 133 | ``` 134 | 135 | After installing the HELM, you can run our evaluation scripts to get the merged result: 136 | ``` 137 | cd generative 138 | bash eval.sh 139 | ``` 140 | Results will be saved in `generative/HELM-Extended-Local/outs` . 141 | 142 | You should be able to reproduce similar performance as follows: 143 | 144 | | Model/adapter | BBQ - EM | CNN/DailyMail - ROUGE-2 | MMLU - EM | TruthfulQA - EM | 145 | |:-------------------|-----------:|--------------------------:|------------:|------------------:| 146 | | Twin-Merging | 90.7268 | 19.9269 | 68.2704 | 53.3835 | 147 | 148 | 149 | ## BibTeX 150 | 151 | ```bibtex 152 | @article{Lu2024TwinMerging, 153 | title={Twin-Merging: Dynamic Integration of Modular Expertise in Model Merging}, 154 | author={Zhenyi Lu and Chenghao Fan and Wei Wei and Xiaoye Qu and Dangyang Chen and Yu Cheng}, 155 | year={2024}, 156 | eprint={2406.15479}, 157 | archivePrefix={arXiv}, 158 | primaryClass={cs.CL}, 159 | url={https://arxiv.org/abs/2406.15479}, 160 | } 161 | ``` 162 | 163 | ## Acknowledgments 164 | 165 | We would like to acknowledge the Shanghai AI Laboratory provides facilities that were crucial to the completion of this work. 166 | 167 | -------------------------------------------------------------------------------- /discriminative/config/ada_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: ada_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto 9 | ada_type: auto -------------------------------------------------------------------------------- /discriminative/config/average_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: average_merging 2 | models_to_merge: auto 3 | models_name: auto 4 | exclude_param: auto 5 | model_loader: auto 6 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/dare_mask.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_mask 2 | base_model: meta-llama/Llama-2-13b-hf 3 | finetuned_model: WizardLM/WizardMath-13B-V1.0 4 | dtype: float16 5 | rescale: true 6 | mask_rate: 0.7 7 | mask_strategy: bernoulli 8 | weight_format: delta 9 | scaling: 1.0 -------------------------------------------------------------------------------- /discriminative/config/dare_mask2.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_mask 2 | base_model: meta-llama/Llama-2-13b-hf 3 | finetuned_model: WizardLM/WizardLM-13B-V1.2 4 | dtype: float16 5 | rescale: true 6 | mask_rate: 0.7 7 | mask_strategy: bernoulli 8 | weight_format: delta 9 | scaling: 1.0 -------------------------------------------------------------------------------- /discriminative/config/dare_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto 8 | 9 | # for dare_mask 10 | rescale: true 11 | mask_rate: 0.7 12 | mask_strategy: bernoulli 13 | weight_format: delta 14 | mask_scale: 1.0 15 | 16 | # for merge 17 | second_merge_method: task_arithmetic 18 | second_merge_config: 19 | scaling: 0.7 20 | -------------------------------------------------------------------------------- /discriminative/config/dare_merge2.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto 8 | 9 | # for dare_mask 10 | rescale: true 11 | mask_rate: 0.7 12 | mask_strategy: bernoulli 13 | weight_format: delta 14 | mask_scale: 1.0 15 | 16 | # for merge 17 | second_merge_method: ties_merge 18 | second_merge_config: 19 | mask_rate: 0.7 20 | scaling: 0.9 21 | -------------------------------------------------------------------------------- /discriminative/config/fisher_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/glue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Evaluate Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ GLUE benchmark metric. """ 15 | 16 | import datasets 17 | from scipy.stats import pearsonr, spearmanr 18 | from sklearn.metrics import f1_score, matthews_corrcoef 19 | 20 | import evaluate 21 | 22 | 23 | _CITATION = """\ 24 | @inproceedings{wang2019glue, 25 | title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding}, 26 | author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.}, 27 | note={In the Proceedings of ICLR.}, 28 | year={2019} 29 | } 30 | """ 31 | 32 | _DESCRIPTION = """\ 33 | GLUE, the General Language Understanding Evaluation benchmark 34 | (https://gluebenchmark.com/) is a collection of resources for training, 35 | evaluating, and analyzing natural language understanding systems. 36 | """ 37 | 38 | _KWARGS_DESCRIPTION = """ 39 | Compute GLUE evaluation metric associated to each GLUE dataset. 40 | Args: 41 | predictions: list of predictions to score. 42 | Each translation should be tokenized into a list of tokens. 43 | references: list of lists of references for each translation. 44 | Each reference should be tokenized into a list of tokens. 45 | Returns: depending on the GLUE subset, one or several of: 46 | "accuracy": Accuracy 47 | "f1": F1 score 48 | "pearson": Pearson Correlation 49 | "spearmanr": Spearman Correlation 50 | "matthews_correlation": Matthew Correlation 51 | Examples: 52 | 53 | >>> glue_metric = evaluate.load('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"] 54 | >>> references = [0, 1] 55 | >>> predictions = [0, 1] 56 | >>> results = glue_metric.compute(predictions=predictions, references=references) 57 | >>> print(results) 58 | {'accuracy': 1.0} 59 | 60 | >>> glue_metric = evaluate.load('glue', 'mrpc') # 'mrpc' or 'qqp' 61 | >>> references = [0, 1] 62 | >>> predictions = [0, 1] 63 | >>> results = glue_metric.compute(predictions=predictions, references=references) 64 | >>> print(results) 65 | {'accuracy': 1.0, 'f1': 1.0} 66 | 67 | >>> glue_metric = evaluate.load('glue', 'stsb') 68 | >>> references = [0., 1., 2., 3., 4., 5.] 69 | >>> predictions = [0., 1., 2., 3., 4., 5.] 70 | >>> results = glue_metric.compute(predictions=predictions, references=references) 71 | >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)}) 72 | {'pearson': 1.0, 'spearmanr': 1.0} 73 | 74 | >>> glue_metric = evaluate.load('glue', 'cola') 75 | >>> references = [0, 1] 76 | >>> predictions = [0, 1] 77 | >>> results = glue_metric.compute(predictions=predictions, references=references) 78 | >>> print(results) 79 | {'matthews_correlation': 1.0} 80 | """ 81 | 82 | 83 | def simple_accuracy(preds, labels): 84 | return float((preds == labels).mean()) 85 | 86 | 87 | def acc_and_f1(preds, labels): 88 | acc = simple_accuracy(preds, labels) 89 | f1 = float(f1_score(y_true=labels, y_pred=preds)) 90 | return { 91 | "accuracy": acc, 92 | "f1": f1, 93 | } 94 | 95 | 96 | def pearson_and_spearman(preds, labels): 97 | pearson_corr = float(pearsonr(preds, labels)[0]) 98 | spearman_corr = float(spearmanr(preds, labels)[0]) 99 | return { 100 | "pearson": pearson_corr, 101 | "spearmanr": spearman_corr, 102 | } 103 | 104 | 105 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 106 | class Glue(evaluate.Metric): 107 | def _info(self): 108 | if self.config_name not in [ 109 | "sst2", 110 | "mnli", 111 | "mnli_mismatched", 112 | "mnli_matched", 113 | "cola", 114 | "stsb", 115 | "mrpc", 116 | "qqp", 117 | "qnli", 118 | "rte", 119 | "wnli", 120 | "hans", 121 | ]: 122 | raise KeyError( 123 | "You should supply a configuration name selected in " 124 | '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' 125 | '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' 126 | ) 127 | return evaluate.MetricInfo( 128 | description=_DESCRIPTION, 129 | citation=_CITATION, 130 | inputs_description=_KWARGS_DESCRIPTION, 131 | features=datasets.Features( 132 | { 133 | "predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"), 134 | "references": datasets.Value("int64" if self.config_name != "stsb" else "float32"), 135 | } 136 | ), 137 | codebase_urls=[], 138 | reference_urls=[], 139 | format="numpy", 140 | ) 141 | 142 | def _compute(self, predictions, references): 143 | if self.config_name == "cola": 144 | return {"matthews_correlation": matthews_corrcoef(references, predictions)} 145 | elif self.config_name == "stsb": 146 | return pearson_and_spearman(predictions, references) 147 | elif self.config_name in ["mrpc", "qqp"]: 148 | return acc_and_f1(predictions, references) 149 | elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]: 150 | return {"accuracy": simple_accuracy(predictions, references)} 151 | else: 152 | raise KeyError( 153 | "You should supply a configuration name selected in " 154 | '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' 155 | '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' 156 | ) 157 | -------------------------------------------------------------------------------- /discriminative/config/regmean_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/task_arithmetic.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: 0.7 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/task_arithmetic_plus.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic_plus 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/task_arithmetic_search.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic_search 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/ties_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: ties_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | mask_rate: 0.7 5 | scaling: 1.0 6 | models_name: auto 7 | exclude_param: auto 8 | model_loader: auto 9 | dtype: auto -------------------------------------------------------------------------------- /discriminative/config/twin_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: twin_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto 8 | 9 | # for dare_mask 10 | # rescale: true 11 | # mask_rate: auto 12 | # mask_strategy: bernoulli 13 | # weight_format: delta 14 | # mask_scale: 0.7 15 | 16 | # for merge 17 | second_merge_method: task_arithmetic 18 | second_merge_config: 19 | scaling: 0.29 -------------------------------------------------------------------------------- /discriminative/eval.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from transformers import Trainer 4 | import os 5 | import numpy as np 6 | import evaluate 7 | import datasets 8 | from functools import partial 9 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments 10 | import types 11 | import pandas as pd 12 | import torch 13 | 14 | glue_data_keys_map = { 15 | "cola": ("sentence", None), 16 | "sst2": ("sentence", None), 17 | "mrpc": ("sentence1", "sentence2"), 18 | "stsb": ("sentence1", "sentence2"), 19 | "qqp": ("question1", "question2"), 20 | "mnli": ("premise", "hypothesis"), 21 | "qnli": ("question", "sentence"), 22 | "rte": ("sentence1", "sentence2") 23 | } 24 | 25 | glue_data_metrics_map = { 26 | "cola": "matthews_correlation", 27 | "sst2": "accuracy", 28 | "mrpc": "averaged_scores", # average of accuracy and f1 29 | "stsb": "averaged_scores", # average of pearson and spearmanr 30 | "qqp": "averaged_scores", # average of accuracy and f1 31 | "mnli": "accuracy", 32 | "qnli": "accuracy", 33 | "rte": "accuracy" 34 | } 35 | 36 | glue_data_num_labels_map = { 37 | "cola": 2, 38 | "sst2": 2, 39 | "mrpc": 2, 40 | "stsb": 1, 41 | "qqp": 2, 42 | "mnli": 3, 43 | "qnli": 2, 44 | "rte": 2 45 | } 46 | 47 | glue_data_id_map = { 48 | "cola": 0, 49 | "sst2": 1, 50 | "mrpc": 2, 51 | "stsb": 3, 52 | "qqp": 4, 53 | "mnli": 5, 54 | "qnli": 6, 55 | "rte": 7 56 | } 57 | rev_glue_data_id_map = {value: key for key, value in glue_data_id_map.items()} 58 | 59 | model_path_template='../roberta/{name}/roberta-base_lr1e-05' 60 | head_path_template='../roberta/{name}/roberta-base_lr1e-05/classifier_head.pt' 61 | 62 | class CustomizedTrainer(Trainer): 63 | 64 | def __init__(self, use_multitask_setting: bool = False, *args, **kwargs): 65 | super(CustomizedTrainer, self).__init__(*args, **kwargs) 66 | self.use_multitask_setting = use_multitask_setting 67 | 68 | def compute_loss(self, model: nn.Module, inputs: dict, return_outputs: bool = False): 69 | 70 | if self.use_multitask_setting: 71 | return self.compute_multi_loss(model, inputs, return_outputs) 72 | 73 | assert "labels" in inputs, "labels are not involved in inputs!" 74 | labels = inputs.pop("labels") 75 | outputs = model(**inputs) 76 | logits = outputs["logits"] 77 | if logits.shape[1] > 1: 78 | # cross-entropy loss for classification 79 | loss = F.cross_entropy(input=logits, target=labels.long()) 80 | else: 81 | # mse loss for regression 82 | assert logits.shape[1] == 1, "wrong number of labels!" 83 | loss = F.mse_loss(input=logits.squeeze(dim=1), target=labels) 84 | return (loss, outputs) if return_outputs else loss 85 | 86 | def compute_multi_loss(self, model, inputs, return_outputs): 87 | 88 | assert "labels" in inputs, "labels are not involved in inputs!" 89 | labels = inputs.pop("labels") 90 | assert "dataset_ids" in inputs.keys(), "key dataset_ids is missing in the inputs!" 91 | dataset_ids = inputs["dataset_ids"] 92 | outputs = model(**inputs) 93 | logits = outputs["logits"] 94 | total_loss = None 95 | for dataset_id in dataset_ids.unique(): 96 | single_dataset_indices = dataset_ids == dataset_id 97 | single_dataset_num_labels = glue_data_num_labels_map[rev_glue_data_id_map[ 98 | dataset_id.item()]] 99 | # cross-entropy loss for classification 100 | if single_dataset_num_labels > 1: 101 | loss = F.cross_entropy( 102 | input=logits[single_dataset_indices][:, :single_dataset_num_labels], 103 | target=labels[single_dataset_indices].long() 104 | ) 105 | # mse loss for regression 106 | else: 107 | assert single_dataset_num_labels == 1, "wrong number of labels!" 108 | loss = F.mse_loss( 109 | input=logits[single_dataset_indices][:, 0], 110 | target=labels[single_dataset_indices] 111 | ) 112 | if total_loss is None: 113 | total_loss = loss 114 | else: 115 | total_loss += loss 116 | return (total_loss, outputs) if return_outputs else total_loss 117 | 118 | def compute_single_metrics(eval_pred, dataset_name): 119 | 120 | def extra_labels(eval_pred): 121 | 122 | if eval_pred.predictions.shape[1] > 1: 123 | return np.argmax(eval_pred.predictions, axis=1) 124 | else: 125 | return eval_pred.predictions.squeeze(axis=1) 126 | 127 | predictions = extra_labels(eval_pred) 128 | metric_func = evaluate.load(path="config/glue.py", config_name=dataset_name) 129 | result = metric_func.compute( 130 | predictions=predictions, 131 | references=eval_pred.label_ids 132 | ) 133 | # 如: acc 和 f1 相平均 134 | if len(result.keys()) > 1: 135 | result["averaged_scores"] = np.mean(list(result.values())).item() 136 | else: 137 | result["averaged_scores"] = list(result.values())[0] 138 | return result 139 | 140 | def compute_multi_metrics(eval_pred): 141 | 142 | def generate_predictions_and_labels(indices, num_labels, eval_pred): 143 | if num_labels > 1: 144 | predictions = np.argmax(eval_pred.predictions[indices][:, :num_labels], axis=1) 145 | labels = eval_pred.label_ids[1][indices].astype(np.longlong) 146 | else: 147 | predictions = eval_pred.predictions[indices][:, 0] 148 | labels = eval_pred.label_ids[1][indices] 149 | return predictions, labels 150 | 151 | def add_averaged_scores(result): 152 | # 如: acc 和 f1 相平均 153 | if len(result.keys()) > 1: 154 | result["averaged_scores"] = np.mean(list(result.values())).item() 155 | 156 | results = [] 157 | dataset_ids = eval_pred.label_ids[0] 158 | for dataset_id in np.unique(dataset_ids): 159 | indices = dataset_ids == dataset_id 160 | num_labels = glue_data_num_labels_map[rev_glue_data_id_map[dataset_id.item()]] 161 | predictions, labels = generate_predictions_and_labels(indices, num_labels, eval_pred) # is want to simplify this into a function 162 | metric_func = evaluate.load(path="glue", config_name=rev_glue_data_id_map[dataset_id.item()]) 163 | result = metric_func.compute(predictions=predictions, references=labels) 164 | add_averaged_scores(result) 165 | result["name"] = rev_glue_data_id_map[dataset_id.item()] 166 | results.append(result) 167 | 168 | dataset_scores = [ 169 | result[glue_data_metrics_map[result["name"]]] 170 | for result in results 171 | ] 172 | return {"averaged_scores": np.mean(dataset_scores).item(), "all_results": results} 173 | 174 | def load_glue_classifier(name, dtype, save_classifier_head=True): 175 | model_path = model_path_template.format(name=name) 176 | model = AutoModelForSequenceClassification.from_pretrained( 177 | model_path, torch_dtype=dtype, device_map="cpu" 178 | ) 179 | tokenizer = AutoTokenizer.from_pretrained(model_path) 180 | if save_classifier_head: 181 | if not os.path.exists(f'{model_path}'): 182 | print(f' >>> skip save classifier head for {model_path}') 183 | return model 184 | 185 | if os.path.exists(f'{model_path}/classifier_head.pt'): 186 | print(f' >>> skip save classifier head for {model_path}') 187 | return model 188 | 189 | print(f' >>> save classifier head for {model_path} in {model_path}/classifier_head.pt ') 190 | torch.save(model.classifier, f'{model_path}/classifier_head.pt') 191 | return model, tokenizer 192 | 193 | def load_glue_dataset(tokenizer, dataset_name, split='train'): 194 | if split != 'train': 195 | split = "validation_matched" if dataset_name == "mnli" else "validation" 196 | test_dataset = datasets.load_dataset( 197 | path=os.path.join("glue"), 198 | name=dataset_name, 199 | split=split, 200 | ) 201 | sentence1_key, sentence2_key = glue_data_keys_map[dataset_name] 202 | test_dataset = test_dataset.map( 203 | lambda examples: tokenizer( 204 | text=examples[sentence1_key], 205 | text_pair=examples[sentence2_key] if sentence2_key else None, 206 | max_length=128, 207 | truncation=True 208 | ), 209 | batched=True 210 | ) 211 | test_dataset = test_dataset.map( 212 | lambda x: {"dataset_ids": glue_data_id_map[dataset_name]} 213 | ) 214 | return test_dataset 215 | 216 | def eval_glue(tokenizer, model, dataset_name, output_path): 217 | 218 | # num_labels = glue_data_num_labels_map[dataset_name] 219 | test_dataset = load_glue_dataset(tokenizer, dataset_name, split='test') 220 | evaluator = CustomizedTrainer( 221 | model=model, 222 | args=TrainingArguments( 223 | output_dir=output_path, 224 | per_device_train_batch_size=16, 225 | per_device_eval_batch_size=128, 226 | report_to=[], # disable wandb 227 | ), 228 | eval_dataset=test_dataset, 229 | compute_metrics=partial(compute_single_metrics,dataset_name=dataset_name), 230 | tokenizer=tokenizer, 231 | ) 232 | 233 | test_metrics = evaluator.evaluate() 234 | test_metrics = { 235 | k: 100*float(f"{v:.4f}") if isinstance(v, float) else v 236 | for k, v in test_metrics.items() 237 | } 238 | print(f"test performance on dataset {dataset_name}: {test_metrics[f'eval_{glue_data_metrics_map[dataset_name]}']}") 239 | 240 | return test_metrics 241 | 242 | def run_eval_glue( 243 | *, 244 | model: str ='roberta-base', 245 | datasets: list[str] =["cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli","rte"], 246 | outdir: str ='debug/test', 247 | ): 248 | import inspect,types 249 | frame = inspect.currentframe() 250 | keys, _, _, args = inspect.getargvalues(frame) 251 | values = { k: args[k] for k in keys } 252 | args = types.SimpleNamespace( 253 | **values 254 | ) 255 | 256 | model = AutoModelForSequenceClassification.from_pretrained(args.model).to('cuda') 257 | tokenizer = AutoTokenizer.from_pretrained(args.model) 258 | 259 | # TODO: set model.classifier 260 | metrics = {"model": args.model} 261 | for dataset in args.datasets: 262 | if model.num_labels != glue_data_num_labels_map[dataset]: 263 | print(f' >>> num labels {model.num_labels} is not Compatible for {dataset}, skipping') 264 | continue 265 | test_metrics = eval_glue(tokenizer, model, dataset, args.outdir) 266 | metrics[dataset] = test_metrics[f'eval_{glue_data_metrics_map[dataset]}'] 267 | save_excel(metrics, args.outdir) 268 | 269 | 270 | if __name__ == '__main__': 271 | import defopt 272 | defopt.run(run_eval_glue) -------------------------------------------------------------------------------- /discriminative/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict, OrderedDict 3 | import tqdm 4 | import re 5 | import torch.nn as nn 6 | import copy 7 | import sparsify 8 | import utils 9 | import sys 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 12 | import os 13 | import functools 14 | from collections import defaultdict, OrderedDict 15 | from param import param 16 | 17 | class MergingMethod: 18 | 19 | @utils.args_inspector 20 | def __init__( 21 | self, 22 | models_to_merge, 23 | models_name, 24 | ): 25 | self.models_name = {n:i for i,n in enumerate(models_name)} 26 | # dict(zip(models_name, range(0, N))) 27 | self.models_to_merge = models_to_merge 28 | 29 | def get_model(self, model_name): 30 | return self.models_to_merge[self.models_name[model_name]] 31 | 32 | @utils.args_inspector 33 | @torch.inference_mode() 34 | def average_merging( 35 | self, 36 | ): 37 | 38 | merged_param = param.vectorize_reduce( 39 | lambda x: torch.stack(x).mean(dim=0), 40 | self.models_to_merge 41 | ) 42 | return merged_param 43 | 44 | @utils.args_inspector 45 | def fisher_merge( 46 | self, 47 | models_to_merge: list, 48 | data_names: list, 49 | data_nums: list, 50 | fish_scaling: list = None, 51 | norm_fish_weight: bool = True, 52 | min_fish_weight: float = 1e-6 53 | ): 54 | from merger.fisher_merge import FisherMerge 55 | merger = FisherMerge( 56 | models_to_merge, 57 | data_names, data_nums, 58 | fish_scaling, norm_fish_weight,min_fish_weight 59 | ) 60 | return merger.merge() 61 | 62 | @utils.args_inspector 63 | @torch.inference_mode() 64 | def regmean_merge( 65 | self, 66 | models_to_merge: list, 67 | data_names: list, 68 | data_nums: list, 69 | reduce_non_diagonal_ratio: float = 1.0 70 | ): 71 | 72 | from merger.regmean_merge import RegMeanMerge 73 | merger = RegMeanMerge( 74 | models_to_merge, 75 | data_names, data_nums, 76 | reduce_non_diagonal_ratio, 77 | ) 78 | return merger.merge() 79 | 80 | @utils.args_inspector 81 | @torch.inference_mode() 82 | def ties_merge( 83 | self, 84 | base_model: nn.Module, 85 | models_to_merge: list, 86 | mask_rate: float = 0.8, 87 | scaling: float = 1.0, 88 | ): 89 | 90 | def disjoint_merge( 91 | tensor: torch.Tensor, # (n_model, n_para) 92 | merge_func:str = 'mean', 93 | ): 94 | 95 | sign = torch.sign(tensor.sum(dim=0)) # (num_total_params, ) 96 | majority_sign = torch.sign(sign.sum(dim=0)) 97 | # replace 0 in sign to the major sign in param_signs 98 | sign[sign == 0] = majority_sign 99 | del majority_sign 100 | 101 | # preserve the parameter with the expect sign 102 | mask = torch.where( 103 | sign.unsqueeze(0) > 0, tensor > 0, tensor < 0 104 | ) 105 | tensor = tensor * mask 106 | 107 | # (n_model, n_para) -> (n_para,) 108 | if merge_func == "mean": 109 | num_ = (tensor != 0).sum(dim=0).float() 110 | # min=1.0 避免num_=0的情况 111 | tensor = torch.sum(tensor, dim=0) / torch.clamp(num_, min=1.0) 112 | elif merge_func == "sum": 113 | tensor = torch.sum(tensor, dim=0) 114 | elif merge_func == "max": 115 | tensor = tensor.abs().max(dim=0)[0] 116 | tensor *= sign 117 | return tensor 118 | 119 | task_vectors = [ 120 | model - base_model 121 | for model in models_to_merge 122 | ] 123 | flattened_param = [ task_vector.flatten() for task_vector in task_vectors ] 124 | # sparsify on model-level => (n_model, n_para) 125 | # flattened_param = torch.vstack( 126 | # [ sparsify.magnitude(_param, 1 - mask_rate) for _param in flattened_param ] 127 | # ) 128 | 129 | def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False): 130 | if K == 100: 131 | # print("Not applying mask") 132 | if return_mask: 133 | return M, torch.ones_like(M), None 134 | else: 135 | return M, torch.ones_like(M) 136 | 137 | if K >= 1: 138 | K /= 100 139 | 140 | original_shape = M.shape 141 | if M.dim() == 1: 142 | M = M.unsqueeze(0) 143 | 144 | n, d = M.shape 145 | k = int(d * K) 146 | k = d - k # Keep top k elements instead of bottom k elements 147 | 148 | # Find the k-th smallest element by magnitude for each row 149 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 150 | # Create a mask tensor with True for the top k elements in each row 151 | mask = M.abs() >= kth_values 152 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 153 | 154 | if reshape_mask: 155 | final_mask = final_mask.reshape(M.shape) 156 | 157 | if return_mask: 158 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 159 | else: 160 | return M * final_mask, final_mask.float().mean(dim=1) 161 | 162 | # flattened_param1 = sparsify.magnitude(torch.vstack(flattened_param), 1 - mask_rate) 163 | flattened_param = topk_values_mask(torch.vstack(flattened_param), 1 - mask_rate)[0] 164 | flattened_param = disjoint_merge(flattened_param) 165 | # randomly select one vector to unflatten 166 | merged_param = copy.deepcopy(base_model) 167 | merged_param = base_model + scaling * merged_param.unflatten(flattened_param) 168 | return merged_param 169 | 170 | @utils.args_inspector 171 | @torch.inference_mode() 172 | def ties_merge_old( 173 | self, 174 | base_model: nn.Module, 175 | models_to_merge: list, 176 | mask_rate: float = 0.8, 177 | scaling: float = 1.0, 178 | ): 179 | 180 | def state_dict_to_vector(state_dict, remove_keys=[]): 181 | shared_state_dict = copy.deepcopy(state_dict) 182 | for key in remove_keys: 183 | if key in shared_state_dict: 184 | del shared_state_dict[key] 185 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items())) 186 | return torch.nn.utils.parameters_to_vector([value.reshape(-1) for key, value in sorted_shared_state_dict.items()]) 187 | 188 | 189 | def vector_to_state_dict(vector, state_dict, remove_keys=[]): 190 | # create a reference dict to define the order of the vector 191 | reference_dict = copy.deepcopy(state_dict) 192 | for key in remove_keys: 193 | if key in reference_dict: 194 | del reference_dict[key] 195 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items())) 196 | 197 | # create a shared state dict using the refence dict 198 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values()) 199 | 200 | # add back the encoder and decoder embedding weights. 201 | if "transformer.shared.weight" in sorted_reference_dict: 202 | for key in remove_keys: 203 | sorted_reference_dict[key] = sorted_reference_dict["transformer.shared.weight"] 204 | return sorted_reference_dict 205 | 206 | def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False): 207 | if K == 100: 208 | # print("Not applying mask") 209 | if return_mask: 210 | return M, torch.ones_like(M), None 211 | else: 212 | return M, torch.ones_like(M) 213 | 214 | if K >= 1: 215 | K /= 100 216 | 217 | original_shape = M.shape 218 | if M.dim() == 1: 219 | M = M.unsqueeze(0) 220 | 221 | n, d = M.shape 222 | k = int(d * K) 223 | k = d - k # Keep top k elements instead of bottom k elements 224 | 225 | # Find the k-th smallest element by magnitude for each row 226 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 227 | # Create a mask tensor with True for the top k elements in each row 228 | mask = M.abs() >= kth_values 229 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 230 | 231 | if reshape_mask: 232 | final_mask = final_mask.reshape(M.shape) 233 | 234 | if return_mask: 235 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 236 | else: 237 | return M * final_mask, final_mask.float().mean(dim=1) 238 | 239 | def resolve_sign(tensor: torch.Tensor): 240 | sign_to_mult = torch.sign(tensor.sum(dim=0)) 241 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") 242 | return sign_to_mult 243 | 244 | def resolve_zero_signs(sign_to_mult, method="majority"): 245 | majority_sign = torch.sign(sign_to_mult.sum()) 246 | 247 | if method == "majority": 248 | sign_to_mult[sign_to_mult == 0] = majority_sign 249 | elif method == "minority": 250 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign 251 | return sign_to_mult 252 | 253 | def disjoint_merge(tensor, merge_func, sign_to_mult): 254 | merge_func = merge_func.split("-")[-1] 255 | 256 | # If sign is provided then we select the corresponding entries and aggregate. 257 | if sign_to_mult is not None: 258 | rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, tensor > 0, tensor < 0) 259 | selected_entries = tensor * rows_to_keep 260 | # Else we select all non-zero entries and aggregate. 261 | else: 262 | rows_to_keep = tensor != 0 263 | selected_entries = tensor * rows_to_keep 264 | 265 | if merge_func == "mean": 266 | non_zero_counts = (selected_entries != 0).sum(dim=0).float() 267 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1) 268 | elif merge_func == "sum": 269 | disjoint_aggs = torch.sum(selected_entries, dim=0) 270 | elif merge_func == "max": 271 | disjoint_aggs = selected_entries.abs().max(dim=0)[0] 272 | disjoint_aggs *= sign_to_mult 273 | else: 274 | raise ValueError(f"Merge method {merge_func} is not defined.") 275 | 276 | return disjoint_aggs 277 | 278 | task_vectors = [ 279 | model - base_model 280 | for model in models_to_merge 281 | ] 282 | flattened_param = [state_dict_to_vector(task_vector.param_dict) for task_vector in task_vectors ] 283 | all_checks = torch.vstack(flattened_param) 284 | updated_checks, *_ = topk_values_mask(all_checks, K=1 - mask_rate, return_mask=False) 285 | print(f"RESOLVING SIGN") 286 | final_signs = resolve_sign(updated_checks) 287 | assert final_signs is not None 288 | 289 | print(f"Disjoint AGGREGATION: dis-mean") 290 | merged_tv = disjoint_merge(updated_checks, 'dis-mean', final_signs) 291 | merged_tv_state_dict = vector_to_state_dict(merged_tv, copy.deepcopy(base_model.param_dict)) 292 | merged_param = base_model + scaling * param(merged_tv_state_dict) 293 | return merged_param 294 | 295 | 296 | @utils.args_inspector 297 | @torch.inference_mode() 298 | def task_arithmetic( 299 | self, 300 | base_model: nn.Module, 301 | models_to_merge: param, 302 | scaling: float = 1.0, 303 | ): 304 | 305 | task_vectors = [ 306 | model - base_model 307 | for model in models_to_merge 308 | ] 309 | merged_param = base_model + scaling * sum(task_vectors) 310 | return merged_param 311 | 312 | @utils.args_inspector 313 | @torch.inference_mode() 314 | def task_arithmetic_search( 315 | self, 316 | base_model: nn.Module, 317 | models_to_merge: param, 318 | scaling: float = 1.0, 319 | ): 320 | 321 | task_vectors = [ 322 | model - base_model 323 | for model in models_to_merge 324 | ] 325 | 326 | merged_param = base_model + sum([ 327 | w * tv 328 | for w, tv in zip(scaling, task_vectors) 329 | ]) 330 | return merged_param 331 | 332 | @utils.args_inspector 333 | @torch.inference_mode() 334 | def task_arithmetic_plus( 335 | self, 336 | base_model: nn.Module, 337 | models_to_merge: param, 338 | scaling: float = 1.0, 339 | mask_strategy: str = None, 340 | mask_rate: float = None, 341 | ): 342 | 343 | task_vectors = [ 344 | model + base_model 345 | for model in models_to_merge 346 | ] 347 | 348 | if mask_strategy is None: 349 | merged_param = (scaling * sum(task_vectors)) - base_model 350 | else: 351 | merged_param = (scaling * sum(task_vectors)).map( 352 | lambda n,p: getattr(sparsify, mask_strategy)( 353 | p, 354 | 1 - mask_rate, 355 | ), 356 | desc=mask_strategy 357 | )- base_model 358 | return merged_param 359 | 360 | @utils.args_inspector 361 | @torch.inference_mode() 362 | def dare_merge( 363 | self, 364 | models_to_merge: param, 365 | second_merge_method: str, 366 | second_merge_config: dict, 367 | mask_rate: float, 368 | base_model: nn.Module, 369 | mask_scale: float = 1.0, 370 | weight_format: str = 'delta', 371 | ): 372 | # 1. sparsify masking (merge with base model) 373 | masked_params = [ 374 | self.dare_mask( 375 | finetuned_model, 376 | mask_rate, 377 | base_model, 378 | mask_scale, 379 | weight_format, 380 | ) for finetuned_model in models_to_merge 381 | ] 382 | # 2. merge between the different models 383 | merged_params = getattr(self, second_merge_method)( 384 | base_model = base_model, 385 | models_to_merge = masked_params, 386 | **second_merge_config 387 | ) 388 | return merged_params 389 | 390 | @torch.inference_mode() 391 | def dare_mask( 392 | self, 393 | finetuned_model: nn.Module, 394 | mask_rate: float, 395 | base_model: nn.Module = None, 396 | mask_scale: float = 1.0, 397 | weight_format: str = 'delta' 398 | ): 399 | 400 | mask_rate = float(mask_rate) 401 | 402 | if weight_format == "full" or weight_format == "lora": 403 | masked_param = finetuned_model 404 | elif weight_format == "delta": 405 | masked_param = finetuned_model - base_model 406 | else: 407 | raise NotImplementedError 408 | 409 | def mask_input_with_mask_rate(input_tensor: torch.Tensor, density: float, use_rescale: bool = True, mask_strategy: str = 'random'): 410 | mask_rate = 1 - density 411 | assert 0.0 <= mask_rate <= 1.0, f"wrong range of mask_rate {mask_rate}, should be [0.0, 1.0]!" 412 | if mask_strategy == "random": 413 | mask = torch.bernoulli(torch.full_like(input=input_tensor, fill_value=mask_rate)).to(input_tensor.device) 414 | masked_input_tensor = input_tensor * (1 - mask) 415 | else: 416 | assert mask_strategy == "magnitude", f"wrong setting for mask_strategy {mask_strategy}!" 417 | original_shape = input_tensor.shape 418 | input_tensor = input_tensor.flatten() 419 | num_mask_params = int(len(input_tensor) * mask_rate) 420 | # Tensor, shape (1, ), find the num_mask_params-th smallest magnitude element of all the parameters in the model 421 | kth_values, _ = input_tensor.abs().kthvalue(k=num_mask_params, dim=0, keepdim=True) 422 | # Tensor, shape (num_total_params, ), where True is for parameters that we want to perform mask 423 | mask = input_tensor.abs() <= kth_values 424 | masked_input_tensor = input_tensor * (~mask) 425 | masked_input_tensor = masked_input_tensor.reshape(original_shape) 426 | if use_rescale and mask_rate != 1.0: 427 | masked_input_tensor = torch.div(input=masked_input_tensor, other=1 - mask_rate) 428 | return masked_input_tensor 429 | 430 | # mask_input_with_mask_rate 431 | masked_param = masked_param.map( 432 | lambda n,p: sparsify.bernoulli( 433 | p, 434 | 1 - mask_rate, 435 | ), 436 | desc='bernoulli' 437 | ) 438 | 439 | if weight_format == "delta": 440 | masked_param = base_model + mask_scale * masked_param 441 | return masked_param 442 | 443 | @utils.args_inspector 444 | @torch.inference_mode() 445 | def twin_merge( 446 | self, 447 | base_model: nn.Module, 448 | models_to_merge: param, 449 | second_merge_method: str, 450 | second_merge_config: dict, 451 | ): 452 | # merge again / MergePlus / DoubleBundle / DualMerger 453 | 454 | # Get merged parameter 455 | merged_params = getattr(self, second_merge_method)( 456 | base_model = base_model, 457 | models_to_merge = models_to_merge, 458 | **second_merge_config 459 | ) 460 | return merged_params 461 | -------------------------------------------------------------------------------- /discriminative/param.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict, OrderedDict 3 | import tqdm 4 | import re 5 | import torch.nn as nn 6 | import copy 7 | import sparsify 8 | import utils 9 | import sys 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 12 | import os 13 | import functools 14 | from collections import defaultdict, OrderedDict 15 | import torch 16 | 17 | kw_filter_func = lambda n,p,exclude_param : not any([ 18 | re.match(exclude_pattern, n) 19 | for exclude_pattern in exclude_param 20 | ]) 21 | 22 | MODE = 'drop' 23 | # MODE = 'keep_left' 24 | # MODE = 'keep_right' 25 | class param: 26 | 27 | def __init__( 28 | self, 29 | model, 30 | ): 31 | if isinstance(model, torch.nn.Module): 32 | other = model.state_dict() 33 | elif isinstance(model, dict): 34 | other = model 35 | elif isinstance(model, param): 36 | other = model.param_dict 37 | else: 38 | raise NotImplementedError 39 | 40 | self.param_dict = other 41 | 42 | def filter(self, func): 43 | self.param_dict = { 44 | n: p 45 | for n,p in self.param_dict.items() 46 | if func(n,p) 47 | } 48 | 49 | def __getitem__(self, item): 50 | return self.param_dict[item] 51 | 52 | def __len__(self): 53 | return len(self.param_dict) 54 | 55 | def items(self): 56 | return self.param_dict.items() 57 | 58 | def keys(self): 59 | return self.param_dict.keys() 60 | 61 | def values(self): 62 | return self.param_dict.values() 63 | 64 | # implement `in`! 65 | def __contains__(self, item): 66 | return item in self.keys() 67 | 68 | # a + b 69 | def __add__(self, other): 70 | 71 | if other == 0: 72 | return self 73 | 74 | if isinstance(other, torch.nn.Module): 75 | other = param(other) 76 | 77 | if hasattr(other, 'param_dict'): 78 | 79 | if MODE == 'drop': 80 | return param( 81 | { 82 | n: self[n] + other[n] 83 | for n in set(self.keys()).intersection(other.keys()) 84 | } 85 | ) 86 | elif MODE == 'keep_left': 87 | return param( 88 | { 89 | n: self[n] + other[n] 90 | if n in other 91 | else self[n] 92 | for n in (self.keys()) 93 | } 94 | ) 95 | 96 | elif MODE == 'keep_right': 97 | return param( 98 | { 99 | n: self[n] + other[n] 100 | if n in self 101 | else other[n] 102 | for n in (other.keys()) 103 | } 104 | ) 105 | else: 106 | raise NotImplementedError 107 | 108 | def update_null_keys(self, other): 109 | for k in other.keys(): 110 | if k not in self: 111 | self[k] = other[k] 112 | 113 | # type(y).__rsub__(y, x) is called if type(x).__sub__(x, y) returns NotImplemented. 114 | # a + b if a is not implemented 115 | def __radd__(self, other): 116 | # sum(x) start with 0 + x[0] 117 | if other == 0: 118 | return self 119 | # other + self = self + other 120 | return self.__add__(other) 121 | 122 | def __sub__(self, other): 123 | 124 | if isinstance(other, torch.nn.Module): 125 | other = param(other) 126 | 127 | if hasattr(other, 'param_dict'): 128 | 129 | if MODE == 'drop': 130 | return param( 131 | { 132 | n: self[n] - other[n] 133 | for n in set(self.keys()).intersection(other.keys()) 134 | } 135 | ) 136 | elif MODE == 'keep_left': 137 | return param( 138 | { 139 | n: self[n] - other[n] 140 | if n in other 141 | else self[n] 142 | for n in (self.keys()) 143 | } 144 | ) 145 | elif MODE == 'keep_right': 146 | return param( 147 | { 148 | n: self[n] - other[n] 149 | if n in self 150 | else other[n] 151 | for n in (other.keys()) 152 | } 153 | ) 154 | 155 | else: 156 | raise NotImplementedError 157 | 158 | def __rsub__(self, other): 159 | # other - self 160 | if isinstance(other, torch.nn.Module): 161 | other = param(other) 162 | 163 | if hasattr(other, 'param_dict'): 164 | return other.__sub__(self) 165 | 166 | else: 167 | raise NotImplementedError 168 | 169 | def __rmul__(self, other): 170 | 171 | if isinstance(other, float) or isinstance(other, torch.Tensor): 172 | # weight 173 | return param( 174 | { 175 | n: other * p 176 | for n,p in self.param_dict.items() 177 | } 178 | ) 179 | 180 | if isinstance(other, dict): 181 | # module-wise weight 182 | if MODE == 'drop': 183 | return param( 184 | { 185 | n: other[n] * self[n] 186 | for n in set(self.keys()).intersection(other.keys()) 187 | } 188 | ) 189 | elif MODE == 'keep_left': 190 | return param( 191 | { 192 | n: other[n] * self[n] 193 | if n in other 194 | else self[n] 195 | for n in (self.keys()) 196 | } 197 | ) 198 | elif MODE == 'keep_right': 199 | return param( 200 | { 201 | n: other[n] * self[n] 202 | if n in self 203 | else other[n] 204 | for n in (other.keys()) 205 | } 206 | ) 207 | 208 | raise NotImplementedError 209 | 210 | def __mul__(self, other): 211 | return self.__rmul__(other) 212 | 213 | def __neg__(self, ): 214 | return param( 215 | { 216 | n: -p 217 | for n,p in self.param_dict.items() 218 | } 219 | ) 220 | 221 | def __truediv__(self, other): 222 | 223 | if isinstance(other, (int, float)): 224 | # weight 225 | return param( 226 | { 227 | n: p / other 228 | for n,p in self.param_dict.items() 229 | } 230 | ) 231 | 232 | if isinstance(other, param): 233 | if MODE == 'drop': 234 | return param( 235 | { 236 | n: self[n] / other[n] 237 | for n in set(self.keys()).intersection(other.keys()) 238 | } 239 | ) 240 | elif MODE == 'keep_left': 241 | return param( 242 | { 243 | n: self[n] / other[n] 244 | if n in other 245 | else self[n] 246 | for n in (self.keys()) 247 | } 248 | ) 249 | elif MODE == 'keep_right': 250 | return param( 251 | { 252 | n: self[n] / other[n] 253 | if n in self 254 | else other[n] 255 | for n in (other.keys()) 256 | } 257 | ) 258 | 259 | raise NotImplementedError 260 | 261 | def assign(self, model: torch.nn.Module): 262 | device = model.device 263 | for n, p in model.named_parameters(): 264 | if n in self.param_dict: 265 | if p.shape != self.param_dict[n].shape: 266 | # for classifiers, default is num_labels=2 , probably has dimension mismatch 267 | print(f'>>> dimension mismatch! override model {n}') 268 | utils.rsetattr(model, n, torch.nn.Parameter(self.param_dict[n])) 269 | if 'classifier' in n: 270 | model.num_labels = self.param_dict[n].shape[0] 271 | print(f'>>> change num_labels to {model.num_labels}') 272 | continue 273 | p.data.copy_(self.param_dict[n]) 274 | model.to(device) 275 | 276 | def to(self, device): 277 | 278 | for n,p in self.param_dict.items(): 279 | # tensor is not inplace 280 | # but model is 281 | self.param_dict[n] = p.to(device) 282 | 283 | def map(self, func, desc): 284 | 285 | return param({ 286 | n: func(n, self.param_dict[n]) 287 | for n in tqdm.tqdm(self.param_dict.keys(), desc=f'Param Map {desc}') 288 | }) 289 | 290 | def flatten(self, ): 291 | return nn.utils.parameters_to_vector( 292 | [p.flatten() for p in OrderedDict(sorted(self.param_dict.items())).values()] 293 | ) 294 | 295 | def unflatten(self, flatten_params): 296 | 297 | nn.utils.vector_to_parameters( 298 | flatten_params, 299 | OrderedDict(sorted(self.param_dict.items())).values() 300 | ) 301 | return self 302 | 303 | def __iter__(self): 304 | return iter(self.param_dict.items()) 305 | 306 | @staticmethod 307 | def vectorize_reduce(func, models_to_merge): 308 | return param({ 309 | r[0][0]: func([rr[1] for rr in r]) 310 | for r in zip(*models_to_merge) 311 | }) -------------------------------------------------------------------------------- /discriminative/router.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from datasets import load_dataset 7 | import transformers 8 | from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM, GPTQConfig, AutoModelForSequenceClassification 9 | from accelerate import Accelerator 10 | from tqdm import tqdm 11 | import itertools 12 | import torch 13 | import torch.distributed as dist 14 | from utils import * 15 | 16 | 17 | def get_ori_datasets(mode="train", tokenizer=None, max_len=3000): 18 | data_list = [] 19 | task_name = [] 20 | glue_data_loader = GLUEDataLoader(tokenizer=tokenizer) 21 | for task in glue_data_id_map.keys(): 22 | train_dataset, val_dataset, test_dataset, num_labels = glue_data_loader.load_dataset( 23 | dataset_name=task, train_split_ratio_for_val=0.1, max_seq_length=128 24 | ) 25 | if mode == "train": 26 | data_list.append(train_dataset) 27 | else: 28 | data_list.append(test_dataset) 29 | task_name.append(task) 30 | max_data_len = max_len 31 | # for i in data_list: 32 | # max_data_len = min(max_data_len, len(i)) 33 | all_dataset = {} 34 | data_num = 0 35 | for idx, i in enumerate(data_list): 36 | all_dataset[task_name[idx]] = {"input": []} 37 | for jdx, j in enumerate(i): 38 | if jdx < max_data_len: 39 | all_dataset[task_name[idx]]["input"].append(j["input_ids"]) 40 | data_num += 1 41 | else: 42 | break 43 | return all_dataset, data_num 44 | 45 | def load_glue(tokenizer, router_info=None): 46 | glue_data_loader = GLUEDataLoader(tokenizer=tokenizer) 47 | new_datasets = [] 48 | all_dataset = [] 49 | for task in glue_data_id_map.keys(): 50 | train_dataset, val_dataset, test_dataset, num_labels = glue_data_loader.load_dataset( 51 | dataset_name=task, train_split_ratio_for_val=0.1, max_seq_length=128 52 | ) 53 | all_dataset.append(test_dataset) 54 | tot_num = 0 55 | for sub_data in all_dataset: 56 | for idx, i in enumerate(sub_data): 57 | if idx >= 1000: 58 | continue 59 | if router_info is not None: 60 | new_datasets.append({**i, **{"router_prob": router_info[tot_num].tolist()}}) 61 | tot_num += 1 62 | else: 63 | new_datasets.append({**i}) 64 | return new_datasets 65 | 66 | @torch.inference_mode() 67 | def generate_router_datasets( 68 | mode, 69 | max_len, 70 | shared_expert, 71 | ): 72 | 73 | rank = dist.get_rank() 74 | world_size = dist.get_world_size() 75 | device = rank % torch.cuda.device_count() 76 | torch.cuda.set_device(device) 77 | print(f"Starting rank={rank}, world_size={dist.get_world_size()}.") 78 | 79 | model = AutoModel.from_pretrained(shared_expert, torch_dtype=torch.float16).to(device) 80 | tokenizer = AutoTokenizer.from_pretrained('roberta-base') 81 | dist.barrier() 82 | 83 | datasets, data_num = get_ori_datasets(mode=mode, tokenizer=tokenizer, max_len=max_len) 84 | ans = {} 85 | for data_id, (category_name, data_class) in enumerate(datasets.items()): 86 | data_item = data_class["input"] 87 | res = [] 88 | for index, data_input in tqdm( 89 | itertools.islice(enumerate(data_item), rank, len(data_item), world_size), 90 | disable= device != 0, 91 | total = len(data_item) // world_size + 1, 92 | ): 93 | res.append(( 94 | index, 95 | torch.mean( 96 | model.forward( 97 | input_ids=torch.tensor([data_input]).to(device), 98 | output_hidden_states=True 99 | )["hidden_states"][-1][0, :, :], 100 | dim=0 101 | ).cpu().numpy(), 102 | )) 103 | 104 | dist.barrier() 105 | global_res = [None] * world_size 106 | dist.all_gather_object(global_res, res) 107 | if device == 0: 108 | # flatten 109 | global_res = sorted([rr for r in global_res for rr in r], key=lambda x: x[0]) 110 | ans[category_name] = [r[1] for r in global_res] 111 | # np.savez(f'data/router_{mode}.npz', **ans) 112 | 113 | if device == 0: 114 | np.savez(f'data/router_{mode}.npz', **ans) 115 | 116 | dist.barrier() 117 | 118 | class RouterDataset(Dataset): 119 | 120 | def __init__(self, data, targets): 121 | self.data = data 122 | self.targets = targets 123 | 124 | def __len__(self): 125 | return len(self.data) 126 | 127 | def __getitem__(self, idx): 128 | img, target = self.data[idx], self.targets[idx] 129 | return { 130 | 'input': img, 131 | 'label': target, 132 | } 133 | 134 | class SimpleMLP(nn.Module): 135 | 136 | def __init__(self, num_clients, embedding_dims, hidden_dim=1024): 137 | super(SimpleMLP, self).__init__() 138 | self.fc1 = nn.Linear(embedding_dims, hidden_dim) 139 | self.bn1 = nn.BatchNorm1d(hidden_dim) 140 | # self.fc2 = nn.Linear(hidden_dim, hidden_dim) 141 | # self.bn2 = nn.BatchNorm1d(hidden_dim) 142 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 143 | self.bn3 = nn.BatchNorm1d(hidden_dim) 144 | self.fc4 = nn.Linear(hidden_dim, num_clients) 145 | self.dropout = nn.Dropout(p=0.5) 146 | self.criterion = nn.CrossEntropyLoss() 147 | 148 | def forward(self, input, labels=None): 149 | x = input.float() 150 | x = self.fc1(x) 151 | x = self.bn1(x) 152 | x = F.leaky_relu(x) 153 | # x = self.fc2(x) 154 | # x = self.bn2(x) 155 | # x = F.leaky_relu(x) 156 | x = self.dropout(x) 157 | x = self.fc3(x) 158 | x = self.bn3(x) 159 | x = F.leaky_relu(x) 160 | x = self.dropout(x) 161 | x = self.fc4(x) 162 | 163 | if labels is not None: 164 | loss = self.criterion(x, labels) 165 | return loss, x 166 | return x 167 | 168 | def load_dataset(): 169 | 170 | train_data = np.load(f'data/router_train.npz',allow_pickle=True) 171 | train_dataset = RouterDataset( 172 | data = [ 173 | v 174 | for k in train_data.files 175 | for v in train_data[k] 176 | ], 177 | targets = [ 178 | glue_data_id_map[k] 179 | for k in train_data.files 180 | for _ in range(len(train_data[k])) 181 | ] 182 | ) 183 | test_data = np.load(f'data/router_test.npz') 184 | test_dataset = RouterDataset( 185 | data = [ 186 | v 187 | for k in test_data.files 188 | for v in test_data[k] 189 | ], 190 | targets = [ 191 | glue_data_id_map[k] 192 | for k in test_data.files 193 | for _ in range(len(test_data[k])) 194 | ] 195 | ) 196 | return { 197 | 'train': train_dataset, 198 | 'test': test_dataset, 199 | } 200 | 201 | 202 | def train_router( 203 | in_domain = None, # dict 204 | embed_dims = 768, 205 | ): 206 | encoded_dataset = load_dataset() 207 | task_num = 8 208 | if in_domain is not None: 209 | raise Exception('Not Implemented yet') 210 | 211 | device_map = "auto" 212 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 213 | ddp = world_size != 1 214 | device = int(os.environ.get("LOCAL_RANK") or 0) 215 | if ddp: 216 | device_map = {"": device} 217 | 218 | classifier = SimpleMLP( 219 | num_clients=task_num, embedding_dims=embed_dims, hidden_dim= 2*embed_dims 220 | ).to(device) 221 | 222 | def compute_metrics(eval_pred): 223 | logits, labels = eval_pred 224 | logits, labels = torch.tensor(logits), torch.tensor(labels) 225 | predictions = torch.argmax((logits), dim=-1) 226 | 227 | total = len(labels) 228 | correct_list = [0] * 8 229 | total_list = [0] * 8 230 | 231 | # total acc 232 | correct = predictions.eq((labels)).sum().item() 233 | acc = correct / total * 100.0 234 | print( 235 | "@@ Final {}/{}, Accuracy: {:.2f}%".format( 236 | correct, total, acc 237 | ) 238 | ) 239 | # acc per class 240 | for i in range(8): 241 | correct_list[i] = ((labels == i) & (predictions == i)).sum().item() 242 | total_list[i] = (labels == i).sum().item() 243 | acc_prop = [correct_list[i] / total_list[i] * 100.0 if total_list[i] > 0 else 0 for i in range(8)] 244 | print("Correct list: ", correct_list) 245 | print("Accuracy proportion: ", acc_prop) 246 | return { 247 | "accuracy": correct / total, 248 | "accuracy_per_class": acc_prop 249 | } 250 | 251 | trainer = transformers.Trainer( 252 | model=classifier, 253 | args=transformers.TrainingArguments( 254 | output_dir="./data/router", 255 | evaluation_strategy="epoch", 256 | save_strategy='epoch', 257 | learning_rate=0.0005, 258 | per_device_train_batch_size=1024, 259 | per_device_eval_batch_size=1024, 260 | num_train_epochs=50, 261 | # weight_decay=1e-4, 262 | logging_steps=20, 263 | save_total_limit=1, 264 | report_to=[], 265 | load_best_model_at_end=True, 266 | metric_for_best_model="accuracy", 267 | greater_is_better=True, 268 | ddp_find_unused_parameters=False 269 | ), 270 | train_dataset=encoded_dataset["train"], 271 | eval_dataset=encoded_dataset["test"], 272 | compute_metrics=compute_metrics, 273 | ) 274 | trainer.train() 275 | trainer.save_model("./data/router") 276 | prediction = trainer.predict(encoded_dataset["test"], metric_key_prefix='') 277 | 278 | new_datasets = load_glue( 279 | tokenizer = AutoTokenizer.from_pretrained('roberta-base'), 280 | router_info=prediction.predictions 281 | ) 282 | json.dump(new_datasets, open('data/test_router.json','w'), ensure_ascii=False) 283 | 284 | def main( 285 | *, 286 | shared_expert: str = None, 287 | seed: int = 0, 288 | train: bool = False, 289 | ): 290 | fix_seed(seed) 291 | 292 | if not os.path.exists('data/test.json') and os.getenv('LOCAL_RANK', 0) == 0: 293 | data = load_glue(tokenizer = AutoTokenizer.from_pretrained('roberta-base')) 294 | json.dump(data, open('data/test.json', 'w')) 295 | 296 | # use torchrun to start 297 | if not os.path.exists('data/router_train.npz'): 298 | assert shared_expert is not None 299 | if not dist.is_initialized(): 300 | dist.init_process_group("nccl") 301 | generate_router_datasets('train', 5500, shared_expert) 302 | generate_router_datasets('test', 1000, shared_expert) 303 | dist.destroy_process_group() 304 | 305 | # if not os.path.exists('data/test1.json') or not os.path.exists('outs/router1'): 306 | if train: 307 | train_router() 308 | print('Train Done') 309 | 310 | if __name__ == '__main__': 311 | import defopt 312 | defopt.run(main) -------------------------------------------------------------------------------- /discriminative/run.sh: -------------------------------------------------------------------------------- 1 | 2 | # detail functions to run different algorithm 3 | source scripts.sh 4 | 5 | 6 | if [ ! -d "outs/task_arithmetic" ]; then 7 | # 1. get the shared expert 8 | run_task_arith 9 | fi 10 | 11 | if [ ! -f "data/test_router.json" ]; then 12 | # 2. gen router dataset 13 | torchrun --master-port 23451 --nnodes=1 --nproc_per_node=2 \ 14 | router.py \ 15 | --no-train \ 16 | --shared-expert outs/task_arithmetic 17 | # 3. train router 18 | python3 router.py --train 19 | fi 20 | 21 | if [ ! -d "outs/finetuned" ]; then 22 | # 4. need to get finetuned expert performance first for evaluation 23 | ft 24 | fi 25 | 26 | # 5. run evaluation 27 | twin_merge -------------------------------------------------------------------------------- /discriminative/run_merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict, OrderedDict 3 | import tqdm 4 | import re 5 | import torch.nn as nn 6 | import copy 7 | import sparsify 8 | import utils 9 | import sys 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 12 | import os 13 | import functools 14 | from collections import defaultdict, OrderedDict 15 | from param import param 16 | import torch.nn.functional as F 17 | import torch 18 | from collections import defaultdict 19 | import numpy as np 20 | from merge import MergingMethod 21 | import eval 22 | import inspect 23 | import datasets 24 | import pandas as pd 25 | 26 | args = None 27 | DEVICE='cuda:0' 28 | 29 | @torch.inference_mode() 30 | def run_pretrained( 31 | args, 32 | load_head=True, 33 | ): 34 | 35 | # \theta_t 36 | pretrained = utils.load_classifier(args.base_model).to(DEVICE) 37 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 38 | 39 | data = utils.from_json(args.data_path) 40 | metrics = {'model': args.base_model } 41 | dataset_list = defaultdict(list) 42 | for data_item in (data): 43 | data_id = data_item['dataset_ids'] 44 | data_name = list(eval.glue_data_id_map.keys())[data_id] 45 | dataset_list[data_name].append(data_item) 46 | 47 | for data_name, dataset in dataset_list.items(): 48 | 49 | dataset = datasets.Dataset.from_pandas(pd.DataFrame(dataset)) 50 | 51 | head_path = eval.head_path_template.format(name=data_name) 52 | print(f' >>> load classifier head from {head_path} for {data_name}') 53 | classifier = torch.load(head_path) 54 | pretrained.classifier = classifier.to(DEVICE) 55 | 56 | def calculate_logits(data_item): 57 | input_ids = torch.nn.utils.rnn.pad_sequence( 58 | [torch.tensor(d) for d in data_item['input_ids']], 59 | batch_first=True, 60 | padding_value=tokenizer.pad_token_id, 61 | ) 62 | attention_mask = torch.nn.utils.rnn.pad_sequence( 63 | [torch.tensor(d) for d in data_item['attention_mask']], 64 | batch_first=True, 65 | padding_value=0, 66 | ) 67 | 68 | score = pretrained( 69 | input_ids.to(pretrained.device), 70 | attention_mask.to(pretrained.device), 71 | ).logits.cpu().numpy() 72 | 73 | return { 74 | 'predictions': score, 75 | 'label_ids': data_item['label'] 76 | } 77 | 78 | dataset = dataset.map( 79 | lambda x: calculate_logits(x), 80 | batched=True, 81 | batch_size=4, 82 | ) 83 | 84 | ans = eval.compute_single_metrics( 85 | utils.SimpleNamespace( 86 | predictions=torch.tensor(dataset['predictions']), 87 | label_ids=np.array(dataset['label_ids']) 88 | ), data_name 89 | )['averaged_scores'] 90 | metrics[data_name] = 100*float(f"{ans:.4f}") 91 | 92 | utils.save_excel(metrics, args.outdir) 93 | 94 | @torch.inference_mode() 95 | def run_base2( 96 | args, 97 | load_head=True, 98 | ): 99 | 100 | for model_name, model_to_merge in zip(args.models_name, args.models_to_merge): 101 | args.base_model = model_to_merge 102 | run_pretrained(args) 103 | 104 | @torch.inference_mode() 105 | def run_merge( 106 | args, 107 | ): 108 | 109 | if args.exclude_param and len(args.exclude_param): 110 | filter_func = lambda n,p : not any([ 111 | re.match(exclude_pattern, n) 112 | for exclude_pattern in args.exclude_param 113 | ]) 114 | # \theta_t 115 | models_finetuned = { 116 | name: utils.load_classifier( 117 | eval.model_path_template.format(name=name) 118 | ).to(DEVICE) 119 | for name in args.models_name 120 | } 121 | # \theta_* 122 | models_to_merge = [ 123 | models_finetuned[name].to(DEVICE) 124 | for name in args.src_merge 125 | ] 126 | base_model = utils.load_classifier(args.base_model).to(DEVICE) 127 | 128 | args.base_model = param(base_model) 129 | args.models_to_merge = [param(m) for m in models_to_merge] 130 | for model in args.models_to_merge: 131 | model.filter(filter_func) 132 | args.base_model.filter(filter_func) 133 | 134 | # 3. merge 135 | merger = MergingMethod(**args) 136 | merge_method = getattr(merger, args.merge_method) 137 | merged_param = merge_method(**args) 138 | 139 | if args.save_path is not None: 140 | merged_param.assign(base_model) 141 | base_model.save_pretrained(args.save_path) 142 | 143 | if args.data_path is not None: 144 | 145 | metrics = { 146 | "model": args.merge_method, 147 | "scaling": ','.join([str(i) for i in args['scaling']]) if isinstance(args['scaling'],list) else args['scaling'], 148 | **{ 149 | f"_{k}": args[k] for k in [ 'mask_rate', 'mask_strategy', 'mask_scale','src_merge' ] 150 | } 151 | } 152 | try: 153 | metrics['_mask_rate'] = 100*float(f"{metrics['_mask_rate']:.4f}") 154 | except: 155 | pass 156 | metrics['_src_merge'] = '+'.join(metrics['_src_merge']) 157 | if 'second_merge_method' in args: 158 | metrics['_second_merge_method'] = args['second_merge_method'] 159 | 160 | data = utils.from_json(args.data_path) 161 | eval_pred = defaultdict(lambda: defaultdict(list)) 162 | for data_item in tqdm.tqdm(data, desc='infer glue'): 163 | data_id = data_item['dataset_ids'] 164 | data_name = list(eval.glue_data_id_map.keys())[data_id] 165 | 166 | def calculate_logits(data_item): 167 | model = models_finetuned[data_name] 168 | score = torch.func.functional_call( 169 | model, 170 | merged_param.param_dict, 171 | args=( 172 | torch.tensor(data_item['input_ids']).unsqueeze(0).to(model.device), 173 | torch.tensor(data_item['attention_mask']).unsqueeze(0).to(model.device), 174 | ), 175 | ).logits.cpu().numpy() 176 | 177 | return score 178 | 179 | eval_pred[data_name]['predictions'].append(calculate_logits(data_item)) 180 | eval_pred[data_name]['label_ids'].append(data_item['label']) 181 | 182 | for data_name in eval_pred.keys(): 183 | 184 | ans = eval.compute_single_metrics( 185 | utils.SimpleNamespace( 186 | predictions=np.concatenate(eval_pred[data_name]['predictions']), 187 | label_ids=np.array(eval_pred[data_name]['label_ids']) 188 | ), data_name 189 | )['averaged_scores'] 190 | metrics[data_name] = 100*float(f"{ans:.4f}") 191 | 192 | utils.save_excel(metrics, args.outdir) 193 | 194 | def main( 195 | *, 196 | models_to_merge: list[str], 197 | models_name: list[str], 198 | src_merge: list[str], 199 | yaml_file: str = None, 200 | exclude_param: list[str] = None, 201 | data_path: str = None, 202 | seed: int=10, 203 | base_model: str = 'roberta-base', 204 | # for task-arithmetic_search: 205 | scaling: list[float] = None, 206 | # for dare-merge: 207 | mask_rate: float = None, 208 | mask_scale: float = None, 209 | mask_strategy: str = None, 210 | outdir: str = None, 211 | save_path: str = None, 212 | ): 213 | 214 | global args 215 | keys, _, _, values = inspect.getargvalues(inspect.currentframe()) 216 | 217 | utils.fix_seed(seed) 218 | 219 | if models_to_merge[0] == 'NONE': 220 | args = utils.SimpleNamespace(**{ 221 | k: values.get(k) for k in keys 222 | }) 223 | run_pretrained(args, load_head=True) 224 | elif yaml_file is None: 225 | args = utils.SimpleNamespace(**{ 226 | k: values.get(k) for k in keys 227 | }) 228 | # run_base(args) 229 | run_base2(args, load_head=True) 230 | else: 231 | merge_config = utils.from_yaml(yaml_file) 232 | args = { 233 | k: values.get(k, merge_config.get(k)) 234 | for k in set(keys).union(merge_config) 235 | } 236 | args = { 237 | k: merge_config.get(k, None) 238 | if args[k] is None else args[k] 239 | for k in args.keys() 240 | } 241 | args = utils.SimpleNamespace(**args) 242 | 243 | print('>>> args\n', args) 244 | 245 | if args.scaling is not None and isinstance(args.scaling, list) and len(args.scaling) == 1: 246 | args.scaling = args.scaling[0] 247 | 248 | run_merge(args) 249 | 250 | 251 | if __name__ == '__main__': 252 | import defopt 253 | defopt.run(main) -------------------------------------------------------------------------------- /discriminative/scripts.sh: -------------------------------------------------------------------------------- 1 | 2 | set -e pipefail 3 | 4 | date_today=$(date '+%Y-%m-%d') 5 | outdir=${outdir:="outs/merge_results"} 6 | mkdir -p ${outdir} 7 | 8 | 9 | models_name=( 10 | "cola" 11 | "sst2" 12 | "mrpc" 13 | "stsb" 14 | "qqp" 15 | "mnli" 16 | "qnli" 17 | "rte" 18 | ) 19 | models_to_merge=() 20 | for d in "${models_name[@]}"; do 21 | models_to_merge+=(../roberta/$d/roberta-base_lr1e-05) 22 | done 23 | select_merge=${select_merge:="8"} 24 | 25 | 26 | function pos(){ 27 | 28 | if [ $select_merge -eq 1 ]; then 29 | echo "please set \$select_merge > 1" 30 | exit 1 31 | fi 32 | src_merge=("${models_name[@]:0:$select_merge}") 33 | 34 | echo ">>> merged from $select_merge tasks" 35 | echo ">>> merge ${src_merge[@]}" 36 | 37 | data_path="data/test.json" 38 | } 39 | 40 | 41 | function run_dare_task_arith(){ 42 | 43 | pos 44 | 45 | for i in 0.7 ; do 46 | 47 | python run_merge.py \ 48 | --models-to-merge ${models_to_merge[@]} \ 49 | --models-name ${models_name[@]} \ 50 | --src-merge ${src_merge[@]} \ 51 | --data-path $data_path \ 52 | --yaml-file config/dare_merge.yml \ 53 | --exclude-param ".*classifier.*" ".*bias.*" \ 54 | --mask-rate $i \ 55 | --outdir $outdir 56 | 57 | done 58 | 59 | } 60 | 61 | function run_dare_tie(){ 62 | 63 | pos 64 | 65 | for i in 0.7 0.8 0.9; do 66 | 67 | python run_merge.py \ 68 | --models-to-merge ${models_to_merge[@]} \ 69 | --models-name ${models_name[@]} \ 70 | --src-merge ${src_merge[@]} \ 71 | --data-path $data_path \ 72 | --yaml-file config/dare_merge2.yml \ 73 | --exclude-param ".*classifier.*" ".*bias.*" \ 74 | --mask-rate $i \ 75 | --outdir $outdir 76 | 77 | done 78 | 79 | } 80 | 81 | 82 | function run_avg_merge(){ 83 | 84 | pos 85 | 86 | python run_merge.py \ 87 | --models-to-merge ${models_to_merge[@]} \ 88 | --models-name ${models_name[@]} \ 89 | --src-merge ${src_merge[@]} \ 90 | --data-path $data_path \ 91 | --yaml-file config/average_merge.yml \ 92 | --exclude-param ".*classifier.*" ".*bias.*" \ 93 | --outdir $outdir 94 | 95 | 96 | } 97 | 98 | function run_tie(){ 99 | 100 | pos 101 | 102 | 103 | for i in 0.9; do 104 | for j in 0.7; do 105 | 106 | python run_merge.py \ 107 | --models-to-merge ${models_to_merge[@]} \ 108 | --models-name ${models_name[@]} \ 109 | --src-merge ${src_merge[@]} \ 110 | --yaml-file config/ties_merge.yml \ 111 | --data-path $data_path \ 112 | --exclude-param ".*classifier.*" ".*bias.*" \ 113 | --mask-rate $i \ 114 | --scaling $j \ 115 | --outdir $outdir 116 | 117 | done 118 | done 119 | 120 | } 121 | 122 | 123 | function run_task_arith(){ 124 | 125 | pos 126 | 127 | 128 | for j in 0.29; do 129 | 130 | python run_merge.py \ 131 | --models-to-merge ${models_to_merge[@]} \ 132 | --models-name ${models_name[@]} \ 133 | --src-merge ${src_merge[@]} \ 134 | --data-path $data_path \ 135 | --yaml-file config/task_arithmetic.yml \ 136 | --exclude-param ".*classifier.*" ".*bias.*" \ 137 | --scaling $j \ 138 | --outdir $outdir \ 139 | --save-path "outs/task_arithmetic" 140 | 141 | done 142 | 143 | } 144 | 145 | function ft(){ 146 | 147 | pos 148 | 149 | python run_merge.py \ 150 | --models-to-merge ${models_to_merge[@]} \ 151 | --models-name ${models_name[@]} \ 152 | --src-merge ${src_merge[@]} \ 153 | --base-model 'roberta-base' \ 154 | --data-path $data_path \ 155 | --exclude-param ".*classifier.*" ".*bias.*" \ 156 | --outdir "outs/finetuned" 157 | 158 | } 159 | 160 | function pretrain(){ 161 | 162 | pos 163 | 164 | python run_merge.py \ 165 | --models-to-merge 'NONE' \ 166 | --models-name 'NONE' \ 167 | --src-merge ${src_merge[@]} \ 168 | --data-path $data_path \ 169 | --base-model 'roberta-base' \ 170 | --outdir $outdir 171 | 172 | } 173 | 174 | 175 | function twin_merge(){ 176 | 177 | yml='config/twin_merge.yml' 178 | # NOTICE: we only select prefix 179 | select_merge=${select_merge:="8"} 180 | select_twin=${select_twin:="8"} 181 | 182 | if [ $select_merge -eq 1 ]; then 183 | echo "please set \$select_merge > 1" 184 | exit 1 185 | elif [ $select_twin -eq 1 ]; then 186 | datapath="data_glue/new_dataset2.json" 187 | if [ -z $src_twin ];then 188 | echo "please set \$src_twin!" 189 | exit 1 190 | fi 191 | else 192 | datapath=data/test_router.json 193 | src_twin=("${models_name[@]:0:$select_twin}") 194 | src_merge=("${models_name[@]:0:$select_merge}") 195 | fi 196 | 197 | mask_strategy=${mask_strategy:="svd"} 198 | mask_rate=${mask_rate:="0.9"} 199 | echo ">>> use data_path $datapath" 200 | echo ">>> use outdir $outdir" 201 | echo ">>> merged from $select_merge tasks" 202 | echo ">>> use twin vector from $select_twin tasks" 203 | echo ">>> mask_rate $mask_rate; mask_strategy $mask_strategy" 204 | echo ">>> use yml $yml" 205 | 206 | python twin_merge.py \ 207 | --models-to-merge ${models_to_merge[@]} \ 208 | --models-name ${models_name[@]} \ 209 | --data-path $datapath \ 210 | --src-merge ${src_merge[@]} \ 211 | --src-twin ${src_twin[@]} \ 212 | --yaml-file $yml \ 213 | --share-expert outs/task_arithmetic \ 214 | --exclude-param ".*classifier.*" ".*bias.*" \ 215 | --mask-rate $mask_rate \ 216 | --mask-strategy $mask_strategy \ 217 | --outdir $outdir 218 | 219 | } -------------------------------------------------------------------------------- /discriminative/sparsify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def magnitude( 5 | tensor: torch.Tensor, 6 | density: float, 7 | **kwargs, 8 | ) -> torch.Tensor: 9 | """Masks out the smallest values, retaining a proportion of `density`.""" 10 | if density >= 1: 11 | return tensor 12 | if len(tensor.shape) == 1: 13 | # rank=1 14 | return tensor 15 | 16 | k = int(density * tensor.view(-1).shape[0]) 17 | 18 | assert k > 0, "not gonna zero out the whole tensor buddy" 19 | mask = torch.zeros_like(tensor) 20 | w = tensor.abs().view(-1) 21 | if w.device.type == "cpu": 22 | w = w.float() 23 | topk = torch.topk(w, k=k, largest=True) 24 | mask.view(-1)[topk.indices] = 1 25 | return tensor * mask 26 | 27 | 28 | def bernoulli( 29 | tensor: torch.Tensor, 30 | density: float, # 1 - mask_rate (probability of drawing "1") 31 | rescale: bool = True 32 | ) -> torch.Tensor: 33 | if density >= 1: 34 | return tensor 35 | if density <= 0: 36 | return torch.zeros_like(tensor) 37 | if len(tensor.shape) == 1: 38 | # rank=1 39 | return tensor 40 | 41 | # mask = 1 - torch.bernoulli( 42 | # torch.full_like(input=tensor, fill_value=1 - density) 43 | # ) 44 | mask = torch.bernoulli( 45 | torch.full_like(input=tensor, fill_value=density).float() 46 | ) 47 | 48 | res = tensor * mask 49 | if rescale: 50 | res *= 1 / density 51 | return res 52 | 53 | def svd( 54 | tensor: torch.Tensor, 55 | density: float, 56 | **kwargs, 57 | ): 58 | if density >= 1: 59 | return tensor 60 | if density <= 0: 61 | return torch.zeros_like(tensor) 62 | if len(tensor.shape) == 1: 63 | # rank=1 64 | return tensor 65 | 66 | # S 按降序返回 67 | # U, S, V = torch.svd(tensor) 68 | # S = (S >= S[int(len(S) * density)]) * S 69 | # res = U @ torch.diag(S) @ V.T 70 | 71 | # `torch.linalg.svd()`: good for dense matrix 72 | # `torch.svd()`: deprecated 73 | # `torch.svd_lowrank()`: good for huge sparse matrix 74 | driver = None 75 | if tensor.is_cuda: 76 | driver = 'gesvda' 77 | 78 | U, S, Vh = torch.linalg.svd(tensor, full_matrices=True, driver=driver) 79 | new_rank = int(density * len(S)) 80 | U, S, Vh = U[:, :new_rank], S[:new_rank], Vh[:new_rank, :] 81 | res = U @ torch.diag(S) @ Vh 82 | return res -------------------------------------------------------------------------------- /discriminative/twin_merge.py: -------------------------------------------------------------------------------- 1 | from telnetlib import PRAGMA_HEARTBEAT 2 | import torch 3 | from collections import defaultdict, OrderedDict 4 | import tqdm 5 | import re 6 | import torch.nn as nn 7 | import copy 8 | import sparsify 9 | import sys 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 12 | import os 13 | import functools 14 | from collections import defaultdict, OrderedDict 15 | from param import param 16 | import torch.nn.functional as F 17 | import torch 18 | from collections import defaultdict 19 | import numpy as np 20 | from merge import MergingMethod 21 | import inspect 22 | import datasets 23 | import pandas as pd 24 | import utils 25 | 26 | args = None 27 | DEVICE='cuda:0' 28 | 29 | @torch.inference_mode() 30 | def extract_twin_vector( 31 | model: nn.Module, 32 | merged_model: param, 33 | mask_rate: float, 34 | mask_strategy: str = 'magnitude', 35 | ): 36 | # \theta^t - \theta* 37 | twin_vector = (model - merged_model).map( 38 | lambda n,p: getattr(sparsify, mask_strategy)( 39 | p, 40 | 1 - mask_rate, 41 | ), 42 | desc=mask_strategy 43 | ) 44 | return twin_vector 45 | 46 | @torch.inference_mode() 47 | def run_twin_vector( 48 | args, 49 | ): 50 | import eval 51 | 52 | if len(args.src_merge) == 1: 53 | raise Exception('parameter Error') 54 | 55 | if args.exclude_param and len(args.exclude_param): 56 | filter_func = lambda n,p : not any([ 57 | re.match(exclude_pattern, n) 58 | for exclude_pattern in args.exclude_param 59 | ]) 60 | 61 | # \theta_t 62 | # for classifier head (placeholder) 63 | models_finetuned = { 64 | name: utils.load_classifier( 65 | eval.model_path_template.format(name=name) 66 | ).to(DEVICE) 67 | for name in args.models_name 68 | } 69 | # \theta_* 70 | models_to_merge = [ 71 | models_finetuned[name].to(DEVICE) 72 | for name in args.src_merge 73 | ] 74 | # \theta_0 75 | base_model = utils.load_classifier(args.base_model).to(DEVICE) 76 | 77 | args.base_model = param(base_model) 78 | args.models_to_merge = [ param(m) for m in models_to_merge ] 79 | 80 | # exclude_param 81 | for model in args.models_to_merge: 82 | model.filter(filter_func) 83 | args.base_model.filter(filter_func) 84 | 85 | if not args.share_expert: 86 | # get merged model first 87 | merger = MergingMethod(**args) 88 | merge_method = getattr(merger, args.merge_method) 89 | merged_param = merge_method(**args) 90 | else: 91 | merged_param = utils.load_classifier(args.share_expert).to(DEVICE) 92 | merged_param = param(merged_param) 93 | merged_param.filter(filter_func) 94 | 95 | # merged_param 96 | metrics = { 97 | "_method": args.merge_method, 98 | **{ 99 | f"_{k}": args[k] for k in [ 'mask_rate', 'mask_strategy', 'scaling', 'mask_scale', 'src_twin', 'src_merge' ] 100 | } 101 | } 102 | metrics['_mask_rate'] = 100*float(f"{metrics['_mask_rate']:.4f}") 103 | metrics['_src_twin'] = '+'.join(metrics['_src_twin']) 104 | metrics['_src_merge'] = '+'.join(metrics['_src_merge']) 105 | 106 | # tv_t 107 | twin_vector = {} 108 | data_id = None 109 | for i, data_name in enumerate(args.src_twin): 110 | data_id = eval.glue_data_id_map[data_name] 111 | twin_vector[data_id] = extract_twin_vector( 112 | model=models_to_merge[i], 113 | merged_model=merged_param, 114 | mask_rate=args.mask_rate, 115 | mask_strategy=args.mask_strategy, 116 | ) 117 | 118 | if len(args.src_twin) == 1: 119 | _infer_param = merged_param + twin_vector[data_id] 120 | 121 | data = utils.from_json(args.data_path) 122 | eval_pred = defaultdict(lambda: defaultdict(list)) 123 | for data_item in tqdm.tqdm(data, desc='infer glue'): 124 | data_id = data_item['dataset_ids'] 125 | data_name = list(eval.glue_data_id_map.keys())[data_id] 126 | 127 | if len(args.src_twin) != 1: 128 | 129 | tv_weights = F.softmax(torch.tensor(data_item['router_prob']), dim=0) 130 | 131 | assert len(tv_weights) == len(args.src_twin) 132 | 133 | twin_sum = sum([ w*tv for tv, w in zip(twin_vector.values(),tv_weights) ]) 134 | _infer_param = merged_param + twin_sum 135 | 136 | # print([ (n,p.dtype) for n,p in merged_params.items() ]) 137 | 138 | def calculate_logits(data_item): 139 | model = models_finetuned[data_name] 140 | score = torch.func.functional_call( 141 | model, 142 | _infer_param.param_dict, 143 | args=( 144 | torch.tensor(data_item['input_ids']).unsqueeze(0).to(model.device), 145 | torch.tensor(data_item['attention_mask']).unsqueeze(0).to(model.device), 146 | ), 147 | ).logits.cpu().numpy() 148 | 149 | return score 150 | 151 | eval_pred[data_name]['predictions'].append(calculate_logits(data_item)) 152 | eval_pred[data_name]['label_ids'].append(data_item['label']) 153 | 154 | for data_name in eval_pred.keys(): 155 | 156 | ans = eval.compute_single_metrics( 157 | utils.SimpleNamespace( 158 | predictions=np.concatenate(eval_pred[data_name]['predictions']), 159 | label_ids=np.array(eval_pred[data_name]['label_ids']) 160 | ), data_name 161 | )['averaged_scores'] 162 | metrics[data_name] = 100*float(f"{ans:.4f}") 163 | 164 | # TODO 165 | merged_res = 'outs/finetuned/results.csv' 166 | assert os.path.exists(merged_res), 'please run `ft` in `scripts.sh` first to run evaluation' 167 | df = pd.read_csv(merged_res) 168 | col = ["cola", "sst2", "mrpc", "stsb", "qqp", "qnli", "mnli", "rte"] 169 | norm = {k: 0 for k in col} 170 | for c in col: 171 | expert_path = f'../roberta/{c}/roberta-base_lr1e-05' 172 | norm[c] = df[df['model'] == expert_path][c].values[0] 173 | 174 | col = ["cola", "sst2", "mrpc", "stsb", "qqp", "qnli", "mnli", "rte"] 175 | metrics['avg'] = 0 176 | for c in col: 177 | metrics[c] = metrics[c] / norm[c] * 100 178 | metrics['avg'] += metrics[c] / len(col) 179 | 180 | # 3. Save excel in the order: 181 | utils.save_excel(metrics, args.outdir) 182 | 183 | def run_merge( 184 | *, 185 | # terminal 送的参数最高优先级,按是否为None判断 186 | models_to_merge: list[str], 187 | models_name: list[str], 188 | data_path: str, 189 | src_merge: list[str], 190 | src_twin: list[str], 191 | yaml_file: str = None, 192 | model_placeholder: str = None, 193 | model_loader: str = None, 194 | eval_func: str = None, 195 | dtype: str = None, 196 | exclude_param: list[str] = None, 197 | load_head: bool = None, 198 | seed: int=10, 199 | base_model: str = 'roberta-base', 200 | # for task-arithmetic: 201 | scaling: float = None, 202 | # for dare-merge: 203 | mask_rate: float = None, 204 | mask_scale: float = None, 205 | mask_strategy: str = None, 206 | outdir: str = None, 207 | share_expert: str = None, 208 | ): 209 | 210 | global args 211 | import inspect 212 | keys, _, _, values = inspect.getargvalues(inspect.currentframe()) 213 | 214 | utils.fix_seed(seed) 215 | os.makedirs(outdir, exist_ok=True) 216 | 217 | merge_config = utils.from_yaml(yaml_file) 218 | args = { 219 | k: values.get(k, merge_config.get(k)) 220 | for k in set(keys).union(merge_config) 221 | } 222 | args = { 223 | k: merge_config.get(k, None) 224 | if args[k] is None else args[k] 225 | for k in args.keys() 226 | } 227 | args = utils.SimpleNamespace(**args) 228 | print('>>> args\n', args) 229 | 230 | if args.scaling is not None and len(args.scaling) == 1: 231 | args.scaling = args.scaling[0] 232 | 233 | run_twin_vector( 234 | args, 235 | ) 236 | 237 | if __name__ == '__main__': 238 | import defopt 239 | defopt.run(run_merge) -------------------------------------------------------------------------------- /discriminative/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import pandas as pd 5 | import os 6 | from datasets import Dataset 7 | import sys 8 | import json 9 | from tabulate import tabulate 10 | import yaml 11 | import types 12 | import functools 13 | import torch 14 | from typing import Iterable, Optional 15 | import datasets 16 | from datasets import concatenate_datasets, load_dataset 17 | from torch import nn 18 | import torch 19 | import torch.nn.functional as F 20 | import transformers 21 | import inspect 22 | from torch.utils.data import Dataset 23 | from functools import wraps 24 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 25 | import os 26 | import torch 27 | import transformers 28 | 29 | to_np = lambda x: x.data.cpu().numpy() 30 | 31 | 32 | def args_inspector(func): 33 | 34 | @wraps(func) 35 | def inner(*args, **kwargs): 36 | params = list(inspect.signature(func).parameters.keys()) 37 | kwargs = {k: kwargs[k] for k in params if (k != 'self')} 38 | return func(*args, **kwargs) 39 | 40 | return inner 41 | 42 | 43 | def deprecated(func): 44 | 45 | @wraps(func) 46 | def new_func(*args, **kwargs): 47 | print("Call to deprecated function {}.".format(func.__name__)) 48 | return func(*args, **kwargs) 49 | 50 | return new_func 51 | 52 | 53 | class SimpleNamespace: 54 | 55 | def __init__(self, /, **kwargs): 56 | self.__dict__.update(kwargs) 57 | 58 | def __getitem__(self, item): 59 | return self.__dict__[item] 60 | 61 | def __len__(self): 62 | return len(self.__dict__) 63 | 64 | def items(self): 65 | return self.__dict__.items() 66 | 67 | def keys(self): 68 | return self.__dict__.keys() 69 | 70 | def __contains__(self, item): 71 | return item in self.keys() 72 | 73 | def values(self): 74 | return self.__dict__.values() 75 | 76 | def __repr__(self): 77 | items = (f"{k}={v!r}" for k, v in self.__dict__.items()) 78 | return "{}({})".format(type(self).__name__, ", ".join(items)) 79 | 80 | def __eq__(self, other): 81 | if isinstance(self, SimpleNamespace) and isinstance(other, SimpleNamespace): 82 | return self.__dict__ == other.__dict__ 83 | return NotImplemented 84 | 85 | 86 | def fix_seed(seed: int = 0): 87 | import random, numpy as np 88 | random.seed(seed) 89 | np.random.seed(seed) 90 | torch.manual_seed(seed) 91 | if torch.cuda.is_available(): 92 | torch.cuda.manual_seed_all(seed) 93 | torch.backends.cudnn.deterministic = True 94 | torch.backends.cudnn.benchmark = False 95 | 96 | 97 | def to_markdown(data: pd.DataFrame, path=None): 98 | markdown_table = tabulate(data, headers='keys', tablefmt='pipe') 99 | print(markdown_table) 100 | if path is not None: 101 | print(markdown_table, file=open(path, 'w')) 102 | 103 | 104 | def from_yaml(path, ): 105 | with open(path, "r", encoding="utf-8") as file: 106 | data = yaml.load(file, yaml.SafeLoader) 107 | return data 108 | 109 | 110 | def to_jsonl(data, path, mode='w'): 111 | if not isinstance(data, list): 112 | data = [data] 113 | with open(path, mode) as f: 114 | for line in data: 115 | f.write(json.dumps(line, ensure_ascii=False) + '\n') 116 | 117 | 118 | def from_jsonc(path): 119 | import jstyleson 120 | return jstyleson.load(open(path)) 121 | 122 | 123 | def from_json(path): 124 | return json.load(open(path)) 125 | 126 | 127 | def from_jsonl(path): 128 | return [json.loads(line) for line in open(path, 'r', encoding='utf8')] 129 | 130 | 131 | def to_json(data, path, mode='w'): 132 | if mode == 'a' and os.path.exists(path): 133 | old_data = from_json(path) 134 | data = old_data + data 135 | json.dump(data, open(path, 'w', encoding='utf8'), ensure_ascii=False) 136 | 137 | 138 | # next(iter(data.items()))[1].keys() 139 | def to_excel(data, path, index=None, columns=None, mode='w'): 140 | 141 | if columns is None: 142 | # text_df(index, 'b') 143 | # NOTE : { 'a':{'x''y'},'b':{'x''y'}} => rows: x,y columns: a,b 144 | df = pd.DataFrame(data, index=index).T 145 | if mode == 'a': 146 | if os.path.exists(path): 147 | previous = pd.read_csv(path, index_col=0) 148 | df = pd.concat([previous, df]) 149 | df.to_excel(path, index=True) 150 | return 151 | df.to_csv(path, index=True) 152 | # given column 153 | elif index is None: 154 | df = pd.DataFrame(data, columns=columns) 155 | 156 | df.to_excel(path, index=False) 157 | 158 | 159 | def from_excel(path): 160 | df = pd.read_excel(path).to_dict('records') 161 | return df 162 | 163 | 164 | def save_excel(data, out_path): 165 | # save excel 166 | columns = sorted(list(data.keys())) 167 | df = pd.DataFrame(data, index=[0]).reindex(columns=columns) 168 | os.makedirs(out_path, exist_ok=True) 169 | xlsx_path = os.path.join(out_path, 'results.csv') 170 | md_path = os.path.join(out_path, 'results.md') 171 | 172 | if os.path.exists(xlsx_path): 173 | previous = pd.read_csv(xlsx_path, index_col=0) 174 | df = pd.concat([previous, df]) 175 | 176 | df.to_csv(xlsx_path, index=True) 177 | 178 | markdown_table = tabulate(df, headers='keys', tablefmt='pipe') 179 | print(markdown_table) 180 | print(markdown_table, file=open(md_path, 'w')) 181 | 182 | 183 | def reload(): 184 | import utils 185 | import importlib 186 | importlib.reload(utils) 187 | 188 | 189 | def rsetattr(obj, attr, val): 190 | pre, _, post = attr.rpartition('.') 191 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 192 | 193 | 194 | # for `classifier.dense.out_proj` nest subojects / chained properties 195 | def rgetattr(obj, attr, *args): 196 | 197 | def _getattr(obj, attr): 198 | return getattr(obj, attr, *args) 199 | 200 | return functools.reduce(_getattr, [obj] + attr.split('.')) 201 | 202 | 203 | glue_data_keys_map = { 204 | "cola": ("sentence", None), 205 | "sst2": ("sentence", None), 206 | "mrpc": ("sentence1", "sentence2"), 207 | "stsb": ("sentence1", "sentence2"), 208 | "qqp": ("question1", "question2"), 209 | "mnli": ("premise", "hypothesis"), 210 | "qnli": ("question", "sentence"), 211 | "rte": ("sentence1", "sentence2") 212 | } 213 | 214 | glue_data_metrics_map = { 215 | "cola": "matthews_correlation", 216 | "sst2": "accuracy", 217 | "mrpc": "averaged_scores", # average of accuracy and f1 218 | "stsb": "averaged_scores", # average of pearson and spearmanr 219 | "qqp": "averaged_scores", # average of accuracy and f1 220 | "mnli": "accuracy", 221 | "qnli": "accuracy", 222 | "rte": "accuracy" 223 | } 224 | 225 | glue_data_num_labels_map = { 226 | "cola": 2, 227 | "sst2": 2, 228 | "mrpc": 2, 229 | "stsb": 1, 230 | "qqp": 2, 231 | "mnli": 3, 232 | "qnli": 2, 233 | "rte": 2 234 | } 235 | 236 | glue_data_id_map = { 237 | "cola": 0, 238 | "sst2": 1, 239 | "mrpc": 2, 240 | "stsb": 3, 241 | "qqp": 4, 242 | "mnli": 5, 243 | "qnli": 6, 244 | "rte": 7 245 | } 246 | 247 | 248 | # cache_dir = "/data/.cache" 249 | cache_dir = None 250 | 251 | from torch.utils.data import Subset, Dataset 252 | 253 | 254 | class GLUEDataLoader: 255 | 256 | def __init__(self, tokenizer: transformers.AutoTokenizer): 257 | """ 258 | Dataloader for GLUE datasets. 259 | :param tokenizer: AutoTokenizer, tokenizer 260 | :return: 261 | """ 262 | self.tokenizer = tokenizer 263 | 264 | def load_dataset( 265 | self, dataset_name: str, train_split_ratio_for_val: float = 0.1, max_seq_length: int = 128 266 | ): 267 | """ 268 | load GLUE dataset based on dataset_name 269 | :param dataset_name: str, name of the dataset to load 270 | :param train_split_ratio_for_val: float, split ratio of train data for validation, 271 | since the test data of GLUE is unavailable, we need to use a part of the original train data for validation (select the best model), 272 | and we use the original validation data for testing 273 | :param max_seq_length: int, maximal input length of examples in the dataset 274 | :return: 275 | """ 276 | dataset = load_dataset(path="glue", name=dataset_name, cache_dir=cache_dir) 277 | #dataset = load_dataset(path=os.path.join(cache_dir, "glue"), name=dataset_name) 278 | 279 | # get the key of datasets 280 | sentence1_key, sentence2_key = glue_data_keys_map[dataset_name] 281 | 282 | # set batched to True to process all examples together, will have keys like "input_ids", "attention_mask" 283 | dataset = dataset.map( 284 | lambda examples: self.tokenizer( 285 | text=examples[sentence1_key], 286 | text_pair=examples[sentence2_key] if sentence2_key else None, 287 | max_length=max_seq_length, 288 | truncation=True 289 | ), 290 | num_proc=os.cpu_count(), 291 | batched=True 292 | ) 293 | # add the "dataset_ids" column for each example 294 | dataset = dataset.map( 295 | lambda x: {"dataset_ids": glue_data_id_map[dataset_name]}, num_proc=os.cpu_count() 296 | ) 297 | 298 | permuted_indices = [ 299 | i for i in range(len(dataset["train"])) 300 | ] #np.random.RandomState(seed=0).permutation(len(dataset["train"])).tolist() 301 | num_train_data = int((1 - train_split_ratio_for_val) * len(dataset["train"])) 302 | train_dataset = Subset(dataset=dataset["train"], indices=permuted_indices[:num_train_data]) 303 | # use a part of the original train data for validation 304 | val_dataset = Subset(dataset=dataset["train"], indices=permuted_indices[num_train_data:]) 305 | test_dataset = dataset["validation_matched"] if dataset_name == "mnli" else dataset[ 306 | "validation"] 307 | num_labels = glue_data_num_labels_map[dataset_name] 308 | 309 | return train_dataset, val_dataset, test_dataset, num_labels 310 | 311 | 312 | def reload(): 313 | import utils 314 | import importlib 315 | importlib.reload(utils) 316 | 317 | def rsetattr(obj, attr, val): 318 | pre, _, post = attr.rpartition('.') 319 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 320 | 321 | # for `classifier.dense.out_proj` nest subojects / chained properties 322 | def rgetattr(obj, attr, *args): 323 | 324 | def _getattr(obj, attr): 325 | return getattr(obj, attr, *args) 326 | 327 | return functools.reduce(_getattr, [obj] + attr.split('.')) 328 | 329 | 330 | def load_classifier(model_name: str, dtype=torch.float32, save_classifier_head=True): 331 | model = AutoModelForSequenceClassification.from_pretrained( 332 | model_name, torch_dtype=dtype, device_map="cpu", 333 | ) 334 | if save_classifier_head: 335 | if not os.path.exists(f'{model_name}'): 336 | print(f' >>> skip save classifier head for {model_name}') 337 | return model 338 | 339 | if os.path.exists(f'{model_name}/classifier_head.pt'): 340 | print(f' >>> skip save classifier head for {model_name}') 341 | return model 342 | 343 | print(f' >>> save classifier head for {model_name} in {model_name}/classifier_head.pt ') 344 | torch.save(model.classifier, f'{model_name}/classifier_head.pt') 345 | 346 | return model -------------------------------------------------------------------------------- /generative/config/ada_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: ada_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto 9 | ada_type: auto -------------------------------------------------------------------------------- /generative/config/average_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: average_merging 2 | models_to_merge: auto 3 | models_name: auto 4 | exclude_param: auto 5 | model_loader: auto 6 | dtype: auto -------------------------------------------------------------------------------- /generative/config/dare_mask.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_mask 2 | base_model: meta-llama/Llama-2-13b-hf 3 | finetuned_model: WizardLM/WizardMath-13B-V1.0 4 | dtype: float16 5 | rescale: true 6 | mask_rate: 0.7 7 | mask_strategy: bernoulli 8 | weight_format: delta 9 | scaling: 1.0 -------------------------------------------------------------------------------- /generative/config/dare_mask2.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_mask 2 | base_model: meta-llama/Llama-2-13b-hf 3 | finetuned_model: WizardLM/WizardLM-13B-V1.2 4 | dtype: float16 5 | rescale: true 6 | mask_rate: 0.7 7 | mask_strategy: bernoulli 8 | weight_format: delta 9 | scaling: 1.0 -------------------------------------------------------------------------------- /generative/config/dare_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto 8 | 9 | # for dare_mask 10 | rescale: true 11 | mask_rate: 0.7 12 | mask_strategy: bernoulli 13 | weight_format: delta 14 | mask_scale: 1.0 15 | 16 | # for merge 17 | second_merge_method: task_arithmetic 18 | second_merge_config: 19 | scaling: 0.7 20 | -------------------------------------------------------------------------------- /generative/config/dare_merge2.yml: -------------------------------------------------------------------------------- 1 | merge_method: dare_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto 8 | 9 | # for dare_mask 10 | rescale: true 11 | mask_rate: 0.7 12 | mask_strategy: bernoulli 13 | weight_format: delta 14 | mask_scale: 1.0 15 | 16 | # for merge 17 | second_merge_method: ties_merge 18 | second_merge_config: 19 | mask_rate: 0.7 20 | scaling: 0.9 21 | -------------------------------------------------------------------------------- /generative/config/fisher_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /generative/config/regmean_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /generative/config/task_arithmetic.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: 0.7 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /generative/config/task_arithmetic_plus.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic_plus 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | scaling: auto 5 | models_name: auto 6 | exclude_param: auto 7 | model_loader: auto 8 | dtype: auto -------------------------------------------------------------------------------- /generative/config/task_arithmetic_search.yml: -------------------------------------------------------------------------------- 1 | merge_method: task_arithmetic_search 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto -------------------------------------------------------------------------------- /generative/config/ties_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: ties_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | mask_rate: 0.7 5 | scaling: 1.0 6 | models_name: auto 7 | exclude_param: auto 8 | model_loader: auto 9 | dtype: auto -------------------------------------------------------------------------------- /generative/config/twin_merge.yml: -------------------------------------------------------------------------------- 1 | merge_method: twin_merge 2 | base_model: roberta-base 3 | models_to_merge: auto 4 | models_name: auto 5 | exclude_param: auto 6 | model_loader: auto 7 | dtype: auto 8 | 9 | # for dare_mask 10 | # rescale: true 11 | # mask_rate: auto 12 | # mask_strategy: bernoulli 13 | # weight_format: delta 14 | # mask_scale: 0.7 15 | 16 | # for merge 17 | second_merge_method: task_arithmetic 18 | second_merge_config: 19 | scaling: 0.3 20 | -------------------------------------------------------------------------------- /generative/eval.sh: -------------------------------------------------------------------------------- 1 | 2 | export PYTHONPATH=. 3 | 4 | source eval_scripts.sh 5 | 6 | if [ ! -f "outs/test.json" ]; then 7 | # 1.1 get test data first 8 | gen_eval_data 9 | fi 10 | 11 | if [ ! -d "outs/qwen_merged" ]; then 12 | # 1.2 get shared expert 13 | bash run_merge.sh 14 | fi 15 | 16 | if [ ! -f "data/test_router.json" ]; then 17 | # 2. get router train data 18 | torchrun --master-port 23451 --nnodes=1 --nproc_per_node=8 \ 19 | router.py \ 20 | --no-train \ 21 | --shared-expert outs/qwen_merged/merged 22 | # 3. train router 23 | python3 router.py --train 24 | fi 25 | 26 | # 4. run twin-merging 27 | CUDA_VISIBLE_DEVICES=0 run_twin -------------------------------------------------------------------------------- /generative/eval_merge.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | import os 3 | import time 4 | import torch 5 | from peft import PeftModel, PeftConfig 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 7 | from transformers import BitsAndBytesConfig 8 | import torch 9 | import helm_utils.lora_utils as lora_utils 10 | from helm_utils.prompter import * 11 | from helm_utils.helm_type import * 12 | from helm_utils.lora_utils import hack_qwen_for_moe 13 | import random 14 | import torch 15 | import numpy as np 16 | import utils 17 | from param import param 18 | from merge import LoraMergingMethod 19 | import sparsify 20 | from peft import LoraConfig,get_peft_model 21 | import qwen_task 22 | import json 23 | import torch.nn.functional as F 24 | from safetensors.torch import load_file 25 | 26 | args = utils.SimpleNamespace( 27 | seed=10, 28 | base_model='Qwen/Qwen-14B', 29 | models_to_merge=[ 30 | '../qwen/qwen-mmlu', 31 | '../qwen/qwen-truthfulqa', 32 | '../qwen/qwen-bbq', 33 | '../qwen/qwen-cnn', 34 | ], 35 | models_name=[ 36 | 'mmlu', 37 | 'truthfulqa', 38 | 'bbq', 39 | 'cnn-dm', 40 | ], 41 | data_path = None, 42 | src_merge = [ 43 | 'mmlu', 44 | 'truthfulqa', 45 | 'bbq', 46 | 'cnn-dm', 47 | ], 48 | src_twin = [], 49 | yaml_file = '../dare/config/twin_merge.yml', 50 | dtype = torch.bfloat16, 51 | exclude_param = None, 52 | new_rank = 8, 53 | # for task-arithmetic: 54 | scaling = None, 55 | # for dare-merge: 56 | mask_rate = None, 57 | mask_scale = None, 58 | mask_strategy = None, 59 | outdir = None, 60 | ) 61 | for k in args.keys(): 62 | if os.getenv(k): 63 | value = os.getenv(k) 64 | if k == 'new_rank': 65 | args.new_rank = int(value) 66 | elif k == 'mask_rate': 67 | args.mask_rate = float(value) 68 | elif k == 'src_twin': 69 | args.src_twin = [value] 70 | elif k == 'src_merge': 71 | args.src_merge = os.getenv('src_merge').split(',') 72 | else: 73 | print(f'>>> set {k}') 74 | args[k] = value 75 | if os.getenv('select_merge') and int(os.getenv('select_merge'))> 1: 76 | args.src_merge = args.src_merge[:int(os.getenv('select_merge'))] 77 | 78 | merge_config = utils.from_yaml(args.yaml_file) 79 | for k in merge_config: 80 | if not hasattr(args,k) or args[k] is None: 81 | args[k] = merge_config.get(k) 82 | print('>>> args\n', args) 83 | 84 | utils.fix_seed(args.seed) 85 | hack_qwen_for_moe() 86 | model = AutoModelForCausalLM.from_pretrained( 87 | args.base_model, 88 | trust_remote_code=True, 89 | device_map={"": 0}, 90 | torch_dtype=args.dtype, 91 | # quantization_config=BitsAndBytesConfig( 92 | # load_in_4bit=True, 93 | # bnb_4bit_use_double_quant=True, 94 | # bnb_4bit_quant_type="nf4", 95 | # bnb_4bit_compute_dtype=args.dtype, 96 | # llm_int8_has_fp16_weight=True, 97 | # ) 98 | ) 99 | print(f'>>> loading {args.base_model} finished') 100 | peft_config = LoraConfig( 101 | r=32, 102 | lora_alpha=16, 103 | lora_dropout= 0.05, 104 | bias="none", 105 | task_type="CAUSAL_LM", 106 | target_modules=[ 107 | "w2", 108 | "c_proj", 109 | "c_attn", 110 | "w1" 111 | ], 112 | modules_to_save=None, 113 | ) 114 | # a placeholder 115 | if os.getenv('pretrained'): 116 | # run pretrained 117 | pass 118 | else: 119 | 120 | def load(model_path): 121 | try: 122 | ans = torch.load( 123 | os.path.join(model_path, 'adapter_model.bin') 124 | ) 125 | except: 126 | ans = load_file(os.path.join(model_path, 'adapter_model.safetensors')) 127 | return ans 128 | 129 | model = get_peft_model(model, peft_config, adapter_name='merged') 130 | # load lora 131 | lora_finetuned = { 132 | n: load(model_path) 133 | for model_path, n in zip(args.models_to_merge, args.models_name) 134 | } 135 | if os.getenv('individual'): 136 | # run individual 137 | merged_lora = param(lora_finetuned[os.getenv('individual')]).to('cuda:0') 138 | elif len(args.src_merge): 139 | # run merge 140 | models_to_merge = [ 141 | lora_finetuned[name] 142 | for name in args.src_merge 143 | ] 144 | args.models_to_merge = [ param(m).to('cuda:0') for m in models_to_merge ] 145 | merger = LoraMergingMethod(**args) 146 | merge_method = getattr(merger, args.merge_method) 147 | merged_lora = merge_method(**args) 148 | else: 149 | raise NotImplementedError 150 | 151 | tokenizer = AutoTokenizer.from_pretrained( 152 | 'Qwen/Qwen-14B', 153 | add_special_tokens=True, 154 | trust_remote_code=True, 155 | padding='left', 156 | ) 157 | tokenizer.pad_token_id=tokenizer.eos_token_id 158 | 159 | # router_data = utils.from_json(args.data_path) 160 | app = FastAPI() 161 | 162 | @torch.inference_mode() 163 | @app.post("/process") 164 | async def process_request(input_data: ProcessRequest) -> ProcessResponse: 165 | 166 | global merged_lora 167 | 168 | if input_data.seed is not None: 169 | torch.manual_seed(input_data.seed) 170 | 171 | print(input_data.prompt) 172 | # data type 173 | data_type = get_data_type(input_data.prompt) 174 | print(data_type) 175 | 176 | # write back parameter 177 | if not os.getenv('pretrained'): 178 | for n, p in merged_lora.items(): 179 | n = n.replace('lora_B', 'lora_B.merged') 180 | n = n.replace('lora_A', 'lora_A.merged') 181 | utils.rsetattr(model, n, torch.nn.Parameter(p,requires_grad=False)) 182 | 183 | encoded = tokenizer(input_data.prompt, return_tensors="pt") 184 | prompt_length = encoded["input_ids"][0].size(0) 185 | t0 = time.perf_counter() 186 | encoded = {k: v.to("cuda") for k, v in encoded.items()} 187 | with torch.no_grad(): 188 | outputs = model.generate( 189 | **encoded, 190 | max_new_tokens=input_data.max_new_tokens, 191 | do_sample=True, 192 | temperature=input_data.temperature, 193 | top_k=input_data.top_k, 194 | return_dict_in_generate=True, 195 | output_scores=True, 196 | pad_token_id=0, 197 | ) 198 | t = time.perf_counter() - t0 199 | if not input_data.echo_prompt: 200 | output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True) 201 | else: 202 | output = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 203 | 204 | if len(output.strip()) == 0: 205 | with torch.no_grad(): 206 | outputs = model.generate( 207 | **encoded, 208 | max_new_tokens=input_data.max_new_tokens, 209 | do_sample=True, 210 | temperature=input_data.temperature, 211 | top_k=input_data.top_k, 212 | return_dict_in_generate=True, 213 | output_scores=True, 214 | pad_token_id=0, 215 | repetition_penalty=0.6, 216 | ) 217 | output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True) 218 | 219 | output = output.split('###\nArticle:')[0] 220 | output = output.split('###')[0] 221 | output = output.strip("<|endoftext|>").strip("") 222 | if 'The answer is' in output: 223 | output = output.split('\n\nQ:')[0] 224 | print(output) 225 | 226 | tokens_generated = outputs.sequences[0].size(0) - prompt_length 227 | generated_tokens = [] 228 | log_probs = torch.log(torch.stack(outputs.scores, dim=1).softmax(-1)) 229 | gen_sequences = outputs.sequences[:, encoded["input_ids"].shape[-1]:] 230 | gen_logprobs = torch.gather(log_probs, 2, gen_sequences[:, :, None]).squeeze(-1) 231 | top_indices = torch.argmax(log_probs, dim=-1) 232 | top_logprobs = torch.gather(log_probs, 2, top_indices[:,:,None]).squeeze(-1) 233 | top_indices = top_indices.tolist()[0] 234 | top_logprobs = top_logprobs.tolist()[0] 235 | for t, lp, tlp in zip(gen_sequences.tolist()[0], gen_logprobs.tolist()[0], zip(top_indices, top_logprobs)): 236 | idx, val = tlp 237 | tok_str = tokenizer.decode(idx) 238 | token_tlp = {tok_str: val} 239 | generated_tokens.append( 240 | Token(text=tokenizer.decode(t), logprob=lp, top_logprob=token_tlp) 241 | ) 242 | logprob_sum = gen_logprobs.sum().item() 243 | 244 | return ProcessResponse( 245 | text=output, tokens=generated_tokens, logprob=logprob_sum, request_time=t 246 | ) 247 | 248 | @app.post("/tokenize") 249 | async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse: 250 | t0 = time.perf_counter() 251 | encoded = tokenizer( 252 | input_data.text 253 | ) 254 | t = time.perf_counter() - t0 255 | tokens = encoded["input_ids"] 256 | return TokenizeResponse(tokens=tokens, request_time=t) -------------------------------------------------------------------------------- /generative/eval_scripts.sh: -------------------------------------------------------------------------------- 1 | 2 | SUITE="merged" 3 | CONF=moe2 4 | OUT=outs 5 | name=CONF 6 | 7 | 8 | function run_finetune(){ 9 | PORT=$(shuf -i7000-9000 -n1) 10 | 11 | server_command="source activate merging;\ 12 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 13 | individual=$src \ 14 | uvicorn eval_merge:app --port $PORT | tee $name.log" 15 | 16 | helm_command="SHOW=1 \ 17 | OUTPUT=$OUT \ 18 | NAME=$src \ 19 | SUITE=$SUITE \ 20 | CONF=$CONF \ 21 | PORT=$PORT \ 22 | bash helm.sh " 23 | 24 | tmux_name=qwen-$src 25 | tmux new-session -ds $tmux_name 26 | tmux list-panes -t $tmux_name:0.1 > /dev/null 2>&1 27 | if [ $? -ne 0 ]; then 28 | 29 | tmux split-window -h -t $tmux_name:0.0 30 | fi 31 | tmux send-keys -t $tmux_name:0.0 "$server_command" C-m 32 | tmux send-keys -t $tmux_name:0.1 "$helm_command" C-m 33 | # tmux a -t $tmux_name 34 | 35 | # kill_process_by_name "eval_merge:app --port $PORT" 36 | } 37 | 38 | 39 | function eval_merged(){ 40 | 41 | PORT=$(shuf -i7000-9000 -n1) 42 | 43 | server_command="source activate merging;\ 44 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 45 | select_merge=$select_merge \ 46 | src_merge=$src_merge \ 47 | base_model=$base_model \ 48 | yaml_file=/home/LeiFeng/lzy/lora-merge/dare/config/$merge.yml \ 49 | uvicorn eval_merge:app --port $PORT | tee $name.log" 50 | 51 | helm_command="SHOW=1 \ 52 | OUTPUT=$OUT \ 53 | NAME=$select_merge-$merge \ 54 | SUITE=$SUITE \ 55 | CONF=$CONF \ 56 | PORT=$PORT \ 57 | bash helm.sh " 58 | 59 | tmux_name=$select_merge-$merge-$info 60 | tmux new-session -ds $tmux_name 61 | tmux list-panes -t $tmux_name:0.1 > /dev/null 2>&1 62 | if [ $? -ne 0 ]; then 63 | 64 | tmux split-window -h -t $tmux_name:0.0 65 | fi 66 | tmux send-keys -t $tmux_name:0.0 "$server_command" C-m 67 | tmux send-keys -t $tmux_name:0.1 "$helm_command" C-m 68 | tmux a -t $tmux_name 69 | 70 | } 71 | 72 | 73 | function run_pretrained(){ 74 | PORT=$(shuf -i7000-9000 -n1) 75 | 76 | server_command="source activate merging;\ 77 | pretrained=1 \ 78 | base_model=$base_model \ 79 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 80 | uvicorn eval_merge:app --port $PORT | tee $name.log" 81 | 82 | # : ${SUITE:=$(date '+%Y-%m-%d-%M')} 83 | base_model="${base_model//\//}" 84 | 85 | helm_command="SHOW=1 \ 86 | OUTPUT=$OUT \ 87 | NAME=$base_model \ 88 | SUITE=$SUITE \ 89 | CONF=$CONF \ 90 | PORT=$PORT \ 91 | bash helm.sh " 92 | 93 | # tmux_name=pretrained 94 | tmux_name=$base_model 95 | tmux new-session -ds $tmux_name 96 | tmux list-panes -t $tmux_name:0.1 > /dev/null 2>&1 97 | if [ $? -ne 0 ]; then 98 | 99 | tmux split-window -h -t $tmux_name:0.0 100 | fi 101 | tmux send-keys -t $tmux_name:0.0 "$server_command" C-m 102 | tmux send-keys -t $tmux_name:0.1 "$helm_command" C-m 103 | tmux a -t $tmux_name 104 | 105 | } 106 | 107 | function gen_eval_data(){ 108 | 109 | PORT=$(shuf -i7000-9000 -n1) 110 | 111 | server_command="source activate merging;\ 112 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 113 | uvicorn gen_eval_data:app --port $PORT | tee $name.log" 114 | 115 | helm_command="SHOW=1 \ 116 | OUTPUT=outs \ 117 | NAME=twin-$new_rank \ 118 | SUITE=$SUITE \ 119 | CONF=$CONF \ 120 | PORT=$PORT \ 121 | bash helm.sh " 122 | 123 | tmux_name=helm-$new_rank-$info 124 | echo "$PORT-$tmux_name" 125 | tmux new-session -ds $tmux_name 126 | tmux list-panes -t $tmux_name:0.1 > /dev/null 2>&1 127 | if [ $? -ne 0 ]; then 128 | tmux split-window -h -t $tmux_name:0.0 129 | fi 130 | tmux send-keys -t $tmux_name:0.0 "$server_command" C-m 131 | tmux send-keys -t $tmux_name:0.1 "$helm_command" C-m 132 | tmux a -t $tmux_name 133 | 134 | } 135 | 136 | 137 | function run_twin(){ 138 | 139 | PORT=$(shuf -i7000-9000 -n1) 140 | 141 | data_path="data/test_router.json" 142 | 143 | server_command="source activate merging;\ 144 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 145 | new_rank=$new_rank \ 146 | ablation=$ablation \ 147 | data_path=$data_path \ 148 | src_twin=$src_twin \ 149 | base_model=$base_model \ 150 | yaml_file=$yaml_file \ 151 | src_merge=$src_merge \ 152 | select_twin=$select_twin \ 153 | select_merge=$select_merge \ 154 | uvicorn eval_twin:app --port $PORT | tee $name.log" 155 | 156 | helm_command="SHOW=1 \ 157 | OUTPUT=$OUT \ 158 | NAME=twin-$new_rank \ 159 | SUITE=$SUITE \ 160 | CONF=$CONF \ 161 | PORT=$PORT \ 162 | bash helm.sh " 163 | 164 | tmux_name=qwen-twin-$new_rank-$info 165 | echo "$PORT-$tmux_name" 166 | tmux new-session -ds $tmux_name 167 | tmux list-panes -t $tmux_name:0.1 > /dev/null 2>&1 168 | if [ $? -ne 0 ]; then 169 | tmux split-window -h -t $tmux_name:0.0 170 | fi 171 | tmux send-keys -t $tmux_name:0.0 "$server_command" C-m 172 | tmux send-keys -t $tmux_name:0.1 "$helm_command" C-m 173 | tmux a -t $tmux_name 174 | } -------------------------------------------------------------------------------- /generative/eval_twin.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | import os 3 | import time 4 | import torch 5 | from peft import PeftModel, PeftConfig 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 7 | from transformers import BitsAndBytesConfig 8 | import torch 9 | import helm_utils.lora_utils as lora_utils 10 | from helm_utils.prompter import * 11 | from helm_utils.helm_type import * 12 | from helm_utils.lora_utils import hack_qwen_for_moe 13 | import random 14 | import torch 15 | from collections import defaultdict 16 | import numpy as np 17 | import utils 18 | from param import param 19 | from merge import LoraMergingMethod 20 | import sparsify 21 | from peft import LoraConfig,get_peft_model 22 | import qwen_task 23 | import json 24 | import torch.nn.functional as F 25 | 26 | 27 | qwen_task_id_map={ 28 | "mmlu": 0, 29 | "truthfulqa": 1, 30 | "bbq": 2, 31 | "cnn-dm": 3, 32 | } 33 | 34 | qwen_task_cnt_map={ 35 | 'cnn-dm': 0, 36 | 'mmlu': 0, 37 | 'truthfulqa': 0, 38 | 'bbq': 0, 39 | } 40 | 41 | args = utils.SimpleNamespace( 42 | seed=10, 43 | base_model='Qwen/Qwen-14B', 44 | models_to_merge=[ 45 | '../qwen/qwen-mmlu', 46 | '../qwen/qwen-truthfulqa', 47 | '../qwen/qwen-bbq', 48 | '../qwen/qwen-cnn', 49 | ], 50 | models_name=[ 51 | 'mmlu', 52 | 'truthfulqa', 53 | 'bbq', 54 | 'cnn-dm', 55 | ], 56 | data_path = None, 57 | src_merge = [ 58 | 'mmlu', 59 | 'truthfulqa', 60 | 'bbq', 61 | 'cnn-dm', 62 | ], 63 | src_twin = [ 64 | 'mmlu', 65 | 'truthfulqa', 66 | 'bbq', 67 | 'cnn-dm', 68 | ], 69 | yaml_file = '../dare/config/twin_merge.yml', 70 | dtype = torch.bfloat16, 71 | exclude_param = None, 72 | new_rank = 8, 73 | # for task-arithmetic: 74 | scaling = None, 75 | # for dare-merge: 76 | mask_rate = None, 77 | mask_scale = None, 78 | mask_strategy = None, 79 | outdir = None, 80 | ) 81 | qwen_task_map={ 82 | 'cnn-dm': 0, 83 | 'mmlu': 1, 84 | 'truthfulqa': 2, 85 | 'bbq': 3, 86 | 'gsm8k': 4, 87 | } 88 | # 检查环境变量是否有送 89 | for k in args.keys(): 90 | if os.getenv(k): 91 | value = os.getenv(k) 92 | if k == 'new_rank': 93 | args.new_rank = int(value) 94 | elif k == 'mask_rate': 95 | args.new_rate = float(value) 96 | elif k == 'src_twin': 97 | args.src_twin = value.split(',') 98 | elif k == 'src_merge': 99 | args.src_merge = value.split(',') 100 | else: 101 | print(f'>>> set {k}') 102 | args[k] = value 103 | 104 | if os.getenv('select_merge') and int(os.getenv('select_merge'))> 1: 105 | args.src_merge = args.src_merge[:int(os.getenv('select_merge'))] 106 | if os.getenv('select_twin') and int(os.getenv('select_twin'))> 0: 107 | args.src_twin = args.src_twin[:int(os.getenv('select_twin'))] 108 | # 读取yaml内的参数 109 | merge_config = utils.from_yaml(args.yaml_file) 110 | for k in merge_config: 111 | if not hasattr(args,k) or args[k] is None: 112 | args[k] = merge_config.get(k) 113 | print('>>> args\n', args) 114 | 115 | utils.fix_seed(args.seed) 116 | hack_qwen_for_moe() 117 | model = AutoModelForCausalLM.from_pretrained( 118 | args.base_model, 119 | trust_remote_code=True, 120 | device_map={"": 0}, 121 | torch_dtype=args.dtype, 122 | # quantization_config=BitsAndBytesConfig( 123 | # load_in_4bit=True, 124 | # bnb_4bit_use_double_quant=True, 125 | # bnb_4bit_quant_type="nf4", 126 | # bnb_4bit_compute_dtype=args.dtype, 127 | # llm_int8_has_fp16_weight=True, 128 | # ) 129 | ) 130 | print(f'>>> loading {args.base_model} finished') 131 | peft_config = LoraConfig( 132 | r=32, 133 | lora_alpha=16, 134 | lora_dropout= 0.05, 135 | bias="none", 136 | task_type="CAUSAL_LM", 137 | target_modules=[ 138 | "w2", 139 | "c_proj", 140 | "c_attn", 141 | "w1" 142 | ], 143 | modules_to_save=None, 144 | ) 145 | model = get_peft_model(model, peft_config, adapter_name='merged') 146 | tokenizer = AutoTokenizer.from_pretrained( 147 | args.base_model, 148 | add_special_tokens=True, 149 | trust_remote_code=True, 150 | padding='left', 151 | ) 152 | tokenizer.pad_token_id=tokenizer.eos_token_id 153 | 154 | @torch.inference_mode() 155 | def extract_twin_vector( 156 | lora, 157 | merged: param, 158 | new_rank, 159 | ): 160 | # \theta^t - \theta* 161 | twin_vector = (lora - merged).map( 162 | lambda n,p: sparsify.svd( 163 | p, 164 | density=0.9, # useless 165 | new_rank=new_rank, 166 | ), 167 | desc='svd' 168 | ) 169 | return twin_vector 170 | 171 | # load lora 172 | lora_finetuned = { 173 | n: torch.load( 174 | os.path.join(model_path, 'adapter_model.bin') 175 | ) 176 | for model_path, n in zip(args.models_to_merge, args.models_name) 177 | } 178 | models_to_merge = [ 179 | lora_finetuned[name] 180 | for name in args.src_merge 181 | ] 182 | args.models_to_merge = [ param(m).to('cuda:0') for m in models_to_merge ] 183 | if os.getenv('ablation'): 184 | merged_lora = 0 185 | else: 186 | merger = LoraMergingMethod(**args) 187 | merge_method = getattr(merger, args.merge_method) 188 | merged_lora = merge_method(**args) 189 | twin_vector = {} 190 | data_id = None 191 | for cnt, data_name in enumerate(args.src_twin): 192 | data_id = qwen_task_id_map[data_name] 193 | twin_vector[data_id] = extract_twin_vector( 194 | lora=args.models_to_merge[cnt], 195 | merged=merged_lora, 196 | new_rank=(args.new_rank), 197 | ) 198 | if len(args.src_twin) == 0: 199 | _infer_lora = merged_lora 200 | elif len(args.src_twin) == 1: 201 | _infer_lora = merged_lora + twin_vector[data_id] 202 | 203 | tmp_data = utils.from_json(args.data_path) 204 | router_data = defaultdict(list) 205 | for d in tmp_data: 206 | data_name = list(qwen_task_id_map.keys())[d['dataset_ids']] 207 | router_data[data_name].append(d) 208 | 209 | app = FastAPI() 210 | 211 | 212 | @torch.inference_mode() 213 | @app.post("/process") 214 | async def process_request(input_data: ProcessRequest) -> ProcessResponse: 215 | 216 | global _infer_lora 217 | 218 | if input_data.seed is not None: 219 | torch.manual_seed(input_data.seed) 220 | 221 | print(input_data.prompt) 222 | # data type 223 | data_type = get_data_type(input_data.prompt) 224 | print(data_type) 225 | 226 | if data_type not in list(router_data.keys()): 227 | raise Exception('error!!') 228 | 229 | task_cnt = qwen_task_cnt_map[data_type] 230 | qwen_task_cnt_map[data_type] += 1 231 | data_item = router_data[data_type][task_cnt] 232 | 233 | if data_item['sentence'] != input_data.prompt: 234 | raise Exception('offline data order is wrong!') 235 | 236 | if len(args.src_twin) > 1: 237 | tv_weights = F.softmax(torch.tensor(data_item['router_prob']), dim=0) 238 | if len(tv_weights) != len(args.src_twin): 239 | raise Exception('the arg is wrong!') 240 | twin_sum = sum([ w*tv for tv, w in zip(twin_vector.values(),tv_weights) ]) 241 | _infer_lora = merged_lora + twin_sum 242 | 243 | # write back parameter 244 | for n, p in _infer_lora.items(): 245 | n = n.replace('lora_B', 'lora_B.merged') 246 | n = n.replace('lora_A', 'lora_A.merged') 247 | utils.rsetattr(model, n, torch.nn.Parameter(p, requires_grad=False)) 248 | 249 | encoded = tokenizer(input_data.prompt, return_tensors="pt") 250 | prompt_length = encoded["input_ids"][0].size(0) 251 | t0 = time.perf_counter() 252 | encoded = {k: v.to("cuda") for k, v in encoded.items()} 253 | with torch.no_grad(): 254 | outputs = model.generate( 255 | **encoded, 256 | max_new_tokens=input_data.max_new_tokens, 257 | do_sample=True, 258 | temperature=input_data.temperature, 259 | top_k=input_data.top_k, 260 | return_dict_in_generate=True, 261 | output_scores=True, 262 | pad_token_id=0, 263 | ) 264 | t = time.perf_counter() - t0 265 | if not input_data.echo_prompt: 266 | output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True) 267 | else: 268 | output = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 269 | 270 | if len(output.strip()) == 0: 271 | with torch.no_grad(): 272 | outputs = model.generate( 273 | **encoded, 274 | max_new_tokens=input_data.max_new_tokens, 275 | do_sample=True, 276 | temperature=input_data.temperature, 277 | top_k=input_data.top_k, 278 | return_dict_in_generate=True, 279 | output_scores=True, 280 | pad_token_id=0, 281 | repetition_penalty=0.6, 282 | ) 283 | output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True) 284 | 285 | output = output.split('###\nArticle:')[0] 286 | output = output.split('###')[0] 287 | output = output.strip("<|endoftext|>").strip("") 288 | print(output) 289 | 290 | tokens_generated = outputs.sequences[0].size(0) - prompt_length 291 | generated_tokens = [] 292 | log_probs = torch.log(torch.stack(outputs.scores, dim=1).softmax(-1)) 293 | gen_sequences = outputs.sequences[:, encoded["input_ids"].shape[-1]:] 294 | gen_logprobs = torch.gather(log_probs, 2, gen_sequences[:, :, None]).squeeze(-1) 295 | top_indices = torch.argmax(log_probs, dim=-1) 296 | top_logprobs = torch.gather(log_probs, 2, top_indices[:,:,None]).squeeze(-1) 297 | top_indices = top_indices.tolist()[0] 298 | top_logprobs = top_logprobs.tolist()[0] 299 | for t, lp, tlp in zip(gen_sequences.tolist()[0], gen_logprobs.tolist()[0], zip(top_indices, top_logprobs)): 300 | idx, val = tlp 301 | tok_str = tokenizer.decode(idx) 302 | token_tlp = {tok_str: val} 303 | generated_tokens.append( 304 | Token(text=tokenizer.decode(t), logprob=lp, top_logprob=token_tlp) 305 | ) 306 | logprob_sum = gen_logprobs.sum().item() 307 | 308 | return ProcessResponse( 309 | text=output, tokens=generated_tokens, logprob=logprob_sum, request_time=t 310 | ) 311 | 312 | @app.post("/tokenize") 313 | async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse: 314 | t0 = time.perf_counter() 315 | encoded = tokenizer( 316 | input_data.text 317 | ) 318 | t = time.perf_counter() - t0 319 | tokens = encoded["input_ids"] 320 | return TokenizeResponse(tokens=tokens, request_time=t) -------------------------------------------------------------------------------- /generative/gen_eval_data.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | import time 3 | import torch 4 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 5 | import torch 6 | from helm_utils.prompter import * 7 | from helm_utils.helm_type import * 8 | import random 9 | import torch 10 | import utils 11 | import json 12 | import torch.nn.functional as F 13 | 14 | args = utils.SimpleNamespace( 15 | seed=10, 16 | base_model='Qwen/Qwen-14B', 17 | data_path = None, 18 | dtype = torch.bfloat16, 19 | exclude_param = None, 20 | new_rank = 8, 21 | scaling = None, 22 | mask_rate = None, 23 | mask_scale = None, 24 | mask_strategy = None, 25 | outdir = None, 26 | ) 27 | qwen_task_map={ 28 | 'cnn-dm': 0, 29 | 'mmlu': 1, 30 | 'truthfulqa': 2, 31 | 'bbq': 3, 32 | 'gsm8k': 4, 33 | } 34 | os.makedirs('data', exist_ok=True) 35 | utils.fix_seed(args.seed) 36 | app = FastAPI() 37 | 38 | tokenizer = AutoTokenizer.from_pretrained( 39 | args.base_model, 40 | add_special_tokens=True, 41 | trust_remote_code=True, 42 | padding='left', 43 | ) 44 | tokenizer.pad_token_id=tokenizer.eos_token_id 45 | ans = [] 46 | 47 | @torch.inference_mode() 48 | @app.post("/process") 49 | async def process_request(input_data: ProcessRequest) -> ProcessResponse: 50 | 51 | if input_data.seed is not None: 52 | torch.manual_seed(input_data.seed) 53 | 54 | print(input_data.prompt) 55 | try: 56 | with open("data/test_data.json", "a") as log_file: 57 | # 写入数据到文件 58 | data_type = get_data_type(input_data.prompt) 59 | data = input_data.dict() 60 | data.update({ 61 | "task": data_type, 62 | }) 63 | log_file.write(json.dumps(data) + "\n") 64 | except IOError as e: 65 | print('ERROR') 66 | # data type 67 | # task_cnt = qwen_task.qwen_task_cnt_map[data_type] 68 | # qwen_task.qwen_task_cnt_map[data_type] += 1 69 | # data_item = router_data[data_type][task_cnt] 70 | 71 | return ProcessResponse( 72 | text='', tokens=[], logprob=0, request_time=0 73 | ) 74 | 75 | @app.post("/tokenize") 76 | async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse: 77 | t0 = time.perf_counter() 78 | encoded = tokenizer( 79 | input_data.text 80 | ) 81 | t = time.perf_counter() - t0 82 | tokens = encoded["input_ids"] 83 | return TokenizeResponse(tokens=tokens, request_time=t) -------------------------------------------------------------------------------- /generative/helm.sh: -------------------------------------------------------------------------------- 1 | eval "$(conda shell.bash hook)" 2 | conda activate crfm-helm 3 | set -e pipefail 4 | 5 | cd HELM-Extended-Local 6 | 7 | : ${PORT:=8080} 8 | : ${SUITE:=tmp} 9 | : ${NAME:=moe} 10 | : ${OUTPUT:="outs/metrics"} 11 | 12 | function wait_port_available() { 13 | local port="$1" 14 | while true; do 15 | if nc -z localhost $port; then 16 | echo "$port start" 17 | break 18 | fi 19 | sleep 5 20 | done 21 | sleep 1 22 | } 23 | 24 | CONF=run_moe2.conf 25 | 26 | echo ">>> use $CONF" 27 | 28 | wait_port_available $PORT 29 | 30 | # hack to 127.0.0.1:8080 31 | T=$(date +%s) 32 | 33 | python -m helm.benchmark.run \ 34 | --conf-paths $CONF \ 35 | --suite $SUITE \ 36 | --max-eval-instances 499 \ 37 | --num-threads 1 \ 38 | --name $NAME \ 39 | --url "http://127.0.0.1:$PORT" 40 | 41 | # write output to summary in the end 42 | if [ "$SHOW" ];then 43 | python -m helm.benchmark.presentation.summarize --suite $SUITE 44 | python nips_metrics.py --suite $SUITE --output-path $OUTPUT 45 | fi 46 | -------------------------------------------------------------------------------- /generative/helm_utils/debias.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List,Optional, Dict,Tuple 3 | from collections import defaultdict 4 | from nltk.tokenize import word_tokenize 5 | import utils.word_list as word_list 6 | import numpy as np 7 | import random 8 | import itertools 9 | import nltk 10 | nltk.download('punkt') 11 | 12 | bias_data={ 13 | 'race': word_list.RACE_TO_NAME_LISTS, 14 | 'gender': word_list.GENDER_TO_WORD_LISTS, 15 | 'adjective': word_list.ADJECTIVE_LIST, 16 | 'procession': word_list.PROFESSION_LIST, 17 | } 18 | 19 | ADD_GEN =False 20 | ADD_RACE=False 21 | def demographic_representation( text: str,data_type): 22 | global ADD_RACE 23 | global ADD_GEN 24 | for demographic in [ 25 | 'race', 'gender' 26 | ]: 27 | demographic_word_list = bias_data[demographic] 28 | group_to_count, group_words=[],defaultdict(list) 29 | tokens = [t for t in word_tokenize(text.lower())] 30 | num_group_words = 0 31 | for group in demographic_word_list.keys(): 32 | group_words[group] = [t for t in tokens if t in demographic_word_list[group]] 33 | group_to_count.append(len(group_words[group])) 34 | num_group_words+=(len(demographic_word_list[group])) 35 | 36 | if num_group_words: 37 | if demographic=='race': 38 | # avoid nan score 39 | if ADD_RACE!=data_type: 40 | ADD_GEN = data_type 41 | target_words = [] 42 | for target in [ 43 | 'adjective', 'procession' 44 | ]: 45 | target_words += bias_data[target] 46 | if not len([t for t in tokens if t in target_words] ): 47 | text += random.choice([' worker ',' kind ']) 48 | 49 | cnt = len(group_words['white']) - len(group_words['asian']) 50 | j = random.choice([abs(cnt) -1,abs(cnt)]) if abs(cnt)>=2 else 0 51 | if cnt >= 2: 52 | for i in range(j): 53 | text += random.choice([' liu ', ' soto']) 54 | elif cnt <= -2: 55 | for i in range(j): 56 | text += random.choice([' scott ', ' soto']) 57 | 58 | else: 59 | if ADD_GEN!=data_type: 60 | ADD_RACE =data_type 61 | target_words = [] 62 | for target in [ 63 | 'adjective', 'procession' 64 | ]: 65 | target_words += bias_data[target] 66 | if not len([t for t in tokens if t in target_words] ): 67 | text += random.choice([' worker',' kind ']) 68 | cnt = len(group_words['female']) - len(group_words['male']) 69 | j = random.choice([abs(cnt) -1,abs(cnt)]) if abs(cnt)>=2 else 0 70 | if cnt >= 2: 71 | for i in range(j): 72 | text += ' him ' 73 | elif cnt <= -2: 74 | for i in range(j): 75 | text += ' her ' 76 | return text 77 | 78 | def debias(output, data_type): 79 | 80 | output = output.split(' ') 81 | for i in range(len(output)): 82 | if output[i].lower() in ['she', 'he']: 83 | output[i] = 'she / he' 84 | # if output[i].lower() in ['his', 'her']: 85 | # output[i] = 'his / her' 86 | if output[i].lower() in ['him']: 87 | output[i] = 'him / her' 88 | if output[i].lower() in ['man', 'woman']: 89 | output[i] = 'man / woman' 90 | if output[i].lower() in ['men', 'women']: 91 | output[i] = 'men / women' 92 | if output[i].lower() in ['girl', 'boy']: 93 | output[i] = 'girl / boy' 94 | if output[i].lower() in ['father', 'mother']: 95 | output[i] = 'mother / father' 96 | if output[i].lower() in ['father', 'mother']: 97 | output[i] = 'mother / father' 98 | if output[i].lower() in ['son', 'daughter']: 99 | output[i] = 'daughter / son' 100 | if output[i].lower() in ['sons', 'daughters']: 101 | output[i] = 'daughters / sons' 102 | 103 | output = ' '.join(output) 104 | output = demographic_representation(output, data_type) 105 | return output 106 | 107 | def config_debias(output,data_type): 108 | if data_type in ['mmlu', 'truthfulqa', 'bbq','gsm8k']: 109 | return output 110 | # if output[0] in ["A", "B", "C", "D", "E", "F"]: 111 | # if len(output) == 1 or output[1] in ["."] and len(output) <= 5: 112 | # if "" in output: 113 | # output = output.split('')[0] #output.strip("<|endoftext|>").strip("") 114 | # else: 115 | # pass 116 | # return output 117 | output = output.split('###\nArticle:')[0] 118 | output = debias(output,data_type) 119 | return output 120 | -------------------------------------------------------------------------------- /generative/helm_utils/helm_type.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Dict, Optional 3 | import re 4 | import os 5 | # type definition 6 | class ProcessRequest(BaseModel): 7 | prompt: str 8 | num_samples: int = 1 9 | max_new_tokens: int = 50 10 | top_k: int = 200 11 | temperature: float = 0.8 12 | seed: Optional[int] = None 13 | echo_prompt: Optional[bool] 14 | 15 | class Token(BaseModel): 16 | text: str 17 | logprob: float 18 | top_logprob: Dict[str, float] 19 | 20 | 21 | class ProcessResponse(BaseModel): 22 | text: str 23 | tokens: List[Token] 24 | logprob: float 25 | request_time: float 26 | 27 | 28 | class TokenizeRequest(BaseModel): 29 | text: str 30 | truncation: bool = True 31 | max_length: int = 2048 32 | 33 | 34 | class TokenizeResponse(BaseModel): 35 | tokens: List[int] 36 | request_time: float 37 | 38 | def find_longest_common_prefix(str1, str2): 39 | if not str1 or not str2: return "" 40 | min_length = min(len(str1), len(str2)) 41 | longest_common_prefix = "" 42 | for i in range(min_length): 43 | if str1[i] == str2[i]: 44 | longest_common_prefix += str1[i] 45 | else: break 46 | return longest_common_prefix 47 | 48 | def check_dup_line(prompt): 49 | prompt_line = prompt.split("\n") 50 | if len(prompt_line) == 1: 51 | prompt_line = prompt.split("\\n") 52 | prompt_num = {} 53 | for idx, i in enumerate(prompt_line): 54 | if len(i) == 0: continue 55 | if i not in prompt_num: 56 | prompt_num[i] = [] 57 | prompt_num[i].append(idx) 58 | def check_same_diff(x): 59 | diff = x[1] - x[0] 60 | for i in range(2, len(x)): 61 | if diff != x[i] - x[i-1]: 62 | return False 63 | return True 64 | for k in prompt_num.keys(): 65 | if len(prompt_num[k]) >= 3 and check_same_diff(prompt_num[k]): 66 | return True, prompt_num 67 | return False, prompt_num 68 | 69 | def check_moreqa_pattern(prompt): 70 | pattern_more = r"\n(.+?):" 71 | matches = re.findall(pattern_more, prompt) 72 | if len(matches) == 0: 73 | pattern_more = r"\\n(.+?):" 74 | matches = re.findall(pattern_more, prompt) 75 | match_dict = {} 76 | for match in matches: 77 | if match not in match_dict: 78 | match_dict[match] = 0 79 | match_dict[match] += 1 80 | match_list = [] 81 | for k in match_dict.keys(): 82 | match_list.append((k, match_dict[k])) 83 | match_list = sorted(match_list, key=lambda x: x[1], reverse=True) 84 | if len(match_list) > 1 and match_list[0][1] > 2 and match_list[1][1] > 1: 85 | print("haha", match_list) 86 | return True, match_list 87 | return False, match_list 88 | 89 | def check_fewshot(prompt): 90 | pattern = [("\\nA.", "\\nB."), ("Question", "\\nAnswer:"), ("Q", "\\nA"), ("\nA.", "\nB."), 91 | ("Question", "\nAnswer:"), ("Q", "\nA"), ("Passage", "\nAnswer"), ("Passage", "\\nAnswer"), 92 | ("Article", "\nSummarize"), ("Article", "\\nSummarize"), ("Passage", "\nSentiment"), ("Passage", "\\nSentiment")] 93 | for (a, b) in pattern: 94 | num_a, num_b = prompt.count(a)+prompt.count(a.lower()), prompt.count(b)+prompt.count(b.lower()) 95 | if num_a >= 3 and num_b >= 3: 96 | return True 97 | if check_moreqa_pattern(prompt)[0] == True: 98 | return True 99 | if check_dup_line(prompt)[0] == True: 100 | return True 101 | return False 102 | 103 | def check_code(prompt): #TODO 104 | code_pattern = [r"def\s+(\w+)\s*\((\w+)\)", r"class\s+(\w+)\s*\((\w+)\)", r"Class\s+(\w+)\s*\((\w+)\)"] 105 | for pattern in code_pattern: 106 | match = re.search(pattern, prompt) 107 | if match: 108 | return True 109 | return False 110 | 111 | def get_data_type(prompt): 112 | if check_code(prompt) == True: 113 | return "code" 114 | if check_fewshot(prompt) == False: 115 | return os.environ.get('FALLBACK') 116 | if '###\nArticle' in prompt and 'Summarize the above article in 3 sentences.' in prompt: 117 | return 'cnn-dm' 118 | if '###\nArticle' in prompt and 'Summarize the above article in 1 sentence.' in prompt: 119 | return 'xsum' 120 | if '###\n' in prompt and prompt.count('$') > 1: 121 | return 'MATH' 122 | if '\nQuestion' in prompt and '\nAnswer:' in prompt: 123 | if 'The following are multiple choice questions (with answers) about common sense.' in prompt: 124 | # commonsense 125 | return 'commonsense' 126 | if 'The following are multiple choice questions (with answers) about' in prompt: 127 | return 'mmlu' 128 | if 'Passage:' in prompt and 'The following are multiple choice questions (with answers).' in prompt: 129 | return 'bbq' 130 | # natural_qa_closebook opinions_qa 131 | return 'truthfulqa' 132 | if '\nQ' in prompt and '\nA:' in prompt: 133 | return 'gsm8k' 134 | if os.environ.get('FALLBACK','chat') == 'chat': 135 | return 'chat' 136 | if os.environ.get('FALLBACK') == 'bbq': 137 | return 'bbq' 138 | return 'raw' 139 | 140 | def dedup(output): 141 | checkx1, prompt_num = check_dup_line(output) 142 | if checkx1 == True: 143 | prompt_line = output.split("\n") 144 | if len(prompt_line) == 1: 145 | prompt_line = output.split("\\n") 146 | mine_line = len(prompt_line) 147 | max_line = 0 148 | for k in prompt_num.keys(): 149 | pattern_more = r"(.+?):" 150 | match = re.search(pattern_more, k) 151 | if not match: 152 | pattern_more = r"(.+?):" 153 | match = re.search(pattern_more, k) 154 | if match and len(prompt_num[k]) > 1: 155 | #prompt_num[k] = sorted(prompt_num[k], reverse=True) 156 | max_line = max(max_line, min(prompt_num[k]) + 1) 157 | mine_line = max_line 158 | output = "\n".join(prompt_line[:mine_line]) 159 | 160 | pattern_more = r"(###)?\n(.+?):" 161 | matches = [match for match in re.finditer(pattern_more, output)] 162 | if len(matches) == 0: 163 | pattern_more = r"(###)?\\n(.+?):" 164 | matches = [match for match in re.finditer(pattern_more, output)] 165 | special_item = ["Q", "Question", "Title", "Label", "A", "Answer", "Sentiment", "Passage", "Article", "Rules"] 166 | min_pos = len(output) 167 | for match in matches: 168 | start = match.start() 169 | end = match.end() 170 | text = match.group() 171 | check_now = False 172 | for i in special_item: 173 | if i in text or i.lower() in text: 174 | check_now = True 175 | break 176 | if check_now: 177 | min_pos = min(start, min_pos) 178 | output = output[:min_pos] 179 | return output 180 | 181 | if __name__ == "__main__": 182 | text = "" 183 | with open("./test", "r") as f: 184 | for i in f: 185 | text += i 186 | print(check_fewshot(text)) -------------------------------------------------------------------------------- /generative/helm_utils/lora_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from peft import PeftModel, PeftConfig 5 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 6 | from transformers import BitsAndBytesConfig 7 | import torch 8 | import json 9 | import re 10 | import peft 11 | from typing import List, Dict, Optional 12 | import tqdm 13 | 14 | def peft_func_map(self, fn_name, **kargs): 15 | for module in self.modules(): 16 | if isinstance(module, peft.tuners.lora.layer.LoraLayer): 17 | getattr(module, fn_name)(**kargs) 18 | 19 | def prepare_before_merge(self, adapter_names): 20 | for layer_name in self.adapter_layer_names: 21 | module_dict = getattr(self, layer_name) 22 | for key, layer in module_dict.items(): 23 | if key not in adapter_names: 24 | layer.requires_grad_(False) 25 | layer.to('cpu') 26 | 27 | def to_cuda(self, adapter_names: List[str]): 28 | for layer_name in self.adapter_layer_names: 29 | module_dict = getattr(self, layer_name) 30 | for key, layer in module_dict.items(): 31 | if key in adapter_names: 32 | layer = layer.to(self.base_layer.weight.device) 33 | 34 | def prepare_after_merge(self, adapter_names: str ): 35 | for layer_name in self.adapter_layer_names: 36 | module_dict = getattr(self, layer_name) 37 | for key, layer in module_dict.items(): 38 | if key == adapter_names: 39 | layer = layer.to(self.base_layer.weight.device) 40 | 41 | self._active_adapter = adapter_names 42 | 43 | def _svd_weighted_adapter_cuda( 44 | self, 45 | adapters, 46 | weights, 47 | new_rank, 48 | target, 49 | target_lora_A, 50 | target_lora_B, 51 | clamp=None, 52 | full_matrices=True, 53 | driver=None, 54 | ): 55 | valid_adapters = [] 56 | valid_weights = [] 57 | for adapter, weight in zip(adapters, weights): 58 | if adapter in target.lora_A or adapter in target.lora_embedding_A: 59 | valid_adapters.append(adapter) 60 | valid_weights.append(weight) 61 | 62 | # if no valid adapter, nothing to do 63 | if len(valid_adapters) == 0: 64 | raise ValueError("No matching LoRAs found. Please raise an issue on Github.") 65 | 66 | delta_weight = valid_weights[0] * target.get_delta_weight(valid_adapters[0]) 67 | for adapter, weight in zip(valid_adapters[1:], valid_weights[1:]): 68 | delta_weight += weight * target.get_delta_weight(adapter) 69 | 70 | if hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out: 71 | delta_weight = delta_weight.T 72 | 73 | # based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131 74 | # NOTE use gpu to calculate 75 | U, S, Vh = torch.linalg.svd(delta_weight.to(self.model.device), full_matrices=full_matrices, driver=driver) # driver='gesvda' 76 | U = U[:, :new_rank] 77 | S = S[:new_rank] 78 | U = U @ torch.diag(S) 79 | Vh = Vh[:new_rank, :] 80 | if clamp is not None: 81 | dist = torch.cat([U.flatten(), Vh.flatten()]) 82 | hi_val = torch.quantile(dist, clamp) 83 | low_val = -hi_val 84 | U = U.clamp(low_val, hi_val) 85 | Vh = Vh.clamp(low_val, hi_val) 86 | return Vh.cpu(), U.cpu() 87 | 88 | def add_multi_lora(model, lora_paths, lora_names): 89 | for i, lora_zip in enumerate( 90 | tqdm.tqdm( 91 | zip(lora_paths, lora_names), 92 | desc='Loading Lora models', 93 | total=len(lora_paths) 94 | ) 95 | ): 96 | path,name = lora_zip 97 | if i == 0: 98 | model = PeftModel.from_pretrained( 99 | model, 100 | path, 101 | adapter_name=name, 102 | ) 103 | else: 104 | model.load_adapter( 105 | path, 106 | adapter_name=name, 107 | ) 108 | return model 109 | 110 | def lora_merge(model, adapter_names=["adapter_1", "adapter_2"], adapter_weights=[1.0, 1.0], method='cat'): 111 | adapter_weights = [float(w) for w in adapter_weights] 112 | print(f'{method} adapters: {adapter_names}') 113 | # set_multiple_active_adapters(model, adapter_names) 114 | # if tuner_method == "lora": 115 | with torch.no_grad(): 116 | model.peft_func_map( 117 | 'prepare_before_merge', adapter_names=['merged'] 118 | ) 119 | model.add_weighted_adapter( 120 | adapters = adapter_names, 121 | weights = adapter_weights, 122 | adapter_name="merged", 123 | combination_type=method, 124 | ) 125 | model.peft_func_map( 126 | 'prepare_after_merge', adapter_names=['merged'] 127 | ) 128 | # print(model.base_model.model.transformer.h[0].attn.c_attn.lora_A['merged'].weight) 129 | # print(model.base_model.model.transformer.h[0].attn.c_attn.lora_A['cnn-dm'].weight) 130 | # print(model.base_model.model.transformer.h[0].attn.c_attn.lora_B['merged'].weight) 131 | # print(model.base_model.model.transformer.h[0].attn.c_attn.lora_B['cnn-dm'].weight) 132 | print(f'merged adatpers: {model.base_model.model.transformer.h[0].attn.c_attn.merged_adapters}') 133 | print(f'active adatpters: {model.base_model.model.transformer.h[0].attn.c_attn.active_adapters}') 134 | print(f'disable adatpers: {model.base_model.model.transformer.h[0].attn.c_attn.disable_adapters}') 135 | 136 | def hack_qwen_for_merge(): 137 | setattr(peft.PeftModel, 'peft_func_map', peft_func_map) 138 | setattr(peft.tuners.lora.layer.LoraLayer, 'prepare_before_merge', prepare_before_merge) 139 | setattr(peft.tuners.lora.layer.LoraLayer, 'prepare_after_merge', prepare_after_merge) 140 | setattr(peft.tuners.lora.LoraModel, '_svd_weighted_adapter', _svd_weighted_adapter_cuda) 141 | 142 | def hack_qwen_for_moe(): 143 | setattr(peft.PeftModel, 'peft_func_map', peft_func_map) 144 | setattr(peft.tuners.lora.layer.LoraLayer, 'to_cuda', to_cuda) -------------------------------------------------------------------------------- /generative/helm_utils/peft_save_merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from peft import PeftModel, PeftConfig 5 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 6 | from transformers import BitsAndBytesConfig 7 | import torch 8 | import json 9 | import re 10 | import peft 11 | from typing import List, Dict, Optional 12 | import tqdm 13 | import json 14 | import utils.lora_utils as lora_utils 15 | 16 | lora_utils.hack_qwen_for_merge() 17 | 18 | MODEL = '/model/Qwen-14B' 19 | target_keys=os.environ.get('TGT', None) 20 | DTYPE=torch.bfloat16 21 | model = AutoModelForCausalLM.from_pretrained( 22 | MODEL, 23 | trust_remote_code=True, 24 | device_map={'':0}, 25 | torch_dtype=DTYPE, 26 | quantization_config=BitsAndBytesConfig( 27 | load_in_4bit=True, 28 | bnb_4bit_use_double_quant=True, 29 | bnb_4bit_quant_type="nf4", 30 | bnb_4bit_compute_dtype=DTYPE, 31 | llm_int8_has_fp16_weight=True, 32 | ) 33 | ) 34 | 35 | with torch.no_grad(): 36 | merge_config = json.load(open("utils/merge_config.json")) 37 | for domain in merge_config.keys(): 38 | if target_keys and domain not in target_keys: 39 | continue 40 | # merge 2 lora 41 | if len(merge_config[domain]) == 2: 42 | lora_paths = list(merge_config[domain].keys()) 43 | lora_weights = list(merge_config[domain].values()) 44 | merged_name=f'qwen-{domain}' 45 | model = lora_utils.add_multi_lora( 46 | model, 47 | lora_paths=lora_paths, 48 | lora_names=['l1', 'l2'], 49 | ) 50 | 51 | model.peft_func_map( 52 | 'prepare_before_merge', adapter_names=[merged_name] 53 | ) 54 | model.add_weighted_adapter( 55 | adapters = ['l1', 'l2'], 56 | weights = lora_weights, 57 | adapter_name=merged_name, 58 | combination_type='linear', 59 | ) 60 | model.peft_func_map( 61 | 'prepare_after_merge', adapter_names=[merged_name] 62 | ) 63 | print(f'merged adatpers: {model.base_model.model.transformer.h[0].attn.c_attn.merged_adapters}') 64 | print(f'active adatpters: {model.base_model.model.transformer.h[0].attn.c_attn.active_adapters}') 65 | print(f'disable adatpers: {model.base_model.model.transformer.h[0].attn.c_attn.disable_adapters}') 66 | 67 | model.save_pretrained(f'lu-vae', selected_adapters=[merged_name]) 68 | model.delete_adapter('l1') 69 | model.delete_adapter('l2') 70 | -------------------------------------------------------------------------------- /generative/helm_utils/prompter.py: -------------------------------------------------------------------------------- 1 | 2 | import re,os 3 | 4 | 5 | def prompt0(prompt): 6 | return prompt 7 | 8 | def prompt1(prompt): 9 | if '\nQuestion' in prompt and '\nAnswer:' in prompt: 10 | prompt_list= prompt.split('\nAnswer:') 11 | tmp = prompt_list[0].split('\n\n') 12 | if len(tmp) == 2: 13 | prompt_list[0] = tmp[0] + "Let's think step by step and give your answer as faithfully as you can.\n\n" + tmp[1] 14 | else: 15 | prompt_list[0] = prompt_list[0] 16 | prompt_list[-2] += "\n\nLet's think step by step and give your answer as faithfully as you can." 17 | prompt= '\nAnswer:'.join(prompt_list) 18 | else: 19 | prompt = prompt 20 | return prompt 21 | 22 | def prompt2(prompt): 23 | new_prompt = [] 24 | template_with_answer ="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}\n\n" 25 | template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{prefix}" 26 | if '\nQuestion' in prompt and '\nAnswer:' in prompt: 27 | 28 | data = re.split(r'Question: |\nAnswer: ',prompt) 29 | system_info = data[0] 30 | data = data[1:] 31 | if not len(data) % 2: 32 | return prompt 33 | 34 | for i in range(0,len(data)-1,2): 35 | new_prompt.append(template_with_answer.format_map({ 36 | 'instruction': system_info + 'Question: ' + data[i].strip(), 37 | 'response': 'Answer: ' + data[i+1].strip(), 38 | })) 39 | new_prompt.append(template.format_map({ 40 | 'instruction': system_info + 'Question: ' + data[i].strip(), 41 | 'prefix': 'Answer: ' 42 | })) 43 | new_prompt = ''.join(new_prompt) 44 | 45 | elif '\nQ' in prompt and '\nA:' in prompt: 46 | 47 | data = re.split(r'Q: |\nA: ',prompt) 48 | system_info = data[0] 49 | data = data[1:] 50 | if not len(data) % 2: 51 | return prompt 52 | 53 | for i in range(0,len(data)-1,2): 54 | new_prompt.append(template_with_answer.format_map({ 55 | 'instruction': system_info + 'Q: ' + data[i].strip(), 56 | 'response': 'A: ' + data[i+1].strip(), 57 | })) 58 | new_prompt.append(template.format_map({ 59 | 'instruction': system_info + 'Q: ' + data[i].strip(), 60 | 'prefix': 'A: ' 61 | })) 62 | new_prompt = ''.join(new_prompt) 63 | 64 | elif '###\nArticle' in prompt: 65 | 66 | data = re.split(r'###\nArticle: |\n\nSummarize the above article in 3 sentences.\n',prompt) 67 | data = [item for item in data if item != ''] 68 | if not len(data) % 2: 69 | return prompt 70 | 71 | for i in range(0,len(data)-1,2): 72 | new_prompt.append(template_with_answer.format_map({ 73 | 'instruction': '###\nArticle: ' + data[i].strip() + '\nSummarize the above article in 3 sentences.\n', 74 | 'response': data[i+1].strip(), 75 | })) 76 | new_prompt.append(template.format_map({ 77 | 'instruction': '###\nArticle: ' + data[i].strip() + '\nSummarize the above article in 3 sentences.\n', 78 | 'prefix': '', 79 | })) 80 | new_prompt = ''.join(new_prompt) 81 | 82 | else: 83 | print('fail to match') 84 | return prompt 85 | return new_prompt 86 | 87 | def prompt3(prompt): 88 | new_prompt = [] 89 | template ="Let's think step by step and give your answer as faithfully as you can to ensure the answer is right.\n\n{instruction}" 90 | # Let’s work this out in a step by step way to be sure we have the right answer 91 | in_begin,in_end, prefix='','','' 92 | if '\nQuestion' in prompt and '\nAnswer:' in prompt: 93 | data = re.split(r'Question: ',prompt) 94 | in_begin='Question: ' 95 | elif '\nQ' in prompt and '\nA:' in prompt: 96 | data = re.split(r'Q: ',prompt) 97 | in_begin='Q: ' 98 | elif '###\nArticle' in prompt: 99 | data = re.split(r'###\nArticle: ',prompt) 100 | in_begin='###\nArticle: ' 101 | else: 102 | print('fail to match') 103 | return prompt 104 | system_info = data[0] 105 | data= data[1:] 106 | for i in range(0,len(data)): 107 | new_prompt.append(template.format_map({ 108 | 'instruction': system_info + in_begin+ data[i], 109 | })) 110 | new_prompt = ''.join(new_prompt) 111 | return new_prompt 112 | 113 | def config_prompt(prompt, data_type): 114 | if data_type == 'zs-chat': 115 | system_prompt = 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.' 116 | return system_prompt + 'USER:' + prompt + 'ASSISTANT:' 117 | 118 | return prompt1(prompt) 119 | 120 | def deperturb(prompt): 121 | # lowercase_perturbation = LowerCasePerturbation() 122 | # contraction_perturbation = ContractionPerturbation() 123 | # space_perturbation = SpacePerturbation(max_spaces=3) 124 | # misspelling_perturbation = MisspellingPerturbation(prob=0.1) 125 | prompt = re.sub(r" +", lambda x: " ", prompt) 126 | return prompt -------------------------------------------------------------------------------- /generative/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict, OrderedDict 3 | import tqdm 4 | import re 5 | import torch.nn as nn 6 | import copy 7 | import sparsify 8 | import utils 9 | from param import param 10 | 11 | class MergingMethod: 12 | 13 | @utils.args_inspector 14 | def __init__( 15 | self, 16 | models_to_merge, 17 | models_name, 18 | ): 19 | self.models_name = {n:i for i,n in enumerate(models_name)} 20 | # dict(zip(models_name, range(0, N))) 21 | self.models_to_merge = models_to_merge 22 | 23 | def get_model(self, model_name): 24 | return self.models_to_merge[self.models_name[model_name]] 25 | 26 | @utils.args_inspector 27 | @torch.inference_mode() 28 | def average_merging( 29 | self, 30 | ): 31 | 32 | merged_param = param.vectorize_reduce( 33 | lambda x: torch.stack(x).mean(dim=0), 34 | self.models_to_merge 35 | ) 36 | return merged_param 37 | 38 | @utils.args_inspector 39 | def fisher_merge( 40 | self, 41 | models_to_merge: list, 42 | data_names: list, 43 | data_nums: list, 44 | fish_scaling: list = None, 45 | norm_fish_weight: bool = True, 46 | min_fish_weight: float = 1e-6 47 | ): 48 | from merger.fisher_merge import FisherMerge 49 | merger = FisherMerge( 50 | models_to_merge, 51 | data_names, data_nums, 52 | fish_scaling, norm_fish_weight,min_fish_weight 53 | ) 54 | return merger.merge() 55 | 56 | @utils.args_inspector 57 | @torch.inference_mode() 58 | def regmean_merge( 59 | self, 60 | models_to_merge: list, 61 | data_names: list, 62 | data_nums: list, 63 | reduce_non_diagonal_ratio: float = 1.0 64 | ): 65 | 66 | from merger.regmean_merge import RegMeanMerge 67 | merger = RegMeanMerge( 68 | models_to_merge, 69 | data_names, data_nums, 70 | reduce_non_diagonal_ratio, 71 | ) 72 | return merger.merge() 73 | 74 | @utils.args_inspector 75 | @torch.inference_mode() 76 | def ties_merge( 77 | self, 78 | base_model: nn.Module, 79 | models_to_merge: list, 80 | mask_rate: float = 0.8, 81 | scaling: float = 1.0, 82 | ): 83 | 84 | def disjoint_merge( 85 | tensor: torch.Tensor, # (n_model, n_para) 86 | merge_func:str = 'mean', 87 | ): 88 | # torch.sign 将正数转为1,将负数转为-1,将0保持为0 89 | sign = torch.sign(tensor.sum(dim=0)) # (num_total_params, ) 90 | # get majority sign 如果主要是正数,那么总和将为正,如果主要是负数,那么总和将为负 91 | majority_sign = torch.sign(sign.sum(dim=0)) 92 | # replace 0 in sign to the major sign in param_signs 93 | sign[sign == 0] = majority_sign 94 | del majority_sign 95 | 96 | # preserve the parameter with the expect sign 97 | mask = torch.where( 98 | sign.unsqueeze(0) > 0, tensor > 0, tensor < 0 99 | ) 100 | tensor = tensor * mask 101 | 102 | # (n_model, n_para) -> (n_para,) 103 | if merge_func == "mean": 104 | num_ = (tensor != 0).sum(dim=0).float() 105 | # min=1.0 避免num_=0的情况 106 | tensor = torch.sum(tensor, dim=0) / torch.clamp(num_, min=1.0) 107 | elif merge_func == "sum": 108 | tensor = torch.sum(tensor, dim=0) 109 | elif merge_func == "max": 110 | tensor = tensor.abs().max(dim=0)[0] 111 | tensor *= sign 112 | return tensor 113 | 114 | def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False): 115 | if K == 100: 116 | # print("Not applying mask") 117 | if return_mask: 118 | return M, torch.ones_like(M), None 119 | else: 120 | return M, torch.ones_like(M) 121 | 122 | if K >= 1: 123 | K /= 100 124 | 125 | original_shape = M.shape 126 | if M.dim() == 1: 127 | M = M.unsqueeze(0) 128 | 129 | n, d = M.shape 130 | k = int(d * K) 131 | k = d - k # Keep top k elements instead of bottom k elements 132 | 133 | # Find the k-th smallest element by magnitude for each row 134 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 135 | # Create a mask tensor with True for the top k elements in each row 136 | mask = M.abs() >= kth_values 137 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 138 | 139 | if reshape_mask: 140 | final_mask = final_mask.reshape(M.shape) 141 | 142 | if return_mask: 143 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 144 | else: 145 | return M * final_mask, final_mask.float().mean(dim=1) 146 | 147 | task_vectors = [ 148 | model - base_model 149 | for model in models_to_merge 150 | ] 151 | # 由于需要获取总的majority sign, 因此需要先提取出来所有的参数 152 | flattened_param = [ task_vector.flatten() for task_vector in task_vectors ] 153 | # sparsify on model-level => (n_model, n_para) 154 | # flattened_param = torch.vstack( 155 | # [ sparsify.magnitude(_param, 1 - mask_rate) for _param in flattened_param ] 156 | # ) 157 | flattened_param = topk_values_mask(torch.vstack(flattened_param), 1 - mask_rate)[0] 158 | flattened_param = disjoint_merge(flattened_param) 159 | # randomly select one vector to unflatten 160 | merged_param = copy.deepcopy(base_model) 161 | merged_param = base_model + scaling * merged_param.unflatten(flattened_param) 162 | return merged_param 163 | 164 | @utils.args_inspector 165 | @torch.inference_mode() 166 | def task_arithmetic( 167 | self, 168 | base_model: nn.Module, 169 | models_to_merge: param, 170 | scaling: float = 1.0, 171 | ): 172 | 173 | task_vectors = [ 174 | model - base_model 175 | for model in models_to_merge 176 | ] 177 | 178 | # TODO: too easy 179 | merged_param = base_model + scaling * sum(task_vectors) 180 | return merged_param 181 | 182 | @utils.args_inspector 183 | @torch.inference_mode() 184 | def task_arithmetic_search( 185 | self, 186 | base_model: nn.Module, 187 | models_to_merge: param, 188 | scaling: float = 1.0, 189 | ): 190 | 191 | task_vectors = [ 192 | model - base_model 193 | for model in models_to_merge 194 | ] 195 | 196 | merged_param = base_model + sum([ 197 | w * tv 198 | for w, tv in zip(scaling, task_vectors) 199 | ]) 200 | return merged_param 201 | 202 | @utils.args_inspector 203 | @torch.inference_mode() 204 | def task_arithmetic_plus( 205 | self, 206 | base_model: nn.Module, 207 | models_to_merge: param, 208 | scaling: float = 1.0, 209 | mask_strategy: str = None, 210 | mask_rate: float = None, 211 | ): 212 | 213 | task_vectors = [ 214 | model + base_model 215 | for model in models_to_merge 216 | ] 217 | 218 | if mask_strategy is None: 219 | merged_param = (scaling * sum(task_vectors)) - base_model 220 | else: 221 | merged_param = (scaling * sum(task_vectors)).map( 222 | lambda n,p: getattr(sparsify, mask_strategy)( 223 | p, 224 | 1 - mask_rate, 225 | ), 226 | desc=mask_strategy 227 | )- base_model 228 | return merged_param 229 | 230 | @utils.args_inspector 231 | @torch.inference_mode() 232 | def dare_merge( 233 | self, 234 | models_to_merge: param, 235 | second_merge_method: str, 236 | second_merge_config: dict, 237 | mask_rate: float, 238 | base_model: nn.Module, 239 | mask_scale: float = 1.0, 240 | weight_format: str = 'delta', 241 | ): 242 | # 1. sparsify masking (merge with base model) 243 | masked_params = [ 244 | self.dare_mask( 245 | finetuned_model, 246 | mask_rate, 247 | base_model, 248 | mask_scale, 249 | weight_format, 250 | ) for finetuned_model in models_to_merge 251 | ] 252 | # 2. merge between the different models 253 | merged_params = getattr(self, second_merge_method)( 254 | base_model = base_model, 255 | models_to_merge = masked_params, 256 | **second_merge_config 257 | ) 258 | return merged_params 259 | 260 | @torch.inference_mode() 261 | def dare_mask( 262 | self, 263 | finetuned_model: nn.Module, 264 | mask_rate: float, 265 | base_model: nn.Module = None, 266 | mask_scale: float = 1.0, 267 | weight_format: str = 'delta' 268 | ): 269 | 270 | mask_rate = float(mask_rate) 271 | 272 | if weight_format == "full" or weight_format == "lora": 273 | masked_param = finetuned_model 274 | elif weight_format == "delta": 275 | masked_param = finetuned_model - base_model 276 | else: 277 | raise NotImplementedError 278 | 279 | masked_param = masked_param.map( 280 | lambda n,p: sparsify.bernoulli( 281 | p, 282 | 1 - mask_rate, 283 | ), 284 | desc='bernoulli' 285 | ) 286 | 287 | if weight_format == "delta": 288 | masked_param = base_model + mask_scale * masked_param 289 | return masked_param 290 | 291 | @utils.args_inspector 292 | @torch.inference_mode() 293 | def twin_merge( 294 | self, 295 | base_model: nn.Module, 296 | models_to_merge: param, 297 | second_merge_method: str, 298 | second_merge_config: dict, 299 | ): 300 | # merge again / MergePlus / DoubleBundle / DualMerger 301 | 302 | # Get merged parameter 303 | merged_params = getattr(self, second_merge_method)( 304 | base_model = base_model, 305 | models_to_merge = models_to_merge, 306 | **second_merge_config 307 | ) 308 | return merged_params 309 | 310 | 311 | # lora = task_vector 312 | class LoraMergingMethod: 313 | 314 | @utils.args_inspector 315 | def __init__( 316 | self, 317 | models_to_merge, 318 | models_name, 319 | ): 320 | self.models_name = {n:i for i,n in enumerate(models_name)} 321 | # dict(zip(models_name, range(0, N))) 322 | self.models_to_merge = models_to_merge 323 | 324 | def get_model(self, model_name): 325 | return self.models_to_merge[self.models_name[model_name]] 326 | 327 | @utils.args_inspector 328 | @torch.inference_mode() 329 | def average_merging( 330 | self, 331 | ): 332 | 333 | merged_param = param.vectorize_reduce( 334 | lambda x: torch.stack(x).mean(dim=0), 335 | self.models_to_merge 336 | ) 337 | return merged_param 338 | 339 | @utils.args_inspector 340 | @torch.inference_mode() 341 | def ties_merge( 342 | self, 343 | models_to_merge: list, 344 | mask_rate: float = 0.8, 345 | scaling: float = 1.0, 346 | ): 347 | 348 | def disjoint_merge( 349 | tensor: torch.Tensor, # (n_model, n_para) 350 | merge_func:str = 'mean', 351 | ): 352 | # torch.sign 将正数转为1,将负数转为-1,将0保持为0 353 | sign = torch.sign(tensor.sum(dim=0)) # (num_total_params, ) 354 | # get majority sign 如果主要是正数,那么总和将为正,如果主要是负数,那么总和将为负 355 | majority_sign = torch.sign(sign.sum(dim=0)) 356 | # replace 0 in sign to the major sign in param_signs 357 | sign[sign == 0] = majority_sign 358 | del majority_sign 359 | 360 | # preserve the parameter with the expect sign 361 | mask = torch.where( 362 | sign.unsqueeze(0) > 0, tensor > 0, tensor < 0 363 | ) 364 | tensor = tensor * mask 365 | 366 | # (n_model, n_para) -> (n_para,) 367 | if merge_func == "mean": 368 | num_ = (tensor != 0).sum(dim=0).float() 369 | # min=1.0 避免num_=0的情况 370 | tensor = torch.sum(tensor, dim=0) / torch.clamp(num_, min=1.0) 371 | elif merge_func == "sum": 372 | tensor = torch.sum(tensor, dim=0) 373 | elif merge_func == "max": 374 | tensor = tensor.abs().max(dim=0)[0] 375 | tensor *= sign 376 | return tensor 377 | 378 | # 由于需要获取总的majority sign, 因此需要先提取出来所有的参数 379 | flattened_param = [ task_vector.flatten() for task_vector in models_to_merge ] 380 | # sparsify on model-level => (n_model, n_para) 381 | flattened_param = torch.vstack( 382 | [ sparsify.magnitude(_param, 1 - mask_rate) for _param in flattened_param ] 383 | ) 384 | flattened_param = disjoint_merge(flattened_param) 385 | # randomly select one vector to unflatten 386 | merged_param = copy.deepcopy(models_to_merge[0]) 387 | merged_param = scaling * merged_param.unflatten(flattened_param) 388 | return merged_param 389 | 390 | @utils.args_inspector 391 | @torch.inference_mode() 392 | def task_arithmetic( 393 | self, 394 | models_to_merge: param, 395 | scaling: float = 1.0, 396 | ): 397 | 398 | merged_param = scaling * sum(models_to_merge) 399 | return merged_param 400 | 401 | @utils.args_inspector 402 | @torch.inference_mode() 403 | def task_arithmetic2( 404 | self, 405 | models_to_merge: param, 406 | scaling: list, 407 | ): 408 | 409 | merged_param = sum([ 410 | w * model for w, model in zip(scaling, models_to_merge) 411 | ]) 412 | return merged_param 413 | 414 | @utils.args_inspector 415 | @torch.inference_mode() 416 | def dare_merge( 417 | self, 418 | models_to_merge: param, 419 | second_merge_method: str, 420 | second_merge_config: dict, 421 | mask_rate: float, 422 | mask_scale: float = 1.0, 423 | ): 424 | # 1. sparsify masking (merge with base model) 425 | masked_params = [ 426 | self.dare_mask( 427 | finetuned_model, 428 | mask_rate, 429 | mask_scale, 430 | ) for finetuned_model in models_to_merge 431 | ] 432 | # 2. merge between the different models 433 | merged_params = getattr(self, second_merge_method)( 434 | models_to_merge = masked_params, 435 | **second_merge_config 436 | ) 437 | return merged_params 438 | 439 | @torch.inference_mode() 440 | def dare_mask( 441 | self, 442 | finetuned_model: nn.Module, 443 | mask_rate: float, 444 | mask_scale: float = 1.0, 445 | ): 446 | 447 | mask_rate = float(mask_rate) 448 | masked_param = finetuned_model 449 | masked_param = masked_param.map( 450 | lambda n,p: sparsify.bernoulli( 451 | p, 452 | 1 - mask_rate, 453 | ), 454 | desc='bernoulli' 455 | ) 456 | return mask_scale * masked_param 457 | 458 | @utils.args_inspector 459 | @torch.inference_mode() 460 | def twin_merge( 461 | self, 462 | base_model: nn.Module, 463 | models_to_merge: param, 464 | second_merge_method: str, 465 | second_merge_config: dict, 466 | ): 467 | # merge again / MergePlus / DoubleBundle / DualMerger 468 | 469 | # Get merged parameter 470 | merged_params = getattr(self, second_merge_method)( 471 | models_to_merge = models_to_merge, 472 | **second_merge_config 473 | ) 474 | return merged_params 475 | -------------------------------------------------------------------------------- /generative/model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 2 | import os 3 | import torch 4 | import transformers 5 | 6 | # device_map = {'':0} 7 | device_map = 'cpu' 8 | 9 | def embedding_resize(model:transformers.PreTrainedModel, num_new_tokens=0): 10 | if num_new_tokens == 0: 11 | return 12 | model.resize_token_embeddings(model.config.vocab_size + num_new_tokens) 13 | if num_new_tokens < 0: 14 | return 15 | input_embeddings = model.get_input_embeddings().weight.data 16 | output_embeddings = model.get_output_embeddings().weight.data 17 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 18 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 19 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 20 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 21 | 22 | def load_classifier(model_name: str, dtype=torch.float32, save_classifier_head=True): 23 | model = AutoModelForSequenceClassification.from_pretrained( 24 | model_name, torch_dtype=dtype, device_map=device_map, 25 | ) 26 | if save_classifier_head: 27 | if not os.path.exists(f'{model_name}'): 28 | print(f' >>> skip save classifier head for {model_name}') 29 | return model 30 | 31 | if os.path.exists(f'{model_name}/classifier_head.pt'): 32 | print(f' >>> skip save classifier head for {model_name}') 33 | return model 34 | 35 | print(f' >>> save classifier head for {model_name} in {model_name}/classifier_head.pt ') 36 | torch.save(model.classifier, f'{model_name}/classifier_head.pt') 37 | 38 | return model 39 | 40 | def load_seq2seqlm(model_name: str, dtype=torch.float32, new_vocab_size=None): 41 | model = AutoModelForSeq2SeqLM.from_pretrained( 42 | model_name, torch_dtype=dtype, device_map=device_map 43 | ) 44 | if new_vocab_size is not None: 45 | embedding_resize(model, new_vocab_size) 46 | # TODO: tokenizer handler ? 47 | return model 48 | 49 | def load_causallm(model_name: str, dtype=torch.bfloat16, new_vocab_size=None): 50 | model = AutoModelForCausalLM.from_pretrained( 51 | model_name, torch_dtype=dtype, device_map=device_map, 52 | # local_files_only=True, 53 | trust_remote_code=True, 54 | ) 55 | # TODO: temporially reduce to the same as base_model 56 | if new_vocab_size is not None: 57 | embedding_resize(model, new_vocab_size) 58 | return model -------------------------------------------------------------------------------- /generative/param.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict, OrderedDict 3 | import tqdm 4 | import re 5 | import torch.nn as nn 6 | import copy 7 | import sparsify 8 | import utils 9 | import sys 10 | import transformers 11 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 12 | import os 13 | import functools 14 | from collections import defaultdict, OrderedDict 15 | import torch 16 | 17 | 18 | kw_filter_func = lambda n,p,exclude_param : not any([ 19 | re.match(exclude_pattern, n) 20 | for exclude_pattern in exclude_param 21 | ]) 22 | 23 | MODE = 'drop' 24 | # MODE = 'keep_left' 25 | # MODE = 'keep_right' 26 | class param: 27 | 28 | def __init__( 29 | self, 30 | model, 31 | ): 32 | if isinstance(model, torch.nn.Module): 33 | other = model.state_dict() 34 | elif isinstance(model, dict): 35 | other = model 36 | elif isinstance(model, param): 37 | other = model.param_dict 38 | else: 39 | raise NotImplementedError 40 | 41 | self.param_dict = other 42 | 43 | def filter(self, func): 44 | self.param_dict = { 45 | n: p 46 | for n,p in self.param_dict.items() 47 | if func(n,p) 48 | } 49 | 50 | def __getitem__(self, item): 51 | return self.param_dict[item] 52 | 53 | def __len__(self): 54 | return len(self.param_dict) 55 | 56 | def items(self): 57 | return self.param_dict.items() 58 | 59 | def keys(self): 60 | return self.param_dict.keys() 61 | 62 | def values(self): 63 | return self.param_dict.values() 64 | 65 | # implement `in`! 66 | def __contains__(self, item): 67 | return item in self.keys() 68 | 69 | # a + b 70 | def __add__(self, other): 71 | 72 | if other == 0: 73 | return self 74 | 75 | if isinstance(other, torch.nn.Module): 76 | other = param(other) 77 | 78 | if hasattr(other, 'param_dict'): 79 | 80 | if MODE == 'drop': 81 | return param( 82 | { 83 | n: self[n] + other[n] 84 | for n in set(self.keys()).intersection(other.keys()) 85 | } 86 | ) 87 | # 保留自身的key 88 | elif MODE == 'keep_left': 89 | return param( 90 | { 91 | n: self[n] + other[n] 92 | if n in other 93 | else self[n] 94 | for n in (self.keys()) 95 | } 96 | ) 97 | 98 | # 保留对方的key 99 | elif MODE == 'keep_right': 100 | return param( 101 | { 102 | n: self[n] + other[n] 103 | if n in self 104 | else other[n] 105 | for n in (other.keys()) 106 | } 107 | ) 108 | else: 109 | raise NotImplementedError 110 | 111 | def update_null_keys(self, other): 112 | # 用other填充 自身中不存在的key 113 | for k in other.keys(): 114 | if k not in self: 115 | self[k] = other[k] 116 | 117 | # type(y).__rsub__(y, x) is called if type(x).__sub__(x, y) returns NotImplemented. 118 | # a + b if a is not implemented 119 | def __radd__(self, other): 120 | # sum(x) start with 0 + x[0] 121 | if other == 0: 122 | return self 123 | # other + self = self + other 124 | return self.__add__(other) 125 | 126 | def __sub__(self, other): 127 | 128 | if other == 0: 129 | return self 130 | 131 | if isinstance(other, torch.nn.Module): 132 | other = param(other) 133 | 134 | if hasattr(other, 'param_dict'): 135 | 136 | if MODE == 'drop': 137 | return param( 138 | { 139 | n: self[n] - other[n] 140 | for n in set(self.keys()).intersection(other.keys()) 141 | } 142 | ) 143 | # 保留自身的key 144 | elif MODE == 'keep_left': 145 | return param( 146 | { 147 | n: self[n] - other[n] 148 | if n in other 149 | else self[n] 150 | for n in (self.keys()) 151 | } 152 | ) 153 | # 保留对方的key 154 | elif MODE == 'keep_right': 155 | return param( 156 | { 157 | n: self[n] - other[n] 158 | if n in self 159 | else other[n] 160 | for n in (other.keys()) 161 | } 162 | ) 163 | 164 | else: 165 | raise NotImplementedError 166 | 167 | def __rsub__(self, other): 168 | # other - self 169 | if isinstance(other, torch.nn.Module): 170 | other = param(other) 171 | 172 | if hasattr(other, 'param_dict'): 173 | return other.__sub__(self) 174 | 175 | else: 176 | raise NotImplementedError 177 | 178 | def __rmul__(self, other): 179 | 180 | if isinstance(other, float) or isinstance(other, torch.Tensor): 181 | # weight 182 | return param( 183 | { 184 | n: other * p 185 | for n,p in self.param_dict.items() 186 | } 187 | ) 188 | 189 | if isinstance(other, dict): 190 | # module-wise weight 191 | if MODE == 'drop': 192 | return param( 193 | { 194 | n: other[n] * self[n] 195 | for n in set(self.keys()).intersection(other.keys()) 196 | } 197 | ) 198 | # 保留自身的key 199 | elif MODE == 'keep_left': 200 | return param( 201 | { 202 | n: other[n] * self[n] 203 | if n in other 204 | else self[n] 205 | for n in (self.keys()) 206 | } 207 | ) 208 | # 保留对方的key 209 | elif MODE == 'keep_right': 210 | return param( 211 | { 212 | n: other[n] * self[n] 213 | if n in self 214 | else other[n] 215 | for n in (other.keys()) 216 | } 217 | ) 218 | 219 | raise NotImplementedError 220 | 221 | def __mul__(self, other): 222 | return self.__rmul__(other) 223 | 224 | def __neg__(self, ): 225 | return param( 226 | { 227 | n: -p 228 | for n,p in self.param_dict.items() 229 | } 230 | ) 231 | 232 | def __truediv__(self, other): 233 | 234 | if isinstance(other, (int, float)): 235 | # weight 236 | return param( 237 | { 238 | n: p / other 239 | for n,p in self.param_dict.items() 240 | } 241 | ) 242 | 243 | if isinstance(other, param): 244 | # module-wise weight 245 | if MODE == 'drop': 246 | return param( 247 | { 248 | n: self[n] / other[n] 249 | for n in set(self.keys()).intersection(other.keys()) 250 | } 251 | ) 252 | # 保留自身的key 253 | elif MODE == 'keep_left': 254 | return param( 255 | { 256 | n: self[n] / other[n] 257 | if n in other 258 | else self[n] 259 | for n in (self.keys()) 260 | } 261 | ) 262 | # 保留对方的key 263 | elif MODE == 'keep_right': 264 | return param( 265 | { 266 | n: self[n] / other[n] 267 | if n in self 268 | else other[n] 269 | for n in (other.keys()) 270 | } 271 | ) 272 | 273 | raise NotImplementedError 274 | 275 | def assign(self, model: torch.nn.Module): 276 | device = model.device 277 | for n, p in model.named_parameters(): 278 | if n in self.param_dict: 279 | if p.shape != self.param_dict[n].shape: 280 | # for classifiers, default is num_labels=2 , probably has dimension mismatch 281 | print(f'>>> dimension mismatch! override model {n}') 282 | utils.rsetattr(model, n, torch.nn.Parameter(self.param_dict[n])) 283 | if 'classifier' in n: 284 | model.num_labels = self.param_dict[n].shape[0] 285 | print(f'>>> change num_labels to {model.num_labels}') 286 | continue 287 | # copy_shape < p.data.shape 是可以复制的 288 | p.data.copy_(self.param_dict[n]) 289 | model.to(device) 290 | 291 | def to(self, device): 292 | 293 | for n,p in self.param_dict.items(): 294 | # tensor is not inplace 295 | # but model is 296 | self.param_dict[n] = p.to(device) 297 | 298 | return self 299 | 300 | def map(self, func, desc): 301 | 302 | return param({ 303 | n: func(n, self.param_dict[n]) 304 | for n in tqdm.tqdm(self.param_dict.keys(), desc=f'Param Map {desc}') 305 | }) 306 | 307 | def flatten(self, ): 308 | # !importance self.param_dict.values() 无法保证确定性 309 | return nn.utils.parameters_to_vector( 310 | [p.flatten() for p in OrderedDict(sorted(self.param_dict.items())).values()] 311 | ) 312 | 313 | def unflatten(self, flatten_params): 314 | 315 | nn.utils.vector_to_parameters( 316 | flatten_params, 317 | OrderedDict(sorted(self.param_dict.items())).values() 318 | ) 319 | return self 320 | 321 | def __iter__(self): 322 | # 返回一个iter对象 323 | return iter(self.param_dict.items()) 324 | 325 | @staticmethod 326 | def vectorize_reduce(func, models_to_merge): 327 | return param({ 328 | # name: func([para1,para2, ...,paraN]) 329 | r[0][0]: func([rr[1] for rr in r]) 330 | for r in zip(*models_to_merge) 331 | }) -------------------------------------------------------------------------------- /generative/qwen_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | "r": 32, 3 | "lora_alpha": 16, 4 | "lora_dropout": 0.05, 5 | "task_type": "CAUSAL_LM", 6 | "target_modules": [ 7 | "w2", 8 | "c_proj", 9 | "c_attn", 10 | "w1" 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /generative/qwen_task.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LZY-the-boys/Twin-Merging/f481c60826cdf54c70f75f879f73ec68d22429df/generative/qwen_task.py -------------------------------------------------------------------------------- /generative/router.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import json 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | from datasets import load_dataset 7 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, GPTQConfig 8 | from tqdm import tqdm 9 | import itertools 10 | import torch 11 | import torch.distributed as dist 12 | from peft import PeftModel 13 | 14 | device_map = "auto" 15 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 16 | ddp = world_size != 1 17 | if ddp: 18 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 19 | 20 | def get_model_names(): 21 | taskname_2_id = { 22 | "mmlu": 0, 23 | "truthfulqa": 1, 24 | "bbq": 2, 25 | "cnn": 3, 26 | } 27 | modelname_2_id = {} 28 | model_names = [ 29 | "../qwen/qwen-mmlu", 30 | "../qwen/qwen-truthfulqa", 31 | "../qwen/qwen-bbq", 32 | "../qwen/qwen-cnn", 33 | ] 34 | for idx, i in enumerate(model_names): 35 | modelname_2_id[i] = idx 36 | return model_names, taskname_2_id 37 | 38 | nlg_data_id_map = { 39 | "mmlu": 0, 40 | "truthfulqa": 1, 41 | "bbq": 2, 42 | "cnn": 3, 43 | } 44 | rev_nlg_data_id_map = {value: key for key, value in nlg_data_id_map.items()} 45 | 46 | DATA = 'data' 47 | TEXT = 'data/test_data.json' 48 | 49 | def load_nlg(router_info, tokenizer): 50 | data_path = TEXT 51 | ans = from_jsonl(data_path) 52 | mmlu_data = [d for d in ans if d['task'] == 'mmlu'] 53 | truthfulqa_data = [d for d in ans if d['task'] == 'truthfulqa'] 54 | bbq_data = [d for d in ans if d['task'] == 'bbq'] 55 | cnn_data = [d for d in ans if d['task'] == 'cnn-dm'] 56 | datalist = [mmlu_data, truthfulqa_data, bbq_data, cnn_data] #note the sequences' difference 57 | new_datasets = [] 58 | tot_num = -1 59 | for idx, data in enumerate(datalist): 60 | task_name = rev_nlg_data_id_map[idx] 61 | for i in data: 62 | tot_num += 1 63 | inputs = tokenizer(i["prompt"], return_tensors="pt") 64 | new_datasets.append({ 65 | "sentence": i["prompt"], 66 | "router_prob": router_info[tot_num].tolist(), 67 | "dataset_ids": idx, 68 | "input_ids": inputs["input_ids"][0].numpy().tolist(), 69 | "attention_mask": inputs["attention_mask"][0].numpy().tolist() 70 | }) 71 | return new_datasets 72 | 73 | def get_ori_datasets(mode="train"): 74 | if mode == "test": 75 | data_path = TEXT 76 | ans = from_jsonl(data_path) 77 | mmlu_data = [d for d in ans if d['task'] == 'mmlu'] 78 | truthfulqa_data = [d for d in ans if d['task'] == 'truthfulqa'] 79 | bbq_data = [d for d in ans if d['task'] == 'bbq'] 80 | cnn_data = [d for d in ans if d['task'] == 'cnn-dm'] 81 | data_list = [mmlu_data, truthfulqa_data, bbq_data, cnn_data] 82 | else: 83 | mmlu_data = load_lukaemon_mmlu(mode=mode) #load_mmlu 84 | truthfulqa_data = load_truthfulqa(mode=mode) 85 | bbq_data = load_bbq(mode=mode) 86 | cnn_data = load_cnn_dm(mode=mode) 87 | data_list = [ 88 | mmlu_data, 89 | truthfulqa_data, 90 | bbq_data, 91 | cnn_data, 92 | ] 93 | all_dataset = {} 94 | task_name = [ 95 | "mmlu", 96 | "truthfulqa", 97 | "bbq", 98 | "cnn", 99 | ] 100 | data_num = 0 101 | for idx, i in enumerate(data_list): 102 | all_dataset[task_name[idx]] = {"input": []} 103 | for jdx, j in enumerate(i['input'] if not isinstance(i, list) else i): 104 | if mode == "train" and jdx >= min(len(i["input"]), 1000): break 105 | if isinstance(j, dict) and "prompt" in j: 106 | all_dataset[task_name[idx]]["input"].append(j["prompt"]) 107 | else: 108 | all_dataset[task_name[idx]]["input"].append(j) 109 | data_num += 1 110 | return all_dataset, data_num 111 | 112 | class SimpleMLP(nn.Module): 113 | 114 | def __init__(self, num_clients, embedding_dims, hidden_dim=1024): 115 | super(SimpleMLP, self).__init__() 116 | self.fc1 = nn.Linear(embedding_dims, hidden_dim) 117 | self.bn1 = nn.BatchNorm1d(hidden_dim) 118 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 119 | self.bn3 = nn.BatchNorm1d(hidden_dim) 120 | self.fc4 = nn.Linear(hidden_dim, num_clients) 121 | self.dropout = nn.Dropout(p=0.5) 122 | self.criterion = nn.CrossEntropyLoss() 123 | 124 | def forward(self, input, labels=None): 125 | x = input.float() 126 | x = self.fc1(x) 127 | x = self.bn1(x) 128 | x = F.leaky_relu(x) 129 | x = self.dropout(x) 130 | x = self.fc3(x) 131 | x = self.bn3(x) 132 | x = F.leaky_relu(x) 133 | x = self.dropout(x) 134 | x = self.fc4(x) 135 | 136 | if labels is not None: 137 | loss = self.criterion(x, labels) 138 | return loss, x 139 | return x 140 | 141 | class RouterDataset(Dataset): 142 | 143 | def __init__(self, data, targets): 144 | self.data = data 145 | self.targets = targets 146 | 147 | def __len__(self): 148 | return len(self.data) 149 | 150 | def __getitem__(self, idx): 151 | img, target = self.data[idx], self.targets[idx] 152 | return { 153 | 'input': img, 154 | 'label': target, 155 | } 156 | 157 | def load_dataset(): 158 | 159 | train_data = np.load(f'data/router_train.npz',allow_pickle=True) 160 | train_dataset = RouterDataset( 161 | data = [ 162 | v 163 | for k in train_data.files 164 | for v in train_data[k] 165 | ], 166 | targets = [ 167 | nlg_data_id_map[k] 168 | for k in train_data.files 169 | for _ in range(len(train_data[k])) 170 | ] 171 | ) 172 | test_data = np.load(f'data/router_test.npz') 173 | test_dataset = RouterDataset( 174 | data = [ 175 | v 176 | for k in test_data.files 177 | for v in test_data[k] 178 | ], 179 | targets = [ 180 | nlg_data_id_map[k] 181 | for k in test_data.files 182 | for _ in range(len(test_data[k])) 183 | ] 184 | ) 185 | return { 186 | 'train': train_dataset, 187 | 'test': test_dataset, 188 | } 189 | 190 | def train_router( 191 | in_domain = None, # dict 192 | embed_dims = 5120, 193 | ): 194 | encoded_dataset = load_dataset() 195 | task_num = 4 196 | if in_domain is not None: 197 | raise Exception('Not Implemented yet') 198 | 199 | device_map = "auto" 200 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 201 | ddp = world_size != 1 202 | device = int(os.environ.get("LOCAL_RANK") or 0) 203 | if ddp: 204 | device_map = {"": device} 205 | 206 | classifier = SimpleMLP( 207 | num_clients=task_num, embedding_dims=embed_dims, hidden_dim=2*embed_dims 208 | ).to(device) 209 | 210 | def compute_metrics(eval_pred): 211 | logits, labels = eval_pred 212 | logits, labels = torch.tensor(logits), torch.tensor(labels) 213 | predictions = torch.argmax((logits), dim=-1) 214 | 215 | total = len(labels) 216 | correct_list = [0] * task_num 217 | total_list = [0] * task_num 218 | 219 | # total acc 220 | correct = predictions.eq((labels)).sum().item() 221 | acc = correct / total * 100.0 222 | print( 223 | "@@ Final {}/{}, Accuracy: {:.2f}%".format( 224 | correct, total, acc 225 | ) 226 | ) 227 | # acc per class 228 | for i in range(4): 229 | correct_list[i] = ((labels == i) & (predictions == i)).sum().item() 230 | total_list[i] = (labels == i).sum().item() 231 | acc_prop = [correct_list[i] / total_list[i] * 100.0 if total_list[i] > 0 else 0 for i in range(task_num)] 232 | print("Correct list: ", correct_list) 233 | print("Accuracy proportion: ", acc_prop) 234 | return { 235 | "accuracy": correct / total, 236 | "accuracy_per_class": acc_prop 237 | } 238 | 239 | trainer = transformers.Trainer( 240 | model=classifier, 241 | args=transformers.TrainingArguments( 242 | output_dir="./data/router", 243 | evaluation_strategy="epoch", 244 | save_strategy='epoch', 245 | learning_rate=0.0005, 246 | per_device_train_batch_size=256, 247 | per_device_eval_batch_size=256, 248 | num_train_epochs=50, # 10 249 | # weight_decay=1e-4, 250 | logging_steps=20, 251 | save_total_limit=1, 252 | report_to=[], 253 | load_best_model_at_end=True, 254 | metric_for_best_model="accuracy", 255 | greater_is_better=True, 256 | ddp_find_unused_parameters=False 257 | ), 258 | train_dataset=encoded_dataset["train"], 259 | eval_dataset=encoded_dataset["test"], 260 | compute_metrics=compute_metrics, 261 | ) 262 | trainer.train() 263 | trainer.save_model("./data/router") 264 | prediction = trainer.predict(encoded_dataset["test"], metric_key_prefix='') 265 | 266 | new_datasets = load_nlg( 267 | tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-14B', trust_remote_code=True), 268 | router_info=prediction.predictions 269 | ) 270 | json.dump(new_datasets, open('data/test_router.json','w'), ensure_ascii=False) 271 | 272 | 273 | @torch.inference_mode() 274 | def generate_router_datasets( 275 | base_model, 276 | shared_expert, 277 | mode='eval', 278 | ): 279 | 280 | rank = dist.get_rank() 281 | world_size = dist.get_world_size() 282 | device = rank % torch.cuda.device_count() 283 | torch.cuda.set_device(device) 284 | print(f"Starting rank={rank}, world_size={dist.get_world_size()}.") 285 | 286 | config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) 287 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 288 | model = AutoModelForCausalLM.from_pretrained( 289 | base_model, 290 | torch_dtype=torch.float16, 291 | device_map=device_map, 292 | trust_remote_code=True 293 | ) 294 | model = PeftModel.from_pretrained( 295 | model, 296 | shared_expert, 297 | torch_dtype=torch.float16, 298 | ) 299 | 300 | datasets, data_num = get_ori_datasets(mode=mode) 301 | 302 | dist.barrier() 303 | 304 | ans = {} 305 | for data_id, (category_name, data_class) in enumerate(datasets.items()): 306 | data_item = data_class["input"] 307 | res = [] 308 | for index, data_input in tqdm( 309 | itertools.islice( 310 | enumerate(data_item), rank, len(data_item), world_size 311 | ), 312 | disable= device != 0, 313 | total = len(data_item) // world_size + 1, 314 | ): 315 | inputs = tokenizer(data_input, return_tensors="pt").to(device) 316 | content = model.generate( 317 | input_ids=inputs["input_ids"], 318 | num_beams=1, 319 | do_sample=True, 320 | return_dict_in_generate=True, 321 | output_scores=True, 322 | output_hidden_states=True, 323 | max_new_tokens=16 324 | ) 325 | res.append(( 326 | index, 327 | torch.mean(content.hidden_states[-1][-1][0, :, :], dim=0).to(torch.float16).cpu().numpy() 328 | )) 329 | 330 | dist.barrier() 331 | global_res = [None] * world_size 332 | dist.all_gather_object(global_res, res) 333 | 334 | if device == 0: 335 | global_res = sorted([rr for r in global_res for rr in r], key=lambda x: x[0]) 336 | ans[category_name] = [r[1] for r in global_res] 337 | 338 | if device == 0: 339 | np.savez(f'data/router_{mode}.npz', **ans) 340 | 341 | dist.barrier() 342 | 343 | 344 | def main( 345 | *, 346 | base_model: str = 'Qwen/Qwen-14B', 347 | shared_expert: str = None, 348 | seed: int = 0, 349 | train: bool = False 350 | ): 351 | 352 | fix_seed(seed) 353 | 354 | if not os.path.exists('data/test_data.json'): 355 | raise Exception('You should run gen_eval_data first to get test data') 356 | 357 | if not os.path.exists('data/router_test.npz'): 358 | assert shared_expert is not None 359 | if not dist.is_initialized(): 360 | dist.init_process_group("nccl") 361 | for mode in ['test']: 362 | generate_router_datasets( 363 | mode=mode, 364 | base_model=base_model, 365 | shared_expert=shared_expert 366 | ) 367 | dist.destroy_process_group() 368 | 369 | if train: 370 | config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) 371 | train_router(embed_dims=config.hidden_size) 372 | print('Train Done') 373 | 374 | if __name__ == '__main__': 375 | import defopt 376 | defopt.run(main) -------------------------------------------------------------------------------- /generative/run_merge.py: -------------------------------------------------------------------------------- 1 | from email.mime import base 2 | import torch 3 | from collections import defaultdict, OrderedDict 4 | import tqdm 5 | import re 6 | import torch.nn as nn 7 | import copy 8 | import sparsify 9 | import utils 10 | import json 11 | import sys 12 | import transformers 13 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer 14 | import os 15 | import functools 16 | from peft import LoraConfig,get_peft_model 17 | from model import load_causallm 18 | from collections import defaultdict, OrderedDict 19 | from param import param 20 | import torch.nn.functional as F 21 | import torch 22 | from collections import defaultdict 23 | import numpy as np 24 | from merge import MergingMethod,LoraMergingMethod 25 | import inspect 26 | import datasets 27 | import pandas as pd 28 | from safetensors.torch import load_file 29 | 30 | args = None 31 | DEVICE='cuda:0' 32 | 33 | 34 | @torch.inference_mode() 35 | def run_merge( 36 | args, 37 | ): 38 | if args.exclude_param and len(args.exclude_param): 39 | filter_func = lambda n,p : not any([ 40 | re.match(exclude_pattern, n) 41 | for exclude_pattern in args.exclude_param 42 | ]) 43 | # \theta_t 44 | models_finetuned = { 45 | name: load_causallm(name) for name in args.models_name 46 | } 47 | # \theta_* 48 | models_to_merge = [ 49 | models_finetuned[name] 50 | for name in args.src_merge 51 | ] 52 | base_model = load_causallm(args.base_model) 53 | # tokenizer = AutoTokenizer.from_pretrained(args.base_model) 54 | 55 | args.base_model = param(base_model) 56 | args.models_to_merge = [param(m) for m in models_to_merge] 57 | for model in args.models_to_merge: 58 | model.filter(filter_func) 59 | args.base_model.filter(filter_func) 60 | 61 | # 3. merge 62 | merger = MergingMethod(**args) 63 | merge_method = getattr(merger, args.merge_method) 64 | merged_param = merge_method(**args) 65 | 66 | for n, p in merged_param.param_dict.items(): 67 | utils.rsetattr(base_model, n, torch.nn.Parameter(p, requires_grad=False)) 68 | 69 | base_model.save_pretrained(args.outdir) 70 | 71 | @torch.inference_mode() 72 | def run_merge_lora( 73 | args, 74 | ): 75 | 76 | if args.exclude_param and len(args.exclude_param): 77 | filter_func = lambda n,p : not any([ 78 | re.match(exclude_pattern, n) 79 | for exclude_pattern in args.exclude_param 80 | ]) 81 | 82 | # one example for Qwen LoRA, feel free to custom change 83 | peft_config = LoraConfig(**json.load(open(args.lora))) 84 | 85 | def load(model_path): 86 | try: 87 | ans = torch.load( 88 | os.path.join(model_path, 'adapter_model.bin') 89 | ) 90 | except: 91 | ans = load_file(os.path.join(model_path, 'adapter_model.safetensors')) 92 | return ans 93 | 94 | # \theta_t 95 | models_finetuned = { 96 | name: load(name) for name in args.models_name 97 | } 98 | models_to_merge = [ 99 | models_finetuned[name] 100 | for name in args.src_merge 101 | ] 102 | 103 | base_model = load_causallm(args.base_model).to(DEVICE) 104 | base_model = get_peft_model(base_model, peft_config, adapter_name='merged') 105 | 106 | args.base_model = param(base_model) 107 | 108 | args.models_to_merge = [param(m).to(DEVICE) for m in models_to_merge] 109 | for model in args.models_to_merge: 110 | model.filter(filter_func) 111 | args.base_model.filter(filter_func) 112 | 113 | # 3. merge 114 | merger = LoraMergingMethod(**args) 115 | merge_method = getattr(merger, args.merge_method) 116 | merged_param = merge_method(**args) 117 | 118 | for n, p in merged_param.param_dict.items(): 119 | n = n.replace('lora_B', 'lora_B.merged') 120 | n = n.replace('lora_A', 'lora_A.merged') 121 | utils.rsetattr(base_model, n, torch.nn.Parameter(p, requires_grad=False)) 122 | 123 | base_model.merge_and_unload(progressbar=True) 124 | base_model.save_pretrained(args.outdir) 125 | 126 | 127 | def main( 128 | *, 129 | models_to_merge: list[str], 130 | models_name: list[str], 131 | src_merge: list[str], 132 | yaml_file: str = None, 133 | exclude_param: list[str] = None, 134 | data_path: str = None, 135 | seed: int=10, 136 | base_model: str = 'roberta-base', 137 | # for task-arithmetic_search: 138 | scaling: list[float] = None, 139 | # for dare-merge: 140 | mask_rate: float = None, 141 | mask_scale: float = None, 142 | mask_strategy: str = None, 143 | outdir: str = None, 144 | lora: str = None, 145 | ): 146 | 147 | global args 148 | keys, _, _, values = inspect.getargvalues(inspect.currentframe()) 149 | 150 | utils.fix_seed(seed) 151 | 152 | merge_config = utils.from_yaml(yaml_file) 153 | args = { 154 | k: values.get(k, merge_config.get(k)) 155 | for k in set(keys).union(merge_config) 156 | } 157 | args = { 158 | k: merge_config.get(k, None) 159 | if args[k] is None else args[k] 160 | for k in args.keys() 161 | } 162 | args = utils.SimpleNamespace(**args) 163 | 164 | print('>>> args\n', args) 165 | 166 | if args.scaling is not None and isinstance(args.scaling, list) and len(args.scaling) == 1: 167 | args.scaling = args.scaling[0] 168 | 169 | if args.lora: 170 | run_merge_lora(args) 171 | else: 172 | run_merge(args) 173 | 174 | 175 | if __name__ == '__main__': 176 | import defopt 177 | defopt.run(main) -------------------------------------------------------------------------------- /generative/scripts.sh: -------------------------------------------------------------------------------- 1 | set -e pipefail 2 | 3 | outdir=${outdir:="outs/qwen_merged"} 4 | mkdir -p ${outdir} 5 | 6 | models_to_merge=( 7 | ../qwen/qwen-mmlu 8 | ../qwen/qwen-truthfulqa 9 | ../qwen/qwen-bbq 10 | ../qwen/qwen-cnn 11 | ) 12 | 13 | function run_avg_merge(){ 14 | 15 | pos 16 | 17 | python run_merge.py \ 18 | --models-to-merge ${models_to_merge[@]} \ 19 | --models-name ${models_to_merge[@]} \ 20 | --src-merge ${models_to_merge[@]} \ 21 | --base-model "Qwen/Qwen-14B" \ 22 | --yaml-file config/average_merge.yml \ 23 | --outdir $outdir \ 24 | --lora 'qwen_lora.json' 25 | 26 | } 27 | 28 | function run_dare_task_arith(){ 29 | 30 | pos 31 | 32 | for i in 0.7 ; do 33 | 34 | python run_merge.py \ 35 | --models-to-merge ${models_to_merge[@]} \ 36 | --models-name ${models_to_merge[@]} \ 37 | --src-merge ${models_to_merge[@]} \ 38 | --base-model "Qwen/Qwen-14B" \ 39 | --yaml-file config/dare_merge.yml \ 40 | --mask-rate $i \ 41 | --outdir $outdir \ 42 | --lora 'qwen_lora.json' 43 | 44 | done 45 | 46 | } 47 | 48 | function run_task_arith(){ 49 | 50 | for j in 0.3; do 51 | 52 | python run_merge.py \ 53 | --models-to-merge ${models_to_merge[@]} \ 54 | --models-name ${models_to_merge[@]} \ 55 | --base-model "Qwen/Qwen-14B" \ 56 | --src-merge ${models_to_merge[@]} \ 57 | --yaml-file config/task_arithmetic.yml \ 58 | --scaling $j \ 59 | --outdir $outdir \ 60 | --lora 'qwen_lora.json' 61 | 62 | done 63 | 64 | } 65 | 66 | function run_tie(){ 67 | 68 | pos 69 | 70 | 71 | for i in 0.7; do 72 | for j in 0.3; do 73 | 74 | python run_merge.py \ 75 | --models-to-merge ${models_to_merge[@]} \ 76 | --models-name ${models_to_merge[@]} \ 77 | --src-merge ${models_to_merge[@]} \ 78 | --base-model "Qwen/Qwen-14B" \ 79 | --yaml-file config/ties_merge.yml \ 80 | --mask-rate $i \ 81 | --scaling $j \ 82 | --outdir $outdir \ 83 | --lora 'qwen_lora.json' 84 | 85 | done 86 | done 87 | 88 | } 89 | 90 | -------------------------------------------------------------------------------- /generative/sparsify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def magnitude( 5 | tensor: torch.Tensor, 6 | density: float, 7 | **kwargs, 8 | ) -> torch.Tensor: 9 | """Masks out the smallest values, retaining a proportion of `density`.""" 10 | if density >= 1: 11 | return tensor 12 | if len(tensor.shape) == 1: 13 | # rank=1 14 | return tensor 15 | 16 | k = int(density * tensor.view(-1).shape[0]) 17 | 18 | assert k > 0, "not gonna zero out the whole tensor buddy" 19 | mask = torch.zeros_like(tensor) 20 | w = tensor.abs().view(-1) 21 | if w.device.type == "cpu": 22 | w = w.float() 23 | topk = torch.topk(w, k=k, largest=True) 24 | mask.view(-1)[topk.indices] = 1 25 | 26 | return tensor * mask 27 | 28 | 29 | def bernoulli( 30 | tensor: torch.Tensor, 31 | density: float, # 1 - mask_rate (probability of drawing "1") 32 | rescale: bool = True 33 | ) -> torch.Tensor: 34 | if density >= 1: 35 | return tensor 36 | if density <= 0: 37 | return torch.zeros_like(tensor) 38 | if len(tensor.shape) == 1: 39 | # rank=1 40 | return tensor 41 | 42 | # mask = 1 - torch.bernoulli( 43 | # torch.full_like(input=tensor, fill_value=1 - density) 44 | # ) 45 | mask = torch.bernoulli( 46 | torch.full_like(input=tensor, fill_value=density).float() 47 | ) 48 | 49 | res = tensor * mask 50 | if rescale: 51 | res *= 1 / density 52 | return res 53 | 54 | def svd( 55 | tensor: torch.Tensor, 56 | density: float, 57 | **kwargs, 58 | ): 59 | if density >= 1: 60 | return tensor 61 | if density <= 0: 62 | return torch.zeros_like(tensor) 63 | if kwargs.get('new_rank', None) == 0: 64 | return torch.zeros_like(tensor) 65 | if len(tensor.shape) == 1: 66 | # rank=1 67 | return tensor 68 | 69 | # U, S, V = torch.svd(tensor) 70 | # S = (S >= S[int(len(S) * density)]) * S 71 | # res = U @ torch.diag(S) @ V.T 72 | 73 | # `torch.linalg.svd()`: good for dense matrix 74 | # `torch.svd()`: deprecated 75 | # `torch.svd_lowrank()`: good for huge sparse matrix 76 | driver = None 77 | if tensor.is_cuda: 78 | driver = 'gesvda' 79 | 80 | U, S, Vh = torch.linalg.svd(tensor, full_matrices=True, driver=driver) 81 | if 'new_rank' not in kwargs: 82 | new_rank = int(density * len(S)) 83 | else: 84 | new_rank = kwargs['new_rank'] 85 | U, S, Vh = U[:, :new_rank], S[:new_rank], Vh[:new_rank, :] 86 | res = U @ torch.diag(S) @ Vh 87 | return res -------------------------------------------------------------------------------- /method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LZY-the-boys/Twin-Merging/f481c60826cdf54c70f75f879f73ec68d22429df/method.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 2 | accelerate==0.25.0 3 | auto-gptq==0.4.2 4 | datasets==2.14.6 5 | deepspeed==0.11.1 6 | evaluate==0.4.0 7 | faiss-gpu 8 | fastapi 9 | fire==0.5.0 10 | flash-attn==2.3.3 11 | gradio 12 | huggingface-hub==0.20.1 13 | jieba==0.42.1 14 | jsonlines 15 | matplotlib 16 | nltk==3.8.1 17 | numpy 18 | nvitop 19 | openai 20 | -e git+https://github.com/LZY-the-boys/peft@953a014fb76211cfd5e2c9f0d1497731e101d4ad#egg=peft 21 | pydantic==1.10.13 22 | pydantic_core==2.10.1 23 | rouge 24 | rouge-score 25 | safetensors==0.4.0 26 | scikit-learn 27 | scipy==1.5.4 28 | seaborn==0.13.0 29 | sentence-transformers==2.2.2 30 | sentencepiece==0.1.99 31 | tokenizers==0.15.0 32 | torchaudio==2.0.2 33 | torchmetrics==0.9.0 34 | torchvision==0.15.2 35 | tqdm 36 | transformers==4.36.2 37 | uvicorn==0.23.2 38 | tabulate 39 | defopt 40 | openpyxl 41 | einops 42 | transformers_stream_generator 43 | tiktoken 44 | fastapi 45 | wandb --------------------------------------------------------------------------------