├── LICENSE ├── README.md ├── configs ├── bert_books.py └── bert_classifier.py ├── evaluation ├── eval_glue.py ├── graphing_glue.py └── graphing_mlm.py ├── experiment ├── merge_bert_classifiers.py ├── run_component_experiments.sh ├── run_data_ablation.sh ├── run_glue_experiments.sh ├── run_mha_experiments.sh └── run_res_experiments.sh ├── graphs ├── base_graph.py └── transformer_enc_graph.py ├── matching_functions.py ├── metric_calculators.py ├── model_merger.py ├── my_datasets ├── books.py ├── configs.py ├── glue.py └── sample_books_corpus.py ├── overview.png ├── requirements.txt ├── training ├── finetune_glue.sh └── run_glue.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Neha Verma, Maha Elbayad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Merging Text Transformer Models from Different Initializations 2 | 3 | This repository contains the code for the paper **[Merging Text Transformer Models from Different Initializations](https://arxiv.org/pdf/2403.00986.pdf)** by [Neha Verma](https://nverma1.github.io/) and [Maha Elbayad](https://elbayadm.github.io/). 4 | 5 | ![Perm Figure](overview.png) 6 | 7 | ## Getting Started 8 | 9 | ### Dependencies 10 | 11 | We recommend creating a new virtual environment, and installing the following dependencies: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | 18 | ### Data 19 | 20 | Main masked language modeling experiments are run on a subset of the [BooksCorpus](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zhu_Aligning_Books_and_ICCV_2015_paper.pdf), as it was a subset of the original BERT training data. To extract a subset of the data, run the following command: 21 | 22 | ``` 23 | cd my_datasets 24 | python sample_books_corpus.py $PATH_TO_BOOKS_CORPUS 25 | ``` 26 | 27 | ## Experiments 28 | 29 | ### Obtaining Models 30 | 31 | We use several models from the [MultiBERTs](https://openreview.net/pdf?id=K0E_F0gFDgA) reproductions, accessible on [HuggingFace](https://huggingface.co/google/multiberts-seed_1) via the ``google/multiberts-seed_i`` paths, where ``i`` is the seed number. We also fine-tune these models on the GLUE benchmark for additional experiments, using the HuggingFace library. To fine-tune these models, run the following command: 32 | 33 | ``` 34 | bash training/finetune_glue.sh $seed_index $task 35 | ``` 36 | Where task is one of the GLUE tasks in ``[mnli, cola, sst2, stsb, qnli, qqp, rte, mrpc]``, and seed_index is the index of the seed in ``1-5``. 37 | 38 | ### Merging Models 39 | 40 | To merge models, we have the following scripts to produce merged models from our paper: 41 | 42 | ``` 43 | # merge models by feed-forward and/or attention components 44 | bash experiments/run_component_experiments.sh 45 | 46 | # merge models using different attention merging algorithms 47 | bash experiments/run_mha_experiments.sh 48 | 49 | # merge models using different residual merging algorithms 50 | bash experiments/run_res_experiments.sh 51 | 52 | # merge models using different data amounts 53 | bash experiments/run_data_ablation.sh 54 | 55 | # merge glue models, trained in previous step 56 | bash experiments/run_glue_experiments.sh 57 | 58 | ``` 59 | 60 | Finally, many other merging experiments can be run using ``experiment/merge_bert_classifiers.py``, adding potentially adding configs to the ``configs/`` directory. 61 | 62 | ### Evaluation 63 | 64 | To evaluate glue models before merging, run the following command: 65 | ``` 66 | python evaluation/graphing_glue.py --task $TASK --originals --outfile vanilla_${TASK} 67 | ``` 68 | Where task is one of the GLUE tasks in ``[mnli, cola, sst2, stsb, qnli, qqp, rte, mrpc]`` 69 | 70 | To evaluate glue models after merging, run the folowing: 71 | ``` 72 | python evaluation/graphing_glue.py --task $TASK --path $PATH_TO_MERGED_MODELS --merge-type $MERGE_TYPE --outfile $OUTFILE 73 | ``` 74 | Where task is one of the GLUE tasks in ``[mnli, cola, sst2, stsb, qnli, qqp, rte, mrpc]``, merge-type is one of the merge types in ``[res_only, attn_only, ff_only, ff+attn, res+attn, ff+res, all]`` 75 | 76 | 77 | To evaluate MLM models before merging, run the following command: 78 | ``` 79 | python evaluation/graphing_mlm.py --originals --outfile vanilla_mlm 80 | ``` 81 | To evaluate MLM models after merging, run the following; 82 | ``` 83 | python evaluation/graphing_mlm.py --path $PATH_TO_MERGED_MODELS --merge-type $MERGE_TYPE --outfile $OUTFILE --train-frac $TRAIN_FRAC 84 | ``` 85 | where merge-type is one of the merge types in ``[res_only, attn_only, ff_only, ff+attn, res+attn, ff+res, all]`` and train-frac is the fraction of the training data used in the MLM experiments. 86 | If a permutation was applied to the output projection, pass the relevant ``--unmerge`` flag to the command. 87 | 88 | ## Citation 89 | 90 | If you use this code, please cite our paper: 91 | 92 | ``` 93 | @article{verma2024merging, 94 | title={Merging Text Transformer Models from Different Initializations}, 95 | author={Neha Verma and Maha Elbayad}, 96 | journal={arXiv}, 97 | year={2024}, 98 | } 99 | ``` 100 | 101 | ### Acknowledgements 102 | 103 | We would like to acknowledge the authors of the [ZipIt!](https://github.com/gstoica27/ZipIt) codebase, which we use as a starting point for our repository. We also acknowledge the authors of the [BERT-similarity](https://github.com/twinkle0331/BERT-similarity), which we used to help with our GLUE fine-tuning code. 104 | -------------------------------------------------------------------------------- /configs/bert_books.py: -------------------------------------------------------------------------------- 1 | config = { 2 | 'dataset': [ 3 | { 4 | 'name': 'books', 5 | 'shuffle_train': False, 6 | 'batch_size': 1, 7 | 'train_fraction': 0.001, 8 | 'sorted': False, 9 | 'tokenizer': 'bert', 10 | 'num':0 11 | } 12 | ], 13 | 'model': { 14 | 'name': 'bert', 15 | 'dir': 'google', 16 | 'bases': [] 17 | }, 18 | 'model_names': { 19 | 'model1':'multiberts-seed_0', 20 | 'model2':'multiberts-seed_1' 21 | }, 22 | 'parallel_data': False, 23 | 'merging_fn': 'match_tensors_permute', 24 | 'merging_metrics': ['covariance'], 25 | } 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/bert_classifier.py: -------------------------------------------------------------------------------- 1 | config = { 2 | 'dataset': [ 3 | { 4 | 'name': 'glue', 5 | 'batch_size': 1, 6 | 'train_fraction': 0.001, 7 | 'shuffle_train': False, 8 | } 9 | ], 10 | 'model': { 11 | 'name': 'bert', 12 | 'dir': 'models/trained/multiberts/', 13 | 'bases': [] 14 | }, 15 | 'model_names': { 16 | 'model1':'mnli/seed_0', 17 | 'model2':'mnli/seed_1' 18 | }, 19 | 'parallel_data': False, 20 | 'merging_fn': 'match_tensors_permute', 21 | 'merging_metrics': ['covariance'], 22 | } 23 | 24 | -------------------------------------------------------------------------------- /evaluation/eval_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import argparse 20 | 21 | import numpy as np 22 | from datasets import load_dataset, load_metric 23 | 24 | from transformers import ( 25 | BertForSequenceClassification, 26 | BertTokenizer, 27 | EvalPrediction, 28 | Trainer, 29 | default_data_collator, 30 | ) 31 | import torch 32 | 33 | task_to_keys = { 34 | "cola": ("sentence", None), 35 | "mnli": ("premise", "hypothesis"), 36 | "mrpc": ("sentence1", "sentence2"), 37 | "qnli": ("question", "sentence"), 38 | "qqp": ("question1", "question2"), 39 | "rte": ("sentence1", "sentence2"), 40 | "sst2": ("sentence", None), 41 | "stsb": ("sentence1", "sentence2"), 42 | } 43 | 44 | task_to_outputs = { 45 | 'cola': 2, 46 | 'mnli': 3, 47 | 'mrpc': 2, 48 | 'qnli': 2, 49 | 'qqp': 2, 50 | 'rte': 2, 51 | 'sst2': 2, 52 | 'stsb': 1, 53 | } 54 | 55 | def get_metrics(task, model, cache_dir, log=False): 56 | 57 | tokenizer = BertTokenizer.from_pretrained( 58 | 'google/multiberts-seed_0', 59 | cache_dir=cache_dir, 60 | use_fast=True, 61 | ) 62 | 63 | # Preprocessing the raw_datasets 64 | if task == 'mnli': 65 | validation_name = 'validation_matched' 66 | else: 67 | validation_name = 'validation' 68 | 69 | raw_datasets = load_dataset("glue", task, cache_dir=cache_dir, split=validation_name) 70 | is_regression = (task == "stsb") 71 | 72 | sentence1_key, sentence2_key = task_to_keys[task] 73 | 74 | # Padding strategy 75 | max_seq_length=128 76 | padding = 'max_length' 77 | 78 | if max_seq_length > tokenizer.model_max_length: 79 | print( 80 | f"The max_seq_length passed ({max_seq_length}) is larger than the maximum length for the" 81 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 82 | ) 83 | max_seq_length = min(max_seq_length, tokenizer.model_max_length) 84 | 85 | def preprocess_function(examples): 86 | # Tokenize the texts 87 | args = ( 88 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 89 | ) 90 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 91 | 92 | return result 93 | 94 | raw_datasets = raw_datasets.map( 95 | preprocess_function, 96 | batched=True, 97 | desc="Running tokenizer on dataset", 98 | ) 99 | 100 | eval_dataset = raw_datasets 101 | 102 | # Get the metric function 103 | metric = load_metric("glue", task) 104 | 105 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 106 | # predictions and label_ids field) and has to return a dictionary string to float. 107 | def compute_metrics(p: EvalPrediction): 108 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 109 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 110 | # print(Counter(preds)) 111 | # print(Counter(p.label_ids)) 112 | if task is not None: 113 | result = metric.compute(predictions=preds, references=p.label_ids) 114 | if len(result) > 1: 115 | result["combined_score"] = np.mean(list(result.values())).item() 116 | return result 117 | elif is_regression: 118 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 119 | else: 120 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 121 | 122 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 123 | # we already did the padding. 124 | if padding == 'max_length': 125 | data_collator = default_data_collator 126 | 127 | # Initialize our Trainer 128 | trainer = Trainer( 129 | model=model, 130 | train_dataset=None, 131 | eval_dataset=eval_dataset, 132 | compute_metrics=compute_metrics, 133 | tokenizer=tokenizer, 134 | data_collator=data_collator, 135 | ) 136 | 137 | # Evaluation 138 | print('evaluate') 139 | 140 | # Loop to handle MNLI double evaluation (matched, mis-matched) 141 | tasks = [task] 142 | eval_datasets = [eval_dataset] 143 | if task == "mnli": 144 | tasks.append("mnli-mm") 145 | mismatched_eval = load_dataset("glue", task, cache_dir=cache_dir, split="validation_mismatched") 146 | eval_datasets.append(mismatched_eval.map( 147 | preprocess_function, 148 | batched=True, 149 | desc="Running tokenizer on dataset", 150 | )) 151 | 152 | metric_list = [] 153 | for eval_dataset, task in zip(eval_datasets, tasks): 154 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 155 | metrics["eval_samples"] = len(eval_dataset) 156 | if log: 157 | trainer.log_metrics("eval", metrics) 158 | metric_list.append(metrics) 159 | #trainer.save_metrics("eval", metrics) 160 | 161 | return metric_list 162 | 163 | def main(): 164 | 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('--task',required=True) 167 | parser.add_argument('--hf-model') 168 | parser.add_argument('--merged-model') 169 | parser.add_argument('--merged-model-dict') 170 | parser.add_argument('--cache-dir') 171 | parser.add_argument('--tokenizer-name', default='google/multiberts-seed_0') 172 | parser.add_argument('--loss', action='store_true', required=False) 173 | 174 | args = parser.parse_args() 175 | 176 | 177 | 178 | is_regression = (args.task == "stsb") 179 | if not is_regression: 180 | num_labels = task_to_outputs[args.task] 181 | else: 182 | num_labels = 1 183 | 184 | # load tokenizer and model 185 | 186 | if args.hf_model: 187 | model = BertForSequenceClassification.from_pretrained( 188 | args.hf_model, 189 | cache_dir=args.cache_dir, 190 | ) 191 | elif args.merged_model: 192 | model = torch.load(args.merged_model) 193 | elif args.merged_model_dict: 194 | model_dict = torch.load(args.merged_model_dict) 195 | model = BertForSequenceClassification.from_pretrained('google/multiberts-seed_0', num_labels=num_labels) 196 | model.load_state_dict(model_dict) 197 | 198 | 199 | metrics = get_metrics(args.task, model, args.cache_dir, log=True) 200 | 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | 206 | -------------------------------------------------------------------------------- /evaluation/graphing_glue.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | from eval_glue import get_metrics 11 | from safetensors.torch import load 12 | from transformers import BertForSequenceClassification 13 | 14 | 15 | pairs_0 = [1,2,3,4,1,2,3,1,2,1] 16 | pairs_1 = [2,3,4,5,3,4,5,4,5,5] 17 | 18 | task_to_outputs = { 19 | 'cola': 2, 20 | 'mnli': 3, 21 | 'mrpc': 2, 22 | 'qnli': 2, 23 | 'qqp': 2, 24 | 'rte': 2, 25 | 'sst2': 2, 26 | 'stsb': 1, 27 | 'wnli': 2 28 | } 29 | 30 | lambdas = np.arange(21) / 20 31 | 32 | def get_merged_state_dict(state_dict_1, state_dict_2, w=0.5,): 33 | """ 34 | Post transformations, obtain state dictionary for merged model by linearly interpolating between 35 | transformed models in each graph. By default all parameters are averaged, but if given an interp_w 36 | weight, will be weightedly averaged instead. 37 | - interp_w (Optional): If None, all parameters of each model is averaged for merge. Otherwise, 38 | interp_w is a list of len(num_models_to_merge), with weights bearing the importance of incorporating 39 | features from each model into the merged result. 40 | Returns: state dict of merged model. 41 | """ 42 | state_dict = {} 43 | merged_state_dict = deepcopy(state_dict_1) 44 | keys = list(state_dict_1.keys()) 45 | try: 46 | for key in keys: 47 | if key in merged_state_dict: 48 | param = state_dict_1[key] 49 | if param.shape == merged_state_dict[key].shape: 50 | new_value = state_dict_1[key] * w + state_dict_2[key] * (1-w) 51 | state_dict[key] = new_value 52 | except RuntimeError as e: 53 | if 'size' not in str(e): 54 | raise e 55 | return state_dict 56 | 57 | 58 | def main(): 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--originals', action='store_true', required=False, default=False) 61 | parser.add_argument('--task',required=True) 62 | parser.add_argument('--path', required=False) 63 | parser.add_argument('--merge-type', required=False) 64 | parser.add_argument('--outfile') 65 | parser.add_argument('--train-frac', required=False, default=1.0, type=float) 66 | parser.add_argument('--cache-dir', required=False) 67 | 68 | args = parser.parse_args() 69 | task = args.task 70 | cache_dir = args.cache_dir 71 | 72 | ROOT_PATH = f'models/' 73 | vanilla_lambda_lists = {} 74 | placeholder_model = BertForSequenceClassification.from_pretrained('google/multiberts-seed_0', num_labels=task_to_outputs[task]) 75 | 76 | # loop through 10 experiments 77 | for i in range(10): 78 | 79 | # load models 0 and 1 80 | if args.originals == True: 81 | model0_path = f'{ROOT_PATH}/trained/multiberts_new/{task}/seed_{pairs_0[i]}/model.safetensors' 82 | model1_path = f'{ROOT_PATH}/trained/multiberts_new/{task}/seed_{pairs_1[i]}/model.safetensors' 83 | model0_load = open(model0_path, 'rb') 84 | model1_load = open(model1_path, 'rb') 85 | model0_statedict = load(model0_load.read()) 86 | model1_statedict = load(model1_load.read()) 87 | else: 88 | new_path = os.path.join(ROOT_PATH, args.path, 'individual_models') 89 | 90 | model_file0 = f'match_tensors_permute_{args.merge_type}_0_{args.task}_seed_{pairs_0[i]}_b8_{args.task}{args.train_frac}_{pairs_0[i]}_{pairs_1[i]}.pt' 91 | model_file1 = f'match_tensors_permute_{args.merge_type}_1_{args.task}_seed_{pairs_1[i]}_b8_{args.task}{args.train_frac}_{pairs_0[i]}_{pairs_1[i]}.pt' 92 | model_path0 = os.path.join(new_path, model_file0) 93 | model_path1 = os.path.join(new_path, model_file1) 94 | model0_statedict = torch.load(model_path0) 95 | model1_statedict = torch.load(model_path1) 96 | 97 | vanilla_lambda_lists[i] = {} 98 | if task == 'mnli': 99 | vanilla_lambda_lists[i][0] =[] 100 | vanilla_lambda_lists[i][1] = [] 101 | else: 102 | vanilla_lambda_lists[i][0] = [] 103 | 104 | # loop through interpolation lambdas and get loss for each 105 | for l in tqdm(lambdas): 106 | merged_statedict = get_merged_state_dict(model0_statedict, model1_statedict, w=l) 107 | placeholder_model.load_state_dict(merged_statedict) 108 | metrics = get_metrics(args.task, placeholder_model, cache_dir) 109 | for j, metric in enumerate(metrics): 110 | vanilla_lambda_lists[i][j].append(metric['eval_loss']) 111 | 112 | if os.path.exists('results/glue') == False: 113 | os.makedirs('results/glue') 114 | 115 | with open(f'results/glue/{args.outfile}.json', 'w+') as out: 116 | json.dump(vanilla_lambda_lists, out) 117 | 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | 123 | -------------------------------------------------------------------------------- /evaluation/graphing_mlm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import math 6 | import numpy as np 7 | import datasets 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from tqdm import tqdm 13 | from copy import deepcopy 14 | from datasets import Dataset 15 | from transformers import BertTokenizer, BertForMaskedLM 16 | from transformers import DataCollatorForLanguageModeling, Trainer 17 | 18 | 19 | pairs_0 = [1,2,3,4,1,2,3,1,2,1] 20 | pairs_1 = [2,3,4,5,3,4,5,4,5,5] 21 | 22 | lambdas = np.arange(21) / 20 23 | 24 | 25 | class PermAllTrainer(Trainer): 26 | def __init__(self, *args, **kwargs): 27 | self.out_proj = kwargs.pop('out_proj') 28 | super().__init__(*args, **kwargs) 29 | 30 | def compute_loss(self, model, inputs, return_outputs=False): 31 | 32 | # this circumvents issues with tied weights via saving out_proj 33 | def mini_forward(data): 34 | x = model(**data, output_hidden_states=True)['hidden_states'][-1] 35 | x =model.cls.predictions.transform.dense(x) 36 | x =model.cls.predictions.transform.transform_act_fn(x) 37 | x =model.cls.predictions.transform.LayerNorm(x) 38 | device =model.cls.predictions.decoder.bias.device 39 | x = F.linear(x, self.out_proj.to(device)) +model.cls.predictions.decoder.bias 40 | return x 41 | 42 | # forward pass 43 | logits = mini_forward(inputs) 44 | labels = inputs.get('labels') 45 | # compute custom loss (suppose one has 3 labels with different weights) 46 | loss_fct = nn.CrossEntropyLoss() 47 | mlm_loss = loss_fct(logits.view(-1, model.config.vocab_size), labels.view(-1)) 48 | return (mlm_loss, logits) if return_outputs else mlm_loss 49 | 50 | def wikitext_ppl_all(tokenizer, model, out_proj): 51 | 52 | wikitext = datasets.load_dataset('wikitext','wikitext-103-raw-v1', split='validation') 53 | block_size=128 54 | 55 | def preprocess_function(examples): 56 | return tokenizer([" ".join(x) for x in examples["text"]]) 57 | def group_texts(examples): 58 | # Concatenate all texts. 59 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 60 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 61 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 62 | # customize this part to your needs. 63 | if total_length >= block_size: 64 | total_length = (total_length // block_size) * block_size 65 | # Split by chunks of block_size. 66 | result = { 67 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 68 | for k, t in concatenated_examples.items() 69 | } 70 | return result 71 | 72 | tokenized_wikitext = wikitext.map( 73 | preprocess_function, 74 | batched=True, 75 | num_proc=4, 76 | remove_columns=wikitext.column_names, 77 | ) 78 | 79 | lm_dataset = tokenized_wikitext.map(group_texts, batched=True, num_proc=4) 80 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) 81 | trainer=PermAllTrainer( 82 | model=model, 83 | eval_dataset=lm_dataset, 84 | data_collator=data_collator, 85 | out_proj=out_proj 86 | ) 87 | eval_results = trainer.evaluate() 88 | print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}") 89 | return math.exp(eval_results['eval_loss']) 90 | 91 | 92 | def wikitext_ppl(tokenizer, model): 93 | 94 | wikitext = datasets.load_dataset('wikitext','wikitext-103-raw-v1', split='validation') 95 | block_size=128 96 | 97 | def preprocess_function(examples): 98 | return tokenizer([" ".join(x) for x in examples["text"]]) 99 | def group_texts(examples): 100 | # Concatenate all texts. 101 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 102 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 103 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 104 | # customize this part to your needs. 105 | if total_length >= block_size: 106 | total_length = (total_length // block_size) * block_size 107 | # Split by chunks of block_size. 108 | result = { 109 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 110 | for k, t in concatenated_examples.items() 111 | } 112 | return result 113 | 114 | tokenized_wikitext = wikitext.map( 115 | preprocess_function, 116 | batched=True, 117 | num_proc=4, 118 | remove_columns=wikitext.column_names, 119 | ) 120 | 121 | lm_dataset = tokenized_wikitext.map(group_texts, batched=True, num_proc=4) 122 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) 123 | 124 | trainer=Trainer( 125 | model=model, 126 | eval_dataset=lm_dataset, 127 | data_collator=data_collator 128 | ) 129 | eval_results = trainer.evaluate() 130 | print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}") 131 | return math.exp(eval_results['eval_loss']) 132 | 133 | 134 | 135 | def get_merged_state_dict(state_dict_1, state_dict_2, w=0.5): 136 | """ 137 | Post transformations, obtain state dictionary for merged model by linearly interpolating between 138 | transformed models in each graph. By default all parameters are averaged, but if given an interp_w 139 | weight, will be weightedly averaged instead. 140 | - interp_w (Optional): If None, all parameters of each model is averaged for merge. Otherwise, 141 | interp_w is a list of len(num_models_to_merge), with weights bearing the importance of incorporating 142 | features from each model into the merged result. 143 | Returns: state dict of merged model. 144 | """ 145 | state_dict = {} 146 | merged_state_dict = deepcopy(state_dict_1) 147 | keys = list(state_dict_1.keys()) 148 | try: 149 | for key in keys: 150 | if key in merged_state_dict: 151 | param = state_dict_1[key] 152 | if param.shape == merged_state_dict[key].shape: 153 | new_value = state_dict_1[key] * w + state_dict_2[key] * (1-w) 154 | state_dict[key] = new_value 155 | except RuntimeError as e: 156 | # Only catch runtime errors about tensor sizes, we need to be able to add models with diff heads together 157 | if 'size' not in str(e): 158 | raise e 159 | return state_dict 160 | 161 | 162 | def main(): 163 | 164 | vanilla_lambda_lists = {} 165 | 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('--originals', action='store_true', required=False, default=False) 168 | parser.add_argument('--path') 169 | parser.add_argument('--merge-type') 170 | parser.add_argument('--outfile') 171 | parser.add_argument('--train-frac', default=0.1, type=float) 172 | parser.add_argument('--dataset', default='wikitext', type=str) 173 | parser.add_argument('--unmerge', action='store_true', required=False, default=False) 174 | 175 | args = parser.parse_args() 176 | 177 | ROOT_PATH = f'models/' 178 | placeholder_tokenizer = BertTokenizer.from_pretrained('google/multiberts-seed_0') 179 | placeholder_model = BertForMaskedLM.from_pretrained('google/multiberts-seed_0') 180 | 181 | train_frac = args.train_frac 182 | for i in range(10): 183 | if args.originals == True: 184 | modeldict1 = BertForMaskedLM.from_pretrained(f'google/multiberts-seed_{pairs_0[i]}').state_dict() 185 | modeldict2 = BertForMaskedLM.from_pretrained(f'google/multiberts-seed_{pairs_1[i]}').state_dict() 186 | else: 187 | new_path = os.path.join(ROOT_PATH, args.path, 'individual_models') 188 | 189 | model_file1 = f'match_tensors_permute_{args.merge_type}_0_multiberts-seed_{pairs_0[i]}_b8_mlm{train_frac}_{pairs_0[i]}_{pairs_1[i]}.pt' 190 | model_file2 = f'match_tensors_permute_{args.merge_type}_1_multiberts-seed_{pairs_1[i]}_b8_mlm{train_frac}_{pairs_0[i]}_{pairs_1[i]}.pt' 191 | model_path1 = os.path.join(new_path, model_file1) 192 | model_path2 = os.path.join(new_path, model_file2) 193 | modeldict1 = torch.load(model_path1) 194 | modeldict2 = torch.load(model_path2) 195 | if args.unmerge == True: 196 | unmerge_file = os.path.join(ROOT_PATH, args.path, 'unmerge', f'unmerge_mat_{pairs_0[i]}_{pairs_1[i]}.pt') 197 | unmerge_mat = torch.load(unmerge_file) 198 | 199 | 200 | vanilla_lambda_lists[i] = {} 201 | vanilla_lambda_lists[i][0] = [] 202 | 203 | for l in tqdm(lambdas): 204 | merged_statedict = get_merged_state_dict(modeldict1, modeldict2, w=l) 205 | placeholder_model.load_state_dict(merged_statedict) 206 | if args.dataset == 'wikitext': 207 | if args.unmerge is True: 208 | new_proj = l * modeldict1['cls.predictions.decoder.weight'] + (1-l) * modeldict2['cls.predictions.decoder.weight'] @ unmerge_mat 209 | vanilla_lambda_lists[i][0].append(wikitext_ppl_all(placeholder_tokenizer, placeholder_model, new_proj)) 210 | else: 211 | vanilla_lambda_lists[i][0].append(wikitext_ppl(placeholder_tokenizer, placeholder_model)) 212 | 213 | if os.path.exists('results/mlm') == False: 214 | os.makedirs('results/mlm') 215 | 216 | with open(f'results/mlm/{args.dataset}/{args.outfile}', 'w+') as out: 217 | json.dump(vanilla_lambda_lists, out) 218 | 219 | 220 | if __name__ == "__main__": 221 | main() 222 | 223 | -------------------------------------------------------------------------------- /experiment/merge_bert_classifiers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import random 5 | import torch 6 | import numpy as np 7 | 8 | from copy import deepcopy 9 | from tqdm.auto import tqdm 10 | 11 | from utils import * 12 | from model_merger import ModelMerge 13 | 14 | torch.manual_seed(0) 15 | random.seed(0) 16 | np.random.seed(0) 17 | 18 | 19 | save_dir_root = 'models' 20 | 21 | def run_auxiliary_experiment(merging_fn, merge_type, experiment_config, pairs, device, args): 22 | for pair in tqdm(pairs, desc='Evaluating Pairs...'): 23 | experiment_config = inject_pair_language(experiment_config, pair) 24 | 25 | 26 | if args.task != 'mlm': 27 | config = prepare_lang_config(experiment_config, type='lang', classifier=True) 28 | else: 29 | config = prepare_lang_config(experiment_config, type='lang', classifier=False) 30 | train_loaders = [config['data'][i]['train']['full'] for i in range(len(config['data']))] 31 | base_models = config['models']['bases'] 32 | 33 | Grapher = config['graph'] 34 | if args.task != 'mlm': 35 | # set classifier to true 36 | graphs = [Grapher(deepcopy(base_model), merge_type=merge_type, qk=args.qk, classifier=True).graphify() for base_model in base_models] 37 | else: 38 | # no classifier head, lm head 39 | graphs = [Grapher(deepcopy(base_model), merge_type=merge_type, qk=args.qk, classifier=False).graphify() for base_model in base_models] 40 | 41 | bsz = experiment_config['dataset'][0]['batch_size'] 42 | train_frac = experiment_config['dataset'][0]['train_fraction'] 43 | model1 = raw_config['model_names']['model1'].replace('/', '_') 44 | model2 = raw_config['model_names']['model2'].replace('/', '_') 45 | 46 | Merge = ModelMerge(*graphs, device=device) 47 | 48 | unmerge, cost_dict = Merge.transform( 49 | deepcopy(config['models']['new']), 50 | train_loaders, 51 | sentence_level=args.sentence_level, 52 | special_toks=args.special_toks, 53 | transform_fn=get_merging_fn(merging_fn), 54 | metric_classes=config['metric_fns'], 55 | permute_heads=args.permute_heads, 56 | ignore_heads=args.ignore_heads, 57 | save_both=args.save_both, 58 | merge_cls=args.merge_cls, 59 | no_absval=args.no_absval, 60 | saved_features=args.saved_feature_path, 61 | res_type=args.res_type, 62 | ) 63 | 64 | if args.sentence_level: 65 | param_tail = f'b{bsz}_{args.task}{train_frac}_sent' 66 | else: 67 | param_tail = f'b{bsz}_{args.task}{train_frac}' 68 | 69 | param_tail+=f'_{args.seed0}_{args.seed1}' 70 | 71 | save_dir = os.path.join(save_dir_root, args.task, args.exp_name) 72 | if os.path.exists(save_dir) == False: 73 | os.makedirs(save_dir) 74 | if os.path.exists(os.path.join(save_dir, 'individual_models')) == False: 75 | os.makedirs(os.path.join(save_dir, 'individual_models')) 76 | 77 | 78 | with open(f'{save_dir}/test_{merging_fn}_{merge_type}_{model1}_{model2}_{param_tail}.args', 'w+') as f_args: 79 | f_args.write(str(vars(args)) + '\n') 80 | f_args.write(str(experiment_config)) 81 | 82 | if args.save_both: 83 | torch.save(Merge.merged_model1.state_dict(), 84 | f'{save_dir}/test_{merging_fn}_{merge_type}_{model1}_{model2}_0_{param_tail}.pt') 85 | torch.save(Merge.merged_model2.state_dict(), 86 | f'{save_dir}/test_{merging_fn}_{merge_type}_{model1}_{model2}_1_{param_tail}.pt') 87 | else: 88 | torch.save(Merge.merged_model, 89 | f'{save_dir}/test_{merging_fn}_{merge_type}_{model1}_{model2}_{param_tail}.pt') 90 | torch.save(Merge.graphs[0].model.state_dict(), f'{save_dir}/individual_models/{merging_fn}_{merge_type}_0_{model1}_{param_tail}.pt') 91 | torch.save(Merge.graphs[1].model.state_dict(), f'{save_dir}/individual_models/{merging_fn}_{merge_type}_1_{model2}_{param_tail}.pt') 92 | 93 | if unmerge != None: 94 | if os.path.exists(os.path.join(save_dir, 'unmerge')) == False: 95 | os.makedirs(os.path.join(save_dir, 'unmerge')) 96 | torch.save(unmerge, f'{save_dir}/unmerge/unmerge_mat_{args.seed0}_{args.seed1}.pt') 97 | 98 | if os.path.exists(os.path.join(save_dir, 'costs')) == False: 99 | os.makedirs(os.path.join(save_dir, 'costs')) 100 | with open(f'{save_dir}/costs/costs_{args.seed0}_{args.seed1}.pt', 'w+') as costs_out: 101 | json.dump(cost_dict, costs_out) 102 | 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument( 110 | "--cfg",required=False,default='bert_books' 111 | ) 112 | parser.add_argument( 113 | '--task',default='mlm' 114 | ) 115 | parser.add_argument( 116 | "--seed0",default=0,type=int 117 | ) 118 | parser.add_argument( 119 | "--seed1",default=1,type=int 120 | ) 121 | parser.add_argument( 122 | "--train-frac",type=float, required=False 123 | ) 124 | parser.add_argument( 125 | "--bsz",type=int, required=False,default=8 126 | ) 127 | parser.add_argument( 128 | "--special-toks",required=False,action='store_true' 129 | ) 130 | parser.add_argument( 131 | "--permute-heads",required=False,action='store_true' 132 | ) 133 | parser.add_argument( 134 | "--ignore-heads",required=False,action='store_true' 135 | ) 136 | parser.add_argument( 137 | "--exp-name", 138 | ) 139 | parser.add_argument( 140 | "--merge-type", # one of ff_only, res_only, ff+res, ff+attn, attn_only, res+attn, all 141 | ) 142 | parser.add_argument( 143 | '--qk',required=False,action='store_true' 144 | ) 145 | parser.add_argument( 146 | '--save-both',action='store_true' 147 | ) 148 | parser.add_argument( 149 | '--merging-fn',default='match_tensors_permute' 150 | ) 151 | parser.add_argument( 152 | '--merge-cls',action='store_true' 153 | ) 154 | parser.add_argument( 155 | '--sentence-level',required=False 156 | ) 157 | parser.add_argument( 158 | '--no-absval',action='store_true' 159 | ) 160 | parser.add_argument( 161 | '--res-type',required=False,default='first' # one of first, last, sep, all 162 | ) 163 | parser.add_argument( 164 | '--saved-feature-path',required=False 165 | ) 166 | args = parser.parse_args() 167 | 168 | 169 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 170 | 171 | raw_config = get_config_from_name(args.cfg, device=device) 172 | model_dir = raw_config['model']['dir'] 173 | model_name = raw_config['model']['name'] 174 | 175 | if args.task != 'mlm': 176 | raw_config['dataset'][0]['task'] = args.task 177 | raw_config['model_names']['model1'] = f'{args.task}/seed_{args.seed0}' 178 | raw_config['model_names']['model2'] = f'{args.task}/seed_{args.seed1}' 179 | else: 180 | raw_config['model_names']['model1'] = f'multiberts-seed_{args.seed0}' 181 | raw_config['model_names']['model2'] = f'multiberts-seed_{args.seed1}' 182 | 183 | 184 | if args.bsz: 185 | for i in range(len(raw_config['dataset'])): 186 | raw_config['dataset'][i]['batch_size'] = args.bsz 187 | if args.train_frac: 188 | raw_config['dataset'][0]['train_fraction'] = args.train_frac 189 | 190 | 191 | run_pairs = [(raw_config['model_names']['model1'], raw_config['model_names']['model2'])] 192 | 193 | 194 | print(raw_config['model_names']['model1']) 195 | 196 | 197 | with torch.no_grad(): 198 | node_results = run_auxiliary_experiment( 199 | merging_fn=args.merging_fn, 200 | merge_type=args.merge_type, 201 | experiment_config=raw_config, 202 | pairs=run_pairs, 203 | device=device, 204 | args=args 205 | ) -------------------------------------------------------------------------------- /experiment/run_component_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this script is to run the experiments by components merged. This corresponds to figure 3 and figure 4 in our paper. 4 | 5 | # all pairs 6 | pairs_0=(1 2 3 4 1 2 3 1 2 1) 7 | pairs_1=(2 3 4 5 3 4 5 4 5 5) 8 | 9 | cd ../ 10 | 11 | 12 | # identity permutation = vanilla merging. 13 | # we use a very small train fraction to speed up the experiments as no data is needed here 14 | for i in 0 1 2 3 4 5 6 7 8 9; do 15 | seed0=${pairs_0[$i]} 16 | seed1=${pairs_1[$i]} 17 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac 0.0001 --bsz 8 --special-toks --permute-heads --no-absval --merge-type ff+attn --merging-fn match_tensors_identity --exp-name vanilla_merge 18 | done 19 | 20 | # by component 21 | train_frac=0.1 22 | for type in ff_only attn_only ff+attn; do 23 | # pairs 24 | for i in 0 1 2 3 4 5 6 7 8 9; do 25 | seed0=${pairs_0[$i]} 26 | seed1=${pairs_1[$i]} 27 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --permute-heads --no-absval --merge-type $type --merging-fn match_tensors_permute --exp-name perm_${type}_frac${train_frac} 28 | done 29 | done -------------------------------------------------------------------------------- /experiment/run_data_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this script is to run the data ablation experiments. It runs through all 10 pairs, and all levels of data fractions we consider. This corresponds to figure 6 in our paper 4 | 5 | # all pairs 6 | pairs_0=(1 2 3 4 1 2 3 1 2 1) 7 | pairs_1=(2 3 4 5 3 4 5 4 5 5) 8 | 9 | cd ../ 10 | 11 | # data fractions 12 | for train_frac in 0.001 0.005 0.01 0.05 0.1 0.5 1.0; do 13 | # pairs 14 | for i in 0 1 2 3 4 5 6 7 8 9; do 15 | seed0=${pairs_0[$i]} 16 | seed1=${pairs_1[$i]} 17 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --permute-heads --no-absval --merge-type ff+attn --merging-fn match_tensors_permute --exp-name perm_ff+attn_frac${train_frac}_recheck 18 | done 19 | done -------------------------------------------------------------------------------- /experiment/run_glue_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this script is to run the glue experiments. This requires fine-tuning the glue models first. This corresponds to table 3 in our paper 4 | 5 | # all pairs 6 | pairs_0=(1 2 3 4 1 2 3 1 2 1) 7 | pairs_1=(2 3 4 5 3 4 5 4 5 5) 8 | 9 | cd ../ 10 | 11 | train_frac=1 12 | type=all 13 | 14 | 15 | # by residual 16 | for task in mnli qqp qnli sst2 cola stsb mrpc rte; do 17 | for i in 0 1 2 3 4 5 6 7 8 9; do 18 | seed0=${pairs_0[$i]} 19 | seed1=${pairs_1[$i]} 20 | 21 | # small data because this is just vanilla merging 22 | python -m experiment.merge_bert_classifiers --cfg bert_classifier --task $task --seed0 $seed0 --seed1 $seed1 --train-frac 0.001 --bsz 8 --special-toks --permute-heads --no-absval --merge-type $type --merging-fn match_tensors_identity --exp-name ${task}_vanilla_merge 23 | 24 | python -m experiment.merge_bert_classifiers --cfg bert_classifier --task $task --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --no-absval --permute-heads --exp-name ${task}_permute_${type}_frac${train_frac} --merge-type $type --merging-fn match_tensors_permute 25 | done 26 | done 27 | -------------------------------------------------------------------------------- /experiment/run_mha_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this script is to run the experiments by multi-headed attention merger type. This corresponds to table 1 in our paper 4 | 5 | # all pairs 6 | pairs_0=(1 2 3 4 1 2 3 1 2 1) 7 | pairs_1=(2 3 4 5 3 4 5 4 5 5) 8 | 9 | cd ../ 10 | 11 | train_frac=0.1 12 | 13 | for i in 0 1 2 3 4 5 6 7 8 9; do 14 | seed0=${pairs_0[$i]} 15 | seed1=${pairs_1[$i]} 16 | 17 | # identity MHA permutation = ff_only 18 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac $train_frac --bsz 8 --special-toks --no-absval --merge-type ff_only --merging-fn match_tensors_identity --exp-name perm_ff_only_frac${train_frac} 19 | 20 | # Monotonic head alignment = --permute-heads off 21 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --no-absval --merge-type ff+attn --merging-fn match_tensors_permute --exp-name perm_ff+attn_frac${train_frac}_noperm 22 | 23 | # Ignore Heads = --ignore-heads on 24 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --ignore-heads --no-absval --merge-type ff+attn --merging-fn match_tensors_permute --exp-name perm_ff+attn_frac${train_frac}_ignoreperm 25 | 26 | # Permute Heads = --permute-heads on 27 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --permute-heads --no-absval --merge-type ff+attn --merging-fn match_tensors_permute --exp-name perm_ff+attn_frac${train_frac} 28 | done 29 | -------------------------------------------------------------------------------- /experiment/run_res_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this script is to run the experiments by residual merger type. This corresponds to table 2 in our paper 4 | 5 | # all pairs 6 | pairs_0=(1 2 3 4 1 2 3 1 2 1) 7 | pairs_1=(2 3 4 5 3 4 5 4 5 5) 8 | 9 | cd ../ 10 | 11 | train_frac=0.1 12 | 13 | # identity residual permutation = vanilla merging. 14 | for i in 0 1 2 3 4 5 6 7 8 9; do 15 | seed0=${pairs_0[$i]} 16 | seed1=${pairs_1[$i]} 17 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac 0.0001 --bsz 8 --special-toks --permute-heads --no-absval --merge-type res_only --merging-fn match_tensors_identity --exp-name vanilla_merge 18 | done 19 | 20 | # by residual 21 | for res_type in first last all sep; do 22 | for i in 0 1 2 3 4 5 6 7 8 9; do 23 | seed0=${pairs_0[$i]} 24 | seed1=${pairs_1[$i]} 25 | 26 | python -m experiment.merge_bert_classifiers --cfg bert_books --task mlm --seed0 $seed0 --seed1 $seed1 --train-frac ${train_frac} --bsz 8 --special-toks --no-absval --res-type ${res_type} --permute-heads --exp-name res_only_frac${train_frac}_res_${res_type} --merge-type res_only --merging-fn match_tensors_permute 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /graphs/base_graph.py: -------------------------------------------------------------------------------- 1 | '''' 2 | This file is included from the https://github.com/gstoica27/ZipIt repository. 3 | ''' 4 | 5 | import torch 6 | import networkx as nx 7 | from enum import Enum 8 | from abc import ABC, abstractmethod 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | class FeatureReshapeHandler: 13 | """ Instructions to reshape layer intermediates for alignment metric computation. """ 14 | def handle_conv2d(self, x): 15 | # reshapes conv2d representation from [B, C, H, W] to [C, -1] 16 | B, C, H, W = x.shape 17 | return x.permute(1, 0, 2, 3).reshape(C, -1) 18 | 19 | def handle_linear(self, x): 20 | # x is shape [seq_len, batch_szd, C]. Want [C, -1] 21 | x = x.flatten(0, len(x.shape)-2).transpose(1, 0).contiguous() 22 | return x 23 | 24 | def __init__(self, class_name, info): 25 | self.handler = { 26 | 'BatchNorm2d': self.handle_conv2d, 27 | 'LayerNorm': self.handle_linear, 28 | 'Conv2d': self.handle_conv2d, 29 | 'Linear': self.handle_linear, 30 | 'GELU': self.handle_linear, 31 | 'AdaptiveAvgPool2d': self.handle_conv2d, 32 | 'LeakyReLU': self.handle_conv2d, 33 | 'ReLU': self.handle_conv2d, 34 | 'Tanh': self.handle_conv2d, 35 | 'MaxPool2d': self.handle_conv2d, 36 | 'AvgPool2d': self.handle_conv2d, 37 | 'SpaceInterceptor': self.handle_conv2d, 38 | 'Identity': self.handle_linear, 39 | 40 | }[class_name] 41 | self.info = info 42 | 43 | def reshape(self, x): 44 | x = self.handler(x) 45 | 46 | # Handle modules that we only want a piece of 47 | if self.info['chunk'] is not None: 48 | idx, num_chunks = self.info['chunk'] 49 | x = x.chunk(num_chunks, dim=0)[idx] 50 | 51 | return x 52 | 53 | 54 | class NodeType(Enum): 55 | MODULE = 0 # node is torch module 56 | PREFIX = 1 # node is a PREFIX (i.e., we want to hook inputs to child node) 57 | POSTFIX = 2 # node is a POSTFIX (i.e., we want to hook outputs to parent node) 58 | SUM = 3 # node is a SUM (e.g., point where residual connections are connected - added) 59 | CONCAT = 4 # node is a CONCATENATION (e.g, point where residual connections are concatenated) 60 | INPUT = 5 # node is an INPUT (graph starting point) 61 | OUTPUT = 6 # node is an OUTPUT (graph output point) 62 | EMBEDDING = 7 # node is an embedding module (these can only be merged) 63 | 64 | 65 | class BIGGraph(ABC): 66 | def __init__(self, model): 67 | """Initialize DAG of computational flow for a model. """ 68 | self.reset_graph() 69 | self.named_modules = dict(model.named_modules()) 70 | self.named_params = dict(model.named_parameters()) 71 | self.model = model 72 | self.intermediates = {} 73 | self.hooks = [] 74 | 75 | # working info about nodes for the merging algorithms 76 | # clear after use! 77 | self.working_info = {} 78 | 79 | self.unmerged = set() 80 | self.merged = set() 81 | 82 | def reset_graph(self): 83 | """ Create New Graph. """ 84 | self.G = nx.DiGraph() 85 | 86 | def preds(self, node): 87 | """ Get predessors from a node (layer). """ 88 | return list(self.G.pred[node]) 89 | 90 | def succs(self, node): 91 | """ Get successors from a node (layer). """ 92 | return list(self.G.succ[node]) 93 | 94 | def get_node_info(self, node_name): 95 | """ Get attribute dict from node. """ 96 | return self.G.nodes()[node_name] 97 | 98 | def get_module_from_node(self, node_name): 99 | """ Get pytorch module associated with node. """ 100 | info = self.get_node_info(node_name) 101 | if info['type'] == NodeType.MODULE: 102 | return self.named_modules[info["layer"]] 103 | else: 104 | raise ValueError(f"Tried to get module from {node_name} of type {info['type']}.") 105 | 106 | def get_module(self, module_name): 107 | """ Get module parameters. """ 108 | return self.named_modules[module_name] 109 | 110 | def get_parameter(self, param_name): 111 | """ Get parameter from name. """ 112 | return self.named_params[param_name] 113 | 114 | def get_node_str(self, node_name): 115 | """ Get node type name. """ 116 | info = self.get_node_info(node_name) 117 | 118 | if info['type'] == NodeType.MODULE: 119 | return self.get_module_from_node(node_name).__class__.__name__ 120 | else: 121 | return info['type'].name 122 | 123 | def create_node_name(self): 124 | """woo magic. A robust id generator """ 125 | return len(self.G) 126 | 127 | def create_node(self, 128 | node_name=None, 129 | layer_name=None, 130 | param_name=None, 131 | node_type=NodeType.MODULE, 132 | chunk=None, 133 | special_merge=None): 134 | """ 135 | Create node to be added to graph. All arguments are optional, but 136 | specify different kinds of node properties. 137 | Arguments: 138 | - node_name: unique identifier for a node. If None, a unique id will be generated 139 | via complex hashing function. 140 | - layer_name: name of pytorch module node represents. layer_name MUST match the module name. 141 | - node_type: type of node created. By default it is a MODULE (pytorch module), but can also 142 | commonly be POSTFIX or PREFIX. These latter specify the node to a place where an alignment 143 | between models will be computed and applied. 144 | - chunk: Whether node represents a disjoint part of a module which other nodes also are a part of. 145 | Chunk is (i, total). 146 | - special_merge: Whether to apply a specific merge/unmerge operation on at this node specially. 147 | If none, the transform_fn from model_merger will be applied. 148 | """ 149 | if node_name is None: 150 | node_name = self.create_node_name() 151 | self.G.add_nodes_from([(node_name, { 152 | 'layer': layer_name, 153 | 'type': node_type, 154 | 'param': param_name, 155 | 'chunk': chunk, 156 | 'special_merge': special_merge 157 | })]) 158 | return node_name 159 | 160 | def add_directed_edge(self, source, target, **kwargs): 161 | """ Add an edge from source node to target node. """ 162 | self.G.add_edge(source, target, **kwargs) 163 | 164 | def add_nodes_from_sequence(self, name_prefix, list_of_names, input_node, sep='.'): 165 | """ 166 | Add multiple nodes in sequence by creating them and adding edges between each. 167 | Args: 168 | - name_prefix: Least common ancestor module name string all nodes share. 169 | Usually this is the name of a nn.Sequential layer 170 | - list_of_names: list of module names. Can be the ordered module names in an nn.Sequential. 171 | - input_node: source node the sequence is attached to. 172 | Returns: 173 | - output sequence node. 174 | """ 175 | source_node = input_node 176 | for name in list_of_names: 177 | if isinstance(name, str): 178 | if name_prefix != '': 179 | temp_node = self.create_node(layer_name=name_prefix + f'{sep}{name}') 180 | else: 181 | temp_node = self.create_node(layer_name= f'{name}') 182 | else: 183 | temp_node = self.create_node(node_type=name) 184 | self.add_directed_edge(source_node, temp_node) 185 | source_node = temp_node 186 | return source_node 187 | 188 | def print_prefix(self): 189 | """ Print (POST/PRE)FIX node inputs and outputs. """ 190 | for node in self.G: 191 | info = self.get_node_info(node) 192 | if info['type'] in (NodeType.PREFIX, NodeType.POSTFIX): 193 | print(f'{node:3} in={len(self.preds(node))}, out={len(self.succs(node))}') 194 | 195 | def draw(self, nodes=None, save_path=None): 196 | """ 197 | Visualize DAG. By default all nodes are colored gray, but if parts of module have already been 198 | transformed, they will be colored according to the kinds of transformations applied on the nodes. 199 | Color Rubric: 200 | - Gray: Not merged or unmerged (output space is not aligned and input space is not aligned) 201 | - Blue: merged but not unmerged (output space is aligned, but input space is not aligned) 202 | - Red: Not merged but unmerged (output space is not aligned, but input space is aligned) 203 | - Pink: Merged and Unmerged (output space is aligned, and input space is aligned) 204 | 205 | Args: 206 | - nodes (optional): list of indices of nodes to visualize. If None, all nodes will be drawn. 207 | - save_path (optional): path in which to save graph. 208 | """ 209 | G = self.G 210 | if nodes is not None: 211 | G = nx.subgraph(G, list(nodes)) 212 | 213 | labels = {i: f'[{i}] ' + self.get_node_str(i) for i in G} 214 | pos = nx.nx_agraph.graphviz_layout(G, prog='neato') 215 | node_size = [len(labels[i])**2 * 60 for i in G] 216 | 217 | colors = { 218 | (False, False): (180, 181, 184), 219 | (True, False): (41, 94, 255), 220 | (False, True): (255, 41, 91), 221 | (True, True): (223, 41, 255), 222 | } 223 | 224 | for k, v in colors.items(): 225 | colors[k] = tuple(map(lambda x: x / 255., v)) 226 | 227 | node_color = [colors[(node in self.merged, node in self.unmerged)] for node in G] 228 | plt.figure(figsize=(120, 160)) 229 | nx.draw_networkx(G, pos=pos, labels=labels, node_size=node_size, node_color=node_color) 230 | if save_path is not None: 231 | plt.savefig(save_path) 232 | plt.show() 233 | 234 | def add_hooks(self, device=0): 235 | """ Propogates PREFIX and postfix POSTFIX. """ 236 | self.clear_hooks() 237 | 238 | for node in self.G: 239 | info = self.get_node_info(node) 240 | 241 | if info['type'] == NodeType.PREFIX: 242 | for succ_node in self.G.succ[node]: 243 | succ_info = self.get_node_info(succ_node) 244 | if succ_info['type'] == NodeType.MODULE: 245 | 246 | def prehook(m, x, this_node=node, this_info=succ_info): 247 | self.intermediates[this_node] = \ 248 | FeatureReshapeHandler(m.__class__.__name__, this_info).reshape( 249 | x[0].detach().to(device) 250 | ) 251 | return None 252 | 253 | module = self.get_module(succ_info['layer']) 254 | self.hooks.append(module.register_forward_pre_hook(prehook)) 255 | break 256 | elif succ_info['type'] == NodeType.EMBEDDING: 257 | def prehook(m, x, this_node=node, this_info=succ_info): 258 | tensor = self.get_parameter(this_info['param']).data 259 | tensor = tensor.flatten(0, len(x[0].shape)-2).transpose(1, 0).contiguous() 260 | self.intermediates[this_node] = tensor 261 | return None 262 | 263 | module = self.get_module(succ_info['layer']) 264 | self.hooks.append(module.register_forward_pre_hook(prehook)) 265 | break 266 | else: 267 | raise RuntimeError(f"PREFIX node {node} had no module to attach to.") 268 | 269 | elif info['type'] == NodeType.POSTFIX: 270 | 271 | for pred_node in self.G.pred[node]: 272 | pred_info = self.get_node_info(pred_node) 273 | 274 | if pred_info['type'] == NodeType.MODULE: 275 | 276 | def posthook(m, x, y, this_node=node, this_info=pred_info): 277 | self.intermediates[this_node] = \ 278 | FeatureReshapeHandler(m.__class__.__name__, this_info).reshape( 279 | y.detach().to(device) 280 | ) 281 | return None 282 | 283 | module = self.get_module(pred_info['layer']) 284 | self.hooks.append(module.register_forward_hook(posthook)) 285 | break 286 | 287 | elif pred_info['type'] == NodeType.EMBEDDING: 288 | # If this is an embedding, we need to populate intermediates with the corresponding 289 | # parameter every time the network is executed. 290 | 291 | def prehook(m, x, this_node=node, this_info=pred_info): 292 | tensor = self.get_parameter(this_info['param']).data 293 | tensor = tensor.flatten(0, len(x[0].shape)-2).transpose(1, 0).contiguous() 294 | self.intermediates[this_node] = tensor 295 | return None 296 | 297 | module = self.get_module(pred_info['layer']) 298 | self.hooks.append(module.register_forward_pre_hook(prehook)) 299 | 300 | else: 301 | raise RuntimeError(f"POSTFIX node {node} had no module to attach to.") 302 | 303 | 304 | def clear_hooks(self): 305 | """ Clear graph hooks. """ 306 | for hook in self.hooks: 307 | hook.remove() 308 | self.hooks = [] 309 | 310 | 311 | def compute_intermediates(self, x, attn_mask=None): 312 | """ Computes all intermediates in a graph network. Takes in a torch tensor (e.g., a batch). """ 313 | self.model = self.model.eval() 314 | # this uses the hooks added in add_hooks() 315 | with torch.no_grad(), torch.cuda.amp.autocast(): 316 | self.intermediates = {} 317 | if attn_mask != None: 318 | self.model(x, attention_mask=attn_mask) 319 | else: 320 | self.model(x) 321 | return self.intermediates 322 | 323 | 324 | @abstractmethod 325 | def graphify(self): 326 | """ 327 | Abstract method. This function is implemented by your architecture graph file, and is what actually 328 | creates the graph for your model. 329 | """ 330 | return NotImplemented 331 | 332 | -------------------------------------------------------------------------------- /graphs/transformer_enc_graph.py: -------------------------------------------------------------------------------- 1 | from graphs.base_graph import BIGGraph, NodeType 2 | import torch 3 | 4 | class TransformerEncoderGraph(BIGGraph): 5 | 6 | def __init__(self, model, 7 | modules, 8 | layer_name='', # for transformer 9 | enc_prefix='encoder', 10 | merge_type='ff_only', 11 | num_layers=12, 12 | num_heads=8, 13 | qk=False, 14 | name='bert', 15 | classifier=False): 16 | super().__init__(model) 17 | 18 | self.layer_name = layer_name 19 | self.enc_prefix = enc_prefix 20 | self.merge_type = merge_type 21 | self.num_layers = num_layers 22 | self.num_heads = num_heads 23 | self.modules = modules 24 | self.qk = qk 25 | self.name = name 26 | self.classifier = classifier 27 | 28 | 29 | def add_layerblock_nodes(self, name_prefix, input_node, merge_type): 30 | # first half 31 | modules = self.modules 32 | # do attention block here 33 | residual = input_node 34 | value_node = self.add_nodes_from_sequence(name_prefix, [modules['v']], residual) 35 | if self.qk: 36 | key_node = self.add_nodes_from_sequence(name_prefix, [modules['k'], NodeType.POSTFIX], residual) 37 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['q'], NodeType.POSTFIX, NodeType.SUM], residual) 38 | else: 39 | key_node = self.add_nodes_from_sequence(name_prefix, [modules['k']], residual) 40 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['q'], NodeType.SUM], residual) 41 | self.add_directed_edge(key_node, input_node) # add key to "SUM" - it is really just a product but same handler 42 | input_node = self.add_nodes_from_sequence(name_prefix, [NodeType.SUM], input_node) #sum (mult)node to outproj 43 | self.add_directed_edge(value_node, input_node) #value node to sum (mult) 44 | 45 | if merge_type == 'ff_only': 46 | # add self attn out proj to dot prod, layer norm, sum residual 47 | input_node = self.add_nodes_from_sequence(name_prefix, 48 | [modules['lin_attn'], NodeType.SUM], 49 | input_node) 50 | # add & norm 51 | self.add_directed_edge(residual, input_node) 52 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln']], input_node=input_node) 53 | 54 | # do second half with residual too 55 | residual = input_node 56 | input_node = self.add_nodes_from_sequence(name_prefix, 57 | [modules['fc1'], NodeType.PREFIX, modules['fc2'], NodeType.SUM], 58 | input_node=input_node) 59 | self.add_directed_edge(residual, input_node) 60 | 61 | if merge_type == 'res_only': 62 | # add self attn out proj to dot prod, layer norm, sum residual 63 | input_node = self.add_nodes_from_sequence(name_prefix, 64 | [modules['lin_attn'], NodeType.SUM], 65 | input_node) 66 | # add & norm 67 | self.add_directed_edge(residual, input_node) 68 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln'], NodeType.POSTFIX], input_node=input_node) 69 | 70 | # do second half with residual too 71 | residual = input_node 72 | input_node = self.add_nodes_from_sequence(name_prefix, 73 | [modules['fc1'], modules['fc2'], NodeType.SUM], 74 | input_node=input_node) 75 | self.add_directed_edge(residual, input_node) 76 | 77 | elif merge_type == 'ff+res': 78 | # add self attn out proj to dot prod, layer norm, sum residual 79 | # get first residual vector from after self attn layer norm 80 | input_node = self.add_nodes_from_sequence(name_prefix, 81 | [modules['lin_attn'], NodeType.SUM], 82 | input_node) 83 | # add & norm 84 | self.add_directed_edge(residual, input_node) 85 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln'], NodeType.POSTFIX], input_node=input_node) 86 | 87 | # do second half with residual too 88 | residual = input_node 89 | input_node = self.add_nodes_from_sequence(name_prefix, 90 | [modules['fc1'], NodeType.PREFIX, modules['fc2'], NodeType.SUM], 91 | input_node=input_node) 92 | self.add_directed_edge(residual, input_node) 93 | 94 | elif merge_type == 'ff+attn': 95 | # add self attn out proj to dot prod, layer norm, sum residual 96 | # get intermeds between attn and self attn out proj 97 | input_node = self.add_nodes_from_sequence(name_prefix, 98 | [NodeType.PREFIX, modules['lin_attn'], NodeType.SUM], 99 | input_node) 100 | # add & norm 101 | self.add_directed_edge(residual, input_node) 102 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln']], input_node=input_node) 103 | 104 | # do second half with residual too 105 | residual = input_node 106 | input_node = self.add_nodes_from_sequence(name_prefix, 107 | [modules['fc1'], NodeType.PREFIX, modules['fc2'], NodeType.SUM], 108 | input_node=input_node) 109 | self.add_directed_edge(residual, input_node) 110 | 111 | elif merge_type == 'attn_only': 112 | # add self attn out proj to dot prod, layer norm, sum residual 113 | # get intermeds between attn and self attn out proj 114 | input_node = self.add_nodes_from_sequence(name_prefix, 115 | [NodeType.PREFIX, modules['lin_attn'], NodeType.SUM], 116 | input_node) 117 | # add & norm 118 | self.add_directed_edge(residual, input_node) 119 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln']], input_node=input_node) 120 | 121 | # do second half with residual too 122 | residual = input_node 123 | input_node = self.add_nodes_from_sequence(name_prefix, 124 | [modules['fc1'], modules['fc2'], NodeType.SUM], 125 | input_node=input_node) 126 | self.add_directed_edge(residual, input_node) 127 | 128 | elif merge_type == 'res+attn': 129 | # add self attn out proj to dot prod, layer norm, sum residual 130 | # get intermeds between attn and self attn out proj 131 | input_node = self.add_nodes_from_sequence(name_prefix, 132 | [NodeType.PREFIX, modules['lin_attn'], NodeType.SUM], 133 | input_node) 134 | # add & norm 135 | self.add_directed_edge(residual, input_node) 136 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln'], NodeType.POSTFIX], input_node=input_node) 137 | 138 | # do second half with residual too 139 | residual = input_node 140 | input_node = self.add_nodes_from_sequence(name_prefix, 141 | [modules['fc1'], modules['fc2'], NodeType.SUM], 142 | input_node=input_node) 143 | self.add_directed_edge(residual, input_node) 144 | 145 | 146 | elif merge_type == 'all': 147 | # add self attn out proj to dot prod, layer norm, sum residual 148 | # get intermeds between attn and self attn out proj 149 | # get first residual vector from after self attn layer norm 150 | input_node = self.add_nodes_from_sequence(name_prefix, 151 | [NodeType.PREFIX, modules['lin_attn'], NodeType.SUM], 152 | input_node) 153 | # add & norm 154 | self.add_directed_edge(residual, input_node) 155 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['attn_ln'], NodeType.POSTFIX], input_node=input_node) 156 | 157 | # do second half with residual too 158 | residual = input_node 159 | input_node = self.add_nodes_from_sequence(name_prefix, 160 | [modules['fc1'], NodeType.PREFIX, modules['fc2'], NodeType.SUM], 161 | input_node=input_node) 162 | self.add_directed_edge(residual, input_node) 163 | 164 | if merge_type in ['all', 'ff+res', 'res_only', 'res+attn']: 165 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['final_ln'], NodeType.POSTFIX], input_node=input_node) 166 | else: 167 | input_node = self.add_nodes_from_sequence(name_prefix, [modules['final_ln']], input_node=input_node) 168 | return input_node 169 | 170 | def add_layer_nodes(self, layer_prefix, input_node, merge_type): 171 | source_node = input_node 172 | 173 | for layer_index in range(self.num_layers): # for graph visualization 174 | #for layer_index, layerblock in enumerate(self.get_module(name_prefix)): 175 | source_node = self.add_layerblock_nodes(layer_prefix+f'.{layer_index}', source_node, merge_type) 176 | return source_node 177 | 178 | def graphify(self): 179 | modules = self.modules 180 | # keep input node 181 | input_node = self.create_node(node_type=NodeType.INPUT) 182 | # input_node -> emb_tok 183 | emb_name = modules['emb'] 184 | emb_node = self.create_node(node_type=NodeType.EMBEDDING, 185 | layer_name=f'{self.enc_prefix}.{emb_name}'.strip('.'), 186 | param_name=f'{self.enc_prefix}.{emb_name}.weight'.strip('.')) 187 | self.add_directed_edge(input_node, emb_node) 188 | 189 | # removing emb_pos node for now... 190 | input_node = self.add_nodes_from_sequence(self.enc_prefix, [modules['emb_ln']], emb_node) 191 | 192 | if self.merge_type in ['all', 'ff+res', 'res_only']: 193 | #adding postfix to emb_ln, before xformer layers 194 | input_node = self.add_nodes_from_sequence(self.enc_prefix, [NodeType.POSTFIX], input_node) 195 | 196 | # layernorm_embedding -> xformer layers 197 | input_node = self.add_layer_nodes(f'{self.layer_name}', input_node, self.merge_type) 198 | 199 | # xformer layers -> dense -> layernorm -> output 200 | if self.name == 'bert' and self.classifier == False: 201 | dense_node = self.add_nodes_from_sequence(modules['head_pref'], ['transform.dense', 'transform.LayerNorm', NodeType.PREFIX, 'decoder'], input_node) 202 | output_node = self.create_node(node_type=NodeType.OUTPUT) 203 | self.add_directed_edge(dense_node, output_node) 204 | elif self.name == 'bert' and self.classifier == True: 205 | pool_node = self.add_nodes_from_sequence(self.enc_prefix, [modules['pooler']], input_node) 206 | class_node = self.add_nodes_from_sequence('', [NodeType.PREFIX, modules['classifier']], pool_node) 207 | output_node = self.create_node(node_type=NodeType.OUTPUT) 208 | self.add_directed_edge(class_node, output_node) 209 | elif self.name == 'roberta': 210 | #dense_node = self.add_nodes_from_sequence(modules['head_pref'], ['dense', NodeType.PREFIX, 'out_proj'], input_node) 211 | output_node = self.create_node(node_type=NodeType.OUTPUT) 212 | self.add_directed_edge(input_node, output_node) 213 | 214 | return self 215 | 216 | 217 | def bert(model, merge_type='ff_only', qk=False, classifier=False): 218 | modules = {'emb': 'embeddings.word_embeddings', 219 | 'emb_pos': 'embeddings.position_embeddings', 220 | 'emb_tok_type': 'embeddings.token_type_embeddings', 221 | 'emb_ln': 'embeddings.LayerNorm', 222 | 'q': 'attention.self.query', 223 | 'k': 'attention.self.key', 224 | 'v': 'attention.self.value', 225 | 'lin_attn': 'attention.output.dense', 226 | 'attn_ln': 'attention.output.LayerNorm', 227 | 'fc1': 'intermediate.dense', 228 | 'fc2': 'output.dense', 229 | 'final_ln': 'output.LayerNorm', 230 | 'head_pref': 'cls.predictions', 231 | 'pooler': 'pooler.dense', 232 | 'classifier': 'classifier'} 233 | return TransformerEncoderGraph(model, 234 | modules, 235 | layer_name='bert.encoder.layer', 236 | enc_prefix='bert', 237 | merge_type=merge_type, 238 | num_layers=12, 239 | num_heads=12, 240 | qk=qk, 241 | name='bert', 242 | classifier=classifier) 243 | 244 | 245 | 246 | ''' 247 | checks if two state_dicts are the same. Used for debugging purposes. 248 | reference: https://gist.github.com/rohan-varma/a0a75e9a0fbe9ccc7420b04bff4a7212 249 | ''' 250 | def validate_state_dicts(model_state_dict_1, model_state_dict_2): 251 | if len(model_state_dict_1) != len(model_state_dict_2): 252 | print( 253 | f"Length mismatch: {len(model_state_dict_1)}, {len(model_state_dict_2)}" 254 | ) 255 | return False 256 | 257 | # Replicate modules have "module" attached to their keys, so strip these off when comparing to local model. 258 | if next(iter(model_state_dict_1.keys())).startswith("module"): 259 | model_state_dict_1 = { 260 | k[len("module") + 1 :]: v for k, v in model_state_dict_1.items() 261 | } 262 | 263 | if next(iter(model_state_dict_2.keys())).startswith("module"): 264 | model_state_dict_2 = { 265 | k[len("module") + 1 :]: v for k, v in model_state_dict_2.items() 266 | } 267 | 268 | for ((k_1, v_1), (k_2, v_2)) in zip( 269 | model_state_dict_1.items(), model_state_dict_2.items() 270 | ): 271 | if k_1 != k_2: 272 | print(f"Key mismatch: {k_1} vs {k_2}") 273 | return False 274 | # convert both to the same CUDA device 275 | if str(v_1.device) != "cuda": 276 | v_1 = v_1.to("cuda:0" if torch.cuda.is_available() else "cpu") 277 | if str(v_2.device) != "cuda": 278 | v_2 = v_2.to("cuda" if torch.cuda.is_available() else "cpu") 279 | 280 | if not torch.allclose(v_1, v_2, atol=1e-03): 281 | print(k_1) 282 | print(f"Tensor mismatch: {v_1} vs {v_2}") 283 | 284 | 285 | -------------------------------------------------------------------------------- /matching_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy 3 | import sys 4 | import numpy as np 5 | import pdb 6 | from collections import defaultdict 7 | 8 | ##################################################################################################################################### 9 | ############################################################## HELPERS ############################################################## 10 | ##################################################################################################################################### 11 | 12 | def remove_col(x, idx): 13 | return torch.cat([x[:, :idx], x[:, idx+1:]], dim=-1) 14 | 15 | def compute_correlation(covariance, eps=1e-7): 16 | covariance = torch.nan_to_num(covariance) 17 | std = torch.diagonal(covariance).sqrt() # there can be some infs in the covariance matrix 18 | covariance = covariance / (torch.clamp(torch.nan_to_num(torch.outer(std, std)),min=eps)) 19 | return covariance 20 | 21 | 22 | ##################################################################################################################################### 23 | #################################################### MATCHING/ALIGNMENT FUNCTIONS ################################################### 24 | ##################################################################################################################################### 25 | 26 | 27 | def match_tensors_permute(r=.5, get_merge_value=False, 28 | print_costs=False, no_absval=False, 29 | correlation_matrix=None): 30 | """ 31 | This function is adapted from ZipIt! (https://github.com/gstoica27/ZipIt) 32 | 33 | Matches arbitrary models by permuting all to the spaces of the first in your graph list. 34 | Mimics Rebasin methods. 35 | """ 36 | 37 | correlation = correlation_matrix 38 | 39 | O = correlation.shape[0] 40 | N = int(1/(1 - r) + 0.5) 41 | Om = O // N 42 | device = correlation.device 43 | 44 | mats = [torch.eye(Om, device=device)] 45 | cost = 0 46 | for i in range(1, N): 47 | try: 48 | corr_matrix = correlation[:Om, Om*i:Om*(i+1)].cpu().numpy() 49 | if no_absval == False: 50 | corr_matrix = np.absolute(corr_matrix) 51 | row_ind, col_ind = scipy.optimize.linear_sum_assignment( 52 | corr_matrix, maximize=True) 53 | cost = corr_matrix[row_ind, col_ind].sum() 54 | # correlation subset is is [0:4096, 4096:8192] 55 | # correlation between the first graph's and second graph's features 56 | except: 57 | pdb.set_trace() 58 | 59 | new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] 60 | mats.append(new_mat.T) 61 | 62 | unmerge_mats = mats 63 | 64 | unmerge = torch.cat(unmerge_mats, dim=0) 65 | merge = torch.cat(mats, dim=0) 66 | merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) 67 | if get_merge_value: 68 | merge_value = correlation[:Om, Om*i:Om*(i+1)].cpu().numpy()[row_ind, col_ind].mean() 69 | return merge.T, unmerge, merge_value 70 | if print_costs: 71 | cost = cost / merge.shape[0] 72 | print(f'cost: {cost}') 73 | 74 | return merge.T, unmerge, None, cost / merge.shape[0] 75 | 76 | def match_tensors_permute_MHA(n_heads, permute_heads=False, 77 | head_assignments=[], r=.5, get_merge_value=False, 78 | print_costs=False, no_absval=False, 79 | correlation_matrix=None): 80 | """ 81 | Handles different head permutations in attention 82 | """ 83 | correlation = correlation_matrix 84 | 85 | O = correlation.shape[0] 86 | 87 | N = int(1/(1 - r) + 0.5) # num models 88 | Om = O // N # matrix dimension 89 | device = correlation.device 90 | query_size = Om // n_heads 91 | 92 | mats = [torch.eye(Om, device=device)] 93 | head_perms = [] 94 | 95 | # compute head perms in order 96 | if permute_heads == False: 97 | cost = 0 98 | for i in range(1, N): #just once if 2 models] 99 | for j in range(n_heads): 100 | try: 101 | # by head 102 | corr_submatrix = correlation[query_size * j:query_size * (j+1), Om*i + query_size*j:Om*i + query_size*(j+1)].cpu().numpy() 103 | if no_absval == False: 104 | corr_submatrix = np.absolute(corr_submatrix) 105 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) 106 | 107 | 108 | head_perms.append(torch.tensor(col_ind + j*query_size)) 109 | cost += corr_submatrix[row_ind, col_ind].sum() 110 | 111 | # for whole model correlation subset is is [0:4096, 4096:8192] 112 | # correlation between the first graph's and second graph's features 113 | except: 114 | pdb.set_trace() 115 | outer_col_ind = np.arange(n_heads) 116 | # compute head perms out of order according to predefined ordering or find our own 117 | elif permute_heads == True: 118 | cost = 0 119 | col_inds_storage = defaultdict(lambda: defaultdict(int)) 120 | if head_assignments != []: 121 | outer_row_ind = np.arange(n_heads) 122 | outer_col_ind = head_assignments 123 | for i in range(n_heads): 124 | head1_idx = [query_size * outer_row_ind[i], query_size * (outer_row_ind[i] + 1)] 125 | head2_idx = [Om + query_size * outer_col_ind[i], Om + query_size * (outer_col_ind[i] + 1)] 126 | # take abs value of submatrix of correlations 127 | corr_submatrix = correlation[head1_idx[0]:head1_idx[1], head2_idx[0]:head2_idx[1]].cpu().numpy() 128 | if no_absval == False: 129 | corr_submatrix = np.absolute(corr_submatrix) 130 | # compute perm for head j & head k 131 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) 132 | 133 | cost += corr_submatrix[row_ind, col_ind].sum() 134 | col_inds_storage[outer_row_ind[i]][outer_col_ind[i]] = col_ind 135 | 136 | else: 137 | costs = np.ones((n_heads, n_heads)) * -sys.maxsize # cost matrix for hungarian algo steps 138 | for i in range(1, N): #just once if 2 models 139 | for j in range(n_heads): # outer loop through all heads 140 | for k in range(n_heads): # inner loop through heads >= current head j 141 | head1_idx = [query_size * j, query_size * (j+1)] 142 | head2_idx = [Om * i + query_size * k, Om * i + query_size * (k+1)] 143 | 144 | # take abs value of submatrix of correlations 145 | corr_submatrix = correlation[head1_idx[0]:head1_idx[1], head2_idx[0]:head2_idx[1]].cpu().numpy() 146 | if no_absval == False: 147 | corr_submatrix = np.absolute(corr_submatrix) 148 | 149 | # compute perm for head j & head k 150 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) 151 | 152 | # store cost (cost is maximized here) 153 | costs[j,k] = corr_submatrix[row_ind, col_ind].sum() 154 | #costs[k,j] = costs[j,k] # make symmetric 155 | 156 | # store perm so we don't have to recompute it later 157 | col_inds_storage[j][k] = col_ind 158 | 159 | 160 | outer_row_ind, outer_col_ind = scipy.optimize.linear_sum_assignment(costs, maximize=True) # get assignment with lowest cost 161 | cost += costs[outer_row_ind, outer_col_ind].sum() 162 | 163 | for j in range(n_heads): 164 | head_1 = outer_row_ind[j] # these are in order, outer_row_ind[j] = j 165 | head_2 = outer_col_ind[j] 166 | 167 | head_perm = col_inds_storage[head_1][head_2] 168 | head_perms.append(torch.tensor(head_perm + query_size*head_2)) 169 | 170 | new_mat = torch.eye(Om, device=device)[torch.tensor(torch.cat(head_perms)).long().to(device)] 171 | mats.append(new_mat.T) 172 | 173 | unmerge_mats = mats 174 | 175 | unmerge = torch.cat(unmerge_mats, dim=0) 176 | merge = torch.cat(mats, dim=0) 177 | merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) 178 | if print_costs: 179 | cost = cost / merge.shape[0] 180 | print(f'cost: {cost}') 181 | if get_merge_value: 182 | merge_value = correlation[:Om, Om*i:Om*(i+1)].cpu().numpy()[row_ind, col_ind].mean() 183 | return merge.T, unmerge, merge_value 184 | return merge.T, unmerge, outer_col_ind, cost / merge.shape[0] 185 | 186 | def match_tensors_identity(r=.5, correlation_matrix=None, **kwargs): 187 | # weight averaging. 188 | 189 | correlation = correlation_matrix 190 | O = correlation.shape[0] 191 | 192 | N = int(1/(1 - r) + 0.5) 193 | Om = O // N 194 | device = correlation.device 195 | corr_matrix = correlation[:Om, Om:Om*2].cpu().numpy() 196 | cost = corr_matrix.trace() 197 | 198 | mats = [torch.eye(Om, device=device) for _ in range(N)] 199 | 200 | unmerge_mats = mats 201 | 202 | unmerge = torch.cat(unmerge_mats, dim=0) 203 | merge = torch.cat(mats, dim=0) 204 | merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) 205 | cost = cost / merge.shape[0] 206 | return merge.T, unmerge, None, cost 207 | 208 | -------------------------------------------------------------------------------- /metric_calculators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | import pdb 4 | 5 | class MetricCalculator(ABC): 6 | 7 | @abstractmethod 8 | def update(self, batch_size, dx, *feats, **aux_params): return NotImplemented 9 | 10 | @abstractmethod 11 | def finalize(self): return NotImplemented 12 | 13 | def compute_correlation(covariance, eps=1e-7): 14 | std = torch.diagonal(covariance).sqrt() 15 | covariance = covariance / (torch.clamp(torch.outer(std, std), min=eps)) 16 | return covariance 17 | 18 | class CovarianceMetric(MetricCalculator): 19 | name = 'covariance' 20 | 21 | def __init__(self): 22 | self.std = None 23 | self.mean = None 24 | self.outer = None 25 | self.sos = None 26 | self.num_updates = 0 27 | self.bsz = None 28 | 29 | def update(self, batch_size, *feats, **aux_params): 30 | feats = torch.cat(feats, dim=0) #[dim, num_elts] 31 | feats = torch.nan_to_num(feats, 0, 0, 0) 32 | 33 | mean = feats.sum(dim=1) 34 | sos = (feats**2).sum(dim=1) # sum of squares 35 | outer = (feats @ feats.T) 36 | 37 | if self.bsz == None: 38 | self.bsz = batch_size 39 | 40 | if self.mean is None: self.mean = torch.zeros_like( mean, dtype=torch.float64) 41 | if self.outer is None: self.outer = torch.zeros_like(outer, dtype=torch.float64) 42 | if self.sos is None: self.sos = torch.zeros_like(sos, dtype=torch.float64) 43 | 44 | self.mean += mean 45 | self.outer += outer 46 | self.sos += sos 47 | 48 | # debugging 49 | self.num_updates +=1 50 | 51 | def finalize(self, numel, eps=1e-4, dot_prod=False, pca=False, scale_cov=False, normalize=False, print_featnorms=False): 52 | self.outer = self.outer.div(numel) 53 | self.mean = self.mean.div(numel) 54 | self.sos = torch.sqrt(self.sos) 55 | #scaling_factor = 1.0 / self.bsz 56 | 57 | if dot_prod: 58 | # this is equivalent to E_ab from git rebasin 59 | cov = self.outer #* scaling_factor 60 | else: 61 | cov = self.outer - torch.outer(self.mean, self.mean) 62 | if scale_cov: 63 | cov *= 1.0 / self.bsz 64 | if pca: 65 | new_val = int(0.95 * cov.shape[1]) 66 | U,S,V = torch.pca_lowrank(cov, q=new_val) 67 | cov = U[:,:new_val] @ torch.diag(S[:new_val]) @ V.T 68 | if normalize: 69 | cov = cov / (torch.outer(self.sos, self.sos) + eps) 70 | 71 | if print_featnorms: 72 | len_feats = len(self.sos) // 2 73 | mean1 = torch.mean(self.sos[:len_feats]).item() 74 | std1 = torch.std(self.sos[:len_feats]).item() 75 | mean2 = torch.mean(self.sos[len_feats:]).item() 76 | std2 = torch.std(self.sos[len_feats:]).item() 77 | print(mean1, std1, mean2, std2) 78 | if torch.isnan(cov).any(): 79 | breakpoint() 80 | if (torch.diagonal(cov) < 0).sum(): 81 | pdb.set_trace() 82 | return cov 83 | 84 | class MeanMetric(MetricCalculator): 85 | name = 'mean' 86 | 87 | def __init__(self): 88 | self.mean = None 89 | 90 | def update(self, batch_size, *feats, **aux_params): 91 | feats = torch.cat(feats, dim=0) 92 | mean = feats.abs().mean(dim=1) 93 | if self.mean is None: 94 | self.mean = torch.zeros_like(mean) 95 | self.mean += mean * batch_size 96 | 97 | def finalize(self, numel, eps=1e-4, print_featnorms=False): 98 | return self.mean / numel 99 | 100 | 101 | def get_metric_fns(names): 102 | metrics = {} 103 | for name in names: 104 | if name == 'mean': 105 | metrics[name] = MeanMetric 106 | elif name == 'covariance': 107 | metrics[name] = CovarianceMetric 108 | else: 109 | raise NotImplementedError(name) 110 | return metrics -------------------------------------------------------------------------------- /model_merger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torch 5 | 6 | from time import time 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | from torch import nn 11 | 12 | from torchnlp.utils import lengths_to_mask 13 | from graphs.base_graph import NodeType 14 | from metric_calculators import CovarianceMetric, MeanMetric 15 | from matching_functions import match_tensors_permute 16 | from matching_functions import compute_correlation 17 | from utils import get_merging_fn, contains_name 18 | 19 | 20 | class MergeHandler: 21 | def __init__(self, graph, merge, unmerge, orig): 22 | self.graph = graph 23 | # just store merge and unmerge matrices 24 | self.orig = orig 25 | self.merge = merge 26 | self.unmerge = unmerge 27 | 28 | class ModelMerge(nn.Module): 29 | 30 | def __init__(self, *graphs, device=0): 31 | super().__init__() 32 | 33 | self.hooks = [] 34 | self.init(graphs, device) 35 | 36 | def init(self, graphs, device): 37 | 38 | # move all graph models to eval 39 | for g in graphs: 40 | g.model.to(device).eval() 41 | 42 | self.graphs = graphs 43 | self.device = device 44 | self.merged_model = None 45 | count = 0 46 | for graph in self.graphs: 47 | print(count) 48 | count+=1 49 | graph.add_hooks(device=device) 50 | 51 | # helper function to collect hiddens. Do not recommend using for large FractionalDataloader 52 | def get_hiddens(self, dataloaders): 53 | data_stores = [defaultdict(lambda: None) for g in self.graphs] 54 | 55 | with torch.no_grad(): 56 | for dataloader in dataloaders: 57 | for x, _ in tqdm(dataloader, desc="Forward Pass to Compute Merge Metrics: "): 58 | x = x.to(self.device) 59 | intermediates = [g.compute_intermediates(x) for g in self.graphs] # shape [feat_dim, num_tokens] 60 | nodes = list(intermediates[0].keys()) 61 | for node in nodes: 62 | intermeds_float = [i[node][:,1:-1].float().detach() for i in intermediates] # len = num_graphs 63 | if data_stores[0][node] == None: 64 | for i in range(len(self.graphs)): 65 | data_stores[i][node] = intermeds_float[i] 66 | else: 67 | for i in range(len(self.graphs)): 68 | data_stores[i][node] = torch.cat((data_stores[i][node], intermeds_float[i]), 1) 69 | return data_stores 70 | 71 | # get average variance across features for each node 72 | def compute_variances(self, dataloaders): 73 | data_stores = self.get_hiddens(dataloaders) 74 | nodes = list(data_stores[0].keys()) 75 | 76 | for node in nodes: 77 | for i in range(len(self.graphs)): 78 | data_stores[i][node] = torch.mean(torch.var(data_stores[i][node], dim=1)) 79 | return data_stores 80 | 81 | # for investigating representations b/w two models 82 | def compute_rep_distances(self, dataloaders): 83 | node_dists = [] 84 | data_stores = self.get_hiddens(dataloaders) 85 | nodes = list(data_stores[0].keys()) 86 | 87 | for node in nodes: 88 | x = data_stores[0][node] # shape [feat_dim, num_tokens] 89 | y = data_stores[1][node] # shape [feat_dim, num_tokens] 90 | dists = (x - y).pow(2).sum(0).sqrt() 91 | node_dists.append(torch.mean(dists)) 92 | return node_dists 93 | 94 | def remove_pads(self, intermediates, input, lens, sentence_level, special_toks): 95 | pad_len = input.shape[1] 96 | bsz = input.shape[0] 97 | 98 | for g_idx in range(len(self.graphs)): 99 | for node in list(intermediates[0].keys()): 100 | # if this is the final node with cls vectors only. 101 | # in this case, the size is just bsz, aka one cls vector per item 102 | if intermediates[g_idx][node].shape[1] == bsz: 103 | # do nothing: 104 | continue 105 | # if not cls vectors, we need to remove padding tokens 106 | else: 107 | tensor_to_edit = intermediates[g_idx][node] # shape [feat_dim, num_tokens] 108 | # plus and minus one are to account for bos/eos tokens 109 | if sentence_level == 'cls': 110 | list_of_tensors = [tensor_to_edit[:,(i*pad_len+1):(i*pad_len+2)] for i in range(bsz)] 111 | else: 112 | if special_toks == False: 113 | list_of_tensors = [tensor_to_edit[:,(i*pad_len+1):(i*pad_len+lens[i]-1)] for i in range(bsz)] 114 | else: 115 | list_of_tensors = [tensor_to_edit[:,(i*pad_len):(i*pad_len+lens[i])] for i in range(bsz)] 116 | new_tensor = torch.cat(list_of_tensors, dim=1) 117 | intermediates[g_idx][node] = new_tensor 118 | return intermediates 119 | 120 | def load_toks(self, saved_path): 121 | filenames = glob.glob(os.path.join(saved_path, 'toks', '*.pt')) 122 | num_files = len(filenames) 123 | tok_ids_all = [] 124 | for i in range(num_files): 125 | tok_ids_all.append(torch.load(f'{saved_path}/toks/{i}.pt')) 126 | return torch.cat(tok_ids_all) 127 | 128 | def sent_rep(self, intermediates, node, sentence_level, lens, special_toks=False): 129 | # already shape [feat_dim, bsz] 130 | if sentence_level == 'cls': 131 | return [intermediates[i][node] for i in range(len(self.graphs))] 132 | bsz = len(lens) 133 | intermeds_float = [] 134 | 135 | if intermediates[0][node].shape[-1] == len(lens): #bsz 136 | intermeds_float = [intermediates[0][node], intermediates[1][node]] 137 | return intermeds_float 138 | for g_idx in range(len(self.graphs)): 139 | sent_levels = [] 140 | last_idx = 0 141 | for senlen in lens: 142 | actual_len = senlen 143 | if special_toks == False: 144 | actual_len = senlen - 2 145 | sent_levels.append(intermediates[g_idx][node][:,last_idx:last_idx + actual_len]) 146 | last_idx += actual_len 147 | if sentence_level == 'maxpool': 148 | try: 149 | sent_avgs = [torch.amax(sent_levels[i].float(), 1).unsqueeze(1) for i in range(bsz)] 150 | except: 151 | breakpoint() 152 | elif sentence_level == 'avgpool': 153 | sent_avgs = [torch.mean(sent_levels[i].float(), 1).unsqueeze(1) for i in range(bsz)] 154 | intermeds_float.append(torch.hstack(sent_avgs)) # list of [[dim, bsz], [dim, bsz]] 155 | return intermeds_float 156 | 157 | def compute_metrics(self, dataloader, metric_classes, sentence_level=None, special_toks=False, 158 | print_featnorms=False): 159 | 160 | self.metrics = None 161 | if not isinstance(dataloader, list): 162 | dataloader_list = [dataloader] 163 | else: 164 | dataloader_list = dataloader 165 | 166 | numel = 0 167 | for dataloader in dataloader_list: 168 | for x, lens in tqdm(dataloader, desc="Forward Pass to Compute Merge Metrics: "): 169 | 170 | # load batch & track number of elements 171 | x = x.to(self.device) 172 | if sentence_level != None: 173 | numel_local = x.shape[0] 174 | else: 175 | numel_local = sum(lens) 176 | if type(numel_local) != int: 177 | numel_local = numel_local.item() 178 | if special_toks == False: 179 | numel_local -= 2*x.shape[0] # num tokens - BOS/EOS toks 180 | numel += numel_local 181 | 182 | # get intermediates and remove padding idxs 183 | if 'Bert' in type(self.graphs[0].model).__name__: 184 | attn_mask = lengths_to_mask(list(lens)) 185 | intermediates = [g.compute_intermediates(x, attn_mask=attn_mask.long().to(self.device)) for g in self.graphs] # shape [feat_dim, num_tokens] 186 | else: 187 | intermediates = [g.compute_intermediates(x) for g in self.graphs] # shape [feat_dim, num_tokens] 188 | intermediates = self.remove_pads(intermediates, x, lens, sentence_level, special_toks) 189 | nodes = list(intermediates[0].keys()) 190 | 191 | # if qk flag is on, add qk node placeholders for each layer 192 | qk_flag = False 193 | if self.graphs[0].qk == True: 194 | for i in range(self.graphs[0].num_layers): 195 | nodes.append(f'qk{i}') 196 | qk_flag = True 197 | 198 | # populate metrics list 199 | if self.metrics is None: 200 | self.metrics = {n: {k: v() for k, v in metric_classes.items()} for n in nodes} 201 | 202 | # special cases nodes (just q, k) 203 | special_cases_names = ['q', 'k'] 204 | special_cases_nodes = [self.graphs[0].modules[name] for name in special_cases_names] 205 | qk_nodes = [self.graphs[0].modules[name] for name in ['q', 'k']] 206 | 207 | for node, node_metrics in self.metrics.items(): 208 | if isinstance(node, int): 209 | prev_node_layer = self.graphs[0].get_node_info(node-1)['layer'] 210 | if prev_node_layer == None or not contains_name(prev_node_layer,special_cases_nodes): 211 | for metric in node_metrics.values(): 212 | if sentence_level != None: 213 | intermeds_float = self.sent_rep(intermediates, node, sentence_level, lens, special_toks) 214 | else: 215 | intermeds_float = [i[node].float().detach() for i in intermediates] # len = num_graphs 216 | metric.update(x.shape[0] , *intermeds_float) 217 | elif contains_name(prev_node_layer, qk_nodes): 218 | layer_no = [int(i) for i in self.graphs[0].get_node_info(node-1)['layer'].split('.') if i.isdigit()][0] 219 | if qk_flag: 220 | qk_metric = self.metrics[f'qk{layer_no}'] 221 | else: 222 | qk_metric = node_metrics 223 | for metric in qk_metric.values(): 224 | if sentence_level != None: 225 | intermeds_float = self.sent_rep(intermediates, node, sentence_level, lens, special_toks) 226 | else: 227 | intermeds_float = [i[node].float().detach() for i in intermediates] # len = num_graphs 228 | metric.update(x.shape[0], *intermeds_float) 229 | 230 | for node, node_metrics in self.metrics.items(): 231 | if isinstance(node, int): 232 | prev_node_layer = self.graphs[0].get_node_info(node-1)['layer'] 233 | if prev_node_layer == None or not contains_name(prev_node_layer,special_cases_nodes): 234 | for metric_name, metric in node_metrics.items(): 235 | self.metrics[node][metric_name] = metric.finalize(numel, print_featnorms=print_featnorms) 236 | if self.graphs[0].qk == True: 237 | for i in range(self.graphs[0].num_layers): 238 | for metric_name, metric in self.metrics[f'qk{i}'].items(): 239 | self.metrics[f'qk{i}'][metric_name] = metric.finalize(numel * 2, print_featnorms=print_featnorms) 240 | 241 | return self.metrics, None 242 | 243 | 244 | def save_features(self, dataloader, sentence_level=False, special_toks=False, 245 | save_feats=False, save_dir=None): 246 | 247 | self.metrics = None 248 | if not isinstance(dataloader, list): 249 | dataloader_list = [dataloader] 250 | else: 251 | dataloader_list = dataloader 252 | 253 | numel = 0 254 | if save_feats: 255 | tok_indices = [] 256 | feats = [defaultdict(list),defaultdict(list)] 257 | 258 | for dataloader in dataloader_list: 259 | batch_count = 0 260 | for x, lens in tqdm(dataloader, desc="Forward Pass to Compute Merge Metrics: "): 261 | 262 | # load batch & track element numbers 263 | x = x.to(self.device) 264 | if sentence_level != None: 265 | numel_local = x.shape[0] 266 | else: 267 | numel_local = sum(lens) 268 | if special_toks == False: 269 | numel_local -= 2*x.shape[0] # num tokens - BOS/EOS toks 270 | numel += numel_local 271 | 272 | # get intermediates and remove padding idxs 273 | if 'Bert' in type(self.graphs[0].model).__name__: 274 | attn_mask = lengths_to_mask(list(lens)) 275 | intermediates = [g.compute_intermediates(x, attn_mask=attn_mask.long().to(self.device)) for g in self.graphs] # shape [feat_dim, num_tokens] 276 | else: 277 | intermediates = [g.compute_intermediates(x) for g in self.graphs] # shape [feat_dim, num_tokens] 278 | intermediates = self.remove_pads(intermediates, x, lens, sentence_level, special_toks) 279 | 280 | # store intermediates 281 | nodes = list(intermediates[0].keys()) 282 | if save_feats: 283 | batch_tok_indices = x.flatten()[torch.argwhere(x.flatten() != 0)].squeeze().detach().cpu() 284 | tok_indices.append(batch_tok_indices) 285 | for node in nodes: 286 | feats[0][node].append(intermediates[0][node].detach().cpu()) 287 | feats[1][node].append(intermediates[1][node].detach().cpu()) 288 | 289 | # if big enough accumulation, save to file 290 | batch_count += 1 291 | if batch_count % 1000 == 0: 292 | num = batch_count // 1000 293 | if batch_count * 8 > 100000: 294 | return 295 | print(f'saving features {num}') 296 | with open(f'{save_dir}/toks/{num}.pt', 'wb+') as tok_out: 297 | torch.save(torch.cat(tok_indices), tok_out) #write toks 298 | tok_indices = [] # release memory 299 | for model_no in [0, 1]: 300 | with open(f'{save_dir}/feats_{model_no}/{num}.pt', 'wb+') as model_out: 301 | for node in feats[model_no].keys(): 302 | feats[model_no][node] = torch.cat(feats[model_no][node], dim=1) 303 | torch.save(feats[model_no], model_out) #write feats 304 | feats[model_no] = defaultdict(list) # release memory 305 | 306 | if save_feats: 307 | print('saving last batch') 308 | num = batch_count // 1000 + 1 309 | for model_no in [0, 1]: 310 | with open(f'{save_dir}/feats_{model_no}/{num}.pt', 'wb+') as model_out: 311 | for node in feats[model_no].keys(): 312 | feats[model_no][node] = torch.cat(feats[model_no][node], dim=1) 313 | torch.save(feats[model_no], model_out) 314 | feats[model_no] = defaultdict(list) 315 | with open(f'{save_dir}/toks/{num}.pt', 'wb+') as toks_out: 316 | tok_indices = torch.cat(tok_indices) 317 | torch.save(tok_indices, toks_out) 318 | print('finished saving features to file') 319 | return None, None 320 | 321 | ### HELPER FUNCTIONS FOR CORRELATIONS ### 322 | 323 | def compute_np_corr(self, X,Y): 324 | feats_concat = torch.cat((X.to('cpu'),Y.to('cpu'))).type(torch.float32) 325 | corr = np.corrcoef(feats_concat) 326 | corr = np.nan_to_num(corr) 327 | corr = (corr + corr.T) / 2 328 | np.fill_diagonal(corr, 1) 329 | return corr 330 | 331 | def compute_np_cov(self, X, Y): 332 | feats_concat = torch.cat((X.to('cpu'),Y.to('cpu'))).type(torch.float32) 333 | feats_concat = feats_concat - feats_concat.mean(dim=1)[:,None] 334 | cov = (feats_concat @ feats_concat.T).div(feats_concat.shape[1]) 335 | return cov 336 | 337 | def cov_to_corr(self, cov, no_corr=False): 338 | if no_corr == True: 339 | return cov 340 | std = torch.diagonal(cov).sqrt() 341 | corr = cov / (torch.clamp(torch.nan_to_num(torch.outer(std, std)),min=1e-7)) 342 | return corr 343 | 344 | def separate_res_nodes(self, nodes): 345 | resnodes = [] 346 | non_resnodes = [] 347 | for node in nodes: 348 | if self.graphs[0].get_node_info(node)['type'] == NodeType.POSTFIX: 349 | prev_node_info = self.graphs[0].get_node_info(node-1)['layer'] 350 | if ((self.graphs[0].modules['q'] in prev_node_info) or 351 | (self.graphs[0].modules['k'] in prev_node_info)): 352 | #non_resnodes.append(node) # this is a qk node 353 | continue 354 | else: 355 | resnodes.append(node) # all res keys are postfixes by design 356 | else: 357 | non_resnodes.append(node) 358 | return resnodes, non_resnodes 359 | 360 | 361 | # load certain number of saved feats 362 | def load_features(self, saved_path, num, res='first', total_num=10): 363 | filenames = glob.glob(os.path.join(saved_path, f'feats_{num}', '*.pt')) 364 | filenames = filenames[:total_num] 365 | feats_final = {} 366 | 367 | print('loading feats') 368 | for filename in tqdm(filenames): 369 | try: 370 | feats = torch.load(filename) 371 | except RuntimeError: 372 | continue 373 | 374 | # sort nodes by res or non-res 375 | resnodes, non_resnodes = self.separate_res_nodes(list(feats.keys())) 376 | 377 | # keep resnodes of interest only 378 | if res == 'first': 379 | res_keys_used = [resnodes[0]] 380 | elif res == 'last': 381 | res_keys_used = [resnodes[-1]] 382 | elif res == 'all': 383 | res_keys_used = resnodes 384 | elif res == 'sep': 385 | res_keys_used = resnodes 386 | elif res == 'none': 387 | res_keys_used = [] 388 | 389 | # go through non resnodes, and get features ready 390 | for node in non_resnodes: 391 | if node not in feats_final: 392 | feats_final[node] = feats[node] 393 | else: 394 | feats_final[node] = torch.cat([feats_final[node], feats[node]], dim=1) 395 | 396 | for node in res_keys_used: 397 | if node not in feats_final: 398 | feats_final[node] = feats[node] 399 | else: 400 | feats_final[node] = torch.cat([feats_final[node], feats[node]], dim=1) 401 | 402 | for key in resnodes: 403 | if key not in feats_final: 404 | feats_final[key] = [] 405 | return feats_final 406 | 407 | def compute_corrs(self, nodes, feats_0, feats_1, res='first'): 408 | corrs = {} 409 | 410 | resnodes, non_resnodes = self.separate_res_nodes(nodes) 411 | 412 | for node in tqdm(non_resnodes): 413 | if feats_0[node] != []: 414 | corrs[node] = torch.Tensor(self.compute_np_corr(feats_0[node], feats_1[node])) 415 | 416 | if res == 'first': 417 | resnode = resnodes[0] 418 | corrs['res'] = torch.Tensor(self.compute_np_corr(feats_0[resnode], feats_1[resnode])) 419 | elif res == 'last': 420 | resnode = resnodes[-1] 421 | corrs['res'] = torch.Tensor(self.compute_np_corr(feats_0[resnode], feats_1[resnode])) 422 | elif res == 'all': 423 | node = resnodes[0] 424 | cov = torch.Tensor(self.compute_np_cov(feats_0[node], feats_1[node])) 425 | for node in resnodes[1:]: 426 | cov += torch.Tensor(self.compute_np_cov(feats_0[node], feats_1[node])) 427 | cov /= len(resnodes) 428 | corrs['res'] = torch.Tensor(self.cov_to_corr(cov)) 429 | elif res == 'sep': 430 | for node in resnodes: 431 | corrs[node] = torch.Tensor(self.compute_np_corr(feats_0[node], feats_1[node])) 432 | # not handling 'none' case for now 433 | 434 | return corrs 435 | 436 | def compute_metric_corrs(self, nodes, res='first', no_corr=False, qk=False): 437 | corrs = {} 438 | resnodes, non_resnodes = self.separate_res_nodes(nodes) 439 | 440 | for node in tqdm(non_resnodes): 441 | corrs[node] = self.cov_to_corr(self.metrics[node]['covariance'], no_corr) 442 | 443 | if resnodes == []: 444 | return corrs 445 | if res == 'first': 446 | resnode = resnodes[0] 447 | corrs['res'] = self.cov_to_corr(self.metrics[resnode]['covariance'], no_corr=no_corr) 448 | elif res == 'last': 449 | resnode = resnodes[-1] 450 | corrs['res'] = self.cov_to_corr(self.metrics[resnode]['covariance'], no_corr=no_corr) 451 | elif res == 'all': 452 | node = resnodes[0] 453 | cov = self.metrics[node]['covariance'] 454 | for node in resnodes[1:]: 455 | cov += self.metrics[node]['covariance'] 456 | cov /= len(resnodes) 457 | corrs['res'] =self.cov_to_corr(cov, no_corr=no_corr) 458 | elif res == 'sep': 459 | for node in resnodes: 460 | corrs[node] = self.cov_to_corr(self.metrics[node]['covariance'], no_corr=no_corr) 461 | 462 | return corrs 463 | 464 | ### END HELPER FUNCTIONS FOR CORRELATIONS ### 465 | 466 | 467 | def compute_transformations(self, transform_fn, reduce_ratio=.5, permute_heads=False, 468 | ignore_heads=False, print_costs=False, no_absval=False, 469 | saved_features=None, res='first', 470 | no_corr=False,**kwargs): 471 | 472 | start_time = time() 473 | self.merges = {} 474 | self.unmerges = {} 475 | 476 | 477 | global_res_merge= None 478 | global_res_unmerge = None 479 | 480 | special_cases_names = ['final_ln', 'attn_ln', 'emb_ln', 'q', 'k'] 481 | special_cases_nodes = [self.graphs[0].modules[name] for name in special_cases_names] 482 | qk_nodes = [self.graphs[0].modules[name] for name in ['q', 'k']] 483 | 484 | cost_dict = {} 485 | 486 | if saved_features: 487 | feats_0 = self.load_features(saved_features, 0, res=res) 488 | feats_1 = self.load_features(saved_features, 1, res=res) 489 | nodes = list(feats_0.keys()) 490 | nodes.sort() 491 | print('computing corrs') 492 | corrs = self.compute_corrs(nodes, feats_0, feats_1, res=res) 493 | else: 494 | nodes = list(self.metrics.keys()) 495 | qk_flag = False 496 | if self.graphs[0].qk == True: 497 | qk_flag = True 498 | for i in range(self.graphs[0].num_layers): 499 | nodes.remove(f'qk{i}') 500 | nodes.sort() 501 | print('computing corrs') 502 | corrs = self.compute_metric_corrs(nodes, res=res, no_corr=no_corr, qk=qk_flag) 503 | 504 | # save all corrs to file to look at them. 505 | # breakpoint() 506 | # with open(f'corrs.pt', 'wb+') as corrs_out: 507 | # torch.save(corrs, corrs_out) 508 | 509 | # corrs has all nonres nodes & the one res node. Unless this is sep, then it has all nodes 510 | 511 | last_node = nodes[-1] 512 | for node in tqdm(nodes, desc="Computing transformations: "): 513 | prev_node_layer = self.graphs[0].get_node_info(node-1)['layer'] 514 | # skip metrics associated with residuals and qk if qk is true 515 | correlation_matrix = None 516 | if prev_node_layer == None or not contains_name(prev_node_layer,special_cases_nodes): 517 | if node in corrs: 518 | correlation_matrix = corrs[node] 519 | 520 | info = self.graphs[0].get_node_info(node) 521 | print(info) 522 | next_node_info = self.graphs[0].get_node_info(node+1)['layer'] 523 | 524 | # Handle Attention Merging 525 | if next_node_info != None and (self.graphs[0].modules['lin_attn'] in next_node_info): 526 | layer_no = [int(i) for i in self.graphs[0].get_node_info(node+1)['layer'].split('.') if i.isdigit()][0] 527 | if transform_fn.__name__ in ['match_tensors_permute'] and ignore_heads == False: 528 | n_heads = self.graphs[0].num_heads 529 | mha_transform_fn = transform_fn.__name__ + '_MHA' 530 | merge, unmerge, attn_head_perm, cost = get_merging_fn(mha_transform_fn)(n_heads, r=reduce_ratio, 531 | permute_heads=permute_heads, print_costs=print_costs, 532 | no_absval=no_absval, correlation_matrix=correlation_matrix, 533 | **kwargs) 534 | merge = merge * len(self.graphs) 535 | self.merges[node] = merge.chunk(len(self.graphs), dim=1) 536 | self.unmerges[node] = unmerge.chunk(len(self.graphs), dim=0) 537 | if qk_flag == True: 538 | metric = self.metrics[f'qk{layer_no}'] 539 | correlation_matrix = self.cov_to_corr(metric['covariance']) 540 | qk_merge, qk_unmerge, _, cost = get_merging_fn(mha_transform_fn)(n_heads, r=reduce_ratio, 541 | permute_heads=permute_heads, head_assignments=attn_head_perm, 542 | print_costs=print_costs, no_absval=no_absval, 543 | correlation_matrix=correlation_matrix, **kwargs) 544 | qk_merge = qk_merge * len(self.graphs) 545 | self.merges[f'qk{layer_no}'] = qk_merge.chunk(len(self.graphs), dim=1) 546 | self.unmerges[f'qk{layer_no}'] = qk_unmerge.chunk(len(self.graphs), dim=0) 547 | else: 548 | # if ignoring heads or non-mha merge matrix 549 | merge, unmerge, _, cost = transform_fn(reduce_ratio, correlation_matrix=correlation_matrix, 550 | no_absval=no_absval, **kwargs) 551 | merge = merge * len(self.graphs) 552 | self.merges[node] = merge.chunk(len(self.graphs), dim=1) 553 | self.unmerges[node] = unmerge.chunk(len(self.graphs), dim=0) 554 | 555 | if qk_flag: 556 | metric = self.metrics[f'qk{layer_no}'] 557 | qk_merge, qk_unmerge, _, cost = transform_fn(reduce_ratio, print_costs=print_costs, no_absval=no_absval, 558 | correlation_matrix=correlation_matrix, **kwargs) 559 | # add qk_merges to dict here so that attn merge can get added at end of block 560 | qk_merge = qk_merge * len(self.graphs) 561 | self.merges[f'qk{layer_no}'] = qk_merge.chunk(len(self.graphs), dim=1) 562 | self.unmerges[f'qk{layer_no}'] = qk_unmerge.chunk(len(self.graphs), dim=0) 563 | 564 | # Handle FF 565 | else: 566 | # returns merge and unmerge matrixs 567 | merge, unmerge, _, cost = transform_fn(reduce_ratio, print_costs=print_costs, no_absval=no_absval, 568 | correlation_matrix=correlation_matrix,**kwargs) 569 | merge = merge * len(self.graphs) 570 | self.merges[node] = merge.chunk(len(self.graphs), dim=1) 571 | self.unmerges[node] = unmerge.chunk(len(self.graphs), dim=0) 572 | 573 | elif contains_name(prev_node_layer, qk_nodes): 574 | continue 575 | # continuing because this is already handled in attention block 576 | 577 | # handle metrics associated with residuals here, other special cases 578 | else: 579 | info = self.graphs[0].get_node_info(node) 580 | print('res') 581 | print(info) 582 | if res == 'sep': 583 | correlation_matrix = corrs[node] 584 | merge, unmerge, _, cost = transform_fn(reduce_ratio, correlation_matrix=correlation_matrix, 585 | no_absval=no_absval,**kwargs) 586 | merge = merge * len(self.graphs) 587 | self.merges[node] = merge.chunk(len(self.graphs), dim=1) 588 | self.unmerges[node] = unmerge.chunk(len(self.graphs), dim=0) 589 | else: 590 | # res is first, last, or all: 591 | if global_res_merge == None: 592 | correlation_matrix = corrs['res'] 593 | global_res_merge, global_res_unmerge, _, cost = transform_fn(reduce_ratio, 594 | correlation_matrix=correlation_matrix, 595 | no_absval=no_absval, **kwargs) 596 | global_res_merge = global_res_merge * len(self.graphs) 597 | self.merges[node] = global_res_merge.chunk(len(self.graphs), dim=1) 598 | self.unmerges[node] = global_res_unmerge.chunk(len(self.graphs), dim=0) 599 | else: # merge was already learned 600 | self.merges[node] = global_res_merge.chunk(len(self.graphs), dim=1) 601 | self.unmerges[node] = global_res_unmerge.chunk(len(self.graphs), dim=0) 602 | cost_dict[node] = cost 603 | if qk_flag == True: 604 | for node in nodes: 605 | prev_node_layer = self.graphs[0].get_node_info(node-1)['layer'] 606 | if prev_node_layer != None and contains_name(prev_node_layer, qk_nodes): 607 | layer_no = [int(i) for i in self.graphs[0].get_node_info(node-1)['layer'].split('.') if i.isdigit()][0] 608 | self.merges[node] = self.merges[f'qk{layer_no}'] 609 | self.unmerges[node] = self.unmerges[f'qk{layer_no}'] 610 | for i in range(self.graphs[0].num_layers): 611 | self.merges.pop(f'qk{i}') 612 | self.unmerges.pop(f'qk{i}') 613 | 614 | self.compute_transform_time = time() - start_time 615 | return self.merges, self.unmerges, cost_dict 616 | 617 | 618 | 619 | def merge_node(self, node, merger): 620 | info = merger.graph.get_node_info(node) 621 | module = merger.graph.get_module(info['layer']) 622 | module.weight.data = merger.merge @ module.weight.data 623 | if hasattr(module, 'bias') and module.bias is not None: 624 | module.bias.data = merger.merge @ module.bias.data 625 | 626 | def unmerge_node(self, node, merger): 627 | info = merger.graph.get_node_info(node) 628 | module = merger.graph.get_module(info['layer']) 629 | module.weight.data = module.weight @ merger.unmerge 630 | 631 | # adding custom transformations here, for more control 632 | def apply_transformations_custom(self, merge_cls=False): 633 | qk_flag = False 634 | if self.graphs[0].qk == True: 635 | qk_flag = True 636 | qk_nodes = [self.graphs[0].modules[name] for name in ['q', 'k']] 637 | 638 | emb_suff_0 = self.graphs[0].modules['emb'] 639 | emb_copy_0 = self.graphs[0].get_module(f'{self.graphs[0].enc_prefix}.{emb_suff_0}').weight.data 640 | emb_copy_0 = torch.clone(emb_copy_0) 641 | 642 | emb_suff_1 = self.graphs[1].modules['emb'] 643 | emb_copy_1= self.graphs[1].get_module(f'{self.graphs[1].enc_prefix}.{emb_suff_1}').weight.data 644 | emb_copy_1 = torch.clone(emb_copy_1) 645 | 646 | final_merger = None 647 | graph_device = emb_copy_0.device 648 | 649 | for node in self.merges: 650 | merges = self.merges[node] 651 | unmerges = self.unmerges[node] 652 | count = 0 653 | for merge, unmerge, graph in zip(merges, unmerges, self.graphs): 654 | merger = MergeHandler(graph, merge, unmerge, node) 655 | merger.merge = merger.merge.to(graph_device) 656 | merger.unmerge = merger.unmerge.to(graph_device) 657 | preds = merger.graph.preds(node) 658 | info = merger.graph.get_node_info(preds[0]) 659 | # self attention merging, and self attention out unmerging 660 | if info['type'] == NodeType.SUM: 661 | print('merging MHA') 662 | # apply merges to k,q,v matrices 663 | sum_preds = merger.graph.preds(preds[0]) 664 | # check if q,k junction or v matrix 665 | for sum_pred in sum_preds: 666 | info = merger.graph.get_node_info(sum_pred) 667 | if info['type'] == NodeType.SUM: 668 | if qk_flag == False: 669 | second_sum_preds = merger.graph.preds(sum_pred) 670 | # merge q & k 671 | for second_sum_pred in second_sum_preds: 672 | self.merge_node(second_sum_pred, merger) 673 | elif 'v_proj' in info['layer'] or 'value' in info['layer']: 674 | # merge v 675 | self.merge_node(sum_pred, merger) 676 | 677 | # unmerge self-attn.out 678 | succ = merger.graph.succs(node)[0] 679 | self.unmerge_node(succ, merger) 680 | elif contains_name(info['layer'], qk_nodes) and qk_flag == True: 681 | print('merging qk') 682 | self.merge_node(preds[0], merger) 683 | 684 | elif 'self_attn_layer_norm' in info['layer'] or 'attention.output.LayerNorm' in info['layer']: 685 | print('merging self-attn res') 686 | # apply merge to ln 687 | module = merger.graph.get_module(info['layer']) 688 | parameter_names = ['weight', 'bias'] 689 | for parameter_name in parameter_names: 690 | parameter = getattr(module, parameter_name) 691 | parameter.data = merger.merge @ parameter 692 | 693 | # apply merges to the self.attn out proj 694 | sum = merger.graph.preds(preds[0])[0] 695 | out_proj = merger.graph.preds(sum)[0] 696 | self.merge_node(out_proj, merger) 697 | 698 | # unmerge the ff1 module 699 | ff1 = merger.graph.succs(node)[0] 700 | self.unmerge_node(ff1, merger) 701 | 702 | elif 'final_layer_norm' in info['layer'] or 'layernorm_embedding' in info['layer'] or 'output.LayerNorm' in info['layer'] or 'embeddings.LayerNorm' in info['layer']: 703 | print('merging final res') 704 | # apply merge to ln 705 | module = merger.graph.get_module(info['layer']) 706 | parameter_names = ['weight', 'bias'] 707 | for parameter_name in parameter_names: 708 | parameter = getattr(module, parameter_name) 709 | parameter.data = merger.merge @ parameter 710 | 711 | sum = merger.graph.preds(preds[0])[0] 712 | info = merger.graph.get_node_info(sum) 713 | if info['type'] == NodeType.SUM: 714 | ff2 = merger.graph.preds(sum)[0] 715 | self.merge_node(ff2, merger) 716 | else: 717 | # this is emb node then 718 | if final_merger == None and count == 1: 719 | final_merger = merger 720 | if merger.graph.enc_prefix == 'bert': 721 | # bert has special token type embedding that must be merged too 722 | emb_tok_suff = merger.graph.modules['emb_tok_type'] 723 | emb_tok_name = f'{merger.graph.enc_prefix}.{emb_tok_suff}' 724 | emb_tok_mod = merger.graph.get_module(emb_tok_name) 725 | emb_tok_mod.weight.data = (merger.merge @ (emb_tok_mod.weight).T).T 726 | 727 | # grabbing naming vars 728 | emb_suff = merger.graph.modules['emb'] 729 | emb_pos_suff = merger.graph.modules['emb_pos'] 730 | emb_name = f'{merger.graph.enc_prefix}.{emb_suff}' 731 | emb_pos_name = f'{merger.graph.enc_prefix}.{emb_pos_suff}' 732 | 733 | # merger emb & emb_pos 734 | emb = merger.graph.get_module(emb_name) 735 | emb_pos = merger.graph.get_module(emb_pos_name) 736 | emb.weight.data = (merger.merge @ (emb.weight).T).T 737 | emb_pos.weight.data = (merger.merge @ (emb_pos.weight).T).T 738 | 739 | # this unmerges w_k, w_q, w_v 740 | succs = merger.graph.succs(node) 741 | if len(succs) > 1: 742 | for succ in succs: 743 | info = merger.graph.get_node_info(succ) 744 | if info['type'] != NodeType.SUM: 745 | self.unmerge_node(succ, merger) 746 | else: 747 | # in this case, we have the second to last node 748 | # separate case for mnli & camembert due to head names 749 | # first we check if model is bert and unmerge the lm head 750 | if 'cls.predictions.transform.dense' in merger.graph.named_modules: 751 | module = merger.graph.get_module('cls.predictions.transform.dense') 752 | module.weight.data = module.weight @ merger.unmerge 753 | 754 | elif 'bert.pooler.dense' in merger.graph.named_modules: 755 | module = merger.graph.get_module('bert.pooler.dense') 756 | module.weight.data = module.weight @ merger.unmerge 757 | elif len(merger.graph.model.classification_heads.keys()) != 0: 758 | if 'classification_heads.mnli.dense' in merger.graph.named_modules: 759 | module = merger.graph.get_module('classification_heads.mnli.dense') 760 | module.weight.data = module.weight @ merger.unmerge 761 | elif 'classification_heads.sentence_classification_head.dense' in merger.graph.named_modules: 762 | module = merger.graph.get_module('classification_heads.sentence_classification_head.dense') 763 | module.weight.data = module.weight @ merger.unmerge 764 | # if has no classification heads, it uses lm heads instead, and is a roberta model 765 | # unmerge this, but in the actual eval of wsc, need to fix forward pass, but this is the minimum needed to 766 | # store the correct weights 767 | else: 768 | module = merger.graph.get_module('encoder.lm_head.dense') 769 | module.weight.data = module.weight @ merger.unmerge 770 | 771 | # apply merge to fc1 & unmerge fc2 772 | elif 'fc1' in info['layer'] or 'intermediate.dense' in info['layer']: 773 | print('merging ff') 774 | # apply merges to the fc1 layer 775 | module = merger.graph.get_module(info['layer']) 776 | self.merge_node(preds[0], merger) 777 | 778 | # apply unmerge to fc2 layer 779 | succ = merger.graph.succs(node)[0] 780 | self.unmerge_node(succ, merger) 781 | 782 | elif 'transform.LayerNorm' in info['layer'] and merge_cls: 783 | if final_merger == None and count == 1: # count ensures this is 2nd model merger being saved 784 | final_merger = merger 785 | 786 | print('merging lm head') 787 | # apply merge to layernorm 788 | module = merger.graph.get_module(info['layer']) 789 | parameter_names = ['weight', 'bias'] 790 | for parameter_name in parameter_names: 791 | parameter = getattr(module, parameter_name) 792 | parameter.data = merger.merge @ parameter 793 | 794 | # merge dense 795 | pred = merger.graph.preds(preds[0])[0] 796 | self.merge_node(pred, merger) 797 | 798 | elif 'pooler' in info['layer'] and merge_cls: 799 | print('merging class head') 800 | # merge pooler weight 801 | self.merge_node(preds[0], merger) 802 | # get cls node & unmerge 803 | succ = merger.graph.succs(node)[0] 804 | self.unmerge_node(succ, merger) 805 | count += 1 806 | 807 | return final_merger 808 | 809 | def get_merged_state_dict(self, interp_w=None, save_both=False): 810 | """ 811 | Post transformations, obtain state dictionary for merged model by linearly interpolating between 812 | transformed models in each graph. By default all parameters are averaged, but if given an interp_w 813 | weight, will be weightedly averaged instead. 814 | - interp_w (Optional): If None, all parameters of each model is averaged for merge. Otherwise, 815 | interp_w is a list of len(num_models_to_merge), with weights bearing the importance of incorporating 816 | features from each model into the merged result. 817 | Returns: state dict of merged model. 818 | """ 819 | if save_both: 820 | 821 | # if we are in bert, the models are the same, but we do not want to average after the 822 | # dense layer in the MLM head. We define exclude as a result 823 | if self.graphs[0].enc_prefix == 'bert': 824 | excluded = ['cls.predictions.transform.LayerNorm.weight', 825 | 'cls.predictions.transform.LayerNorm.bias', 826 | 'cls.predictions.decoder.weight', 827 | 'cls.predictions.decoder.bias', 828 | 'bert.pooler.dense.weight', 829 | 'bert.pooler.dense.bias' 830 | 'classifier.weight', 831 | 'classifier.bias'] 832 | #excluded = [] 833 | else: 834 | excluded = [] 835 | state_dict = {} 836 | merged_state_dict1 = self.graphs[0].model.state_dict().copy() 837 | keys1 = list(self.graphs[0].model.state_dict().keys()) 838 | merged_state_dict2 = self.graphs[1].model.state_dict().copy() 839 | keys2 = list(self.graphs[1].model.state_dict().keys()) 840 | for key in keys1: 841 | param = self.graphs[0].model.state_dict()[key] 842 | if key in keys2 and param.shape == merged_state_dict2[key].shape and key not in excluded: 843 | merged_state_dict1[key] = sum(graph.model.state_dict()[key] for graph in self.graphs) / len(self.graphs) 844 | else: 845 | # modified models 846 | merged_state_dict1[key] = self.graphs[0].model.state_dict()[key] 847 | 848 | for key in keys2: 849 | param = self.graphs[1].model.state_dict()[key] 850 | if key in keys1 and param.shape == merged_state_dict1[key].shape and key not in excluded: 851 | merged_state_dict2[key] = sum(graph.model.state_dict()[key] for graph in self.graphs) / len(self.graphs) 852 | else: 853 | merged_state_dict2[key] = self.graphs[1].model.state_dict()[key] 854 | return [merged_state_dict1, merged_state_dict2] 855 | else: 856 | state_dict = {} 857 | merged_state_dict = self.merged_model.state_dict() 858 | keys = list(self.graphs[0].model.state_dict().keys()) 859 | try: 860 | for key in keys: 861 | if key in merged_state_dict: 862 | param = self.graphs[0].model.state_dict()[key] 863 | if interp_w is not None and param.shape == merged_state_dict[key].shape: 864 | new_value = sum(graph.model.state_dict()[key] * w for graph, w in zip(self.graphs, interp_w)) 865 | else: 866 | new_value = sum(graph.model.state_dict()[key] for graph in self.graphs) / len(self.graphs) 867 | state_dict[key] = new_value 868 | except RuntimeError as e: 869 | # Only catch runtime errors about tensor sizes, we need to be able to add models with diff heads together 870 | if 'size' not in str(e): 871 | raise e 872 | return state_dict 873 | 874 | 875 | 876 | def clear_hooks(self): 877 | """ Clears all hooks from graphs. """ 878 | for g in self.graphs: 879 | g.clear_hooks() 880 | for hook in self.hooks: 881 | hook.remove() 882 | self.hooks = [] 883 | 884 | 885 | def transform(self, model, 886 | dataloader, 887 | sentence_level=None, 888 | special_toks=False, 889 | transform_fn=match_tensors_permute, 890 | metric_classes=(CovarianceMetric, MeanMetric), 891 | save_both=False, 892 | permute_heads=False, 893 | ignore_heads=False, 894 | no_absval=False, 895 | merge_cls=False, 896 | saved_features=None, 897 | res_type='none', 898 | **transform_kwargs 899 | ): 900 | """ Note: this consumes the models given to the graphs. Do not modify the models you give this. """ 901 | if save_both: 902 | self.merged_model1 = deepcopy(self.graphs[0].model).to(self.device) 903 | self.merged_model2 = deepcopy(self.graphs[1].model).to(self.device) 904 | else: 905 | self.merged_model = model.to(self.device).eval() # same arch as graph models 906 | 907 | if not isinstance(metric_classes, dict): 908 | metric_classes = { x.name: x for x in metric_classes } 909 | 910 | self.metric_classes = metric_classes 911 | self.transform_fn = transform_fn 912 | 913 | # if we did not pre-save features, compute them here: 914 | if saved_features == None: 915 | _, vars = self.compute_metrics(dataloader, 916 | metric_classes=metric_classes, 917 | sentence_level=sentence_level, 918 | special_toks=special_toks) 919 | 920 | _, _, cost_dict = self.compute_transformations(transform_fn, reduce_ratio=1 - 1. / len(self.graphs), 921 | permute_heads=permute_heads, 922 | ignore_heads=ignore_heads, 923 | no_absval=no_absval, 924 | saved_features=saved_features, 925 | res=res_type, 926 | **transform_kwargs 927 | ) 928 | 929 | final_merger = self.apply_transformations_custom(merge_cls=merge_cls) 930 | 931 | if save_both: 932 | merged_dicts = self.get_merged_state_dict(save_both=True) 933 | self.merged_model1.load_state_dict(merged_dicts[0]) 934 | self.merged_model2.load_state_dict(merged_dicts[1]) 935 | else: 936 | self.merged_model.load_state_dict(self.get_merged_state_dict(save_both=False), strict=False) 937 | self.add_hooks() 938 | 939 | if final_merger == None: 940 | unmerge = None 941 | else: 942 | unmerge = final_merger.unmerge 943 | 944 | return unmerge, cost_dict 945 | 946 | def add_hooks(self): 947 | """ Add hooks at zip start or stop at locations for merged model and base models. """ 948 | # Remove the hooks from the models to add or own 949 | self.clear_hooks() 950 | 951 | 952 | -------------------------------------------------------------------------------- /my_datasets/books.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.nn.utils.rnn import pad_sequence 4 | from transformers import BertTokenizer 5 | from functools import partial 6 | 7 | 8 | CONTEXT_LEN=512 9 | class Books(Dataset): 10 | def __init__(self, 11 | root=None, 12 | train=True, 13 | sort=False, 14 | bpe='bert', 15 | num=0): # do bpe 16 | 17 | if root: 18 | self.root = root 19 | else: 20 | self.root = f'my_datasets/sample{num}.txt' 21 | self.name = 'books' 22 | self.train = train 23 | self.sort = sort 24 | self.bpe = bpe 25 | 26 | split = self.load_split() # loads split & labels 27 | 28 | self.dataset = split 29 | if bpe == 'bert': 30 | self.bert_path = 'google/multiberts-seed_0' 31 | self.bert_encoder = BertTokenizer.from_pretrained(self.bert_path) 32 | 33 | 34 | def __len__(self): 35 | return len(self.dataset) 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | Returns: 42 | tuple: (encoded, len) 43 | """ 44 | 45 | sentence = self.dataset[index] 46 | if self.bpe == 'bert': 47 | tokens = self.bert_encoder(sentence, return_tensors='pt')['input_ids'][0] 48 | # automatically truncating 49 | if len(tokens) > CONTEXT_LEN: 50 | return tokens[:CONTEXT_LEN], CONTEXT_LEN 51 | 52 | return tokens, len(tokens) # sequence of tokens, None label as placeholder 53 | 54 | def load_split(self): 55 | 56 | if self.train: 57 | # skip first header line 58 | train_lines = open(self.root).readlines() 59 | cleaned_lines = [line.strip() for line in train_lines] 60 | 61 | if self.sort: 62 | cleaned_lines_sorted = sorted(cleaned_lines, key=len) 63 | return cleaned_lines_sorted 64 | return cleaned_lines 65 | 66 | def pad_collate(batch, pad_tok): 67 | (xx, lens) = zip(*batch) 68 | xx_pad = pad_sequence(xx, batch_first=True, padding_value=pad_tok) 69 | return xx_pad, lens 70 | 71 | def prepare_train_loaders(config): 72 | if config['tokenizer'] == 'bert': 73 | pad_tok = 0 74 | elif config['tokenizer'] == 'roberta': 75 | pad_tok = 1 76 | return { 77 | 'full': torch.utils.data.DataLoader( 78 | Books( train=True, 79 | num=config['num'], 80 | sort=config['sorted'], 81 | bpe=config['tokenizer']), 82 | batch_size=config['batch_size'], 83 | shuffle=True, collate_fn=partial(pad_collate, pad_tok=pad_tok) 84 | ) 85 | } 86 | 87 | def prepare_test_loaders(config): 88 | if config['tokenizer'] == 'bert': 89 | pad_tok = 0 90 | elif config['tokenizer'] == 'roberta': 91 | pad_tok = 1 92 | return { 93 | 'full': torch.utils.data.DataLoader( 94 | Books( train=False), 95 | batch_size=config['batch_size'], 96 | shuffle=True, num=0, 97 | num_workers=config['num_workers'], collate_fn=partial(pad_collate, pad_tok=pad_tok) 98 | ) 99 | } 100 | 101 | if __name__ == "__main__": 102 | config = {'batch_size': 4, 'num_workers': 4, 'shuffle_train': True, 'sorted': False, 'tokenizer': 'bert', 'num': 0} 103 | train_loader = prepare_train_loaders(config) 104 | x = next(iter(train_loader['full'])) 105 | print(x) -------------------------------------------------------------------------------- /my_datasets/configs.py: -------------------------------------------------------------------------------- 1 | from .books import Books 2 | from .glue import Glue 3 | 4 | books = { 5 | 'wrapper': Books, 6 | 'batch_size': 8, 7 | 'type': 'books', 8 | 'num_workers': 8, 9 | 'shuffle_train': False, 10 | 'shuffle_test': False, 11 | 'sorted': False, 12 | } 13 | 14 | 15 | glue = { 16 | 'wrapper': Glue, 17 | 'batch_size': 8, 18 | 'type': 'glue', 19 | 'num_workers': 8, 20 | 'shuffle_train': False, 21 | 'shuffle_test': False, 22 | 'task': 'mnli' 23 | } 24 | -------------------------------------------------------------------------------- /my_datasets/glue.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.nn.utils.rnn import pad_sequence 4 | from datasets import load_dataset 5 | from transformers import BertTokenizer 6 | from functools import partial 7 | 8 | task_to_keys = { 9 | "cola": ("sentence", None), 10 | "mnli": ("premise", "hypothesis"), 11 | "mrpc": ("sentence1", "sentence2"), 12 | "qnli": ("question", "sentence"), 13 | "qqp": ("question1", "question2"), 14 | "rte": ("sentence1", "sentence2"), 15 | "sst2": ("sentence", None), 16 | "stsb": ("sentence1", "sentence2"), 17 | "wnli": ("sentence1", "sentence2"), 18 | } 19 | 20 | CONTEXT_LEN=512 21 | class Glue(Dataset): 22 | def __init__(self, 23 | task, 24 | train=True, 25 | num=0): # do bpe 26 | 27 | self.name = 'glue' 28 | self.train = train 29 | self.task = task 30 | 31 | split = self.load_split() # loads split & labels 32 | # encoder only used for tokenization 33 | self.encoder = BertTokenizer.from_pretrained('google/multiberts-seed_0') 34 | self.dataset = split 35 | 36 | 37 | def __len__(self): 38 | return len(self.dataset) 39 | 40 | def __getitem__(self, index): 41 | """ 42 | Args: 43 | index (int): Index 44 | Returns: 45 | tuple: (encoded, len, tgt) where target is index of the target character class. 46 | """ 47 | 48 | sentence = self.dataset[index] 49 | if self.task in ['cola', 'sst2']: 50 | tokens = self.encoder(sentence, return_tensors='pt')['input_ids'][0] 51 | elif self.task in ['mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'stsb']: 52 | tokens = self.encoder(*sentence, return_tensors='pt')['input_ids'][0] 53 | # automatically truncating 54 | if len(tokens) > CONTEXT_LEN: 55 | return tokens[:CONTEXT_LEN], CONTEXT_LEN 56 | 57 | return tokens, len(tokens) # sequence of tokens, None label as placeholder 58 | 59 | def load_split(self): 60 | 61 | if self.train: 62 | dataset = load_dataset('glue', self.task)['train'] 63 | if self.task in ['cola', 'sst2']: 64 | # skip first header line 65 | train_lines = [dataset[i]['sentence'] for i in range(len(dataset))] 66 | else: 67 | name0 = task_to_keys[self.task][0] 68 | name1 = task_to_keys[self.task][1] 69 | train_lines = [(dataset[i][name0], dataset[i][name1]) 70 | for i in range(len(dataset))] 71 | return train_lines 72 | 73 | def pad_collate(batch, pad_tok): 74 | (xx, lens) = zip(*batch) 75 | xx_pad = pad_sequence(xx, batch_first=True, padding_value=pad_tok) 76 | return xx_pad, lens 77 | 78 | def prepare_train_loaders(config): 79 | pad_tok = 0 80 | return { 81 | 'full': torch.utils.data.DataLoader( 82 | Glue(train=True, 83 | task=config['task']), 84 | batch_size=config['batch_size'], 85 | shuffle=config['shuffle_train'], 86 | collate_fn=partial(pad_collate, pad_tok=pad_tok) 87 | ) 88 | } 89 | 90 | def prepare_test_loaders(config): 91 | pad_tok = 0 92 | return { 93 | 'full': torch.utils.data.DataLoader( 94 | Glue(train=False, 95 | task=config['task']), 96 | batch_size=config['batch_size'], 97 | shuffle=True, 98 | collate_fn=partial(pad_collate, pad_tok=pad_tok) 99 | ) 100 | } 101 | 102 | 103 | if __name__ == "__main__": 104 | for task in ['cola', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb']: 105 | config = {'batch_size': 4, 'num_workers': 4, 'shuffle_train': True, 'task': task} 106 | train_loader = prepare_train_loaders(config) 107 | x = next(iter(train_loader['full'])) 108 | print(x) -------------------------------------------------------------------------------- /my_datasets/sample_books_corpus.py: -------------------------------------------------------------------------------- 1 | import random 2 | from glob import glob 3 | from tqdm import tqdm 4 | import sys 5 | 6 | # pass path to bookscorpus 7 | BOOKSROOT = sys.argv[1] 8 | OUTPUT = 'sample0.txt' 9 | 10 | # sample N_LINES_PER_FILE lines from each book 11 | N_LINES_PER_FILE = 100 12 | 13 | 14 | with open(OUTPUT, 'w+') as f_out: 15 | # books path should have structure BOOKSROOT/{genre}/{book}.txt 16 | for filename in tqdm(glob(BOOKSROOT + '*/*.txt')): 17 | with open(filename, 'r', encoding='utf-8',errors='ignore') as f: 18 | all_lines = list(filter(lambda x: x != '\n',f.readlines())) 19 | n_total_lines = len(all_lines) 20 | if n_total_lines >= N_LINES_PER_FILE: 21 | lines = random.sample(all_lines, N_LINES_PER_FILE) 22 | f_out.write(''.join(lines)) 23 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nverma1/merging-text-transformers/0a03fa0336f35790481b843896f8ec32b860d069/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | transformers 3 | datasets 4 | torch 5 | numpy 6 | networkx 7 | matplotlib 8 | scipy 9 | pytorch-nlp -------------------------------------------------------------------------------- /training/finetune_glue.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | index=$1 4 | 5 | task=$2 6 | 7 | if [[ $task == "mrpc" ]]; then 8 | epoch=5 9 | else 10 | epoch=3 11 | fi 12 | 13 | python3 run_glue.py --model_name_or_path google/multiberts-seed_${index} \ 14 | --task_name $task \ 15 | --do_train \ 16 | --do_eval \ 17 | --learning_rate 2e-5 \ 18 | --max_seq_length 128 \ 19 | --num_train_epochs $epoch \ 20 | --output_dir models/trained/multiberts/$task/seed_${index} \ 21 | --load_best_model_at_end \ 22 | --save_total_limit 1 \ 23 | --save_strategy "no" \ 24 | --fp16 > models/trained/multiberts/log/$task/${task}_${index}.txt 2>&1 25 | -------------------------------------------------------------------------------- /training/run_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | # This file is adapted from https://github.com/twinkle0331/BERT-similarity 20 | 21 | import logging 22 | import os 23 | import random 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | import torch 28 | 29 | import datasets 30 | import numpy as np 31 | from datasets import load_dataset, load_metric 32 | 33 | import transformers 34 | from transformers import ( 35 | BertConfig, 36 | BertForSequenceClassification, 37 | BertForMaskedLM, 38 | BertTokenizer, 39 | DataCollatorWithPadding, 40 | EvalPrediction, 41 | HfArgumentParser, 42 | PretrainedConfig, 43 | Trainer, 44 | TrainingArguments, 45 | default_data_collator, 46 | set_seed, 47 | ) 48 | from copy import deepcopy 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version 51 | from transformers.utils.versions import require_version 52 | 53 | 54 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 55 | check_min_version("4.17.0") 56 | 57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 58 | 59 | task_to_keys = { 60 | "cola": ("sentence", None), 61 | "mnli": ("premise", "hypothesis"), 62 | "mrpc": ("sentence1", "sentence2"), 63 | "qnli": ("question", "sentence"), 64 | "qqp": ("question1", "question2"), 65 | "rte": ("sentence1", "sentence2"), 66 | "sst2": ("sentence", None), 67 | "stsb": ("sentence1", "sentence2"), 68 | "wnli": ("sentence1", "sentence2"), 69 | } 70 | 71 | logger = logging.getLogger(__name__) 72 | 73 | # added. merges two state dictionaries with w weight 74 | def get_merged_state_dict(state_dict_1, state_dict_2, w=0.5, save_both=False, unmerge=None): 75 | """ 76 | Post transformations, obtain state dictionary for merged model by linearly interpolating between 77 | transformed models in each graph. By default all parameters are averaged, but if given an interp_w 78 | weight, will be weightedly averaged instead. 79 | - interp_w (Optional): If None, all parameters of each model is averaged for merge. Otherwise, 80 | interp_w is a list of len(num_models_to_merge), with weights bearing the importance of incorporating 81 | features from each model into the merged result. 82 | Returns: state dict of merged model. 83 | """ 84 | state_dict = {} 85 | merged_state_dict = deepcopy(state_dict_1) 86 | keys = list(state_dict_1.keys()) 87 | try: 88 | for key in keys: 89 | if key in merged_state_dict: 90 | param = state_dict_1[key] 91 | if param.shape == merged_state_dict[key].shape: 92 | new_value = state_dict_1[key] * w + state_dict_2[key] * (1-w) 93 | state_dict[key] = new_value 94 | except RuntimeError as e: 95 | # Only catch runtime errors about tensor sizes, we need to be able to add models with diff heads together 96 | if 'size' not in str(e): 97 | raise e 98 | return state_dict 99 | 100 | @dataclass 101 | class DataTrainingArguments: 102 | """ 103 | Arguments pertaining to what data we are going to input our model for training and eval. 104 | 105 | Using `HfArgumentParser` we can turn this class 106 | into argparse arguments to be able to specify them on 107 | the command line. 108 | """ 109 | 110 | task_name: Optional[str] = field( 111 | default=None, 112 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 113 | ) 114 | dataset_name: Optional[str] = field( 115 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 116 | ) 117 | dataset_config_name: Optional[str] = field( 118 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 119 | ) 120 | max_seq_length: int = field( 121 | default=128, 122 | metadata={ 123 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 124 | "than this will be truncated, sequences shorter will be padded." 125 | }, 126 | ) 127 | overwrite_cache: bool = field( 128 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 129 | ) 130 | pad_to_max_length: bool = field( 131 | default=True, 132 | metadata={ 133 | "help": "Whether to pad all samples to `max_seq_length`. " 134 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 135 | }, 136 | ) 137 | max_train_samples: Optional[int] = field( 138 | default=None, 139 | metadata={ 140 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 141 | "value if set." 142 | }, 143 | ) 144 | max_eval_samples: Optional[int] = field( 145 | default=None, 146 | metadata={ 147 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 148 | "value if set." 149 | }, 150 | ) 151 | max_predict_samples: Optional[int] = field( 152 | default=None, 153 | metadata={ 154 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 155 | "value if set." 156 | }, 157 | ) 158 | train_file: Optional[str] = field( 159 | default=None, metadata={"help": "A csv or a json file containing the training data."} 160 | ) 161 | validation_file: Optional[str] = field( 162 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 163 | ) 164 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 165 | 166 | def __post_init__(self): 167 | if self.task_name is not None: 168 | self.task_name = self.task_name.lower() 169 | if self.task_name not in task_to_keys.keys(): 170 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 171 | elif self.dataset_name is not None: 172 | pass 173 | elif self.train_file is None or self.validation_file is None: 174 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 175 | else: 176 | train_extension = self.train_file.split(".")[-1] 177 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 178 | validation_extension = self.validation_file.split(".")[-1] 179 | assert ( 180 | validation_extension == train_extension 181 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 182 | 183 | 184 | @dataclass 185 | class ModelArguments: 186 | """ 187 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 188 | """ 189 | 190 | model_name_or_path: Optional[str] = field( 191 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 192 | ) 193 | not_state_dict: bool = field( 194 | default=False, metadata={"help": "is path statedict or not" 195 | } 196 | ) 197 | merged: bool = field( 198 | default=False, metadata={"help": "finetune from a merged model" 199 | } 200 | ) 201 | model1_path: Optional[str] = field( 202 | default=None, metadata={"help": "statedict path 1, if needing to ft from merged model" 203 | } 204 | ) 205 | model2_path: Optional[str] = field( 206 | default=None, metadata={"help": "statedict path 2, if needing to ft from merged model" 207 | } 208 | ) 209 | config_name: Optional[str] = field( 210 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}, 211 | ) 212 | tokenizer_name: Optional[str] = field( 213 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, 214 | ) 215 | cache_dir: Optional[str] = field( 216 | default=None, 217 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 218 | ) 219 | use_fast_tokenizer: bool = field( 220 | default=True, 221 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 222 | ) 223 | head_only: bool = field( 224 | default=False, 225 | metadata={"help": "finetune only classification head params" 226 | }, 227 | ) 228 | model_revision: str = field( 229 | default="main", 230 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 231 | ) 232 | use_auth_token: bool = field( 233 | default=False, 234 | metadata={ 235 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 236 | "with private models)." 237 | }, 238 | ) 239 | 240 | 241 | 242 | def main(): 243 | # See all possible arguments in src/transformers/training_args.py 244 | # or by passing the --help flag to this script. 245 | # We now keep distinct sets of args, for a cleaner separation of concerns. 246 | 247 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 248 | 249 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 250 | # If we pass only one argument to the script and it's the path to a json file, 251 | # let's parse it to get our arguments. 252 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 253 | else: 254 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 255 | 256 | # Setup logging 257 | logging.basicConfig( 258 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 259 | datefmt="%m/%d/%Y %H:%M:%S", 260 | handlers=[logging.StreamHandler(sys.stdout)], 261 | ) 262 | 263 | log_level = training_args.get_process_log_level() 264 | logger.setLevel(log_level) 265 | datasets.utils.logging.set_verbosity(log_level) 266 | transformers.utils.logging.set_verbosity(log_level) 267 | transformers.utils.logging.enable_default_handler() 268 | transformers.utils.logging.enable_explicit_format() 269 | 270 | # Log on each process the small summary: 271 | logger.warning( 272 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 273 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 274 | ) 275 | logger.info(f"Training/evaluation parameters {training_args}") 276 | 277 | # Detecting last checkpoint. 278 | last_checkpoint = None 279 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 280 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 281 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 282 | raise ValueError( 283 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 284 | "Use --overwrite_output_dir to overcome." 285 | ) 286 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 287 | logger.info( 288 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 289 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 290 | ) 291 | 292 | # Set seed before initializing model. 293 | set_seed(training_args.seed) 294 | 295 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 296 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 297 | # 298 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 299 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 300 | # label if at least two columns are provided. 301 | # 302 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 303 | # single column. You can easily tweak this behavior (see below) 304 | # 305 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 306 | # download the dataset. 307 | if data_args.task_name is not None: 308 | # Downloading and loading a dataset from the hub. 309 | raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) 310 | elif data_args.dataset_name is not None: 311 | # Downloading and loading a dataset from the hub. 312 | raw_datasets = load_dataset( 313 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 314 | ) 315 | else: 316 | # Loading a dataset from your local files. 317 | # CSV/JSON training and evaluation files are needed. 318 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 319 | 320 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 321 | # when you use `do_predict` without specifying a GLUE benchmark task. 322 | if training_args.do_predict: 323 | if data_args.test_file is not None: 324 | train_extension = data_args.train_file.split(".")[-1] 325 | test_extension = data_args.test_file.split(".")[-1] 326 | assert ( 327 | test_extension == train_extension 328 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 329 | data_files["test"] = data_args.test_file 330 | else: 331 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 332 | 333 | for key in data_files.keys(): 334 | logger.info(f"load a local file for {key}: {data_files[key]}") 335 | 336 | if data_args.train_file.endswith(".csv"): 337 | # Loading a dataset from local csv files 338 | raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 339 | else: 340 | # Loading a dataset from local json files 341 | raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 342 | # See more about loading any type of standard or custom dataset at 343 | # https://huggingface.co/docs/datasets/loading_datasets.html. 344 | 345 | # Labels 346 | if data_args.task_name is not None: 347 | is_regression = data_args.task_name == "stsb" 348 | if not is_regression: 349 | label_list = raw_datasets["train"].features["label"].names 350 | num_labels = len(label_list) 351 | else: 352 | num_labels = 1 353 | else: 354 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 355 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 356 | if is_regression: 357 | num_labels = 1 358 | else: 359 | # A useful fast method: 360 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 361 | label_list = raw_datasets["train"].unique("label") 362 | label_list.sort() # Let's sort it for determinism 363 | num_labels = len(label_list) 364 | 365 | # Load pretrained model and tokenizer 366 | # 367 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 368 | # download model & vocab. 369 | config = BertConfig.from_pretrained( 370 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 371 | num_labels=num_labels, 372 | finetuning_task=data_args.task_name, 373 | cache_dir=model_args.cache_dir, 374 | revision=model_args.model_revision, 375 | use_auth_token=True if model_args.use_auth_token else None, 376 | ) 377 | tokenizer = BertTokenizer.from_pretrained( 378 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 379 | cache_dir=model_args.cache_dir, 380 | use_fast=model_args.use_fast_tokenizer, 381 | revision=model_args.model_revision, 382 | use_auth_token=True if model_args.use_auth_token else None, 383 | ) 384 | if model_args.not_state_dict: 385 | model = torch.load(model_args.model_name_or_path) 386 | elif model_args.merged: 387 | model1 = torch.load(model_args.model1_path) 388 | model2 = torch.load(model_args.model2_path) 389 | model_placeholder = BertForMaskedLM.from_pretrained('google/multiberts-seed_0') 390 | model_placeholder.load_state_dict(get_merged_state_dict(model1,model2)) 391 | model = BertForSequenceClassification.from_pretrained( 392 | 'google/multiberts-seed_0', 393 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 394 | config=config, 395 | cache_dir=model_args.cache_dir, 396 | revision=model_args.model_revision, 397 | use_auth_token=True if model_args.use_auth_token else None, 398 | ) 399 | model.bert.embeddings = model_placeholder.bert.embeddings 400 | model.bert.encoder = model_placeholder.bert.encoder 401 | else: 402 | model = BertForSequenceClassification.from_pretrained( 403 | model_args.model_name_or_path, 404 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 405 | config=config, 406 | cache_dir=model_args.cache_dir, 407 | revision=model_args.model_revision, 408 | use_auth_token=True if model_args.use_auth_token else None, 409 | ) 410 | # Preprocessing the raw_datasets 411 | if data_args.task_name is not None: 412 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 413 | else: 414 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 415 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 416 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 417 | sentence1_key, sentence2_key = "sentence1", "sentence2" 418 | else: 419 | if len(non_label_column_names) >= 2: 420 | sentence1_key, sentence2_key = non_label_column_names[:2] 421 | else: 422 | sentence1_key, sentence2_key = non_label_column_names[0], None 423 | 424 | # Padding strategy 425 | if data_args.pad_to_max_length: 426 | padding = "max_length" 427 | else: 428 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 429 | padding = False 430 | 431 | # Some models have set the order of the labels to use, so let's make sure we do use it. 432 | label_to_id = None 433 | if ( 434 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 435 | and data_args.task_name is not None 436 | and not is_regression 437 | ): 438 | # Some have all caps in their config, some don't. 439 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 440 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 441 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 442 | else: 443 | logger.warning( 444 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 445 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 446 | "\nIgnoring the model labels as a result.", 447 | ) 448 | elif data_args.task_name is None and not is_regression: 449 | label_to_id = {v: i for i, v in enumerate(label_list)} 450 | 451 | if label_to_id is not None: 452 | model.config.label2id = label_to_id 453 | model.config.id2label = {id: label for label, id in config.label2id.items()} 454 | elif data_args.task_name is not None and not is_regression: 455 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 456 | model.config.id2label = {id: label for label, id in config.label2id.items()} 457 | 458 | if data_args.max_seq_length > tokenizer.model_max_length: 459 | logger.warning( 460 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 461 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 462 | ) 463 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 464 | 465 | def preprocess_function(examples): 466 | # Tokenize the texts 467 | args = ( 468 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 469 | ) 470 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 471 | 472 | # Map labels to IDs (not necessary for GLUE tasks) 473 | if label_to_id is not None and "label" in examples: 474 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 475 | return result 476 | 477 | with training_args.main_process_first(desc="dataset map pre-processing"): 478 | raw_datasets = raw_datasets.map( 479 | preprocess_function, 480 | batched=True, 481 | load_from_cache_file=not data_args.overwrite_cache, 482 | desc="Running tokenizer on dataset", 483 | ) 484 | if training_args.do_train: 485 | if "train" not in raw_datasets: 486 | raise ValueError("--do_train requires a train dataset") 487 | train_dataset = raw_datasets["train"] 488 | if data_args.max_train_samples is not None: 489 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 490 | 491 | if training_args.do_eval: 492 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 493 | raise ValueError("--do_eval requires a validation dataset") 494 | eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 495 | if data_args.max_eval_samples is not None: 496 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 497 | 498 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 499 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 500 | raise ValueError("--do_predict requires a test dataset") 501 | predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] 502 | if data_args.max_predict_samples is not None: 503 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 504 | 505 | # Log a few random samples from the training set: 506 | if training_args.do_train: 507 | for index in random.sample(range(len(train_dataset)), 3): 508 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 509 | 510 | # Get the metric function 511 | if data_args.task_name is not None: 512 | metric = load_metric("glue", data_args.task_name) 513 | else: 514 | metric = load_metric("accuracy") 515 | 516 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 517 | # predictions and label_ids field) and has to return a dictionary string to float. 518 | def compute_metrics(p: EvalPrediction): 519 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 520 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 521 | if data_args.task_name is not None: 522 | result = metric.compute(predictions=preds, references=p.label_ids) 523 | if len(result) > 1: 524 | result["combined_score"] = np.mean(list(result.values())).item() 525 | return result 526 | elif is_regression: 527 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 528 | else: 529 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 530 | 531 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 532 | # we already did the padding. 533 | if data_args.pad_to_max_length: 534 | data_collator = default_data_collator 535 | elif training_args.fp16: 536 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 537 | else: 538 | data_collator = None 539 | 540 | if model_args.head_only: 541 | for name, param in model.named_parameters(): 542 | if 'classifier' not in name: 543 | param.requires_grad = False 544 | 545 | # Initialize our Trainer 546 | trainer = Trainer( 547 | model=model, 548 | args=training_args, 549 | train_dataset=train_dataset if training_args.do_train else None, 550 | eval_dataset=eval_dataset if training_args.do_eval else None, 551 | compute_metrics=compute_metrics, 552 | tokenizer=tokenizer, 553 | data_collator=data_collator, 554 | ) 555 | 556 | # Training 557 | if training_args.do_train: 558 | checkpoint = None 559 | if training_args.resume_from_checkpoint is not None: 560 | checkpoint = training_args.resume_from_checkpoint 561 | elif last_checkpoint is not None: 562 | checkpoint = last_checkpoint 563 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 564 | metrics = train_result.metrics 565 | max_train_samples = ( 566 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 567 | ) 568 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 569 | 570 | trainer.save_model() # Saves the tokenizer too for easy upload 571 | 572 | trainer.log_metrics("train", metrics) 573 | trainer.save_metrics("train", metrics) 574 | trainer.save_state() 575 | 576 | # Evaluation 577 | if training_args.do_eval: 578 | logger.info("*** Evaluate ***") 579 | 580 | # Loop to handle MNLI double evaluation (matched, mis-matched) 581 | tasks = [data_args.task_name] 582 | eval_datasets = [eval_dataset] 583 | if data_args.task_name == "mnli": 584 | tasks.append("mnli-mm") 585 | eval_datasets.append(raw_datasets["validation_mismatched"]) 586 | 587 | for eval_dataset, task in zip(eval_datasets, tasks): 588 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 589 | 590 | max_eval_samples = ( 591 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 592 | ) 593 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 594 | 595 | trainer.log_metrics("eval", metrics) 596 | trainer.save_metrics("eval", metrics) 597 | 598 | for eval_dataset, task in zip(eval_datasets, tasks): 599 | predictions = trainer.predict(eval_dataset, metric_key_prefix="predict").predictions 600 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 601 | output_predict_file = os.path.join(training_args.output_dir,f"predict_results_{task}.npy") 602 | if trainer.is_world_process_zero(): 603 | logger.info(f"***** Predict results {task} *****") 604 | np.save(output_predict_file,predictions) 605 | 606 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") 607 | if trainer.is_world_process_zero(): 608 | with open(output_predict_file, "w") as writer: 609 | logger.info(f"***** Predict results {task} *****") 610 | writer.write("index\tprediction\n") 611 | for index, item in enumerate(predictions): 612 | if is_regression: 613 | writer.write(f"{index}\t{item:3.3f}\n") 614 | else: 615 | item = label_list[item] 616 | writer.write(f"{index}\t{item}\n") 617 | 618 | if training_args.do_predict: 619 | logger.info("*** Predict ***") 620 | 621 | # Loop to handle MNLI double evaluation (matched, mis-matched) 622 | tasks = [data_args.task_name] 623 | predict_datasets = [predict_dataset] 624 | if data_args.task_name == "mnli": 625 | tasks.append("mnli-mm") 626 | predict_datasets.append(raw_datasets["test_mismatched"]) 627 | 628 | for predict_dataset, task in zip(predict_datasets, tasks): 629 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 630 | predict_dataset = predict_dataset.remove_columns("label") 631 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions 632 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 633 | 634 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") 635 | if trainer.is_world_process_zero(): 636 | with open(output_predict_file, "w") as writer: 637 | logger.info(f"***** Predict results {task} *****") 638 | writer.write("index\tprediction\n") 639 | for index, item in enumerate(predictions): 640 | if is_regression: 641 | writer.write(f"{index}\t{item:3.3f}\n") 642 | else: 643 | item = label_list[item] 644 | writer.write(f"{index}\t{item}\n") 645 | 646 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} 647 | if data_args.task_name is not None: 648 | kwargs["language"] = "en" 649 | kwargs["dataset_tags"] = "glue" 650 | kwargs["dataset_args"] = data_args.task_name 651 | kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" 652 | 653 | if training_args.push_to_hub: 654 | trainer.push_to_hub(**kwargs) 655 | else: 656 | trainer.create_model_card(**kwargs) 657 | 658 | 659 | def _mp_fn(index): 660 | # For xla_spawn (TPUs) 661 | main() 662 | 663 | 664 | if __name__ == "__main__": 665 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | from copy import deepcopy 7 | from inspect import getmembers, isfunction 8 | from metric_calculators import get_metric_fns 9 | import random 10 | 11 | 12 | class FractionalDataloader: 13 | def __init__(self, dataloader, fraction, seed=None): 14 | self.dataloader_numel = len(dataloader.dataset) 15 | self.numel = int(fraction * self.dataloader_numel) 16 | 17 | self.batch_size = self.dataloader_numel / len(dataloader) 18 | self.num_batches = int(math.ceil(self.numel / self.batch_size)) 19 | self.dataloader = dataloader 20 | self.dataset = self.dataloader.dataset 21 | self.seed = seed 22 | 23 | def __iter__(self): 24 | cur_elems = 0 25 | if self.seed is not None: 26 | self.dataloader.dataset.set_seed(self.seed) 27 | torch.manual_seed(self.seed) 28 | random.seed(self.seed) 29 | np.random.seed(self.seed) 30 | it = iter(self.dataloader) 31 | while cur_elems < self.numel: 32 | try: 33 | x, y = next(it) 34 | cur_elems += x.shape[0] 35 | yield x, y 36 | except StopIteration: 37 | it = iter(self.dataloader) 38 | 39 | 40 | def __len__(self): 41 | return self.num_batches 42 | 43 | 44 | def prepare_data(config, device='cuda'): 45 | """ Load all dataloaders required for experiment. """ 46 | if isinstance(config, list): 47 | return [prepare_data(c, device) for c in config] 48 | 49 | dataset_name = config['name'] 50 | 51 | import my_datasets.configs as config_module 52 | data_config = deepcopy(getattr(config_module, dataset_name)) 53 | data_config.update(config) 54 | data_config['device'] = device 55 | 56 | if data_config['type'] == 'books': 57 | from my_datasets.books import prepare_train_loaders 58 | train_loaders = prepare_train_loaders(data_config) 59 | test_loaders = None 60 | elif data_config['type'] == 'glue': 61 | from my_datasets.glue import prepare_train_loaders 62 | train_loaders = prepare_train_loaders(data_config) 63 | test_loaders = None 64 | else: 65 | raise NotImplementedError(config['type']) 66 | 67 | if 'train_fraction' in data_config: 68 | for k, v in dict(train_loaders.items()).items(): 69 | if k == 'splits': 70 | train_loaders[k] = [FractionalDataloader(x, data_config['train_fraction']) for x in v] 71 | elif not isinstance(v, list) and not isinstance(v, torch.Tensor): 72 | train_loaders[k] = FractionalDataloader(v, data_config['train_fraction']) 73 | 74 | return { 75 | 'train': train_loaders, 76 | 'test': test_loaders 77 | } 78 | 79 | def prepare_bert(config, device, repair=False, classifier=False): 80 | from transformers import BertForMaskedLM, BertForSequenceClassification 81 | bases = [] 82 | config_example = None 83 | for i, base_path in tqdm(enumerate(config['bases']), desc="Preparing Models"): 84 | if classifier: 85 | base_model = BertForSequenceClassification.from_pretrained(base_path) 86 | else: 87 | base_model = BertForMaskedLM.from_pretrained(base_path) 88 | config_example = base_model.config 89 | bases.append(base_model) 90 | if repair != False: 91 | from models.bert_bn import BertBNForMaskedLM 92 | new_model = BertBNForMaskedLM(config_example, rescale=False) 93 | elif classifier == True: 94 | new_model = deepcopy(base_model) 95 | else: 96 | new_model = BertForMaskedLM(config_example) 97 | return {'bases': bases, 'new': new_model} 98 | 99 | def prepare_models(config, device='cuda', repair=False, classifier=False): 100 | """ Load all pretrained models in config. """ 101 | if config['name'].startswith('bert'): 102 | return prepare_bert(config, device, repair=repair, classifier=classifier) 103 | else: 104 | # can add more models here 105 | raise NotImplementedError(config['name']) 106 | 107 | 108 | def prepare_graph(config, classifier=False): 109 | """ Get graph class of experiment models in config. """ 110 | if config['name'].startswith('bert'): 111 | import graphs.transformer_enc_graph as graph_module 112 | model_name = 'bert' 113 | graph = getattr(graph_module, model_name) 114 | else: 115 | raise NotImplementedError(config['name']) 116 | return graph 117 | 118 | 119 | def get_merging_fn(name): 120 | """ Get alignment function from name. """ 121 | import matching_functions 122 | matching_fns = dict([(k, v) for (k, v) in getmembers(matching_functions, isfunction) if 'match_tensors' in k]) 123 | return matching_fns[name] 124 | 125 | 126 | def prepare_experiment_config(config, type='vis'): 127 | """ Load all functions/classes/models requested in config to experiment config dict. """ 128 | 129 | data = prepare_data(config['dataset'], device=config['device']) 130 | if config['eval_type'] == 'logits': 131 | config['model']['output_dim'] = len(data['test']['class_names']) 132 | else: 133 | config['model']['output_dim'] = 512 134 | new_config = { 135 | 'graph': prepare_graph(config['model']), 136 | 'data': data, 137 | 'models': prepare_models(config['model'], device=config['device']), 138 | 'merging_fn': get_merging_fn(config['merging_fn']), 139 | 'metric_fns': get_metric_fns(config['merging_metrics']), 140 | } 141 | # Add outstanding elements 142 | for key in config: 143 | if key not in new_config: 144 | new_config[key] = config[key] 145 | return new_config 146 | 147 | def prepare_lang_config(config, type='vis', repair=False, classifier=False): 148 | """ Load all functions/classes/models requested in config to experiment config dict. """ 149 | data = prepare_data(config['dataset'], device=config['device']) 150 | new_config = { 151 | 'graph': prepare_graph(config['model'], classifier=classifier), 152 | 'data': data, 153 | 'models': prepare_models(config['model'], device=config['device'], repair=repair, classifier=classifier), 154 | 'merging_fn': get_merging_fn(config['merging_fn']), 155 | 'metric_fns': get_metric_fns(config['merging_metrics']), 156 | } 157 | # Add outstanding elements 158 | for key in config: 159 | if key not in new_config: 160 | new_config[key] = config[key] 161 | return new_config 162 | 163 | 164 | def get_config_from_name(name, device=None): 165 | """ Load config based on its name. """ 166 | out = deepcopy(getattr(__import__('configs.' + name), name).config) 167 | if device is None and 'device' not in out: 168 | out['device'] = 'cuda' 169 | elif device is not None: 170 | out['device'] = device 171 | return out 172 | 173 | 174 | def set_seed(seed): 175 | torch.manual_seed(seed) 176 | random.seed(seed) 177 | np.random.seed(seed) 178 | 179 | def contains_name(layer_name, node_list): 180 | for node in node_list: 181 | if node in layer_name: 182 | return True 183 | return False 184 | 185 | def find_pairs(str_splits): 186 | pairs = [] 187 | for i, str_split_i in enumerate(str_splits): 188 | try: 189 | split_i = set([int(k) for k in str_split_i.split('_')]) 190 | except: 191 | continue 192 | for str_split_j in str_splits[i+1:]: 193 | try: 194 | split_j = set([int(k) for k in str_split_j.split('_')]) 195 | except: 196 | continue 197 | if len(split_i.intersection(split_j)) == 0: 198 | pairs.append((str_split_i, str_split_j)) 199 | return pairs 200 | 201 | 202 | def split_str_to_ints(split): 203 | return [int(i) for i in split.split('_')] 204 | 205 | 206 | def is_valid_pair(model_dir, pair, model_type): 207 | paths = os.listdir(os.path.join(model_dir, pair[0])) 208 | flag = True 209 | for path in paths: 210 | if f'{model_type}_v0.pth.tar' not in path: 211 | flag = False 212 | return flag 213 | 214 | 215 | def inject_pair_language(config, pair): 216 | config['model']['bases'] = [os.path.join(config['model']['dir'], pair_item) for pair_item in pair] 217 | return config 218 | --------------------------------------------------------------------------------