├── .gitignore ├── utils.py ├── LICENSE.md ├── environment.yml ├── README.md ├── run_glue.py └── glue_evaluator.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.cache* 3 | *.ipynb_checkpoints* 4 | *__pycache__* 5 | test 6 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | def setup_logging(): 6 | logging.basicConfig(stream=sys.stdout, format='%(asctime)s - %(module)s - %(levelname)s - %(message)s', 7 | level=logging.INFO) 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 benzakenelad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torch5 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - ca-certificates=2020.12.8=h06a4308_0 7 | - certifi=2020.12.5=py37h06a4308_0 8 | - ld_impl_linux-64=2.33.1=h53a641e_7 9 | - libedit=3.1.20191231=h14c3975_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - openssl=1.1.1i=h27cfd23_0 15 | - pip=20.3.3=py37h06a4308_0 16 | - python=3.7.9=h7579374_0 17 | - readline=8.0=h7b6447c_0 18 | - setuptools=51.0.0=py37h06a4308_2 19 | - sqlite=3.33.0=h62c20be_0 20 | - tk=8.6.10=hbc83047_0 21 | - wheel=0.36.2=pyhd3eb1b0_0 22 | - xz=5.2.5=h7b6447c_0 23 | - zlib=1.2.11=h7b6447c_3 24 | - pip: 25 | - argon2-cffi==20.1.0 26 | - async-generator==1.10 27 | - attrs==20.3.0 28 | - backcall==0.2.0 29 | - bleach==3.2.1 30 | - cffi==1.14.4 31 | - chardet==4.0.0 32 | - click==7.1.2 33 | - cycler==0.10.0 34 | - datasets==1.1.2 35 | - decorator==4.4.2 36 | - defusedxml==0.6.0 37 | - dill==0.3.3 38 | - entrypoints==0.3 39 | - filelock==3.0.12 40 | - idna==2.10 41 | - importlib-metadata==3.3.0 42 | - ipykernel==5.4.2 43 | - ipython==7.19.0 44 | - ipython-genutils==0.2.0 45 | - ipywidgets==7.6.2 46 | - jedi==0.18.0 47 | - jinja2==2.11.2 48 | - joblib==1.0.0 49 | - jsonschema==3.2.0 50 | - jupyter==1.0.0 51 | - jupyter-client==6.1.7 52 | - jupyter-console==6.2.0 53 | - jupyter-core==4.7.0 54 | - jupyterlab-pygments==0.1.2 55 | - jupyterlab-widgets==1.0.0 56 | - kiwisolver==1.3.1 57 | - markupsafe==1.1.1 58 | - matplotlib==3.3.3 59 | - mistune==0.8.4 60 | - multiprocess==0.70.11.1 61 | - nbclient==0.5.1 62 | - nbconvert==6.0.7 63 | - nbformat==5.0.8 64 | - nest-asyncio==1.4.3 65 | - notebook==6.1.6 66 | - numpy==1.19.4 67 | - packaging==20.8 68 | - pandas==1.2.0 69 | - pandocfilters==1.4.3 70 | - parso==0.8.1 71 | - pexpect==4.8.0 72 | - pickleshare==0.7.5 73 | - pillow==8.0.1 74 | - prometheus-client==0.9.0 75 | - prompt-toolkit==3.0.8 76 | - protobuf==3.14.0 77 | - ptyprocess==0.7.0 78 | - pyarrow==2.0.0 79 | - pycparser==2.20 80 | - pygments==2.7.3 81 | - pyparsing==2.4.7 82 | - pyrsistent==0.17.3 83 | - python-dateutil==2.8.1 84 | - pytz==2020.5 85 | - pyzmq==20.0.0 86 | - qtconsole==5.0.1 87 | - qtpy==1.9.0 88 | - regex==2020.11.13 89 | - requests==2.25.1 90 | - sacremoses==0.0.43 91 | - scikit-learn==0.24.0 92 | - scipy==1.5.4 93 | - seaborn==0.11.1 94 | - send2trash==1.5.0 95 | - sentencepiece==0.1.94 96 | - six==1.15.0 97 | - sklearn==0.0 98 | - terminado==0.9.1 99 | - testpath==0.4.4 100 | - threadpoolctl==2.1.0 101 | - tokenizers==0.9.4 102 | - torch==1.7.1 103 | - torchvision==0.8.2 104 | - tornado==6.1 105 | - tqdm==4.49.0 106 | - traitlets==5.0.5 107 | - transformers==4.2.1 108 | - typing-extensions==3.7.4.3 109 | - urllib3==1.26.2 110 | - wcwidth==0.2.5 111 | - webencodings==0.5.1 112 | - widgetsnbextension==3.5.1 113 | - xxhash==2.0.0 114 | - zipp==3.4.0 115 | prefix: /home/dsi/eladbz/anaconda3/envs/torch5 116 | 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BitFit [(Paper)](https://arxiv.org/abs/2106.10199) 2 | Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models 3 | 4 | # Abstract 5 | We introduce BitFit, a sparse-finetuning method where only the bias-terms of the model (or a subset of them) are being modified. We show that with small-to-medium training data, applying BitFit on pre-trained BERT models is competitive with (and sometimes better than) fine-tuning the entire model. For larger data, the method is competitive with other sparse fine-tuning methods. 6 | Besides their practical utility, these findings are relevant for the question of understanding the commonly-used process of finetuning: they support the hypothesis that finetuning is mainly about exposing knowledge induced by language-modeling training, rather than learning new task-specific linguistic knowledge. 7 | 8 | # Environment 9 | First, create an environment with all the dependencies: 10 | ``` 11 | $ conda env create -n bitfit_env -f environment.yml 12 | ``` 13 | Then activate it: 14 | ``` 15 | $ conda activate bitfit_env 16 | ``` 17 | 18 | # [GLUE Benchmark](https://arxiv.org/abs/1804.07461) evaluation examples: 19 | 20 | ``` 21 | python run_glue.py 22 | --output-path \ 23 | --task-name \ 24 | --model-name \ 25 | --fine-tune-type \ 26 | --bias-terms \ 27 | --gpu-device \ 28 | --learning-rate \ 29 | --epochs \ 30 | --batch-size \ 31 | --optimizer \ 32 | --save-evaluator\ 33 | --predict-test\ 34 | --verbose 35 | ``` 36 | For further information about the arguments run: 37 | ``` 38 | python run_glue.py -h 39 | ``` 40 | 41 | Example of executing full fine tuning: 42 | ``` 43 | python run_glue.py 44 | --output-path \ 45 | --task-name rte\ 46 | --model-name bert-base-cased\ 47 | --fine-tune-type full_ft\ 48 | --learning-rate 1e-5 49 | ``` 50 | 51 | Example of executing full BitFit (training all bias terms): 52 | ``` 53 | python run_glue.py 54 | --output-path \ 55 | --task-name rte\ 56 | --model-name bert-base-cased\ 57 | --fine-tune-type bitfit\ 58 | --learning-rate 1e-3 59 | ``` 60 | 61 | Example of executing partial BitFit (training a subset of the bias terms): 62 | ``` 63 | python run_glue.py 64 | --output-path \ 65 | --task-name rte\ 66 | --model-name bert-base-cased\ 67 | --fine-tune-type bitfit\ 68 | --bias-terms query intermediate\ 69 | --learning-rate 1e-3 70 | ``` 71 | 72 | Example of executing "frozen" training (i.e. using the pre-trained transformer as a feature extractor): 73 | ``` 74 | python run_glue.py 75 | --output-path \ 76 | --task-name rte\ 77 | --model-name bert-base-cased\ 78 | --fine-tune-type frozen\ 79 | --learning-rate 1e-3 80 | ``` 81 | 82 | Example of training uniformly chosen trainable parameters (similar to "rand_100k" row in Table 3 in BitFit paper) 83 | ``` 84 | python run_glue.py 85 | --output-path \ 86 | --task-name rte\ 87 | --model-name bert-base-cased\ 88 | --fine-tune-type rand_uniform\ 89 | --learning-rate 1e-3 90 | ``` 91 | 92 | 101 | 102 | # MIT License 103 | 104 | Copyright (c) 2022 benzakenelad 105 | 106 | Permission is hereby granted, free of charge, to any person obtaining a copy 107 | of this software and associated documentation files (the "Software"), to deal 108 | in the Software without restriction, including without limitation the rights 109 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 110 | copies of the Software, and to permit persons to whom the Software is 111 | furnished to do so, subject to the following conditions: 112 | 113 | The above copyright notice and this permission notice shall be included in all 114 | copies or substantial portions of the Software. 115 | 116 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 117 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 118 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 119 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 120 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 121 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 122 | SOFTWARE. 123 | 124 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | """This file contains a tool that wraps the GLUEvaluator API, the tool supports all the evaluations that were 2 | performed in BitFit paper (https://arxiv.org/abs/1804.07461), such as: 'full_ft', 'bitfit', 'frozen', 'rand_uniform' 3 | and 'rand_row_col'. 4 | 5 | For questions please reach: benzakenelad@gmail.com 6 | 7 | Author Elad Ben-Zaken 8 | """ 9 | import argparse 10 | import os 11 | import logging 12 | 13 | from utils import setup_logging 14 | from glue_evaluator import GLUEvaluator, set_seed 15 | 16 | setup_logging() 17 | LOGGER = logging.getLogger(__file__) 18 | 19 | PADDING = "max_length" 20 | MAX_SEQUENCE_LEN = 128 21 | 22 | RAND_UNIFORM_MASK_SIZE = {'bert-base-cased': 100000, 'bert-large-cased': 280000, 'roberta-base': 105000} 23 | 24 | 25 | def _parse_args(): 26 | parser = argparse.ArgumentParser(description='BitFit GLUE evaluation', 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | 29 | parser.add_argument('--output-path', '-o', required=True, type=str, 30 | help='output directory path for evaluation products.') 31 | parser.add_argument('--task-name', '-t', required=True, type=str, help='GLUE task name for evaluation.', 32 | choices={'cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli'}) 33 | parser.add_argument('--model-name', '-m', type=str, default='bert-base-cased', help='model-name to evaluate with.', 34 | choices={'bert-base-cased', 'bert-large-cased', 'roberta-base'}) 35 | 36 | parser.add_argument('--fine-tune-type', '-f', required=True, type=str, 37 | help='Which fine tuning process to perform, types are the types that were performed in BitFit paper.', 38 | choices={'full_ft', 'bitfit', 'frozen', 'rand_uniform', 'rand_row_col'}) 39 | parser.add_argument('--bias-terms', metavar='N', type=str, nargs='+', default=['all'], 40 | choices={'intermediate', 'key', 'query', 'value', 'output', 'output_layernorm', 41 | 'attention_layernorm', 'all'}, 42 | help='bias terms to BitFit, should be given in case --fine-tune-type is bitfit ' 43 | '(choose \'all\' for BitFit all bias terms)') 44 | 45 | parser.add_argument('--gpu-device', '-d', type=int, default=None, 46 | help='GPU id for BitFit, if not mentioned will train on CPU.') 47 | parser.add_argument('--seed', '-s', type=int, default=0, help='seed value to set.') 48 | parser.add_argument('--learning-rate', '-l', type=float, default=1e-3, help='learning rate for training.') 49 | parser.add_argument('--epochs', '-e', type=int, default=16, help='number of training epochs.') 50 | parser.add_argument('--batch-size', '-b', type=int, default=8, help='training and evaluation batch size.') 51 | parser.add_argument('--optimizer', type=str, default='adamw', choices={'adam', 'adamw'}) 52 | parser.add_argument('--save-evaluator', action='store_true', default=False, 53 | help='if given, will save the evaluator for later inference/examination.') 54 | parser.add_argument('--predict-test', action='store_true', default=False, 55 | help='if given, will infer on test set using the fine-tuned model (predictions file will be in ' 56 | 'GLUE benchmark test server format). Predictions will be saved to output_path.') 57 | parser.add_argument('--verbose', action='store_true', default=True, 58 | help='if given, will plot a list of trainable weights.') 59 | 60 | return parser.parse_args() 61 | 62 | 63 | def _validate_args(args): 64 | if not os.path.exists(args.output_path): 65 | os.makedirs(args.output_path) 66 | if not os.path.isdir(args.output_path): 67 | raise ValueError("--output_path must be a path to directory") 68 | if len(os.listdir(args.output_path)): 69 | raise ValueError("--output_path directory isn't empty, please supply an empty directory path.") 70 | if args.fine_tune_type == 'rand_uniform' and args.model_name not in RAND_UNIFORM_MASK_SIZE.keys(): 71 | raise ValueError(f'Currently the rand_uniform fine-tune type is not supported for {args.model_name}.') 72 | 73 | 74 | def _plot_training_details(args): 75 | [LOGGER.info('############################################################################################') for _ 76 | in range(3)] 77 | LOGGER.info('') 78 | 79 | LOGGER.info('Training Details: ') 80 | LOGGER.info('----------------------------------------------') 81 | LOGGER.info(f'Model Name: {args.model_name}') 82 | LOGGER.info(f'Task Name: {args.task_name}') 83 | LOGGER.info(f'Fine Tuning Type: {args.fine_tune_type}') 84 | LOGGER.info(f'Output Directory: {args.output_path}') 85 | 86 | if args.gpu_device is not None: 87 | LOGGER.info(f'Running on GPU #{args.gpu_device}') 88 | else: 89 | LOGGER.info(f'Running on CPU') 90 | 91 | if args.fine_tune_type == 'bitfit': 92 | LOGGER.info(f"Bias Trainable Terms: {'all bias terms' if 'all' in args.bias_terms else args.bias_terms}") 93 | 94 | LOGGER.info(f'Epochs: {args.epochs}') 95 | LOGGER.info(f'Learning Rate: {args.learning_rate}') 96 | LOGGER.info(f'Batch Size: {args.batch_size}') 97 | LOGGER.info(f"Optimizer: {'AdamW' if args.optimizer == 'adamw' else 'Adam'}") 98 | 99 | LOGGER.info('') 100 | [LOGGER.info('############################################################################################') for _ 101 | in range(3)] 102 | 103 | 104 | def _perform_training_preparations(evaluator, args, trainable_components): 105 | if args.fine_tune_type == 'frozen': 106 | trainable_components = [] 107 | 108 | if args.fine_tune_type == 'full_ft': 109 | evaluator.training_preparation(learning_rate=args.learning_rate, 110 | optimizer=args.optimizer, 111 | encoder_trainable=True, 112 | verbose=args.verbose) 113 | elif args.fine_tune_type in {'bitfit', 'frozen'}: 114 | evaluator.training_preparation(learning_rate=args.learning_rate, 115 | optimizer=args.optimizer, 116 | encoder_trainable=False, 117 | trainable_components=trainable_components, 118 | verbose=args.verbose) 119 | else: 120 | evaluator.training_preparation(learning_rate=args.learning_rate, 121 | optimizer=args.optimizer, 122 | encoder_trainable=True, 123 | verbose=False) 124 | 125 | # randomizing mask 126 | if args.fine_tune_type == 'rand_uniform': 127 | evaluator.set_uniform_mask(mask_size=RAND_UNIFORM_MASK_SIZE[args.model_name]) 128 | else: # args.fine_tune_type == 'rand_row_col' 129 | evaluator.set_row_and_column_random_mask() 130 | 131 | 132 | def main(): 133 | # args parsing 134 | args = _parse_args() 135 | _validate_args(args) 136 | _plot_training_details(args) 137 | 138 | # seed 139 | set_seed(args.seed) 140 | 141 | # evaluator creation 142 | evaluator = GLUEvaluator(args.task_name, args.model_name, args.gpu_device) 143 | 144 | # data preprocessing 145 | evaluator.preprocess_dataset(PADDING, MAX_SEQUENCE_LEN, args.batch_size) 146 | 147 | # training preparation 148 | trainable_components = GLUEvaluator.convert_to_actual_components(args.bias_terms) 149 | _perform_training_preparations(evaluator, args, trainable_components) 150 | 151 | # train and evaluate 152 | evaluator.train_and_evaluate(args.epochs, args.output_path) 153 | 154 | # saving artifacts 155 | if args.fine_tune_type == 'bitfit': 156 | evaluator.plot_terms_changes(os.path.join(args.output_path, 'bias_term_changes')) 157 | 158 | # save model 159 | if args.save_evaluator: 160 | evaluator.save(os.path.join(args.output_path, 'evaluator')) 161 | 162 | # export model test set predictions 163 | if args.predict_test: 164 | evaluator.export_model_test_set_predictions(args.output_path) 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /glue_evaluator.py: -------------------------------------------------------------------------------- 1 | """This file contains the GLUEvaluator class which exposes an API for all the evaluations that were performed in 2 | BitFit paper (https://arxiv.org/abs/1804.07461), such as: 'full_ft', 'bitfit', 'frozen', 'rand_uniform' and 3 | 'rand_row_col'. 4 | 5 | For questions please reach: benzakenelad@gmail.com 6 | 7 | Author Elad Ben-Zaken 8 | """ 9 | 10 | import os 11 | import re 12 | from functools import reduce 13 | 14 | import logging 15 | import pickle 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | from seaborn import heatmap 19 | from scipy.stats import spearmanr, pearsonr 20 | from sklearn.metrics import f1_score, matthews_corrcoef, accuracy_score 21 | 22 | import torch 23 | from torch.optim import Adam 24 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 25 | from datasets import load_dataset 26 | from transformers.optimization import AdamW 27 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig 28 | from datasets.arrow_dataset import Dataset 29 | 30 | from utils import setup_logging 31 | 32 | setup_logging() 33 | LOGGER = logging.getLogger(__file__) 34 | 35 | 36 | def set_seed(seed): 37 | torch.manual_seed(seed) 38 | np.random.seed(seed) 39 | 40 | 41 | TASK_TO_KEYS = { 42 | "cola": ("sentence", None), 43 | "mnli": ("premise", "hypothesis"), 44 | "mrpc": ("sentence1", "sentence2"), 45 | "qnli": ("question", "sentence"), 46 | "qqp": ("question1", "question2"), 47 | "rte": ("sentence1", "sentence2"), 48 | "sst2": ("sentence", None), 49 | "stsb": ("sentence1", "sentence2"), 50 | "wnli": ("sentence1", "sentence2"), 51 | } 52 | 53 | TASK_TO_METRICS = { 54 | "cola": ["MCC"], 55 | "mnli": ["Accuracy"], 56 | "mrpc": ["Accuracy", "F1"], 57 | "qnli": ["Accuracy"], 58 | "qqp": ["Accuracy", "F1"], 59 | "rte": ["Accuracy"], 60 | "sst2": ["Accuracy"], 61 | "stsb": ["Spearman", "Pearson"], 62 | "wnli": ["Accuracy"], 63 | } 64 | 65 | METRIC_NAME_TO_FUNCTION = { 66 | "MCC": matthews_corrcoef, 67 | "Accuracy": accuracy_score, 68 | "F1": f1_score, 69 | "Spearman": spearmanr, 70 | "Pearson": pearsonr, 71 | } 72 | 73 | BIAS_TERMS_DICT = { 74 | 'intermediate': 'intermediate.dense.bias', 75 | 'key': 'attention.self.key.bias', 76 | 'query': 'attention.self.query.bias', 77 | 'value': 'attention.self.value.bias', 78 | 'output': 'output.dense.bias', 79 | 'output_layernorm': 'output.LayerNorm.bias', 80 | 'attention_layernorm': 'attention.output.LayerNorm.bias', 81 | 'all': 'bias', 82 | } 83 | 84 | TASK_NAME_TO_SUBMISSION_FILE_NAME = { 85 | "cola": "CoLA.tsv", 86 | "mnli": ("MNLI-m.tsv", "MNLI-mm.tsv"), 87 | "mrpc": "MRPC.tsv", 88 | "qnli": "QNLI.tsv", 89 | "qqp": "QQP.tsv", 90 | "rte": "RTE.tsv", 91 | "sst2": "SST-2.tsv", 92 | "stsb": "STS-B.tsv", 93 | "wnli": "WNLI.tsv", 94 | } 95 | 96 | TASK_IS_BINARY = { 97 | "cola": True, 98 | "mnli": False, 99 | "mrpc": True, 100 | "qnli": False, 101 | "qqp": True, 102 | "rte": False, 103 | "sst2": True, 104 | "stsb": True, 105 | "wnli": True, 106 | } 107 | 108 | BIAS_LAYER_NAME_TO_LATEX = { 109 | 'attention.self.query.bias': '$\mathbf{b}_{q}^{\ell}$', 110 | 'attention.self.key.bias': '$\mathbf{b}_{k}^{\ell}$', 111 | 'attention.self.value.bias': '$\mathbf{b}_{v}^{\ell}$', 112 | 'attention.output.dense.bias': '$\mathbf{b}_{m_1}^{\ell}$', 113 | 'attention.output.LayerNorm.bias': '$\mathbf{b}_{LN_1}^{\ell}$', 114 | 'intermediate.dense.bias': '$\mathbf{b}_{m_2}^{\ell}$', 115 | 'output.dense.bias': '$\mathbf{b}_{m_3}^{\ell}$', 116 | 'output.LayerNorm.bias': '$\mathbf{b}_{LN_2}^{\ell}$', 117 | } 118 | 119 | 120 | class GLUEvaluator: 121 | """This class contains all the functionality for GLUE benchmark evaluations that were performed in BitFit paper. 122 | 123 | This class exposes an API for all the evaluations that were performed in BitFit paper 124 | (https://arxiv.org/abs/1804.07461), such as: 'full_ft', 'bitfit', 'frozen', 'rand_uniform' and 'rand_row_col'. 125 | """ 126 | 127 | def __init__(self, task_name, model_name, device): 128 | """ 129 | Args: 130 | task_name (str): task name, e.g. 'rte'. 131 | model_name (str): model name, e.g. 'bert-base-uncased'. 132 | device (int): GPU device to run on, if None will run on CPU. 133 | 134 | """ 135 | self.task_name = task_name 136 | self.model_name = model_name 137 | self.device = device 138 | 139 | # initialization 140 | self.is_regression = task_name == 'stsb' 141 | self.num_labels = None 142 | self.data_loaders = None 143 | self.batch_size = None 144 | self.model = None 145 | self.optimizer = None 146 | self.learning_rate = None 147 | self.evaluations = None 148 | self.encoder_trainable = None 149 | self.masks = None 150 | self.idx_to_label = None 151 | 152 | def preprocess_dataset(self, padding, max_sequence_len, batch_size, train_size=None): 153 | """Preprocess the train and validation datasets. 154 | 155 | Args: 156 | padding (str): padding method (currently 'max_length' is the suggested method) 157 | max_sequence_len (int): the maximum sequence length 158 | batch_size (int): training and evaluating batch size 159 | train_size (int): clip the train dataset size, if None will use all available samples 160 | 161 | """ 162 | LOGGER.info(f'Downloading dataset: {self.task_name}') 163 | datasets = load_dataset('glue', self.task_name) 164 | 165 | self.batch_size = batch_size 166 | tokenizer = AutoTokenizer.from_pretrained(self.model_name) 167 | 168 | is_regression = self.task_name == "stsb" 169 | if not is_regression: 170 | label_list = datasets["train"].features["label"].names 171 | self.idx_to_label = {k: v for k, v in enumerate(datasets['train'].features['label'].__dict__['_int2str'])} 172 | self.num_labels = len(label_list) 173 | else: 174 | self.num_labels = 1 175 | 176 | sentence1_key, sentence2_key = TASK_TO_KEYS[self.task_name] 177 | 178 | def _preprocess_function(examples): 179 | # Tokenize the texts 180 | args = ( 181 | (examples[sentence1_key],) if sentence2_key is None else ( 182 | examples[sentence1_key], examples[sentence2_key]) 183 | ) 184 | result = tokenizer(*args, padding=padding, max_length=max_sequence_len, truncation=True) 185 | return result 186 | 187 | datasets = datasets.map(_preprocess_function, batched=True, load_from_cache_file=False) 188 | 189 | self.data_loaders = dict() 190 | 191 | if train_size: 192 | perm = np.random.permutation(len(datasets['train']))[:train_size] 193 | self.data_loaders['train'] = Dataset.from_dict(datasets['train'][perm]) 194 | else: 195 | self.data_loaders['train'] = datasets['train'] 196 | 197 | if self.task_name == 'mnli': 198 | self.data_loaders['validation_matched'] = datasets['validation_matched'] 199 | self.data_loaders['validation_mismatched'] = datasets['validation_mismatched'] 200 | self.data_loaders['test_matched'] = datasets['test_matched'] 201 | self.data_loaders['test_mismatched'] = datasets['test_mismatched'] 202 | else: 203 | self.data_loaders['validation'] = datasets['validation'] 204 | self.data_loaders['test'] = datasets['test'] 205 | 206 | for dataset_name, dataset in self.data_loaders.items(): 207 | self.data_loaders[dataset_name] = self._convert_dataset_to_data_loader(dataset=dataset, 208 | model_name=self.model_name, 209 | batch_size=self.batch_size, 210 | random_sampler=dataset_name == 'train', 211 | test='test' in dataset_name) 212 | 213 | def training_preparation(self, learning_rate, optimizer, encoder_trainable, trainable_components=None, 214 | verbose=True): 215 | """Performs training preparation. 216 | 217 | Perform training preparation including: model initialization, optimizer initialization, relevant 218 | gradients deactivation and plotting a list of all trainable params (if verbose is True). 219 | 220 | Args: 221 | learning_rate (float): learning_rate to train with. 222 | optimizer(str): optimizer to perform the training with, currently adam and adamw are supported. 223 | encoder_trainable (bool): if True will perform a Full-FT else will perform BitFit training preparation. 224 | trainable_components(Union[List[str], None]): list of trainable component.(subset of `BIAS_TERMS_DICT` keys) 225 | verbose: if True will plot a list of all trainable params 226 | 227 | """ 228 | if self.model: 229 | raise Exception('Training preparation was already completed.') 230 | 231 | if encoder_trainable and trainable_components: 232 | raise Exception( 233 | f"If encoder_trainable is True, you shouldn't supply trainable_components. " 234 | f"Got trainable_components: {trainable_components}") 235 | 236 | self.encoder_trainable = encoder_trainable 237 | # model declaration 238 | config = AutoConfig.from_pretrained(self.model_name, num_labels=self.num_labels, return_dict=True) 239 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, config=config) 240 | if not encoder_trainable: 241 | self._deactivate_relevant_gradients(trainable_components) 242 | 243 | # optimizer declaration 244 | if optimizer == 'adam': 245 | self.optimizer = Adam(self.model.parameters(), lr=learning_rate) 246 | elif optimizer == 'adamw': 247 | self.optimizer = AdamW(self.model.parameters(), lr=learning_rate, correct_bias=True) 248 | else: 249 | raise Exception(f"optimizer arg must be in ['adam', 'adamw'], got: {optimizer}") 250 | 251 | self.learning_rate = learning_rate 252 | 253 | if verbose: 254 | print('\n\nTrainable Components:\n----------------------------------------\n') 255 | total_trainable_params = 0 256 | for name, param in self.model.named_parameters(): 257 | if param.requires_grad: 258 | print(name, ' ---> ', param.shape) 259 | total_trainable_params += param.shape[0] if len(param.shape) == 1 else param.shape[0] * param.shape[ 260 | 1] 261 | print( 262 | f'\n----------------------------------------\nNumber of Trainable Parameters: {total_trainable_params}\n') 263 | 264 | self.evaluations = {k: {metric_name: [] for metric_name in TASK_TO_METRICS[self.task_name]} for k in 265 | self.data_loaders.keys()} 266 | 267 | def train_and_evaluate(self, num_epochs, output_path=None, evaluation_frequency=1): 268 | """Trains the encoder model and evaluate it on validation set. 269 | 270 | Learning curves will be saved to the output_path. 271 | 272 | Args: 273 | num_epochs (int): Number of epochs to perform. 274 | output_path (str): Directory path to save the learning curves too. 275 | evaluation_frequency (int): will evaluate every `evaluation_frequency` epochs. 276 | 277 | """ 278 | 279 | # validations 280 | if not self.data_loaders: 281 | raise Exception('data loaders were not initialized, please run "preprocess_dataset" before training.') 282 | 283 | if not self.model: 284 | raise Exception('model was not initialized, please run "training_preparation" before training.') 285 | 286 | # moving model to the required device 287 | if self.device is not None: 288 | self.model.cuda(self.device) 289 | 290 | # train and evaluate 291 | for epoch in range(num_epochs): 292 | # training for a single epoch 293 | self._train(self.data_loaders['train'], epoch) 294 | 295 | # evaluation 296 | if not epoch % evaluation_frequency: 297 | for dataloader_type, dataloader in self.data_loaders.items(): 298 | if not ('test' in dataloader_type): 299 | results = self._evaluate(dataloader, dataloader_type.upper()) 300 | for metric_name, result in results.items(): 301 | self.evaluations[dataloader_type][metric_name].append(result) 302 | print('') 303 | 304 | # plot learning curves 305 | self.plot_learning_curves(output_path) 306 | 307 | def save(self, output_path): 308 | """Saves the evaluator to the output_path directory. 309 | 310 | Args: 311 | output_path (str): Directory to save to model to. 312 | 313 | """ 314 | LOGGER.info(f'Saving the model to: {output_path}') 315 | 316 | self.model.cpu() 317 | data = {'model': self.model, 'model_name': self.model_name, 'task_name': self.task_name, 318 | 'learning_rate': self.learning_rate, 'evaluations': self.evaluations, 319 | 'batch_size': self.batch_size, 'num_labels': self.num_labels, 320 | 'encoder_trainable': self.encoder_trainable} 321 | with open(output_path, 'wb') as file: 322 | pickle.dump(data, file) 323 | 324 | @staticmethod 325 | def load(path, gpu_device): 326 | """Loads the evaluator from `path`. 327 | 328 | Args: 329 | path (str): Directory to load to model from. 330 | gpu_device (int): GPU device ID. 331 | 332 | Returns: 333 | (GLUEvaluator): the GLUEvaluator instance we loaded 334 | """ 335 | with open(path, 'rb') as file: 336 | data = pickle.load(file) 337 | evaluator = GLUEvaluator(data['task_name'], data['model_name'], gpu_device) 338 | evaluator.num_labels = data['num_labels'] 339 | evaluator.batch_size = data['batch_size'] 340 | evaluator.model = data['model'] 341 | evaluator.learning_rate = data['learning_rate'] 342 | evaluator.evaluations = data['evaluations'] 343 | evaluator.encoder_trainable = data.get('encoder_trainable', None) 344 | 345 | return evaluator 346 | 347 | def export_model_test_set_predictions(self, output_path): 348 | """Infers on test set and saves the predictions to output_path (predictions are in "GLUE test server" format). 349 | 350 | Args: 351 | output_path (str): Directory to save the predictions. 352 | 353 | """ 354 | # validations 355 | if not self.data_loaders: 356 | raise Exception( 357 | 'data loaders were not initialized, please run "preprocess_dataset" before test evaluation.') 358 | 359 | if not self.model: 360 | raise Exception('model was not initialized, please run "training_preparation" before test evaluation.') 361 | 362 | # move the model the required device 363 | if self.device is not None: 364 | self.model.cuda(self.device) 365 | 366 | LOGGER.info(f'Exporting model test set predictions to: {output_path}.') 367 | 368 | test_data_loaders = dict() 369 | if self.task_name == 'mnli': 370 | test_data_loaders["MNLI-m.tsv"] = self.data_loaders["test_matched"] 371 | test_data_loaders["MNLI-mm.tsv"] = self.data_loaders["test_mismatched"] 372 | else: 373 | test_data_loaders[TASK_NAME_TO_SUBMISSION_FILE_NAME[self.task_name]] = self.data_loaders["test"] 374 | 375 | # change to eval mode 376 | self.model.eval() 377 | 378 | for prediction_file_name, dataloader in test_data_loaders.items(): 379 | results = list() 380 | counter = 0 381 | num_samples = len(dataloader.dataset) 382 | for batch in dataloader: 383 | # move batch data to gpu 384 | if self.device is not None: 385 | batch = tuple(obj.cuda(self.device) for obj in batch) 386 | 387 | input_ids, attention_mask, token_type_ids = batch 388 | 389 | # forward pass 390 | with torch.no_grad(): 391 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, 392 | token_type_ids=token_type_ids) 393 | outputs = outputs.logits 394 | 395 | # aggregate results 396 | if self.is_regression: 397 | outputs = outputs.view(-1) 398 | outputs = outputs.detach().cpu().numpy() 399 | results.extend(list(outputs)) 400 | else: 401 | outputs = outputs.view(-1, self.num_labels) 402 | outputs = outputs.detach().cpu().numpy() 403 | outputs = np.argmax(outputs, axis=-1) 404 | if TASK_IS_BINARY[self.task_name]: 405 | results.extend([int(pred) for pred in outputs]) 406 | else: 407 | results.extend([self.idx_to_label[pred] for pred in outputs]) 408 | 409 | counter += len(outputs) 410 | print(f'Test inference progress: {counter}/{num_samples}\r', end='') 411 | print('') 412 | 413 | # save the test set results (in "GLUE test server" format) 414 | with open(os.path.join(output_path, prediction_file_name), 'w') as file: 415 | file.write('index\tprediction\n') 416 | for idx, result in enumerate(results): 417 | file.write(f'{idx}\t{result}\n') 418 | 419 | LOGGER.info(f'Test set inference is done, inference artifacts are: {list(test_data_loaders.keys())}') 420 | 421 | def plot_learning_curves(self, output_path=None): 422 | """Plot the learning curves for each metric. 423 | 424 | Args: 425 | output_path (str): Directory path to save the learning curves too, if None will print the figure. 426 | 427 | """ 428 | for metric_name in TASK_TO_METRICS[self.task_name]: 429 | for dataloader_type, results_mapper in self.evaluations.items(): 430 | if not ('test' in dataloader_type): 431 | label = f'{dataloader_type} (max is {round(max(results_mapper[metric_name]) * 100, 2)})' 432 | plt.plot(results_mapper[metric_name], label=label) 433 | plt.title(f'Learning Curves - {self.task_name}') 434 | plt.xlabel('Epoch') 435 | plt.ylabel(metric_name) 436 | plt.legend() 437 | if output_path: 438 | plt.savefig(os.path.join(output_path, f'learning_curves_{metric_name.lower()}')) 439 | plt.clf() 440 | else: 441 | plt.show() 442 | 443 | def plot_terms_changes(self, output_path=None): 444 | """Plot/save the terms changes (calculating explained below). 445 | 446 | We define the amount of change in a bias vector b to be (1/dim(b)) * |b_0 - b_f|_1 that is, the average 447 | absolute change, across its dimensions, between the initial LM values b_0 and its fine-tuned values b_f. 448 | 449 | Args: 450 | output_path (str): Directory path to save the terms changes heatmap too, if None will print the figure. 451 | 452 | """ 453 | if self.encoder_trainable: 454 | raise ValueError('Can plot terms changes only when BitFit.') 455 | 456 | if output_path: 457 | LOGGER.info(f'Saving the BitFit bias terms changes to: {output_path}') 458 | 459 | if 'roberta' in self.model_name: 460 | base_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, return_dict=True).roberta 461 | fine_tuned_model = self.model.cpu().roberta 462 | else: 463 | base_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, return_dict=True).bert 464 | fine_tuned_model = self.model.cpu().bert 465 | 466 | num_layers = self.model.config.num_hidden_layers 467 | 468 | def _calc_mean_diff(ft_p, base_p): 469 | return np.mean(np.abs(np.array(ft_p.data - base_p.data))) 470 | 471 | changes = [] 472 | for ft_name, ft_param in fine_tuned_model.named_parameters(): 473 | if ft_param.requires_grad and 'layer' in ft_name: 474 | for base_name, base_param in base_model.named_parameters(): 475 | if ft_name == base_name: 476 | changes.append({'name': ft_name, 'value': _calc_mean_diff(ft_param, base_param)}) 477 | 478 | def _get_component_name(name): 479 | return re.split(r'.[0-9]+.', name)[1] 480 | 481 | def _get_component_layer(name): 482 | return int(name.split('.')[2]) 483 | 484 | keys = list(set(_get_component_name(c['name']) for c in changes)) 485 | keys_mapper = {k: i for i, k in enumerate(keys)} 486 | 487 | total_weights = np.zeros(len(keys)) 488 | for change in changes: 489 | total_weights[keys_mapper[_get_component_name(change['name'])]] += change['value'] 490 | 491 | keys = [keys[i] for i in np.argsort(-total_weights)] 492 | keys_mapper = {k: i for i, k in enumerate(keys)} 493 | 494 | avg_column = np.zeros(len(keys)) 495 | values_map = np.zeros((len(keys), num_layers + 1)) 496 | for change in changes: 497 | avg_column[keys_mapper[_get_component_name(change['name'])]] += change['value'] 498 | values_map[keys_mapper[_get_component_name(change['name'])], _get_component_layer(change['name'])] = change[ 499 | 'value'] 500 | avg_column /= num_layers 501 | values_map[:, -1] = avg_column 502 | 503 | fig, ax = plt.subplots(figsize=(num_layers, len(keys))) 504 | xticklabels = [f'layer {i + 1}' for i in range(num_layers)] 505 | xticklabels.append('Avg.') 506 | 507 | keys = [BIAS_LAYER_NAME_TO_LATEX[key] for key in keys] 508 | heatmap(values_map, cmap="Blues", ax=ax, yticklabels=keys, xticklabels=xticklabels) 509 | 510 | plt.xticks(rotation=45) 511 | plt.yticks(rotation=0, ha='left') 512 | 513 | # align the y-axis text to the left 514 | yax = ax.get_yaxis() 515 | pad = max(T.label.get_window_extent().width for T in yax.majorTicks) 516 | yax.set_tick_params(pad=pad) 517 | 518 | if output_path: 519 | plt.savefig(output_path) 520 | plt.clf() 521 | else: 522 | plt.show() 523 | 524 | if self.device is not None: 525 | self.model.cuda(self.device) 526 | 527 | def set_uniform_mask(self, mask_size): 528 | """Uniformly chooses `mask_size` parameters from the model and generates a boolean mask for every component. 529 | 530 | Uniformly sample `mask_size` parameters from the entire model parameters, and in fine-tuning process only them 531 | will be fine-tuned. 532 | 533 | Args: 534 | mask_size (int): number of non-masked parameters. 535 | 536 | """ 537 | if not self.encoder_trainable: 538 | raise Exception('In order to train with a random mask the encoder must be trainable.') 539 | 540 | if 'roberta' in self.model_name: 541 | model = self.model.roberta 542 | else: 543 | model = self.model.bert 544 | 545 | total_params = 0 546 | self.masks, params_per_component = dict(), dict() 547 | for name, param in model.named_parameters(): 548 | self.masks[name] = torch.zeros(param.size(), dtype=torch.bool) 549 | component_params = reduce(lambda x, y: x * y, param.shape) 550 | params_per_component[name] = component_params 551 | total_params += component_params 552 | 553 | tunable_params_per_component = {k: int((v * mask_size) / total_params) for k, v in 554 | params_per_component.items()} 555 | 556 | LOGGER.info(f'Non-Masked params amount: {reduce(lambda x, y: x + y, tunable_params_per_component.values())}. ' 557 | f'Total params: {total_params}') 558 | 559 | for name, param in model.named_parameters(): 560 | component_mask_size = tunable_params_per_component[name] 561 | component_params = params_per_component[name] 562 | indices = np.random.randint(0, component_params, component_mask_size) 563 | mask = self.masks[name] 564 | for index in indices: 565 | if len(param.shape) == 1: 566 | mask[index] = True 567 | else: 568 | mask[int(index / param.shape[1]), index % param.shape[1]] = True 569 | 570 | def set_row_and_column_random_mask(self): 571 | """Initializes the mask by randomly choosing rows or a column from each weight 572 | 573 | Initializes the mask by randomly choosing rows or a column (column size is equal the bias size) from each 574 | weight, the amount of total non-masked parameters in each weight is equal to the matching bias param size. 575 | 576 | """ 577 | if not self.encoder_trainable: 578 | raise Exception('In order to train with a random mask the encoder must be trainable.') 579 | 580 | if 'roberta' in self.model_name: 581 | model = self.model.roberta 582 | else: 583 | model = self.model.bert 584 | 585 | self.masks = dict() 586 | total_params = 0 587 | for name, param in model.named_parameters(): 588 | self.masks[name] = torch.zeros(param.size(), dtype=torch.bool) 589 | total_params += reduce(lambda x, y: x * y, param.shape) 590 | 591 | if ('encoder' not in name and 'pooler' not in name) or 'weight' not in name: 592 | continue 593 | 594 | if len(param.shape) == 1 and 'LayerNorm' in name: # in case it's a LayerNorm 595 | self.masks[name][:] = True 596 | continue 597 | 598 | if np.random.randint(0, 2) or param.shape[0] < param.shape[1]: # we randomly choose a column 599 | n_columns = int(param.shape[1]) 600 | column_index = np.random.randint(0, n_columns) 601 | self.masks[name][:, column_index] = True 602 | else: # we randomly choose rows 603 | bias_shape = int(param.shape[0]) 604 | row_size = int(param.shape[1]) 605 | n_rows_to_activate = int(bias_shape / row_size) 606 | row_indices = np.random.randint(0, bias_shape, n_rows_to_activate) 607 | self.masks[name][row_indices] = True 608 | 609 | LOGGER.info(f'Non-Masked params amount: {int(np.sum([np.sum(mask.numpy()) for mask in self.masks.values()]))}. ' 610 | f'Total params: {total_params}') 611 | 612 | def _deactivate_relevant_gradients(self, trainable_components): 613 | """Turns off the model parameters requires_grad except the trainable_components. 614 | 615 | Args: 616 | trainable_components (List[str]): list of trainable components (the rest will be deactivated) 617 | 618 | """ 619 | for param in self.model.parameters(): 620 | param.requires_grad = False 621 | if trainable_components: 622 | trainable_components = trainable_components + ['pooler.dense.bias'] 623 | trainable_components = trainable_components + ['classifier'] 624 | for name, param in self.model.named_parameters(): 625 | for component in trainable_components: 626 | if component in name: 627 | param.requires_grad = True 628 | break 629 | 630 | @staticmethod 631 | def convert_to_actual_components(components): 632 | return [BIAS_TERMS_DICT[component] for component in components] 633 | 634 | def _train(self, train_dataloader, epoch, max_grad_norm=1.0): 635 | """Trains the model for a single epoch 636 | 637 | Args: 638 | train_dataloader (torch.utils.data.DataLoader): the train data loader 639 | epoch (int): the epoch number (for logging) 640 | max_grad_norm (float): the maximum gradient norm we allow. The norm is computed over all gradients together, 641 | as if they were concatenated into a single vector. 642 | 643 | """ 644 | # move to train mode 645 | self.model.train() 646 | 647 | # loss initialization 648 | criteria = torch.nn.MSELoss() if self.is_regression else torch.nn.CrossEntropyLoss() 649 | 650 | n = len(train_dataloader.dataset) 651 | trained_samples = loss_sum = 0 652 | for step, batch in enumerate(train_dataloader): 653 | # move batch data to gpu 654 | if self.device is not None: 655 | batch = tuple(obj.cuda(self.device) for obj in batch) 656 | 657 | if 'roberta' in self.model_name: 658 | input_ids, attention_mask, labels = batch 659 | token_type_ids = None 660 | else: 661 | input_ids, attention_mask, token_type_ids, labels = batch 662 | 663 | # forward pass 664 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 665 | outputs = outputs.logits 666 | 667 | # loss calculation 668 | labels = labels.view(-1) 669 | outputs = outputs.view(-1) if self.is_regression else outputs.view(-1, self.num_labels) 670 | 671 | loss = criteria(outputs, labels) 672 | 673 | # backward pass (gradients calculation) 674 | loss.backward() 675 | 676 | # masking the relevant gradients (if needed) 677 | if self.masks: 678 | if 'roberta' in self.model_name: 679 | for name, param in self.model.roberta.named_parameters(): 680 | param.grad[~self.masks[name]] = 0 681 | else: 682 | for name, param in self.model.bert.named_parameters(): 683 | param.grad[~self.masks[name]] = 0 684 | 685 | # gradient clipping 686 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_grad_norm) 687 | 688 | # update parameters 689 | self.optimizer.step() 690 | self.model.zero_grad() 691 | 692 | # track train loss 693 | loss_sum += loss.item() 694 | trained_samples += len(labels) 695 | 696 | # printing training progress 697 | print(f'EPOCH: {epoch} TRAIN: {trained_samples}/{n} LOSS: {round(loss_sum / (step + 1), 3)}\r', end='') 698 | print('') 699 | 700 | def _evaluate(self, dataloader, dataloader_type): 701 | """Evaluates the model on the dataloader 702 | 703 | Args: 704 | dataloader (torch.utils.data.DataLoader): the data loader we evaluate the model on 705 | dataloader_type (str): the dataloader type (train/validation) 706 | 707 | Returns: 708 | (Dict[str, float]): dictionary that maps between metric_name and the metric result 709 | """ 710 | # move to eval mode 711 | self.model.eval() 712 | 713 | evaluated_samples = accuracy_sum = 0 714 | all_predictions, all_labels = [], [] 715 | for step, batch in enumerate(dataloader): 716 | # move batch data to gpu 717 | if self.device is not None: 718 | batch = tuple(obj.cuda(self.device) for obj in batch) 719 | 720 | if 'roberta' in self.model_name: 721 | input_ids, attention_mask, labels = batch 722 | token_type_ids = None 723 | else: 724 | input_ids, attention_mask, token_type_ids, labels = batch 725 | 726 | # forward pass 727 | with torch.no_grad(): 728 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 729 | outputs = outputs.logits 730 | 731 | # reshaping 732 | labels = labels.view(-1) 733 | outputs = outputs.view(-1) if self.is_regression else outputs.view(-1, self.num_labels) 734 | 735 | # moving tensor to cpu and detaching for aggregation 736 | outputs = outputs.detach().cpu().numpy() 737 | labels = labels.cpu().numpy() 738 | 739 | evaluated_samples += len(labels) 740 | 741 | # calculate the accuracy in the classification case 742 | if not self.is_regression: 743 | outputs = np.argmax(outputs, axis=1) 744 | # accuracy calculation 745 | accuracy_sum += accuracy_score(labels, outputs) * len(labels) 746 | print(f'{dataloader_type} ACC: {round(accuracy_sum / evaluated_samples, 5)}\r', end='') 747 | 748 | # aggregate predictions and labels 749 | all_predictions.extend(list(outputs)) 750 | all_labels.extend(list(labels)) 751 | print('') 752 | 753 | # calculate the required metrics 754 | results = {} 755 | for metric_name in TASK_TO_METRICS[self.task_name]: 756 | metric = METRIC_NAME_TO_FUNCTION[metric_name] 757 | result = metric(all_labels, all_predictions) 758 | result = result[0] if self.is_regression else result 759 | results[metric_name] = result 760 | 761 | return results 762 | 763 | @staticmethod 764 | def _convert_dataset_to_data_loader(dataset, model_name, batch_size, random_sampler, test=False): 765 | """converts a datasets.arrow_dataset.Dataset to torch.utils.data.DataLoader. 766 | 767 | Args: 768 | dataset (datasets.arrow_dataset.Dataset): the Dataset to convert to DataLoader. 769 | model_name (str): model name (e.g. bert-base-uncased). 770 | batch_size (int): batch size for training and evaluation. 771 | random_sampler (bool): if True, DataLoader will sample randomly else sequentially. 772 | test (bool): if True, dataset contains test samples. 773 | 774 | Returns: 775 | (torch.utils.data.DataLoader): the data loader 776 | """ 777 | if test: 778 | keys = ['input_ids', 'attention_mask', 'token_type_ids'] 779 | else: 780 | keys = ['input_ids', 'attention_mask', 'token_type_ids', 'label'] 781 | 782 | if 'roberta' in model_name: 783 | keys.remove('token_type_ids') 784 | 785 | data = {key: list() for key in keys} 786 | for sample in dataset: 787 | for key in keys: 788 | data[key].append(sample[key]) 789 | 790 | for k, v in data.items(): 791 | data[k] = torch.tensor(v) 792 | 793 | tensor_dataset = TensorDataset(*[data[key] for key in keys]) 794 | data_sampler = RandomSampler(tensor_dataset) if random_sampler else SequentialSampler(tensor_dataset) 795 | return DataLoader(tensor_dataset, sampler=data_sampler, batch_size=batch_size) 796 | --------------------------------------------------------------------------------