├── .gitignore ├── README.md ├── resource ├── MM4Long2Short.pdf ├── fig1.png ├── plot_radar_7b.png └── rank_plot_7b.png └── src ├── evaluation ├── data_process.py └── l2s_eval.sh ├── main_merging.py ├── mask_weights_utils.py ├── merging_methods.py ├── task_vector.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unlocking Efficient Long-to-Short LLM Reasoning with Model Merging 2 | 3 | [[ArXiv]](http://arxiv.org/abs/2503.20641) | [[Latest PDF]](resource/MM4Long2Short.pdf) 4 | 5 | ![overall figures](resource/fig1.png) 6 | 7 | ## News 🔔 🔔 🔔 8 | 9 | - 🔥 [04 June, 2025] We found the fantastic performance of TIES-Merging on shortening the output length of DeepSeek-R1-0528-Qwen3-8B. 10 | > [!IMPORTANT] 11 | > **Reducing around 50% length with performance improvement 4 points on AIME24! AND preserving the no_think ability!** 12 | - 💡 [27 March, 2025] We release the code of Average Merging, Task Arithmetic, Ties-Merging and DARE. Special thanks to [MergeLM](https://github.com/yule-BUAA/MergeLM) for their great work. We use this repo as our codebase. 13 | - 📣 [26 March, 2025] Our work is available on [[ArXiv]](http://arxiv.org/abs/2503.20641). 14 | 15 | ## 🔥 Early Test on DeepSeek-R1-0528-Qwen3-8B 16 | 17 | We directly merge the R1-Qwen3-8B with the Qwen3-8B (as base) models using TIES-Merging with `k = 0.7, α = 0.7`. We sample 16 answers for each question and calculate the the average score (`generation_parameters: max_new_tokens = 32768, temperature = 0.6, top_p = 0.95, top_k = 20}`). 18 | 19 | | Model | R1-0528-Qwen3-8B | Merged Qwen3-8B | 20 | |------------------|----------------|--------------------| 21 | | AIME24 | 70.63 (11328.25) | 74.58 (6448.5) | 22 | 23 | According to [Qwen3's guidelines](https://huggingface.co/Qwen/Qwen3-8B), there are two ways to achieve the switch between the /think and /no_think modes, i.e. `enable_thinking=False|True` and appending `/think | /no_think` to the instruction. 24 | 25 | - Extensive experiments on R1-Qwen3-8B reveal that the no_think capability has been completely diminished in both modes, resulting in excessively lengthy responses. 26 | - Our merged model perserves the no_think ability according to our testing. We found that the model can directly generate the answer part by setting `enable_thinking=False`. However, the alternative switch mode triggered by the `/no_think` keyword appears to fail in most cases. 27 | 28 | *Stay tuned to our more new results!* 29 | 30 | ## Summary of our findings 🔥🔥🔥: 31 | 32 | - Model merging is a highly efficient approach for long-to-short reasoning, as it directly operates on model parameters **without requiring additional training**. 33 | 34 | - Task-vector based merging methods, especially like TA and Ties-Merging, can achieve long-to-short reasoning with around **50\% length reduction** alongside **accuracy parity or even marginal gains** on 7B models. 35 | 36 | - SVD-based merging methods exhibit limited effectiveness, delivering moderate performance and serving as viable alternatives only when task vectors inherently possess low-rank spectral characteristics. 37 | 38 | - Activation-based merging is the future, as it demonstrates impressive performance in terms of both reasoning accuracy (+1.9) and response length compression ratios (-49.8\%). 39 | 40 | - Model merging methods applied to 1.5B-scale models remain effective on simple tasks. Smaller models struggle to learn long CoT reasoning ability through model merging. 41 | 42 | - The merging of large-scale models (14B and 32B) poses significant challenges in simultaneously maintaining reasoning performance while substantially reducing response length. 43 | 44 | ## Related Work 📑 45 | 46 | - Average Merging: [Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time](https://arxiv.org/abs/2203.05482) 47 | 48 | - Task Arithmetic: [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089) 49 | 50 | - Ties-Merging: [TIES-Merging: Resolving Interference When Merging Models](https://arxiv.org/abs/2306.01708) 51 | 52 | - DARE: [Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch](https://arxiv.org/abs/2311.03099) 53 | 54 | - LoRE-Merging: [LoRE-Merging: Exploring Low-Rank Estimation For Large Language Model Merging](https://arxiv.org/abs/2502.10749) 55 | 56 | - Twin-Merging: [Twin-Merging: Dynamic Integration of Modular Expertise in Model Merging](https://arxiv.org/abs/2406.15479) 57 | 58 | - AIM: [Activation-Informed Merging of Large Language Models](https://arxiv.org/abs/2502.02421) 59 | 60 | - Sens-Merging: [Sens-Merging: Sensitivity-Guided Parameter Balancing for Merging Large Language Models](https://arxiv.org/abs/2502.12420) 61 | 62 | ## Environments 63 | For merging methods: 64 | ```angular2html 65 | numpy==1.26.4 66 | torch==2.5.1 67 | transformers==4.48.2 68 | ``` 69 | 70 | For the evaluation, we recommend following the usage of [[Qwen2.5-Math Eval Toolkit]](https://github.com/QwenLM/Qwen2.5-Math). 71 | 72 | ## Implementation 73 | 74 | Our implementation is adapted from [MergeLM](https://github.com/yule-BUAA/MergeLM). We optimize the code of some merging methods, such as TA and Ties-Merging, for computational efficiency. 75 | 76 | ### Average merging 77 | 78 | ```shell 79 | python src/main_merging.py --merge_method average_merging \ 80 | --output_dir DIR \ 81 | --base_model MODEL_PATH \ 82 | --models_to_merge MODEL_PATH1,MODEL_PATH2,...,MODEL_PATHn 83 | ``` 84 | 85 | ### Task Arithmetic 86 | ```shell 87 | python src/main_merging.py --merge_method task_arithmetic \ 88 | --output_dir DIR \ 89 | --base_model MODEL_PATH \ 90 | --models_to_merge MODEL_PATH1,MODEL_PATH2,...,MODEL_PATHn \ 91 | --scaling_coefficient α 92 | ``` 93 | 94 | ### Ties-Merging 95 | ```shell 96 | python src/main_merging.py --merge_method ties_merging/ties_merging_dare \ 97 | --output_dir DIR \ 98 | --base_model MODEL_PATH \ 99 | --models_to_merge MODEL_PATH1,MODEL_PATH2,...,MODEL_PATHn \ 100 | --scaling_coefficient α \ 101 | --param_value_mask_rate k 102 | ``` 103 | 104 | Note: You can choose to use `ties_merging` that we optimize the implementation for better computational efficiency or `ties_merging_dare` which is the original implementation in MergeLM. We have compared two implementations and find the results are comparable. 105 | 106 | ### DARE 107 | ```shell 108 | python src/main_merging.py --merge_method mask_merging \ 109 | --output_dir DIR \ 110 | --base_model MODEL_PATH \ 111 | --models_to_merge MODEL_PATH1,MODEL_PATH2,...,MODEL_PATHn \ 112 | --scaling_coefficient α \ 113 | --param_value_mask_rate k \ 114 | --mask_apply_method [average_merging || task_arithmetic || ties_merging || ties_merging_dare] \ 115 | --weight_mask_rates p 116 | ``` 117 | 118 | ### AIM & Sens-Merging 119 | We will release the code of activation-based methods soon. Stay tuned! 120 | 121 | ### Evaluations 122 | 123 | ```shell 124 | git clone https://github.com/QwenLM/Qwen2.5-Math.git 125 | cd Qwen2.5-Math-main/evaluation 126 | ``` 127 | 128 | Move `src/evaluation/data_process.py` and `src/evaluation/l2s_eval.sh` as following: 129 | 130 | ```markdown 131 | Qwen2.5-Math-Main/evaluation 132 | ├── sh 133 | │ ├── l2s_eval.sh 134 | ├── data 135 | │ ├── math500 136 | ├── outputs 137 | │ ├── ... 138 | ├── data_process.py 139 | └── ... 140 | ``` 141 | 142 | Note: [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) are not in original database. You should manually add it to `Qwen2.5-Math-Main/evaluation/data`. 143 | 144 | Run the evaluation: 145 | ```shell 146 | CUDA_VISIBLE_DEVICES="0,1,2,3" bash sh/l2s_eval.sh [PROMPT_TYPE:qwen25-math-cot] [MODEL_PATH] [MAX_TOKEN:10240] [NUM_SHOTS:0] [DATASETS:aime24,math500,gsm8k,college_math,minerva_math,olympiadbench] 147 | ``` 148 | 149 | To make sure the reproducibility of the results, we set `temperature=0`, `top_p=1`. 150 | 151 | ## Configurations 152 | 153 | | Method | 1.5B | 7B | 14B | 32B | 154 | |------------------|----------------|--------------------|--------------------|--------------------| 155 | | Task Arithmetic | α = 0.7 | α = 0.7 | α = 0.7 | α = 0.7 | 156 | | Ties-Merging | k = 0.8, α = 1.0 | k = 0.8, α = 1.0 | k = 0.2, α = 0.5 | k = 0.25, α = 0.55 | 157 | | DARE | p = 0.3 | p = 0.3 | p = 0.4 | - | 158 | | AIM-Ties | ω = 0.4 | ω = 0.4 | ω = 0.4 | - | 159 | | Sens-Merging | α = 0.4, T = 3.0 | α = 0.7, T = 2.0 | α = 0.8, T = 6.0 | - | 160 | 161 | *Table: The hyper-parameters of various merging methods. α means the coefficient in TA merging. p means the drop rate in DARE. k denotes the trim ratio in Ties-Merging. ω means the balance factor in AIM. T is the temperature in Sens-Merging.* 162 | 163 | 164 | ## Citation 165 | ``` 166 | @article{wu2025unlockingefficientlongtoshortllm, 167 | title={Unlocking Efficient Long-to-Short LLM Reasoning with Model Merging}, 168 | author={Han Wu and Yuxuan Yao and Shuqi Liu and Zehua Liu and Xiaojin Fu and Xiongwei Han and Xing Li and Hui-Ling Zhen and Tao Zhong and Mingxuan Yuan}, 169 | year={2025}, 170 | eprint={2503.20641}, 171 | archivePrefix={arXiv}, 172 | primaryClass={cs.CL}, 173 | url={https://arxiv.org/abs/2503.20641}, 174 | } 175 | ``` 176 | -------------------------------------------------------------------------------- /resource/MM4Long2Short.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahahawu/Long-to-Short-via-Model-Merging/6bc078e4be5da9e6e58254653bf24497d3958f60/resource/MM4Long2Short.pdf -------------------------------------------------------------------------------- /resource/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahahawu/Long-to-Short-via-Model-Merging/6bc078e4be5da9e6e58254653bf24497d3958f60/resource/fig1.png -------------------------------------------------------------------------------- /resource/plot_radar_7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahahawu/Long-to-Short-via-Model-Merging/6bc078e4be5da9e6e58254653bf24497d3958f60/resource/plot_radar_7b.png -------------------------------------------------------------------------------- /resource/rank_plot_7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahahawu/Long-to-Short-via-Model-Merging/6bc078e4be5da9e6e58254653bf24497d3958f60/resource/rank_plot_7b.png -------------------------------------------------------------------------------- /src/evaluation/data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import transformers 8 | import re 9 | 10 | from transformers import AutoTokenizer 11 | 12 | transformers.utils.logging.set_verbosity_error() 13 | 14 | 15 | def get_avg_length(model_path, result_path): 16 | result_path = Path(result_path) 17 | tokenizer = AutoTokenizer.from_pretrained(model_path) 18 | keywords = ['wait', 're-examine', 'double-check', 'let me check', 'recap', 'let me just verify', 'let me just check'] 19 | pattern = r'|'.join(re.escape(keyword) for keyword in keywords) 20 | 21 | for dataset_name in os.listdir(result_path): 22 | print(dataset_name) 23 | dataset_dir = result_path / dataset_name 24 | for file_name in os.listdir(dataset_dir): 25 | if not file_name.endswith('.jsonl'): 26 | continue 27 | print(f'\t{file_name}', end='') 28 | length_list = [] 29 | kw_freq_list = [] 30 | level_length_map = {} 31 | level_acc_map = {} 32 | level_reflection_map = {} 33 | reflection_cnt = 0 34 | with open(dataset_dir / file_name) as f: 35 | for line_data in f: 36 | line_data = json.loads(line_data) 37 | length = len(tokenizer(line_data['code'][0])['input_ids']) 38 | length_list.append(length) 39 | keywords_match = re.findall(pattern, line_data["code"][0], re.IGNORECASE) 40 | if len(keywords_match) > 0: 41 | kw_freq_list.append(len(keywords_match)) 42 | reflection_cnt += 1 43 | if 'level' in line_data: 44 | level_length_map.setdefault(line_data['level'], []).append(length) 45 | level_acc_map.setdefault(line_data['level'], []).append(int(line_data["score"][0])) 46 | level_reflection_map.setdefault(line_data['level'], []).append(int(len(keywords_match) > 0)) 47 | print(f'\t{round(np.mean(length_list), 2)}[{reflection_cnt}]; [{round(np.mean(kw_freq_list), 1) if kw_freq_list else 0}]') 48 | if level_length_map: 49 | for level, level_length_list in sorted(level_length_map.items()): 50 | print( 51 | f'\t\tlevel-{level}: {np.mean(level_length_list)};\t{round(np.mean(level_acc_map[level]) * 100, 1)};\t[{round(np.mean(level_reflection_map[level]), 3)}]') 52 | 53 | 54 | if __name__ == '__main__': 55 | _, _model_path, _result_path = sys.argv 56 | get_avg_length(_model_path, _result_path) 57 | -------------------------------------------------------------------------------- /src/evaluation/l2s_eval.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | PROMPT_TYPE=$1 4 | MODEL_NAME_OR_PATH=$2 5 | MAX_TOKEN=$3 6 | NUM_SHOTS=$4 7 | DATASETS=$5 8 | OUTPUT_DIR=${MODEL_NAME_OR_PATH}/math_eval 9 | 10 | SPLIT="test" 11 | NUM_TEST_SAMPLE=-1 12 | 13 | # English open datasets 14 | DATA_NAME=${DATASETS} 15 | TOKENIZERS_PARALLELISM=false \ 16 | python3 -u math_eval.py \ 17 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 18 | --data_name ${DATA_NAME} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --split ${SPLIT} \ 21 | --prompt_type ${PROMPT_TYPE} \ 22 | --num_test_sample ${NUM_TEST_SAMPLE} \ 23 | --seed 0 \ 24 | --temperature 0 \ 25 | --n_sampling 1 \ 26 | --max_tokens_per_call ${MAX_TOKEN} \ 27 | --top_p 1 \ 28 | --start 0 \ 29 | --end -1 \ 30 | --use_vllm \ 31 | --num_shots ${NUM_SHOTS} \ 32 | --save_outputs \ 33 | --overwrite \ 34 | 35 | # get response statistics 36 | python data_process.py ${MODEL_NAME_OR_PATH} "outputs/"${OUTPUT_DIR} 37 | -------------------------------------------------------------------------------- /src/main_merging.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | from merging_methods import MergingMethod 7 | 8 | 9 | def main(): 10 | models_to_merge = args.models_to_merge.split(",") 11 | print(f"Base model is: {args.base_model}") 12 | print(f"Models to be merged are: {models_to_merge}") 13 | print(f"Scaling coefficient is {args.scaling_coefficient}") 14 | device = "cuda" if args.use_gpu else "cpu" 15 | print(f"Merging conducted on {device}") 16 | base_model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16).to(device) 17 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 18 | candidate_models = [] 19 | for model_to_merge in models_to_merge: 20 | candidate_models.append(AutoModelForCausalLM.from_pretrained(model_to_merge, torch_dtype=torch.bfloat16).to(device)) 21 | merging_engine = MergingMethod(merging_method_name=args.merge_method) 22 | if args.weight_mask_rates is not None: 23 | weight_mask_rates = args.weight_mask_rates.split(",") 24 | weight_mask_rates = [float(_) for _ in weight_mask_rates] 25 | else: 26 | weight_mask_rates = None 27 | exclude_param_names_regex = args.exclude_param_names_regex 28 | if args.exclude_param_names_regex: 29 | exclude_param_names_regex = args.exclude_param_names_regex.split(",") 30 | print(f"Following params are excluded: {exclude_param_names_regex}") 31 | merged_model = merging_engine.get_merged_model( 32 | merged_model=base_model, 33 | models_to_merge=candidate_models, 34 | exclude_param_names_regex=exclude_param_names_regex, 35 | param_value_mask_rate=args.param_value_mask_rate, 36 | scaling_coefficient=args.scaling_coefficient, 37 | mask_apply_method=args.mask_apply_method, 38 | weight_mask_rates=weight_mask_rates 39 | ) 40 | print(f"Saving model to {args.output_dir}") 41 | if not os.path.exists(args.output_dir): 42 | os.makedirs(args.output_dir, exist_ok=True) 43 | 44 | merged_model = merged_model.to(torch.bfloat16) 45 | merged_model.save_pretrained(args.output_dir) 46 | tokenizer.save_pretrained(args.output_dir) 47 | 48 | 49 | if __name__ == '__main__': 50 | arg_parser = argparse.ArgumentParser() 51 | arg_parser.add_argument("--merge_method", type=str, required=True, default="average_merging") 52 | arg_parser.add_argument("--output_dir", type=str) 53 | arg_parser.add_argument('--base_model', type=str, help='base model') 54 | arg_parser.add_argument("--models_to_merge", type=str, required=True) 55 | arg_parser.add_argument("--exclude_param_names_regex", type=str, default=[]) 56 | arg_parser.add_argument("--scaling_coefficient", type=float, default=1.0) 57 | arg_parser.add_argument("--param_value_mask_rate", type=float, default=0.8) 58 | arg_parser.add_argument("--use_gpu", action='store_true', default=False) 59 | arg_parser.add_argument("--mask_apply_method", type=str, default="average_merging") 60 | arg_parser.add_argument("--weight_mask_rates", type=str, default=None) 61 | args = arg_parser.parse_args() 62 | main() 63 | -------------------------------------------------------------------------------- /src/mask_weights_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import torch.nn as nn 4 | 5 | from utils import get_param_names_to_merge 6 | from task_vector import TaskVector 7 | 8 | 9 | def mask_input_with_mask_rate(input_tensor: torch.Tensor, mask_rate: float, use_rescale: bool, mask_strategy: str): 10 | """ 11 | mask the input with mask rate 12 | :param input_tensor: Tensor, input tensor 13 | :param mask_rate: float, mask rate 14 | :param use_rescale: boolean, whether to rescale the input by 1 / (1 - mask_rate) 15 | :param mask_strategy: str, mask strategy, can be "random" and "magnitude" 16 | :return: 17 | """ 18 | assert 0.0 <= mask_rate <= 1.0, f"wrong range of mask_rate {mask_rate}, should be [0.0, 1.0]!" 19 | if mask_strategy == "random": 20 | mask = torch.bernoulli(torch.full_like(input=input_tensor, fill_value=mask_rate)).to(input_tensor.device) 21 | masked_input_tensor = input_tensor * (1 - mask) 22 | else: 23 | assert mask_strategy == "magnitude", f"wrong setting for mask_strategy {mask_strategy}!" 24 | original_shape = input_tensor.shape 25 | input_tensor = input_tensor.flatten() 26 | num_mask_params = int(len(input_tensor) * mask_rate) 27 | # Tensor, shape (1, ), find the num_mask_params-th smallest magnitude element of all the parameters in the model 28 | kth_values, _ = input_tensor.abs().kthvalue(k=num_mask_params, dim=0, keepdim=True) 29 | # Tensor, shape (num_total_params, ), where True is for parameters that we want to perform mask 30 | mask = input_tensor.abs() <= kth_values 31 | masked_input_tensor = input_tensor * (~mask) 32 | masked_input_tensor = masked_input_tensor.reshape(original_shape) 33 | if use_rescale and mask_rate != 1.0: 34 | masked_input_tensor = torch.div(input=masked_input_tensor, other=1 - mask_rate) 35 | return masked_input_tensor 36 | 37 | 38 | def mask_model_weights(finetuned_model: nn.Module, pretrained_model: nn.Module, exclude_param_names_regex: list, weight_format: str, 39 | weight_mask_rate: float, use_weight_rescale: bool, mask_strategy: str): 40 | """ 41 | mask model weights 42 | :param finetuned_model: nn.Module, the finetuned model 43 | :param pretrained_model: nn.Module, the pretrained model 44 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 45 | :param weight_format: str, the format of weights to be masked, can be "finetuned_weight" and "delta_weight" 46 | :param weight_mask_rate: float, weight mask rate 47 | :param use_weight_rescale: boolean, whether to rescale the weight by 1 / (1 - weight_mask_rate) 48 | :param mask_strategy: str, mask strategy, can be "random" and "magnitude" 49 | :return: 50 | """ 51 | # get weights that need to be masked 52 | if weight_format == "finetuned_weight": 53 | param_dict = {param_name: param_value for param_name, param_value in finetuned_model.named_parameters()} 54 | # exclude parameter whose name matches element in exclude_param_names_regex 55 | param_names_to_merge = get_param_names_to_merge(input_param_names=list(param_dict.keys()), exclude_param_names_regex=exclude_param_names_regex) 56 | model_param_dict = {param_name: param_dict[param_name] for param_name in param_names_to_merge} 57 | else: 58 | assert weight_format == "delta_weight", f"wrong setting for weight_format {weight_format}!" 59 | task_vector = TaskVector(pretrained_model=pretrained_model, finetuned_model=finetuned_model, exclude_param_names_regex=exclude_param_names_regex) 60 | model_param_dict = task_vector.task_vector_param_dict 61 | 62 | with torch.no_grad(): 63 | masked_param_dict = {} 64 | for param_name, param_value in tqdm(model_param_dict.items()): 65 | masked_param_dict[param_name] = mask_input_with_mask_rate(input_tensor=param_value, mask_rate=weight_mask_rate, 66 | use_rescale=use_weight_rescale, mask_strategy=mask_strategy) 67 | 68 | if weight_format == "delta_weight": 69 | new_task_vector = TaskVector(task_vector_param_dict=masked_param_dict) 70 | # combine with parameters of the merged model based on scaling coefficient 71 | masked_param_dict = new_task_vector.combine_with_pretrained_model(pretrained_model=pretrained_model, scaling_coefficient=1.0) 72 | 73 | return masked_param_dict 74 | -------------------------------------------------------------------------------- /src/merging_methods.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, OrderedDict 2 | from tqdm import tqdm 3 | import copy 4 | import torch 5 | import torch.nn as nn 6 | import re 7 | import gc 8 | from task_vector import TaskVector 9 | from utils import get_param_names_to_merge, get_modules_to_merge 10 | from mask_weights_utils import mask_model_weights 11 | 12 | 13 | class MergingMethod: 14 | def __init__(self, merging_method_name: str): 15 | """ 16 | Methods for model merging. 17 | :param merging_method_name: str, name of the merging method, can be "average_merging", "task_arithmetic", 18 | "ties_merging", "latent_merging" 19 | :return: 20 | """ 21 | self.merging_method_name = merging_method_name 22 | 23 | def copy_params_to_model(self, params: dict, model: nn.Module): 24 | """ 25 | copy parameters in "params" to the model 26 | :param params: dict, dictionary of parameters 27 | :param model: nn.Module, model that needs to copy parameters 28 | :return: 29 | """ 30 | for param_name, param_value in model.named_parameters(): 31 | if param_name in params: 32 | param_value.data.copy_(params[param_name]) 33 | 34 | def average_merging(self, models_to_merge: list, exclude_param_names_regex: list): 35 | """ 36 | average merging method 37 | :param models_to_merge: list, individual models that need to be merged 38 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 39 | :return: 40 | """ 41 | # dictionary of list, where key is the parameter name, 42 | # value is a list of the corresponding parameters of all the models that need to be merged 43 | models_to_merge_param_dict = defaultdict(list) 44 | # iterate each individual model that needs to be merged 45 | for model_to_merge in models_to_merge: 46 | param_dict = {param_name: param_value for param_name, param_value in model_to_merge.named_parameters()} 47 | # exclude parameter whose name matches element in exclude_param_names_regex 48 | param_names_to_merge = get_param_names_to_merge(input_param_names=list(param_dict.keys()), 49 | exclude_param_names_regex=exclude_param_names_regex) 50 | for param_name in param_names_to_merge: 51 | models_to_merge_param_dict[param_name].append(param_dict[param_name]) 52 | 53 | with torch.no_grad(): 54 | # average merging of individual models' parameters 55 | averaged_params = {param_name: torch.stack(model_to_merge_param, dim=0).mean(dim=0) for 56 | param_name, model_to_merge_param in models_to_merge_param_dict.items()} 57 | 58 | return averaged_params 59 | 60 | def task_arithmetic(self, merged_model: nn.Module, models_to_merge: list, exclude_param_names_regex: list, 61 | scaling_coefficient: float = 1.0): 62 | """ 63 | task arithmetic method 64 | :param merged_model: nn.Module, the merged model 65 | :param models_to_merge: list, individual models that need to be merged 66 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 67 | :param scaling_coefficient: float, scaling coefficient to merge the task vectors 68 | :return: 69 | """ 70 | assert isinstance(scaling_coefficient, float), "wrong type of scaling_coefficient, should be float!" 71 | 72 | merged_task_vector = None 73 | while len(models_to_merge): 74 | if merged_task_vector is None: 75 | merged_task_vector = TaskVector(pretrained_model=merged_model, finetuned_model=models_to_merge.pop(0), 76 | exclude_param_names_regex=exclude_param_names_regex) 77 | else: 78 | merged_task_vector += TaskVector(pretrained_model=merged_model, finetuned_model=models_to_merge.pop(0), 79 | exclude_param_names_regex=exclude_param_names_regex) 80 | merged_params = merged_task_vector.combine_with_pretrained_model(pretrained_model=merged_model, 81 | scaling_coefficient=scaling_coefficient) 82 | 83 | return merged_params 84 | 85 | def ties_merging(self, merged_model: nn.Module, 86 | models_to_merge: list, 87 | exclude_param_names_regex: list, 88 | param_value_mask_rate: float = 0.8, 89 | scaling_coefficient: float = 1.0): 90 | """ 91 | Optimized ties merging method that minimizes memory usage by processing parameters individually. 92 | 93 | Instead of flattening all parameters across the entire model (which creates huge temporary tensors), 94 | this implementation iterates over each parameter tensor (key by key). For each parameter: 95 | 1. It masks the smallest (by absolute value) elements rel. to the local tensor (using a fraction 96 | param_value_mask_rate). Note that this is a per-tensor approximation versus the original global 97 | thresholding, but saves a lot of memory. 98 | 2. It computes an aggregated sign (using the sign of the sum of the masked tensors) and then 99 | preserves only the elements whose sign agrees with this aggregated sign. 100 | 3. It averages the preserved values from the different models. 101 | Finally, the merged task vector is combined with the pretrained model (merged_model) using the 102 | provided scaling_coefficient. 103 | 104 | :param merged_model: nn.Module, the baseline (pre-trained) model. 105 | :param models_to_merge: list of models (e.g. finetuned versions) whose task vectors will be merged. 106 | :param exclude_param_names_regex: list of regex strings to exclude certain parameter names. 107 | :param param_value_mask_rate: float, fraction of values (per parameter tensor) to mask (set to 0) based on their magnitude. 108 | :param scaling_coefficient: float, scaling coefficient used when combining the task vector with the pretrained model. 109 | :return: nn.Module, merged model 110 | """ 111 | # Create TaskVector objects for each model to merge. 112 | # (If the TaskVector implementation does deep copies, one might be able to optimize that separately.) 113 | task_vectors = [TaskVector(pretrained_model=merged_model, 114 | finetuned_model=model, 115 | exclude_param_names_regex=exclude_param_names_regex) 116 | for model in models_to_merge] 117 | 118 | # Use the parameter keys from the first TaskVector; we assume all task_vectors have the same keys. 119 | sorted_keys = sorted(task_vectors[0].task_vector_param_dict.keys()) 120 | merged_task_vector_param_dict = OrderedDict() 121 | 122 | # Process each parameter key individually. 123 | for key in sorted_keys: 124 | # Get parameter tensors from each task vector (for each model) 125 | param_list = [tv.task_vector_param_dict[key] for tv in task_vectors] 126 | 127 | masked_params = [] 128 | # For each parameter tensor, perform local masking. 129 | for param in param_list: 130 | # Flatten the parameter tensor 131 | param_flat = param.view(-1) 132 | num_params = param_flat.numel() 133 | # Determine how many elements to mask in this tensor. 134 | k = int(num_params * param_value_mask_rate) 135 | if k > 0: 136 | # Compute kth smallest (by absolute value) threshold. 137 | # kthvalue returns a namedtuple (values, indices) 138 | kth_val = param_flat.abs().kthvalue(k=k).values 139 | # Create a mask: keep elements with absolute values >= kth_val. 140 | mask = param.abs() >= kth_val 141 | # Multiply (elementwise) to zero out the smallest ones. 142 | masked_param = param * mask.to(param.dtype) 143 | else: 144 | masked_param = param 145 | masked_params.append(masked_param) 146 | 147 | # Compute an aggregated sign per element. 148 | # This mirrors the behavior: aggregated_sign = sign(sum over models) 149 | summed = sum(masked_params) 150 | aggregated_sign = torch.sign(summed) 151 | # For any element where the sign is zero, default to positive (1.0) 152 | aggregated_sign[aggregated_sign == 0] = 1.0 153 | 154 | # For each model’s masked parameter, keep only elements whose sign 155 | # matches the aggregated sign. 156 | preserved_params = [] 157 | for mp in masked_params: 158 | # Create a boolean mask for matching signs. 159 | sign_mask = (((aggregated_sign > 0) & (mp > 0)) | 160 | ((aggregated_sign < 0) & (mp < 0))).to(mp.dtype) 161 | preserved = mp * sign_mask 162 | preserved_params.append(preserved) 163 | 164 | # Count how many models preserved a nonzero element for each coordinate. 165 | count_preserved = sum([(p != 0).float() for p in preserved_params]) 166 | # Compute the merged parameter (average the preserved contributions) 167 | merged_param = sum(preserved_params) / torch.clamp(count_preserved, min=1.0) 168 | merged_task_vector_param_dict[key] = merged_param 169 | 170 | # Build the merged task vector from the merged parameter dictionary. 171 | merged_task_vector = TaskVector(task_vector_param_dict=merged_task_vector_param_dict) 172 | # Combine with the base (pretrained) model using the provided scaling coefficient. 173 | return merged_task_vector.combine_with_pretrained_model(pretrained_model=merged_model, 174 | scaling_coefficient=scaling_coefficient) 175 | 176 | def ties_merging_dare(self, merged_model: nn.Module, models_to_merge: list, exclude_param_names_regex: list, 177 | param_value_mask_rate: float = 0.8, scaling_coefficient: float = 1.0): 178 | """ 179 | ties merging method 180 | :param merged_model: nn.Module, the merged model 181 | :param models_to_merge: list, individual models that need to be merged 182 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 183 | :param param_value_mask_rate: float, mask rate of the smallest-magnitude parameter values 184 | :param scaling_coefficient: float, scaling coefficient to merge the task vectors 185 | :return: 186 | """ 187 | 188 | def task_vector_param_dict_to_single_vector(task_vector: TaskVector): 189 | """ 190 | convert parameter dictionary in task vector to a single vector 191 | :param task_vector: TaskVector, task vector 192 | :return: 193 | """ 194 | task_vector_param_dict = copy.deepcopy(task_vector.task_vector_param_dict) 195 | sorted_task_vector_param_dict = OrderedDict(sorted(task_vector_param_dict.items())) 196 | del task_vector_param_dict 197 | 198 | # Tensor, shape (num_total_params, ) 199 | return nn.utils.parameters_to_vector([param.flatten() for param in sorted_task_vector_param_dict.values()]) 200 | 201 | def single_vector_to_task_vector_param_dict(single_vector: torch.Tensor, task_vector: TaskVector): 202 | """ 203 | convert a single vector to parameter dictionary in task vector 204 | :param single_vector: Tensor, single vector that contain all parameters in task_vector.task_vector_param_dict 205 | :param task_vector: TaskVector, task vector 206 | :return: 207 | """ 208 | task_vector_param_dict = copy.deepcopy(task_vector.task_vector_param_dict) 209 | sorted_task_vector_param_dict = OrderedDict(sorted(task_vector_param_dict.items())) 210 | del task_vector_param_dict 211 | 212 | nn.utils.vector_to_parameters(single_vector, sorted_task_vector_param_dict.values()) 213 | 214 | return sorted_task_vector_param_dict 215 | 216 | def mask_smallest_magnitude_param_values(flattened_models_to_merge_param: torch.Tensor, 217 | param_value_mask_rate: float = 0.8): 218 | """ 219 | mask the smallest-magnitude parameter values (set to zeros) based on parameter value mask rate 220 | :param flattened_models_to_merge_param: Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged 221 | :param param_value_mask_rate: float, mask rate of the smallest-magnitude parameter values 222 | :return: 223 | """ 224 | # num_models_to_merge, num_total_params = flattened_models_to_merge_param.shape 225 | num_mask_params = int(flattened_models_to_merge_param.shape[1] * param_value_mask_rate) 226 | 227 | # Tensor, shape (num_models_to_merge, 1), find the num_mask_params-th smallest magnitude element of all the parameters in each individual model 228 | kth_values, _ = flattened_models_to_merge_param.abs().kthvalue(k=num_mask_params, dim=1, keepdim=True) 229 | # Tensor, shape (num_models_to_merge, num_total_params), where True is for parameters that we want to preserve 230 | mask = flattened_models_to_merge_param.abs() >= kth_values 231 | del kth_values 232 | flattened_models_to_merge_param = flattened_models_to_merge_param * mask 233 | del mask 234 | 235 | return flattened_models_to_merge_param 236 | 237 | def get_param_signs(flattened_models_to_merge_param: torch.Tensor): 238 | """ 239 | get the signs for each parameter in flattened_models_to_merge_param, computed over individual models that need to be merged 240 | :param flattened_models_to_merge_param: Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged 241 | :return: 242 | """ 243 | # Tensor, shape (num_total_params, ), the signs of parameters aggregated across individual models that need to be merged 244 | param_signs = torch.sign(flattened_models_to_merge_param.sum(dim=0)) 245 | # Tensor, shape (, ), a scalar, replace 0 in param_signs to the major sign in param_signs 246 | majority_sign = torch.sign(param_signs.sum(dim=0)) 247 | param_signs[param_signs == 0] = majority_sign 248 | return param_signs 249 | 250 | def disjoint_merge(flattened_models_to_merge_param: torch.Tensor, param_signs: torch.Tensor): 251 | """ 252 | disjoint merge that only keeps the parameter values in individual models whose signs are the same as the param_signs, and calculates the averaged parameters. 253 | :param flattened_models_to_merge_param: Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged 254 | :param param_signs: Tensor, shape (num_total_params, ), the signs of parameters aggregated across individual models that need to be merged 255 | :return: 256 | """ 257 | # Tensor, shape (num_models_to_merge, num_total_params), where True is for parameters that we want to preserve 258 | param_to_preserve_mask = ((param_signs.unsqueeze(dim=0) > 0) & (flattened_models_to_merge_param > 0)) | ( 259 | (param_signs.unsqueeze(dim=0) < 0) & (flattened_models_to_merge_param < 0)) 260 | # Tensor, shape (num_models_to_merge, num_total_params), the preserved parameters 261 | param_to_preserve = flattened_models_to_merge_param * param_to_preserve_mask 262 | del param_to_preserve_mask 263 | 264 | # Tensor, shape (num_total_params, ), the number of models whose parameters can be preserved 265 | num_models_param_preserved = (param_to_preserve != 0).sum(dim=0).float() 266 | # Tensor, shape (num_total_params, ), the averaged flattened parameters 267 | merged_flattened_param = torch.sum(param_to_preserve, dim=0) / torch.clamp(num_models_param_preserved, 268 | min=1.0) 269 | del param_to_preserve 270 | 271 | return merged_flattened_param 272 | 273 | assert isinstance(scaling_coefficient, float), "wrong type of scaling_coefficient, should be float!" 274 | 275 | models_to_merge_task_vectors = [TaskVector(pretrained_model=merged_model, finetuned_model=model_to_merge, 276 | exclude_param_names_regex=exclude_param_names_regex) for 277 | model_to_merge in models_to_merge] 278 | del models_to_merge 279 | 280 | flattened_models_to_merge_param = [task_vector_param_dict_to_single_vector(task_vector=task_vector) for 281 | task_vector in models_to_merge_task_vectors] 282 | models_to_merge_task_vectors = models_to_merge_task_vectors[0] 283 | # Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged 284 | flattened_models_to_merge_param = torch.vstack(flattened_models_to_merge_param) 285 | 286 | with torch.no_grad(): 287 | # Tensor, shape (num_models_to_merge, num_total_params), mask the smallest-magnitude parameter values using param_value_mask_rate 288 | flattened_models_to_merge_param = mask_smallest_magnitude_param_values( 289 | flattened_models_to_merge_param=flattened_models_to_merge_param, 290 | param_value_mask_rate=param_value_mask_rate) 291 | 292 | # Tensor, shape (num_total_params, ), get the signs for each parameter in flattened_models_to_merge_param 293 | param_signs = get_param_signs(flattened_models_to_merge_param=flattened_models_to_merge_param) 294 | 295 | # Tensor, shape (num_total_params, ), disjoint merge 296 | merged_flattened_param = disjoint_merge(flattened_models_to_merge_param=flattened_models_to_merge_param, 297 | param_signs=param_signs) 298 | del flattened_models_to_merge_param, param_signs 299 | 300 | # merged parameter dictionary 301 | merged_task_vector_param_dict = single_vector_to_task_vector_param_dict( 302 | single_vector=merged_flattened_param, task_vector=models_to_merge_task_vectors) 303 | merged_task_vector = TaskVector(task_vector_param_dict=merged_task_vector_param_dict) 304 | del merged_task_vector_param_dict 305 | # combine with parameters of the merged model based on scaling coefficient 306 | merged_model = merged_task_vector.combine_with_pretrained_model(pretrained_model=merged_model, 307 | scaling_coefficient=scaling_coefficient) 308 | 309 | return merged_model 310 | 311 | def merging_models(self, merged_model: nn.Module, models_to_merge: list, exclude_param_names_regex: list, 312 | scaling_coefficient: float = 1.0, 313 | param_value_mask_rate: float = 0.8, 314 | weight_format: str = "delta_weight", weight_mask_rates: list = None, 315 | use_weight_rescale: bool = True, mask_strategy: str = "random", 316 | mask_apply_method: str = "average_merging", models_use_deepcopy: bool = False): 317 | """ 318 | model merging methods 319 | :param merged_model: nn.Module, the merged model 320 | :param models_to_merge: list, individual models that need to be merged 321 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 322 | :param scaling_coefficient: float, scaling coefficient to merge the task vectors 323 | :param param_value_mask_rate: float, mask rate of the smallest-magnitude parameter values 324 | :param weight_format: str, the format of weights to be masked, can be "finetuned_weight" and "delta_weight" 325 | :param weight_mask_rates: list, list of weight mask rates 326 | :param use_weight_rescale: boolean, whether to rescale the weight by 1 / (1 - weight_mask_rate) 327 | :param mask_strategy: str, mask strategy, can be "random" and "magnitude" 328 | :param mask_apply_method: str, merging method that the mask strategy applies 329 | :param models_use_deepcopy: boolean, whether to deepcopy the models 330 | :return: 331 | """ 332 | if self.merging_method_name == "average_merging": 333 | merged_params = self.average_merging(models_to_merge=models_to_merge, 334 | exclude_param_names_regex=exclude_param_names_regex) 335 | elif self.merging_method_name == "task_arithmetic": 336 | merged_params = self.task_arithmetic(merged_model=merged_model, models_to_merge=models_to_merge, 337 | exclude_param_names_regex=exclude_param_names_regex, 338 | scaling_coefficient=scaling_coefficient) 339 | elif self.merging_method_name == "ties_merging": 340 | merged_params = self.ties_merging(merged_model=merged_model, models_to_merge=models_to_merge, 341 | exclude_param_names_regex=exclude_param_names_regex, 342 | param_value_mask_rate=param_value_mask_rate, 343 | scaling_coefficient=scaling_coefficient) 344 | elif self.merging_method_name == "ties_merging_dare": 345 | merged_params = self.ties_merging_dare(merged_model=merged_model, models_to_merge=models_to_merge, 346 | exclude_param_names_regex=exclude_param_names_regex, 347 | param_value_mask_rate=param_value_mask_rate, 348 | scaling_coefficient=scaling_coefficient) 349 | elif self.merging_method_name == "mask_merging": 350 | with torch.no_grad(): 351 | if models_use_deepcopy: 352 | new_models_to_merge = copy.deepcopy(models_to_merge) 353 | else: 354 | new_models_to_merge = models_to_merge 355 | for new_model_to_merge, weight_mask_rate in zip(new_models_to_merge, weight_mask_rates): 356 | # for each individual model, mask its weight 357 | masked_param_dict = mask_model_weights(finetuned_model=new_model_to_merge, 358 | pretrained_model=merged_model, 359 | exclude_param_names_regex=exclude_param_names_regex, 360 | weight_format=weight_format, 361 | weight_mask_rate=weight_mask_rate, 362 | use_weight_rescale=use_weight_rescale, 363 | mask_strategy=mask_strategy) 364 | self.copy_params_to_model(params=masked_param_dict, model=new_model_to_merge) 365 | if mask_apply_method == "average_merging": 366 | merged_params = self.average_merging(models_to_merge=new_models_to_merge, 367 | exclude_param_names_regex=exclude_param_names_regex) 368 | elif mask_apply_method == "task_arithmetic": 369 | merged_params = self.task_arithmetic(merged_model=merged_model, models_to_merge=new_models_to_merge, 370 | exclude_param_names_regex=exclude_param_names_regex, 371 | scaling_coefficient=scaling_coefficient) 372 | elif mask_apply_method == "ties_merging": 373 | merged_params = self.ties_merging(merged_model=merged_model, models_to_merge=new_models_to_merge, 374 | exclude_param_names_regex=exclude_param_names_regex, 375 | param_value_mask_rate=param_value_mask_rate, 376 | scaling_coefficient=scaling_coefficient) 377 | elif mask_apply_method == "ties_merging_dare": 378 | merged_params = self.ties_merging_dare(merged_model=merged_model, models_to_merge=new_models_to_merge, 379 | exclude_param_names_regex=exclude_param_names_regex, 380 | param_value_mask_rate=param_value_mask_rate, 381 | scaling_coefficient=scaling_coefficient) 382 | else: 383 | raise NotImplementedError(f"unsupported for mask_apply_method {mask_apply_method}!") 384 | else: 385 | raise NotImplementedError(f"unsupported for merging_method_name {self.merging_method_name}!") 386 | return merged_params 387 | 388 | def get_merged_model(self, merged_model: nn.Module, models_to_merge: list, exclude_param_names_regex: list, 389 | scaling_coefficient: float = 1.0, 390 | param_value_mask_rate: float = 0.8, 391 | weight_format: str = "delta_weight", weight_mask_rates: list = None, 392 | use_weight_rescale: bool = True, mask_strategy: str = "random", 393 | mask_apply_method: str = "average_merging", models_use_deepcopy: bool = False): 394 | """ 395 | merge the parameters of models_to_merge to merged_model 396 | :param merged_model: nn.Module, the merged model 397 | :param models_to_merge: list, individual models that need to be merged 398 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 399 | :param scaling_coefficient: float, scaling coefficient to merge the task vectors 400 | :param param_value_mask_rate: float, mask rate of the smallest-magnitude parameter values 401 | :param weight_format: str, the format of weights to be masked, can be "finetuned_weight" and "delta_weight" 402 | :param weight_mask_rates: list, list of weight mask rates 403 | :param use_weight_rescale: boolean, whether to rescale the weight by 1 / (1 - weight_mask_rate) 404 | :param mask_strategy: str, mask strategy, can be "random" and "magnitude" 405 | :param mask_apply_method: str, merging method that the mask strategy applies 406 | :param models_use_deepcopy: boolean, whether to deepcopy the models 407 | :return: 408 | """ 409 | # merged_params, dict of parameters 410 | merged_params = self.merging_models(merged_model=merged_model, models_to_merge=models_to_merge, 411 | exclude_param_names_regex=exclude_param_names_regex, 412 | scaling_coefficient=scaling_coefficient, 413 | param_value_mask_rate=param_value_mask_rate, 414 | weight_format=weight_format, weight_mask_rates=weight_mask_rates, 415 | use_weight_rescale=use_weight_rescale, mask_strategy=mask_strategy, 416 | mask_apply_method=mask_apply_method, 417 | models_use_deepcopy=models_use_deepcopy) 418 | self.copy_params_to_model(params=merged_params, model=merged_model) 419 | 420 | return merged_model 421 | -------------------------------------------------------------------------------- /src/task_vector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import get_param_names_to_merge 5 | 6 | 7 | class TaskVector: 8 | def __init__(self, pretrained_model: nn.Module = None, finetuned_model: nn.Module = None, exclude_param_names_regex: list = None, task_vector_param_dict: dict = None): 9 | """ 10 | Task vector. Initialize the task vector from a pretrained model and a finetuned model, or 11 | directly passing the task_vector_param_dict dictionary. 12 | :param pretrained_model: nn.Module, pretrained model 13 | :param finetuned_model: nn.Module, finetuned model 14 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 15 | :param task_vector_param_dict: dict, task vector to initialize self.task_vector_param_dict 16 | """ 17 | if task_vector_param_dict is not None: 18 | self.task_vector_param_dict = task_vector_param_dict 19 | else: 20 | self.task_vector_param_dict = {} 21 | pretrained_param_dict = {param_name: param_value for param_name, param_value in pretrained_model.named_parameters()} 22 | finetuned_param_dict = {param_name: param_value for param_name, param_value in finetuned_model.named_parameters()} 23 | param_names_to_merge = get_param_names_to_merge(input_param_names=list(pretrained_param_dict.keys()), exclude_param_names_regex=exclude_param_names_regex) 24 | with torch.no_grad(): 25 | for param_name in param_names_to_merge: 26 | self.task_vector_param_dict[param_name] = finetuned_param_dict[param_name] - pretrained_param_dict[param_name] 27 | 28 | def __add__(self, other): 29 | """ 30 | add task vector 31 | :param other: TaskVector to add, at right side 32 | :return: 33 | """ 34 | assert isinstance(other, TaskVector), "addition of TaskVector can only be done with another TaskVector!" 35 | new_task_vector_param_dict = {} 36 | with torch.no_grad(): 37 | for param_name in self.task_vector_param_dict: 38 | assert param_name in other.task_vector_param_dict.keys(), f"param_name {param_name} is not contained in both task vectors!" 39 | new_task_vector_param_dict[param_name] = self.task_vector_param_dict[param_name] + other.task_vector_param_dict[param_name] 40 | return TaskVector(task_vector_param_dict=new_task_vector_param_dict) 41 | 42 | def __radd__(self, other): 43 | """ 44 | other + self = self + other 45 | :param other: TaskVector to add, at left side 46 | :return: 47 | """ 48 | return self.__add__(other) 49 | 50 | def combine_with_pretrained_model(self, pretrained_model: nn.Module, scaling_coefficient: float = 1.0): 51 | """ 52 | combine the task vector with pretrained model 53 | :param pretrained_model: nn.Module, pretrained model 54 | :param scaling_coefficient: float, scaling coefficient to merge the task vector 55 | :return: 56 | """ 57 | pretrained_param_dict = {param_name: param_value for param_name, param_value in pretrained_model.named_parameters()} 58 | 59 | with torch.no_grad(): 60 | merged_params = {} 61 | for param_name in self.task_vector_param_dict: 62 | merged_params[param_name] = pretrained_param_dict[param_name] + scaling_coefficient * self.task_vector_param_dict[param_name] 63 | 64 | return merged_params 65 | 66 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from typing import Dict 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import transformers 9 | from transformers import Trainer, TrainerState 10 | 11 | 12 | def set_random_seed(seed: int = 0): 13 | """ 14 | set random seed 15 | :param seed: int, random seed 16 | :return: 17 | """ 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed_all(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | 26 | 27 | def save_state_and_model_for_hf_trainer(trainer: Trainer): 28 | """ 29 | save the state and model for trainer 30 | :param trainer: transformers.Trainer to be saved 31 | :return: 32 | """ 33 | # save trainer state at trainer.args.output_dir path 34 | trainer.save_state() 35 | # save model at output_dir 36 | if trainer.args.should_save: 37 | # convert state_dict to cpu 38 | cpu_state_dict = {key: value.cpu() for key, value in trainer.model.state_dict().items()} 39 | trainer._save(trainer.args.output_dir, state_dict=cpu_state_dict) 40 | 41 | 42 | def load_state_and_model_for_hf_trainer(model: nn.Module, load_model_dir: str, map_location: str = None): 43 | """ 44 | load the state and model for trainer 45 | :param model: nn.Module, the model to be loaded 46 | :param load_model_dir: str, the path where the state and model to be loaded 47 | :param map_location: str, how to remap the storage locations 48 | :return: 49 | """ 50 | # load model and trainer state from load_model_dir 51 | model.load_state_dict(torch.load(os.path.join(load_model_dir, "pytorch_model.bin"), map_location=map_location)) 52 | # model = model.from_pretrained(load_model_dir) 53 | trainer_state = TrainerState.load_from_json(os.path.join(load_model_dir, "trainer_state.json")) 54 | return model, trainer_state 55 | 56 | 57 | def get_param_names_to_merge(input_param_names: list, exclude_param_names_regex: list): 58 | """ 59 | get the names of parameters that need to be merged 60 | :param input_param_names: list, names of input parameters 61 | :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded 62 | :return: 63 | """ 64 | param_names_to_merge = [] 65 | for param_name in input_param_names: 66 | exclude = any([re.match(exclude_pattern, param_name) for exclude_pattern in exclude_param_names_regex]) 67 | if not exclude: 68 | param_names_to_merge.append(param_name) 69 | return param_names_to_merge 70 | 71 | 72 | def get_modules_to_merge(model: nn.Module, include_module_types: list): 73 | """ 74 | get the model modules that need to be merged, whose type is in include_module_types 75 | :param model: nn.Module, input model 76 | :param include_module_types: list, module types that want to include 77 | :return: 78 | """ 79 | modules_to_merge = {} 80 | for module_name, module in model.named_modules(): 81 | is_valid_type = not include_module_types or any([isinstance(module, include_module_type) for include_module_type in include_module_types]) 82 | if is_valid_type: 83 | modules_to_merge[module_name] = module 84 | return modules_to_merge 85 | 86 | 87 | def smart_tokenizer_and_embedding_resize( 88 | special_tokens_dict: Dict, 89 | tokenizer: transformers.PreTrainedTokenizer, 90 | model: transformers.PreTrainedModel, 91 | ): 92 | """Resize tokenizer and embedding. 93 | 94 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 95 | """ 96 | assert tokenizer.vocab_size == 32000 97 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 98 | if num_new_tokens > 0: 99 | model.resize_token_embeddings(tokenizer.vocab_size + num_new_tokens) 100 | 101 | input_embeddings = model.get_input_embeddings().weight.data 102 | output_embeddings = model.get_output_embeddings().weight.data 103 | 104 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 105 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 106 | 107 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 108 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 109 | --------------------------------------------------------------------------------