├── .gitignore ├── LICENSE ├── README.md ├── data └── download_dataset.sh ├── requirements.txt ├── run.py ├── run_fewshot.sh ├── src ├── dataset.py ├── kernel_solvers.py ├── kernel_trainer.py ├── linearhead_trainer.py ├── models.py ├── processors.py └── trainer.py └── tools ├── gather_result.py └── generate_k_shot_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | *.sh 141 | .venv 142 | .vscode 143 | data 144 | log* 145 | runs 146 | result 147 | wandb 148 | ensemble_predict_results 149 | auto* 150 | my* 151 | slurm 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Kernel-Based View of Language Model Fine-Tuning (ICML'23) 2 | 3 | This is the implementation for the paper [A Kernel-Based View of Language Model Fine-tuning](https://arxiv.org/abs/2210.05643) 4 | and can be used to compute kernel approximations for the fine-tuning of pre-trained language models. 5 | 6 | We extend the [LM-BFF](https://github.com/princeton-nlp/LM-BFF) repository and 7 | add a new "kernel trainer" powered by [functorch](https://github.com/pytorch/functorch) to compute empirical-NTK kernel matrices using the SGD, SignGD or Asymmetric-SignGD kernel formulas. 8 | We also provide our pre-computed kernels for download to facilitate further analysis. 9 | 10 | ## Installation 11 | Please install all the dependency packages by using the following command: 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | We updated the LM-BFF code to work with a newer version of HuggingFace transformers and additionally require functorch. 17 | If you would like to run LoRA fine-tuning, install the LoRA version of the transformers library ([see here](https://github.com/microsoft/LoRA/tree/main/examples/NLU)) and add the flags `--apply_lora --lora_alpha .... --lora_r ...` . 18 | 19 | **NOTE**: Different versions of some packages (`pytorch`, `numpy`, `transformers`) may cause minor variations in kernels and results. 20 | 21 | ## Prepare the data 22 | Please run the following commands to download and prepare the data: 23 | 24 | ```bash 25 | ( cd data; bash download_dataset.sh ) 26 | 27 | for K in 16 64 512; do 28 | # Generate k-shot splits for seeds 13,21,42,87,100 with a maximum of 1k test examples in data/k-shot-1k-test, 29 | # where k is the number of training/validation examples per label 30 | python tools/generate_k_shot_data.py --mode k-shot-1k-test --k $K 31 | done 32 | ``` 33 | 34 | This follows LM-BFF, but `download_dataset.sh` additionally rebalances the `cr` dataset and uses the GLUE version of the SST-2 dataset. Additionally `k-shot-1k-test` limits test datasets to 1k examples for faster evaluation. 35 | 36 | **NOTE**: During training, the model will generate/load cache files in the data folder. If your data have changed, make sure to clean all the cache files (starting with "cache"). 37 | 38 | ## Run the code 39 | To easily run our experiments, you can use `run_fewshot.sh`: 40 | 41 | ```bash 42 | TAG=kernel-prompting TRAINER=kernel TASK=SST-2 SEED=42 MODEL=roberta-base bash run_fewshot.sh 43 | ``` 44 | 45 | The templates and label word mappings are already defined, so you only need to set hyper-parameters and `TAG` (you can use whatever tag you want and it just makes finding results easier). See `run_fewshot.sh` for more options. Besides, you can easily add extra arguments: 46 | 47 | ```bash 48 | NUM_GPU=4 TAG=kernel-prompting TRAINER=kernel TASK=SST-2 SEED=42 MODEL=roberta-base bash run_fewshot.sh \ 49 | --kernel_formula signgd --kernel_solver logistic --per_device_train_batch_size 2 --per_device_eval_batch_size 4 50 | ``` 51 | This splits the kernel computation across 4 GPUs and uses the SignGD kernel formula and a logistic kernel solver (the default is least-squares regression) and uses batch sizes 2 and 4 along the two axes of the kernel matrices respectively. 52 | 53 | For more advanced use cases, such as [how to aggregate results over multiple runs](https://github.com/princeton-nlp/LM-BFF#experiments-with-multiple-runs), [zero-shot experiments](https://github.com/princeton-nlp/LM-BFF#zero-shot-experiments) or [writing your own prompt formats](https://github.com/princeton-nlp/LM-BFF#how-to-design-your-own-templates), we refer to the README in the LM-BFF repo. 54 | Note that we deleted some tools to do automatic prompt and label search that are unrelated to our paper. 55 | 56 | ## Download our pre-computed kernels 57 | Here are the links for downloading our pre-computed kernels: 58 | 59 | | | SGD | SignGD | Asymmetric-SignGD | 60 | |--------|-----|--------|-------------------| 61 | | 16-shot| [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-sgd-16-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-sgd-16-shot.zip) | [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-signgd-16-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-signgd-16-shot.zip) | [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-asymmetric_signgd-16-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-asymmetric_signgd-16-shot.zip) | 62 | | 64-shot| [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-sgd-64-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-sgd-64-shot.zip) | [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-signgd-64-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-signgd-64-shot.zip) | [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-asymmetric_signgd-64-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-asymmetric_signgd-64-shot.zip) | 63 | | 512-shot| [prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/prompt-sgd-512-shot.zip) / [no-prompt](https://nlp.cs.princeton.edu/projects/LM-Kernel-FT/roberta-base/no_prompt-sgd-512-shot.zip) | | ||| 64 | 65 | The provided kernels were computed for RoBERTa-base for 14 datasets (SST-2, SST-5, MR, CR, MPQA, Subj, TREC, AG News, MNLI, SNLI, QNLI, RTE, MRPC, QQP). The no prompt kernels were obtained by initializing the [CLS] head with the logistic regression solution. 66 | 67 | For each task and data split, we include separate files for training, development, test kernel matrices and pre-trained logits. Each file can be read using `torch.load` and contains a tuple of (kernel matrix, labels), 68 | and the kernel matrix has the shape of [training examples, training logits, *X* examples, *X* logits], where *X* dataset is given by the file name (train, dev or test). 69 | 70 | ## Bugs and questions? 71 | If you have any questions related to the code or the paper, feel free to email Alexander and Sadhika (`{awettig,smalladi}@cs.princeton.edu`). If you encounter a problem or bug when using the code, you can also open an issue. 72 | 73 | ## Citation 74 | 75 | Please cite our work if you make use of our code or our pre-computed kernels in your work: 76 | 77 | ```bibtex 78 | 79 | @InProceedings{malladi2023kernel, 80 | title = {A Kernel-Based View of Language Model Fine-Tuning}, 81 | author = {Malladi, Sadhika and Wettig, Alexander and Yu, Dingli and Chen, Danqi and Arora, Sanjeev}, 82 | booktitle = {Proceedings of the 40th International Conference on Machine Learning}, 83 | pages = {23610--23641}, 84 | year = {2023}, 85 | editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan}, 86 | volume = {202}, 87 | series = {Proceedings of Machine Learning Research}, 88 | month = {23--29 Jul}, 89 | publisher = {PMLR}, 90 | pdf = {https://proceedings.mlr.press/v202/malladi23a/malladi23a.pdf}, 91 | url = {https://proceedings.mlr.press/v202/malladi23a.html} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /data/download_dataset.sh: -------------------------------------------------------------------------------- 1 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar 2 | tar xvf datasets.tar 3 | 4 | echo "*** Use GLUE-SST-2 as default SST-2 ***" 5 | mv original/SST-2 original/SST-2-original 6 | mv original/GLUE-SST-2 original/SST-2 7 | 8 | echo "*** Modify 'cr' test and train splits ***" 9 | # Redistribute train and test examples in cr only using 500 examples for testing. 10 | # This is necessary to have enough training and validation examples for 512-shot experiments. 11 | sed -i 's/1000/250/' original/cr/process.py 12 | ( cd original/cr; python process.py ) 13 | 14 | echo "*** Done ***" 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dataclasses 2 | filelock 3 | importlib-metadata 4 | regex 5 | tqdm 6 | packaging==20.8 7 | future==0.18.2 8 | six==1.15.0 9 | numpy==1.23.1 10 | pandas==1.1.5 11 | scikit-learn==0.24.0 12 | scipy==1.5.4 13 | tokenizers==0.10.3 14 | torch==1.12.1 15 | transformers==4.4.2 16 | functorch==0.2.1 17 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Finetuning the library models for sequence classification on GLUE.""" 2 | 3 | import dataclasses 4 | import logging 5 | import os 6 | import sys 7 | from dataclasses import dataclass, field 8 | from typing import Callable, Dict, Optional, Union 9 | import torch 10 | 11 | import numpy as np 12 | 13 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, PreTrainedTokenizerBase 14 | from transformers import GlueDataTrainingArguments as DataTrainingArguments 15 | from transformers import HfArgumentParser, TrainingArguments, set_seed 16 | 17 | from src.linearhead_trainer import LinearHeadTrainer 18 | from src.kernel_trainer import KernelTrainerFunc 19 | from src.dataset import FewShotDataset 20 | from src.models import ModelForPromptFinetuning, resize_token_type_embeddings 21 | from src.trainer import Trainer 22 | from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, bound_mapping 23 | 24 | from filelock import FileLock 25 | from datetime import datetime 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | @dataclass 31 | class ModelArguments: 32 | """ 33 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 34 | """ 35 | model_name_or_path: str = field( 36 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 37 | ) 38 | config_name: Optional[str] = field( 39 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 40 | ) 41 | tokenizer_name: Optional[str] = field( 42 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 43 | ) 44 | cache_dir: Optional[str] = field( 45 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 46 | ) 47 | # Few-shot type 48 | # - finetune: standard fine-tuning 49 | # - prompt: prompt-based fine-tuning 50 | # - prompt-demo: prompt-based fine-tuning with demonstrations 51 | few_shot_type: str = field( 52 | default='prompt-demo', 53 | metadata={"help": "Few-shot learning model type. Choice: finetune, prompt, prompt-demo"} 54 | ) 55 | 56 | # Only for BERT-type model 57 | random_segment: bool = field( 58 | default=False, 59 | metadata={"help": "Whether to reinitialize the token type embeddings (only for BERT)."} 60 | ) 61 | l2_loss: bool = field( 62 | default=False, 63 | metadata={"help": "Whether to use L2 loss (only makes a difference in standard FT)."} 64 | ) 65 | use_task_word: bool = field( 66 | default=False, 67 | metadata={'help': 'uses the task words MLM logit for kernel computation'} 68 | ) 69 | 70 | # LoRA arguments: only for BERT-type model 71 | apply_lora: bool = field( 72 | default=False, 73 | metadata={'help': 'use LoRA for finetuning'} 74 | ) 75 | lora_alpha: int = field( 76 | default=None, 77 | metadata={'help': 'initialization scale for one of the low rank matrices in lora'} 78 | ) 79 | lora_r: int = field( 80 | default=None, 81 | metadata={'help': 'inner rank for lora matrices'} 82 | ) 83 | 84 | @dataclass 85 | class DynamicDataTrainingArguments(DataTrainingArguments): 86 | """ 87 | Arguments for dynamic training. 88 | """ 89 | num_k: Optional[int] = field( 90 | default=16, 91 | metadata={"help": "Number of training instances per class"} 92 | ) 93 | 94 | num_sample: Optional[int] = field( 95 | default=16, 96 | metadata={"help": "Number of samples (for inference) in fine-tuning with demonstrations"} 97 | ) 98 | 99 | num_demo: Optional[int] = field( 100 | default=1, 101 | metadata={"help": "Number of demonstrations from each class"} 102 | ) 103 | 104 | auto_demo: bool = field( 105 | default=True, 106 | metadata={"help": "Automatically generate template for using demonstrations"} 107 | ) 108 | 109 | # For prompting 110 | template: str = field( 111 | default=None, 112 | metadata={"help": "Template"} 113 | ) 114 | 115 | mapping: str = field( 116 | default=None, 117 | metadata={"help": "Label word mapping"} 118 | ) 119 | 120 | template_path: str = field( 121 | default=None, 122 | metadata={"help": "Path to a txt file that stores all the templates, one per line. Do not set this when prompt_path is used"} 123 | ) 124 | 125 | mapping_path: str = field( 126 | default=None, 127 | metadata={"help": "Path to a txt file that stores all the label word mappings, one per line. Do not set this when prompt_path is used"} 128 | ) 129 | 130 | prompt_path: str = field( 131 | default=None, 132 | metadata={"help": "Path to a txt file that stores all the prompts (templates and mappings), one per line"} 133 | ) 134 | 135 | template_id: int = field( 136 | default=None, 137 | metadata={"help": "Template id if using template_path"} 138 | ) 139 | 140 | mapping_id: int = field( 141 | default=None, 142 | metadata={"help": "Mapping id if using template_path"} 143 | ) 144 | 145 | prompt_id: int = field( 146 | default=None, 147 | metadata={"help": "Prompt id if using prompt_path"} 148 | ) 149 | 150 | top_n_template: int = field( 151 | default=None, 152 | metadata={"help": "Use top-n template in the template path"} 153 | ) 154 | 155 | # For logging 156 | tag: str = field( 157 | default='', 158 | metadata={"help": "Set the tag and find the result easier in the log."} 159 | ) 160 | 161 | # For filtering when using demonstrations 162 | demo_filter: bool = field( 163 | default=False, 164 | metadata={"help": "Only use similar instances in demonstrations"} 165 | ) 166 | 167 | demo_filter_rate: float = field( 168 | default=0.5, 169 | metadata={"help": "Only use top-x\% similar instances in demonstrations"} 170 | ) 171 | 172 | demo_filter_model: str = field( 173 | default=None, 174 | metadata={"help": "Model name for demonstration filter embeddings. Will load embeddings based on the model name."} 175 | ) 176 | 177 | debug_mode: bool = field( 178 | default=False, 179 | metadata={"help": "Debug mode"} 180 | ) 181 | 182 | # For max length 183 | double_demo: bool = field( 184 | default=False, 185 | metadata={"help": "Use double length for using demonstrations"} 186 | ) 187 | 188 | first_sent_limit: int = field( 189 | default=None, 190 | metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"} 191 | ) 192 | 193 | other_sent_limit: int = field( 194 | default=None, 195 | metadata={"help": "Limit the length of sentences other than the first sentence"} 196 | ) 197 | 198 | use_full_length: bool = field( 199 | default=None, 200 | metadata={"help": "Use the full length (512)"} 201 | ) 202 | 203 | # GPT-3's in-context learning 204 | gpt3_in_context_head: bool = field( 205 | default=False, 206 | metadata={"help": "GPT-3's in-context learning (context at the beginning)"} 207 | ) 208 | 209 | gpt3_in_context_tail: bool = field( 210 | default=False, 211 | metadata={"help": "GPT-3's in-context learning (context at the end)"} 212 | ) 213 | 214 | gpt3_in_context_num: int = field( 215 | default=32, 216 | metadata={"help": "Number of context examples"} 217 | ) 218 | 219 | truncate_head: bool = field( 220 | default=False, 221 | metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."} 222 | ) 223 | 224 | # Do not set up the following fields. They are set up automatically. 225 | prompt: bool = field( 226 | default=False, 227 | metadata={"help": "Whether to use prompt-based fine-tuning"} 228 | ) 229 | template_list: list = field( 230 | default=None, 231 | metadata={"help": "(DO NOT List of templates (only initialized after the program starts."} 232 | ) 233 | 234 | 235 | @dataclass 236 | class DynamicTrainingArguments(TrainingArguments): 237 | evaluate_during_training: bool = field( 238 | default=False, 239 | metadata={"help": "Whether to run evaluation during training or at the."} 240 | ) 241 | 242 | # For ensemble 243 | array_id: int = field( 244 | default=-1, 245 | metadata={"help": "Array ID (contains seed and hyper-paramter search) to idenfity the model"} 246 | ) 247 | 248 | model_id: int = field( 249 | default=-1, 250 | metadata={"help": "Model ID (contains template information) to identify the model"} 251 | ) 252 | 253 | save_logit: bool = field( 254 | default=False, 255 | metadata={"help": "Save test file logit with name $TASK-$MODEL_ID-$ARRAY_ID.npy"} 256 | ) 257 | 258 | save_logit_dir: str = field( 259 | default=None, 260 | metadata={"help": "Where to save the prediction result"} 261 | ) 262 | 263 | # Regularization 264 | fix_layers: int = field( 265 | default=0, 266 | metadata={"help": "Fix bottom-n layers when optimizing"} 267 | ) 268 | 269 | # Training 270 | save_at_last: bool = field( 271 | default=False, 272 | metadata={"help": "Instead of saving the best (dev performance) checkpoint, save the last checkpoint"} 273 | ) 274 | 275 | # Turn off train/test 276 | no_train: bool = field( 277 | default=False, 278 | metadata={"help": "No training"} 279 | ) 280 | no_predict: bool = field( 281 | default=False, 282 | metadata={"help": "No test"} 283 | ) 284 | optimizer: str = field( 285 | default='adam', 286 | metadata={'help': 'choose sgd or adam. default is adam'} 287 | ) 288 | optimizer_variant: str = field( 289 | default='', 290 | metadata={'help': 'define variants on optimizer: signgd'} 291 | ) 292 | 293 | trainer: str = field( 294 | default="standard", 295 | metadata={"help": "Pick from {standard, kernel, linearhead}"} 296 | ) 297 | from_linearhead: bool = field( 298 | default=False, 299 | metadata={"help": "Whether to initialize head with the linearhead solution. Works for both normal and kernel trainer."} 300 | ) 301 | random_model_init: bool = field( 302 | default=False, 303 | metadata={'help': 'reinit the model randomly'} 304 | ) 305 | sweep: bool = field( 306 | default=False, 307 | metadata={'help': 'configures the output directories to be informative when running W&B sweep'} 308 | ) 309 | kernel_formula: str = field( 310 | default='sgd', 311 | metadata={"help": "choose kernel formula from {sgd, signgd, asymmetric_signgd}"} 312 | ) 313 | kernel_solver: str = field( 314 | default="logistic", 315 | metadata={"help": "choose kernel solver from {lstsq, logistic, svr, svc, asym (only for asymmetric_signgd)}"} 316 | ) 317 | load_kernels: str = field( 318 | default=None, 319 | metadata={'help': 'when specified, loads the kernels from the folder given here'} 320 | ) 321 | overwrite_kernels: bool = field( 322 | default=False, 323 | metadata={'help': 'when specified, overwrites the kernels in the output_dir and computes them from scratch'} 324 | ) 325 | 326 | exclude_embeddings: bool = field( 327 | default=False, 328 | metadata={"help": "Don't use embeddings for kernel computation "} 329 | ) 330 | exclude_head: bool = field( 331 | default=False, 332 | metadata={"help": "Don't use head for kernel computation "} 333 | ) 334 | only_biases: bool = field( 335 | default=False, 336 | metadata={"help": "Only use bias parameters for kernel computation for BitFit-style kernel"} 337 | ) 338 | 339 | kernel_regularization: float = field( 340 | default=0.0, 341 | metadata={"help": "Regularization constant for kernel"} 342 | ) 343 | kernel_gamma: float = field( 344 | default=1.0, 345 | metadata={"help": "Gamma for asymmetric kernel solver"} 346 | ) 347 | binary_classification: bool = field( 348 | default=False, 349 | metadata={"help": "If num_classes=2, convert two softmax logits to single sigmoid logit"} 350 | ) 351 | adjust_for_init: bool = field( 352 | default=False, 353 | metadata={'help': 'when on, trains kernel on y-f0 and adds f0 at test time'} 354 | ) 355 | f0_scaling: float = field( 356 | default=1.0, 357 | metadata={'help': 'adjust label scaling, might help with --adjust_for_init perf'} 358 | ) 359 | 360 | @dataclass 361 | class MyDataCollatorWithPadding: 362 | """ 363 | Implements padding for LM-BFF inputs. 364 | Args: 365 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 366 | The tokenizer used for encoding the data. 367 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 368 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 369 | among: 370 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single 371 | sequence is provided). 372 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 373 | acceptable input length for the model if that argument is not provided. 374 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). 375 | max_length (`int`, *optional*): 376 | Maximum length of the returned list and optionally padding length (see above). 377 | pad_to_multiple_of (`int`, *optional*): 378 | If set will pad the sequence to a multiple of the provided value. 379 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 380 | 7.5 (Volta). 381 | return_tensors (`str`): 382 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 383 | """ 384 | 385 | tokenizer: PreTrainedTokenizerBase 386 | padding: Union[bool, str] = True 387 | max_length: Optional[int] = None 388 | pad_to_multiple_of: Optional[int] = None 389 | return_tensors: str = "pt" 390 | 391 | def __call__(self, features): 392 | mask_pos = [] 393 | standard_features = [] 394 | for item in features: 395 | standard_item = {} 396 | for field in ["input_ids", "label", "attention_mask", "token_type_ids"]: 397 | if getattr(item, field) is not None: 398 | standard_item[field] = getattr(item, field) 399 | standard_features.append(standard_item) 400 | mask_pos.append(item.mask_pos) 401 | 402 | batch = self.tokenizer.pad( 403 | standard_features, 404 | padding=self.padding, 405 | max_length=self.max_length, 406 | pad_to_multiple_of=self.pad_to_multiple_of, 407 | return_tensors=self.return_tensors, 408 | ) 409 | 410 | if any(mask_pos): 411 | batch["mask_pos"] = torch.tensor(mask_pos) 412 | 413 | if "label" in batch: 414 | batch["labels"] = batch["label"] 415 | del batch["label"] 416 | if "label_ids" in batch: 417 | batch["labels"] = batch["label_ids"] 418 | del batch["label_ids"] 419 | return batch 420 | 421 | 422 | def main(): 423 | parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments)) 424 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 425 | # If we pass only one argument to the script and it's the path to a json file, 426 | # let's parse it to get our arguments. 427 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 428 | else: 429 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 430 | 431 | if training_args.sweep: 432 | now = datetime.now() 433 | dt_str = now.strftime('%m_%d_%H_%M_%S') 434 | training_args.output_dir = os.path.join(training_args.output_dir, dt_str) 435 | 436 | if model_args.apply_lora: 437 | assert 'roberta' in model_args.model_name_or_path, 'LoRA only implemented for RoBERTa models' 438 | 439 | if training_args.kernel_formula == 'asymmetric_signgd': 440 | assert training_args.binary_classification, 'asymmetric solver not implemented for multi-class setting, use --binary_classification' 441 | 442 | if training_args.optimizer_variant != '': 443 | assert training_args.optimizer == 'sgd', 'variants on optimizer are only implemented for SGD' 444 | 445 | if 'prompt' in model_args.few_shot_type: 446 | data_args.prompt = True 447 | 448 | if training_args.no_train: 449 | training_args.do_train = False 450 | if training_args.no_predict: 451 | training_args.do_predict = False 452 | 453 | # Setup logging 454 | logging.basicConfig( 455 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 456 | datefmt="%m/%d/%Y %H:%M:%S", 457 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 458 | ) 459 | 460 | # Load prompt/template/mapping file 461 | if data_args.prompt: 462 | if data_args.prompt_path is not None: 463 | assert data_args.prompt_id is not None 464 | prompt_list = [] 465 | with open(data_args.prompt_path) as f: 466 | for line in f: 467 | line = line.strip() 468 | template, mapping = line.split('\t') 469 | prompt_list.append((template, mapping)) 470 | 471 | data_args.template, data_args.mapping = prompt_list[data_args.prompt_id] 472 | logger.info("Specify load the %d-th prompt: %s | %s" % (data_args.prompt_id, data_args.template, data_args.mapping)) 473 | else: 474 | if data_args.template_path is not None: 475 | with open(data_args.template_path) as f: 476 | data_args.template_list = [] 477 | for line in f: 478 | line = line.strip() 479 | if len(line) > 0: 480 | data_args.template_list.append(line) 481 | 482 | # Load top-n templates 483 | if data_args.top_n_template is not None: 484 | data_args.template_list = data_args.template_list[:data_args.top_n_template] 485 | logger.info("Load top-%d templates from %s" % (len(data_args.template_list), data_args.template_path)) 486 | 487 | # ... or load i-th template 488 | if data_args.template_id is not None: 489 | data_args.template = data_args.template_list[data_args.template_id] 490 | data_args.template_list = None 491 | logger.info("Specify load the %d-th template: %s" % (data_args.template_id, data_args.template)) 492 | 493 | if data_args.mapping_path is not None: 494 | assert data_args.mapping_id is not None # Only can use one label word mapping 495 | with open(data_args.mapping_path) as f: 496 | mapping_list = [] 497 | for line in f: 498 | line = line.strip() 499 | mapping_list.append(line) 500 | 501 | data_args.mapping = mapping_list[data_args.mapping_id] 502 | logger.info("Specify using the %d-th mapping: %s" % (data_args.mapping_id, data_args.mapping)) 503 | 504 | # Check save path 505 | if ( 506 | os.path.exists(training_args.output_dir) 507 | and os.listdir(training_args.output_dir) 508 | and training_args.do_train 509 | and not training_args.overwrite_output_dir 510 | ): 511 | raise ValueError(f"Output directory ({training_args.output_dir}) already exists.") 512 | 513 | logger.warning( 514 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 515 | training_args.local_rank, 516 | training_args.device, 517 | training_args.n_gpu, 518 | bool(training_args.local_rank != -1), 519 | training_args.fp16, 520 | ) 521 | logger.info("Training/evaluation parameters %s", training_args) 522 | 523 | # Set seed 524 | set_seed(training_args.seed) 525 | 526 | try: 527 | num_labels = num_labels_mapping[data_args.task_name] 528 | output_mode = output_modes_mapping[data_args.task_name] 529 | logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode)) 530 | except KeyError: 531 | raise ValueError("Task not found: %s" % (data_args.task_name)) 532 | 533 | # Automatically generate template for using demonstrations 534 | if data_args.auto_demo and model_args.few_shot_type == 'prompt-demo': 535 | # GPT-3's in-context learning 536 | if data_args.gpt3_in_context_head or data_args.gpt3_in_context_tail: 537 | logger.info("Automatically convert the template to GPT-3's in-context learning.") 538 | assert data_args.template_list is None 539 | 540 | old_template = data_args.template 541 | new_template = old_template + '' 542 | old_template = old_template.replace('*cls*', '') 543 | # Single sentence or sentence pair? 544 | sent_num = 1 545 | if "_1" in old_template: 546 | sent_num = 2 547 | for instance_id in range(data_args.gpt3_in_context_num): 548 | sub_template = old_template + '' 549 | # Replace sent_id 550 | for sent_id in range(sent_num): 551 | sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * instance_id + sent_id)) 552 | # Replace mask 553 | sub_template = sub_template.replace("*mask*", "*labelx_{}*".format(instance_id)) 554 | if data_args.gpt3_in_context_tail: 555 | new_template = new_template + sub_template # Put context at the end 556 | else: 557 | new_template = sub_template + new_template # Put context at the beginning 558 | logger.info("| {} => {}".format(data_args.template, new_template)) 559 | data_args.template = new_template 560 | else: 561 | logger.info("Automatically convert the template to using demonstrations.") 562 | if data_args.template_list is not None: 563 | for i in range(len(data_args.template_list)): 564 | old_template = data_args.template_list[i] 565 | new_template = old_template + '' 566 | old_template = old_template.replace('*cls*', '') 567 | # Single sentence or sentence pair? 568 | sent_num = 1 569 | if "_1" in old_template: 570 | sent_num = 2 571 | for label_id in range(num_labels): 572 | sub_template = old_template + '' 573 | # Replace sent id 574 | for sent_id in range(sent_num): 575 | sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * label_id + sent_id)) 576 | # Replace mask 577 | sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id)) 578 | new_template = new_template + sub_template 579 | logger.info("| {} => {}".format(data_args.template_list[i], new_template)) 580 | data_args.template_list[i] = new_template 581 | else: 582 | old_template = data_args.template 583 | new_template = old_template + '' 584 | old_template = old_template.replace('*cls*', '') 585 | # Single sentence or sentence pair? 586 | sent_num = 1 587 | if "_1" in old_template: 588 | sent_num = 2 589 | for label_id in range(num_labels): 590 | sub_template = old_template + '' 591 | # Replace sent id 592 | for sent_id in range(sent_num): 593 | sub_template = sub_template.replace("_{}".format(sent_id), "_{}".format(sent_num + sent_num * label_id + sent_id)) 594 | # Replace mask 595 | sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id)) 596 | new_template = new_template + sub_template 597 | logger.info("| {} => {}".format(data_args.template, new_template)) 598 | data_args.template = new_template 599 | 600 | # Create config 601 | config_kwargs = {'apply_lora': model_args.apply_lora, 602 | 'lora_alpha': model_args.lora_alpha, 603 | 'lora_r': model_args.lora_r} 604 | config = AutoConfig.from_pretrained( 605 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 606 | num_labels=num_labels, 607 | finetuning_task=data_args.task_name, 608 | cache_dir=model_args.cache_dir, 609 | **config_kwargs 610 | ) 611 | 612 | if 'prompt' in model_args.few_shot_type: 613 | model_fn = ModelForPromptFinetuning 614 | elif model_args.few_shot_type == 'finetune': 615 | if training_args.from_linearhead: 616 | model_fn = ModelForPromptFinetuning 617 | else: 618 | model_fn = AutoModelForSequenceClassification 619 | else: 620 | raise NotImplementedError 621 | special_tokens = [] 622 | 623 | # Create tokenizer 624 | tokenizer = AutoTokenizer.from_pretrained( 625 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 626 | additional_special_tokens=special_tokens, 627 | cache_dir=model_args.cache_dir, 628 | ) 629 | 630 | # Get our special datasets. 631 | train_dataset = ( 632 | FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=("demo" in model_args.few_shot_type)) 633 | ) 634 | eval_dataset = ( 635 | FewShotDataset(data_args, tokenizer=tokenizer, mode="dev", use_demo=("demo" in model_args.few_shot_type)) 636 | if training_args.do_eval 637 | else None 638 | ) 639 | test_dataset = ( 640 | FewShotDataset(data_args, tokenizer=tokenizer, mode="test", use_demo=("demo" in model_args.few_shot_type)) 641 | if training_args.do_predict 642 | else None 643 | ) 644 | 645 | set_seed(training_args.seed) 646 | 647 | model = model_fn.from_pretrained( 648 | model_args.model_name_or_path, 649 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 650 | config=config, 651 | cache_dir=model_args.cache_dir, 652 | ) 653 | if training_args.random_model_init: 654 | model.init_weights() # reinit weights to random 655 | 656 | # For BERT, increase the size of the segment (token type) embeddings 657 | if config.model_type == 'bert': 658 | model.resize_token_embeddings(len(tokenizer)) 659 | resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment) 660 | 661 | # Pass dataset and argument information to the model 662 | if train_dataset.label_word_list is not None: 663 | model.label_word_list = torch.tensor(train_dataset.label_word_list).long().to(training_args.device) 664 | if output_modes_mapping[data_args.task_name] == 'regression': 665 | # lower / upper bounds 666 | model.lb, model.ub = bound_mapping[data_args.task_name] 667 | model.model_args = model_args 668 | model.data_args = data_args 669 | model.tokenizer = tokenizer 670 | 671 | if model_args.apply_lora: 672 | for name, param in model.named_parameters(): 673 | if name.startswith('roberta') and "lora" not in name: 674 | param.requires_grad_(False) 675 | 676 | # Build metric 677 | def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: 678 | def compute_metrics_fn(p: EvalPrediction): 679 | # Note: the eval dataloader is sequential, so the examples are in order. 680 | # We average the logits over each sample for using demonstrations. 681 | predictions = p.predictions 682 | num_logits = predictions.shape[-1] 683 | 684 | num_sample = test_dataset.num_sample if eval_dataset is None else eval_dataset.num_sample 685 | logits = predictions.reshape([num_sample, -1, num_logits]) 686 | logits = logits.mean(axis=0) 687 | 688 | if num_logits == 1: 689 | preds = np.squeeze(logits) 690 | else: 691 | preds = np.argmax(logits, axis=1) 692 | 693 | # Just for sanity, assert label ids are the same. 694 | label_ids = p.label_ids.reshape([num_sample, -1]) 695 | label_ids_avg = label_ids.mean(axis=0) 696 | label_ids_avg = label_ids_avg.astype(p.label_ids.dtype) 697 | assert (label_ids_avg - label_ids[0]).mean() < 1e-2 698 | label_ids = label_ids[0] 699 | 700 | return compute_metrics_mapping[task_name](task_name, preds, label_ids) 701 | 702 | return compute_metrics_fn 703 | 704 | # Initialize our Trainer 705 | trainer_classes = { 706 | "standard": Trainer, 707 | "linearhead": LinearHeadTrainer, 708 | "kernel": KernelTrainerFunc, 709 | } 710 | trainer_class = trainer_classes[training_args.trainer] 711 | trainer_kwargs = {} 712 | trainer = trainer_class( 713 | model=model, 714 | args=training_args, 715 | train_dataset=train_dataset, 716 | eval_dataset=eval_dataset, 717 | compute_metrics=build_compute_metrics_fn(data_args.task_name), 718 | data_collator=MyDataCollatorWithPadding(tokenizer), 719 | **trainer_kwargs 720 | ) 721 | 722 | # Training 723 | if training_args.do_train: 724 | trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None) 725 | # Use the early stop, so do not save the model in the end (unless specify save_at_last) 726 | 727 | if training_args.trainer == "standard": 728 | if training_args.save_at_last: 729 | trainer.save_model(training_args.output_dir) 730 | 731 | if trainer.is_world_process_zero(): 732 | tokenizer.save_pretrained(training_args.output_dir) 733 | torch.save(model_args, os.path.join(training_args.output_dir, "model_args.bin")) 734 | torch.save(data_args, os.path.join(training_args.output_dir, "data_args.bin")) 735 | 736 | # Reload the best checkpoint (for eval) 737 | model = model_fn.from_pretrained(training_args.output_dir) 738 | model = model.to(training_args.device) 739 | trainer.model = model 740 | if train_dataset.label_word_list is not None: 741 | model.label_word_list = torch.tensor(train_dataset.label_word_list).long().to(training_args.device) 742 | if output_modes_mapping[data_args.task_name] == 'regression': 743 | # lower / upper bounds 744 | model.lb, model.ub = bound_mapping[data_args.task_name] 745 | model.model_args = model_args 746 | model.data_args = data_args 747 | model.tokenizer = tokenizer 748 | 749 | # Evaluation 750 | final_result = { 751 | 'time': str(datetime.today()), 752 | 'output_dir': training_args.output_dir 753 | } 754 | 755 | eval_results = {} 756 | if training_args.do_eval: 757 | logger.info("*** Validate ***") 758 | 759 | eval_datasets = [eval_dataset] 760 | 761 | for eval_dataset in eval_datasets: 762 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 763 | output = trainer.evaluate(eval_dataset=eval_dataset) 764 | eval_result = output.metrics 765 | 766 | output_eval_file = os.path.join( 767 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 768 | ) 769 | if trainer.is_world_process_zero(): 770 | with open(output_eval_file, "w") as writer: 771 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 772 | for key, value in eval_result.items(): 773 | logger.info(" %s = %s", key, value) 774 | writer.write("%s = %s\n" % (key, value)) 775 | final_result[eval_dataset.args.task_name + '_dev_' + key] = value 776 | eval_results.update(eval_result) 777 | 778 | test_results = {} 779 | if training_args.do_predict: 780 | logging.info("*** Test ***") 781 | test_datasets = [test_dataset] 782 | ### Don't evaluate on mnli-mm for our purposes 783 | # if data_args.task_name == "mnli": 784 | # mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 785 | # test_datasets.append( 786 | # FewShotDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", use_demo=('demo' in model_args.few_shot_type)) 787 | # ) 788 | 789 | for test_dataset in test_datasets: 790 | trainer.compute_metrics = build_compute_metrics_fn(test_dataset.args.task_name) 791 | output = trainer.evaluate(eval_dataset=test_dataset) 792 | test_result = output.metrics 793 | 794 | output_test_file = os.path.join( 795 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 796 | ) 797 | if trainer.is_world_process_zero(): 798 | with open(output_test_file, "w") as writer: 799 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 800 | for key, value in test_result.items(): 801 | logger.info(" %s = %s", key, value) 802 | writer.write("%s = %s\n" % (key, value)) 803 | final_result[test_dataset.args.task_name + '_test_' + key] = value 804 | 805 | if training_args.save_logit: 806 | predictions = output.predictions 807 | num_logits = predictions.shape[-1] 808 | logits = predictions.reshape([test_dataset.num_sample, -1, num_logits]).mean(axis=0) 809 | np.save(os.path.join(training_args.save_logit_dir, "{}-{}-{}.npy".format(test_dataset.task_name, training_args.model_id, training_args.array_id)), logits) 810 | 811 | test_results.update(test_result) 812 | 813 | 814 | if trainer.is_world_process_zero(): 815 | with FileLock('log.lock'): 816 | with open('log', 'a') as f: 817 | final_result.update(vars(model_args)) 818 | final_result.update(vars(training_args)) 819 | final_result.update(vars(data_args)) 820 | if 'evaluation_strategy' in final_result: 821 | final_result.pop('evaluation_strategy') 822 | f.write(str(final_result) + '\n') 823 | 824 | logger.info('****** Output Dir *******') 825 | logger.info(training_args.output_dir) 826 | 827 | return eval_results 828 | 829 | if __name__ == "__main__": 830 | main() 831 | -------------------------------------------------------------------------------- /run_fewshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Main settings with default values 4 | TASK=${TASK:-"SST-2"} # see all the options in the "cases" below 5 | SEED=${SEED:-13} # random seed and also data seed, by default the data split seeds are {13, 21, 42, 87, 100} 6 | K=${K:-16} # choose from {16, 64, 512} by default 7 | MODEL=${MODEL:-"roberta-base"} # pick a RoBERTa or BERT model 8 | TYPE=${TYPE:-"prompt"} # fine-tuning setting, choose from "finetune" and "prompt" 9 | TRAINER=${TRAINER:-"standard"} # choose from "standard", "kernel" and "linearhead" 10 | TAG=${TAG:-} # set a tag to distinguish and aggregate runs in the log 11 | NUM_GPU=${NUM_GPU:-1} # by default use 1 GPU, set to 0 for CPU-only training 12 | 13 | 14 | TASK_EXTRA="" 15 | case $TASK in 16 | SST-2) 17 | TEMPLATE=*cls**sent_0*_It_was*mask*.*sep+* 18 | MAPPING="{'0':'terrible','1':'great'}" 19 | ;; 20 | QQP) 21 | TEMPLATE=*cls**sent_0**mask*,*+sentl_1**sep+* 22 | MAPPING="{'0':'No','1':'Yes'}" 23 | ;; 24 | QNLI) 25 | TEMPLATE=*cls**sent-_0*?*mask*,*+sentl_1**sep+* 26 | MAPPING="{'not_entailment':'No','entailment':'Yes'}" 27 | ;; 28 | MNLI) 29 | TEMPLATE=*cls**sent-_0*?*mask*,*+sentl_1**sep+* 30 | MAPPING="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}" 31 | TASK_EXTRA="--max_seq_len 256 --first_sent_limit 240" 32 | ;; 33 | SNLI) 34 | TEMPLATE=*cls**sent-_0*?*mask*,*+sentl_1**sep+* 35 | MAPPING="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}" 36 | TASK_EXTRA="--max_seq_len 256 --num_sample 4" 37 | ;; 38 | trec) 39 | TEMPLATE="*cls**mask*:*+sent_0**sep+*" 40 | MAPPING="{0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'}" 41 | TASK_EXTRA="--first_sent_limit 110" 42 | ;; 43 | mr) 44 | TEMPLATE=*cls**sent_0*_It_was*mask*.*sep+* 45 | MAPPING="{0:'terrible',1:'great'}" 46 | TASK_EXTRA="--first_sent_limit 110 --other_sent_limit 50" 47 | ;; 48 | cr) 49 | TEMPLATE=*cls**sent_0*_It_was*mask*.*sep+* 50 | MAPPING="{0:'terrible',1:'great'}" 51 | TASK_EXTRA="--first_sent_limit 110 --other_sent_limit 50" 52 | ;; 53 | mpqa) 54 | TEMPLATE=*cls**sent_0*_It_was*mask*.*sep+* 55 | MAPPING="{0:'terrible',1:'great'}" 56 | TASK_EXTRA="--first_sent_limit 110" 57 | ;; 58 | CoLA) 59 | TEMPLATE=*cls**sent_0*_This_is*mask*.*sep+* 60 | MAPPING="{'0':'incorrect','1':'correct'}" 61 | ;; 62 | subj) 63 | TEMPLATE=*cls**sent_0*_This_is*mask*.*sep+* 64 | MAPPING="{0:'subjective',1:'objective'}" 65 | TASK_EXTRA="--first_sent_limit 110 --other_sent_limit 50" 66 | ;; 67 | MRPC) 68 | TEMPLATE=*cls**sent_0**mask*,*+sentl_1**sep+* 69 | MAPPING="{'0':'No','1':'Yes'}" 70 | ;; 71 | RTE) 72 | TEMPLATE=*cls**sent-_0*?*mask*,*+sentl_1**sep+* 73 | MAPPING="{'not_entailment':'No','entailment':'Yes'}" 74 | TASK_EXTRA="--max_seq_len 256 --first_sent_limit 240" 75 | ;; 76 | sst-5) 77 | TEMPLATE=*cls**sent_0*_It_was*mask*.*sep+* 78 | MAPPING="{0:'terrible',1:'bad',2:'okay',3:'good',4:'great'}" 79 | ;; 80 | ag_news) 81 | TEMPLATE=*cls**sent_0*_This_article_is_about*mask*_news.*sep+* 82 | MAPPING="{1:'world',2:'sports',3:'business',4:'tech'}" 83 | TASK_EXTRA="--max_seq_len 256 --first_sent_limit 240" 84 | ;; 85 | esac 86 | 87 | if [ ! -z "$LOAD_KERNELS_TAG" ]; then 88 | # Load pre-computed kernels from an existing directory 89 | LOAD_KERNELS="--load_kernels result/$TASK-$MODEL-$TYPE-$TRAINER-$LOAD_KERNELS_TAG/$K-$SEED" 90 | fi 91 | 92 | ALL_ARGS_TOGETHER=" 93 | --model_name_or_path $MODEL --few_shot_type $TYPE 94 | --task_name $TASK --template $TEMPLATE --mapping $MAPPING 95 | --data_dir data/k-shot-1k-test/$TASK/$K-$SEED 96 | --overwrite_output_dir --output_dir result/$TASK-$MODEL-$TYPE-$TRAINER-$TAG$GRID_TAG/$K-$SEED 97 | --num_k $K 98 | --tag $TAG 99 | --per_device_eval_batch_size 1 100 | --per_device_train_batch_size 1 101 | --max_seq_length 128 102 | --seed $SEED 103 | --do_eval --do_predict --do_train 104 | --trainer $TRAINER 105 | $TASK_EXTRA 106 | $LOAD_KERNELS 107 | $@ 108 | " 109 | 110 | if [[ $NUM_GPU > 0 ]]; then 111 | # Randomly set a port number 112 | # If you encounter "address already used" error, just run again or manually set an available port id. 113 | PORT_ID=$(expr $RANDOM + 1000) 114 | 115 | # Allow multiple threads 116 | export OMP_NUM_THREADS=8 117 | 118 | python -m torch.distributed.launch --nproc_per_node $NUM_GPU --master_port $PORT_ID run.py \ 119 | $ALL_ARGS_TOGETHER 120 | else 121 | python run.py \ 122 | $ALL_ARGS_TOGETHER 123 | fi 124 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset utils for different data settings for GLUE.""" 2 | 3 | import os 4 | import logging 5 | import torch 6 | import numpy as np 7 | import time 8 | from filelock import FileLock 9 | import json 10 | from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, median_mapping 11 | from transformers.data.processors.utils import InputFeatures 12 | import dataclasses 13 | from dataclasses import dataclass 14 | from typing import List, Optional, Union 15 | import pandas as pd 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | @dataclass(frozen=True) 20 | class OurInputFeatures(InputFeatures): 21 | """ 22 | Inherit from Transformers' InputFeatuers. 23 | """ 24 | 25 | input_ids: List[int] 26 | attention_mask: Optional[List[int]] = None 27 | token_type_ids: Optional[List[int]] = None 28 | label: Optional[Union[int, float]] = None 29 | mask_pos: Optional[List[int]] = None # Position of the mask token 30 | label_word_list: Optional[List[int]] = None # Label word mapping (dynamic) 31 | 32 | def to_json_string(self): 33 | """Serializes this instance to a JSON string.""" 34 | return json.dumps(dataclasses.asdict(self)) + "\n" 35 | 36 | def input_example_to_string(example, sep_token): 37 | if example.text_b is None: 38 | return example.text_a 39 | else: 40 | # Warning: very simple hack here 41 | return example.text_a + ' ' + sep_token + ' ' + example.text_b 42 | 43 | def input_example_to_tuple(example): 44 | if example.text_b is None: 45 | if pd.isna(example.text_a) or example.text_a is None: 46 | return [''] 47 | logger.warn("Empty input") 48 | else: 49 | return [example.text_a] 50 | else: 51 | return [example.text_a, example.text_b] 52 | 53 | def tokenize_multipart_input( 54 | input_text_list, 55 | max_length, 56 | tokenizer, 57 | task_name=None, 58 | prompt=False, 59 | template=None, 60 | label_word_list=None, 61 | first_sent_limit=None, 62 | other_sent_limit=None, 63 | gpt3=False, 64 | truncate_head=False, 65 | support_labels=None, 66 | ): 67 | def enc(text): 68 | return tokenizer.encode(text, add_special_tokens=False) 69 | 70 | input_ids = [] 71 | attention_mask = [] 72 | token_type_ids = [] # Only for BERT 73 | mask_pos = None # Position of the mask token 74 | 75 | if prompt: 76 | """ 77 | Concatenate all sentences and prompts based on the provided template. 78 | Template example: '*cls*It was*mask*.*sent_0***label_0:*sent_1****label_1*:*sent_2***' 79 | *xx* represent variables: 80 | *cls*: cls_token 81 | *mask*: mask_token 82 | *sep*: sep_token 83 | *sep+*: sep_token, also means +1 for segment id 84 | *sent_i*: sentence i (input_text_list[i]) 85 | *sent-_i*: same as above, but delete the last token 86 | *sentl_i*: same as above, but use lower case for the first word 87 | *sentl-_i*: same as above, but use lower case for the first word and delete the last token 88 | *+sent_i*: same as above, but add a space before the sentence 89 | *+sentl_i*: same as above, but add a space before the sentence and use lower case for the first word 90 | *label_i*: label_word_list[i] 91 | *label_x*: label depends on the example id (support_labels needed). this is only used in GPT-3's in-context learning 92 | 93 | Use "_" to replace space. 94 | PAY ATTENTION TO SPACE!! DO NOT leave space before variables, for this will lead to extra space token. 95 | """ 96 | assert template is not None 97 | 98 | special_token_mapping = { 99 | 'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id, 100 | } 101 | template_list = template.split('*') # Get variable list in the template 102 | segment_id = 0 # Current segment id. Segment id +1 if encountering sep+. 103 | 104 | for part_id, part in enumerate(template_list): 105 | new_tokens = [] 106 | segment_plus_1_flag = False 107 | if part in special_token_mapping: 108 | if part == 'cls' and 'T5' in type(tokenizer).__name__: 109 | # T5 does not have cls token 110 | continue 111 | new_tokens.append(special_token_mapping[part]) 112 | if part == 'sep+': 113 | segment_plus_1_flag = True 114 | elif part[:6] == 'label_': 115 | # Note that label_word_list already has extra space, so do not add more space ahead of it. 116 | label_id = int(part.split('_')[1]) 117 | label_word = label_word_list[label_id] 118 | new_tokens.append(label_word) 119 | elif part[:7] == 'labelx_': 120 | instance_id = int(part.split('_')[1]) 121 | label_id = support_labels[instance_id] 122 | label_word = label_word_list[label_id] 123 | new_tokens.append(label_word) 124 | elif part[:5] == 'sent_': 125 | sent_id = int(part.split('_')[1]) 126 | new_tokens += enc(input_text_list[sent_id]) 127 | elif part[:6] == '+sent_': 128 | # Add space 129 | sent_id = int(part.split('_')[1]) 130 | new_tokens += enc(' ' + input_text_list[sent_id]) 131 | elif part[:6] == 'sent-_': 132 | # Delete the last token 133 | sent_id = int(part.split('_')[1]) 134 | new_tokens += enc(input_text_list[sent_id][:-1]) 135 | elif part[:6] == 'sentl_': 136 | # Lower case the first token 137 | sent_id = int(part.split('_')[1]) 138 | text = input_text_list[sent_id] 139 | text = text[:1].lower() + text[1:] 140 | new_tokens += enc(text) 141 | elif part[:7] == '+sentl_': 142 | # Lower case the first token and add space 143 | sent_id = int(part.split('_')[1]) 144 | text = input_text_list[sent_id] 145 | text = text[:1].lower() + text[1:] 146 | new_tokens += enc(' ' + text) 147 | elif part[:7] == 'sentl-_': 148 | # Lower case the first token and discard the last token 149 | sent_id = int(part.split('_')[1]) 150 | text = input_text_list[sent_id] 151 | text = text[:1].lower() + text[1:] 152 | new_tokens += enc(text[:-1]) 153 | elif part[:6] == 'sentu_': 154 | # Upper case the first token 155 | sent_id = int(part.split('_')[1]) 156 | text = input_text_list[sent_id] 157 | text = text[:1].upper() + text[1:] 158 | new_tokens += enc(text) 159 | elif part[:7] == '+sentu_': 160 | # Upper case the first token and add space 161 | sent_id = int(part.split('_')[1]) 162 | text = input_text_list[sent_id] 163 | text = text[:1].upper() + text[1:] 164 | new_tokens += enc(' ' + text) 165 | else: 166 | # Just natural language prompt 167 | part = part.replace('_', ' ') 168 | # handle special case when T5 tokenizer might add an extra space 169 | if len(part) == 1: 170 | new_tokens.append(tokenizer.convert_tokens_to_ids(part)) 171 | else: 172 | new_tokens += enc(part) 173 | 174 | if part[:4] == 'sent' or part[1:5] == 'sent': 175 | # If this part is the sentence, limit the sentence length 176 | sent_id = int(part.split('_')[1]) 177 | if sent_id == 0: 178 | if first_sent_limit is not None: 179 | new_tokens = new_tokens[:first_sent_limit] 180 | else: 181 | if other_sent_limit is not None: 182 | new_tokens = new_tokens[:other_sent_limit] 183 | 184 | input_ids += new_tokens 185 | attention_mask += [1 for i in range(len(new_tokens))] 186 | token_type_ids += [segment_id for i in range(len(new_tokens))] 187 | 188 | if segment_plus_1_flag: 189 | segment_id += 1 190 | else: 191 | input_ids = [tokenizer.cls_token_id] 192 | attention_mask = [1] 193 | token_type_ids = [0] 194 | 195 | for sent_id, input_text in enumerate(input_text_list): 196 | if input_text is None: 197 | # Do not have text_b 198 | continue 199 | if pd.isna(input_text) or input_text is None: 200 | # Empty input 201 | input_text = '' 202 | input_tokens = enc(input_text) + [tokenizer.sep_token_id] 203 | input_ids += input_tokens 204 | attention_mask += [1 for i in range(len(input_tokens))] 205 | token_type_ids += [sent_id for i in range(len(input_tokens))] 206 | 207 | if 'T5' in type(tokenizer).__name__: # T5 does not have CLS token 208 | input_ids = input_ids[1:] 209 | attention_mask = attention_mask[1:] 210 | token_type_ids = token_type_ids[1:] 211 | 212 | # Padding 213 | if first_sent_limit is not None and len(input_ids) > max_length: 214 | # If using sentence limit, the total length still exceeds the maximum limit, report a warning 215 | logger.warn("Input exceeds max_length limit: {}".format(tokenizer.decode(input_ids))) 216 | 217 | ### Code below is commented out, because we use dynamic padding rather than static padding to max_length 218 | # while len(input_ids) < max_length: 219 | # input_ids.append(tokenizer.pad_token_id) 220 | # attention_mask.append(0) 221 | # token_type_ids.append(0) 222 | 223 | # Truncate 224 | if len(input_ids) > max_length: 225 | if truncate_head: 226 | input_ids = input_ids[-max_length:] 227 | attention_mask = attention_mask[-max_length:] 228 | token_type_ids = token_type_ids[-max_length:] 229 | else: 230 | # Default is to truncate the tail 231 | input_ids = input_ids[:max_length] 232 | attention_mask = attention_mask[:max_length] 233 | token_type_ids = token_type_ids[:max_length] 234 | 235 | # Find mask token 236 | if prompt: 237 | # Make sure that the masked position is inside the max_length 238 | assert tokenizer.mask_token_id in input_ids, \ 239 | "Mask token not found for input: {} {}".format(input_text_list, input_ids) 240 | mask_pos = [input_ids.index(tokenizer.mask_token_id)] 241 | assert mask_pos[0] < max_length 242 | 243 | result = {'input_ids': input_ids, 'attention_mask': attention_mask} 244 | if 'BERT' in type(tokenizer).__name__: 245 | # Only provide token type ids for BERT 246 | result['token_type_ids'] = token_type_ids 247 | 248 | if prompt: 249 | result['mask_pos'] = mask_pos 250 | 251 | return result 252 | 253 | 254 | 255 | class FewShotDataset(torch.utils.data.Dataset): 256 | """Few-shot dataset.""" 257 | 258 | def __init__(self, args, tokenizer, cache_dir=None, mode="train", use_demo=False): 259 | self.args = args 260 | self.task_name = args.task_name 261 | 262 | self.processor = processors_mapping[args.task_name] 263 | 264 | self.tokenizer = tokenizer 265 | self.mode = mode 266 | 267 | # If not using demonstrations, use use_demo=True 268 | self.use_demo = use_demo 269 | if self.use_demo: 270 | logger.info("Use demonstrations") 271 | assert mode in ["train", "dev", "test"] 272 | 273 | # Get label list and (for prompt) label word list 274 | self.label_list = self.processor.get_labels() 275 | self.num_labels = len(self.label_list) 276 | if args.prompt and args.mapping is not None: 277 | self.label_to_word = eval(args.mapping) 278 | 279 | for key in self.label_to_word: 280 | # For RoBERTa/BART/T5, tokenization also considers space, so we use space+word as label words. 281 | if self.label_to_word[key][0] not in ['<', '[', '.', ',']: 282 | # Make sure space+word is in the vocabulary 283 | assert len(tokenizer.tokenize(' ' + self.label_to_word[key])) == 1 284 | self.label_to_word[key] = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(' ' + self.label_to_word[key])[0]) 285 | else: 286 | self.label_to_word[key] = tokenizer.convert_tokens_to_ids(self.label_to_word[key]) 287 | logger.info("Label {} to word {} ({})".format(key, tokenizer.convert_ids_to_tokens(self.label_to_word[key]), self.label_to_word[key])) 288 | 289 | if len(self.label_list) > 1: 290 | self.label_word_list = [self.label_to_word[label] for label in self.label_list] 291 | else: 292 | # Regression task 293 | # '0' represents low polarity and '1' represents high polarity. 294 | self.label_word_list = [self.label_to_word[label] for label in ['0', '1']] 295 | else: 296 | self.label_to_word = None 297 | self.label_word_list = None 298 | 299 | # Multiple sampling: when using demonstrations, we sample different combinations of demonstrations during 300 | # inference and aggregate the results by averaging the logits. The number of different samples is num_sample. 301 | if (mode == "train") or not self.use_demo: 302 | # We do not do multiple sampling when not using demonstrations or when it's the training mode 303 | self.num_sample = 1 304 | else: 305 | self.num_sample = args.num_sample 306 | 307 | # If we use multiple templates, we also need to do multiple sampling during inference. 308 | if args.prompt and args.template_list is not None: 309 | logger.info("There are %d templates. Multiply num_sample by %d" % (len(args.template_list), len(args.template_list))) 310 | self.num_sample *= len(args.template_list) 311 | 312 | logger.info("Total num_sample for mode %s: %d" % (mode, self.num_sample)) 313 | 314 | # Load cache 315 | # Cache name distinguishes mode, task name, tokenizer, and length. So if you change anything beyond these elements, make sure to clear your cache. 316 | cached_features_file = os.path.join( 317 | cache_dir if cache_dir is not None else args.data_dir, 318 | "cached_{}_{}_{}_{}".format( 319 | mode, 320 | tokenizer.__class__.__name__, 321 | str(args.max_seq_length), 322 | args.task_name, 323 | ), 324 | ) 325 | 326 | logger.info(f"Creating/loading examples from dataset file at {args.data_dir}") 327 | 328 | lock_path = cached_features_file + ".lock" 329 | with FileLock(lock_path): 330 | 331 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 332 | start = time.time() 333 | self.support_examples, self.query_examples = torch.load(cached_features_file) 334 | logger.info( 335 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 336 | ) 337 | else: 338 | logger.info(f"Creating features from dataset file at {args.data_dir}") 339 | 340 | # The support examples are sourced from the training set. 341 | self.support_examples = self.processor.get_train_examples(args.data_dir) 342 | 343 | if mode == "dev": 344 | self.query_examples = self.processor.get_dev_examples(args.data_dir) 345 | elif mode == "test": 346 | self.query_examples = self.processor.get_test_examples(args.data_dir) 347 | else: 348 | self.query_examples = self.support_examples 349 | 350 | start = time.time() 351 | torch.save([self.support_examples, self.query_examples], cached_features_file) 352 | # ^ This seems to take a lot of time so I want to investigate why and how we can improve. 353 | logger.info( 354 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 355 | ) 356 | 357 | # For filtering in using demonstrations, load pre-calculated embeddings 358 | if self.use_demo and args.demo_filter: 359 | split_name = '' 360 | if mode == 'train': 361 | split_name = 'train' 362 | elif mode == 'dev': 363 | if args.task_name == 'mnli': 364 | split_name = 'dev_matched' 365 | elif args.task_name == 'mnli-mm': 366 | split_name = 'dev_mismatched' 367 | else: 368 | split_name = 'dev' 369 | elif mode == 'test': 370 | if args.task_name == 'mnli': 371 | split_name = 'test_matched' 372 | elif args.task_name == 'mnli-mm': 373 | split_name = 'test_mismatched' 374 | else: 375 | split_name = 'test' 376 | else: 377 | raise NotImplementedError 378 | 379 | self.support_emb = np.load(os.path.join(args.data_dir, "train_{}.npy".format(args.demo_filter_model))) 380 | self.query_emb = np.load(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model))) 381 | logger.info("Load embeddings (for demonstration filtering) from {}".format(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model)))) 382 | 383 | assert len(self.support_emb) == len(self.support_examples) 384 | assert len(self.query_emb) == len(self.query_examples) 385 | 386 | # Size is expanded by num_sample 387 | self.size = len(self.query_examples) * self.num_sample 388 | 389 | # Prepare examples (especially for using demonstrations) 390 | support_indices = list(range(len(self.support_examples))) 391 | self.example_idx = [] 392 | for sample_idx in range(self.num_sample): 393 | for query_idx in range(len(self.query_examples)): 394 | # If training, exclude the current example. Else keep all. 395 | if self.use_demo and args.demo_filter: 396 | # Need sentence_transformers for demonstrations, 397 | # which is not included in the requirements for us, but see original LM-BFF repo. 398 | from sentence_transformers import SentenceTransformer, util 399 | 400 | # Demonstration filtering 401 | candidate = [support_idx for support_idx in support_indices 402 | if support_idx != query_idx or mode != "train"] 403 | sim_score = [] 404 | for support_idx in candidate: 405 | sim_score.append((support_idx, util.pytorch_cos_sim(self.support_emb[support_idx], self.query_emb[query_idx]))) 406 | sim_score.sort(key=lambda x: x[1], reverse=True) 407 | if self.num_labels == 1: 408 | # Regression task 409 | limit_each_label = int(len(sim_score) // 2 * args.demo_filter_rate) 410 | count_each_label = {'0': 0, '1': 0} 411 | context_indices = [] 412 | 413 | if args.debug_mode: 414 | print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug 415 | for support_idx, score in sim_score: 416 | if count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] < limit_each_label: 417 | count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] += 1 418 | context_indices.append(support_idx) 419 | if args.debug_mode: 420 | print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug 421 | else: 422 | limit_each_label = int(len(sim_score) // self.num_labels * args.demo_filter_rate) 423 | count_each_label = {label: 0 for label in self.label_list} 424 | context_indices = [] 425 | 426 | if args.debug_mode: 427 | print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug 428 | for support_idx, score in sim_score: 429 | if count_each_label[self.support_examples[support_idx].label] < limit_each_label: 430 | count_each_label[self.support_examples[support_idx].label] += 1 431 | context_indices.append(support_idx) 432 | if args.debug_mode: 433 | print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug 434 | else: 435 | # Using demonstrations without filtering 436 | context_indices = [support_idx for support_idx in support_indices 437 | if support_idx != query_idx or mode != "train"] 438 | 439 | # We'll subsample context_indices further later. 440 | self.example_idx.append((query_idx, context_indices, sample_idx)) 441 | 442 | # If it is not training, we pre-process the data; otherwise, we process the data online. 443 | if mode != "train": 444 | self.features = [] 445 | _ = 0 446 | for query_idx, context_indices, bootstrap_idx in self.example_idx: 447 | # The input (query) example 448 | example = self.query_examples[query_idx] 449 | # The demonstrations 450 | supports = self.select_context([self.support_examples[i] for i in context_indices]) 451 | 452 | if args.template_list is not None: 453 | template = args.template_list[sample_idx % len(args.template_list)] # Use template in order 454 | else: 455 | template = args.template 456 | 457 | self.features.append(self.convert_fn( 458 | example=example, 459 | supports=supports, 460 | use_demo=self.use_demo, 461 | label_list=self.label_list, 462 | prompt=args.prompt, 463 | template=template, 464 | label_word_list=self.label_word_list, 465 | verbose=True if _ == 0 else False, 466 | )) 467 | 468 | _ += 1 469 | else: 470 | self.features = None 471 | 472 | def select_context(self, context_examples): 473 | """ 474 | Select demonstrations from provided examples. 475 | """ 476 | max_demo_per_label = 1 477 | counts = {k: 0 for k in self.label_list} 478 | if len(self.label_list) == 1: 479 | # Regression 480 | counts = {'0': 0, '1': 0} 481 | selection = [] 482 | 483 | if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail: 484 | # For GPT-3's in-context learning, we sample gpt3_in_context_num demonstrations randomly. 485 | order = np.random.permutation(len(context_examples)) 486 | for i in range(min(self.args.gpt3_in_context_num, len(order))): 487 | selection.append(context_examples[order[i]]) 488 | else: 489 | # Our sampling strategy 490 | order = np.random.permutation(len(context_examples)) 491 | 492 | for i in order: 493 | label = context_examples[i].label 494 | if len(self.label_list) == 1: 495 | # Regression 496 | label = '0' if float(label) <= median_mapping[self.args.task_name] else '1' 497 | if counts[label] < max_demo_per_label: 498 | selection.append(context_examples[i]) 499 | counts[label] += 1 500 | if sum(counts.values()) == len(counts) * max_demo_per_label: 501 | break 502 | 503 | assert len(selection) > 0 504 | 505 | return selection 506 | 507 | def __len__(self): 508 | return self.size 509 | 510 | def __getitem__(self, i): 511 | if self.features is None: 512 | query_idx, context_indices, bootstrap_idx = self.example_idx[i] 513 | # The input (query) example 514 | example = self.query_examples[query_idx] 515 | # The demonstrations 516 | supports = self.select_context([self.support_examples[i] for i in context_indices]) 517 | 518 | if self.args.template_list is not None: 519 | template = self.args.template_list[sample_idx % len(self.args.template_list)] 520 | else: 521 | template = self.args.template 522 | 523 | features = self.convert_fn( 524 | example=example, 525 | supports=supports, 526 | use_demo=self.use_demo, 527 | label_list=self.label_list, 528 | prompt=self.args.prompt, 529 | template=template, 530 | label_word_list=self.label_word_list, 531 | verbose=False, 532 | ) 533 | else: 534 | features = self.features[i] 535 | 536 | return features 537 | 538 | def get_labels(self): 539 | return self.label_list 540 | 541 | 542 | def convert_fn( 543 | self, 544 | example, 545 | supports, 546 | use_demo=False, 547 | label_list=None, 548 | prompt=False, 549 | template=None, 550 | label_word_list=None, 551 | verbose=False 552 | ): 553 | """ 554 | Returns a list of processed "InputFeatures". 555 | """ 556 | max_length = self.args.max_seq_length 557 | 558 | # Prepare labels 559 | label_map = {label: i for i, label in enumerate(label_list)} # Mapping the label names to label ids 560 | if len(label_list) == 1: 561 | # Regression 562 | label_map = {'0': 0, '1': 1} 563 | 564 | # Get example's label id (for training/inference) 565 | if example.label is None: 566 | example_label = None 567 | elif len(label_list) == 1: 568 | # Regerssion 569 | example_label = float(example.label) 570 | else: 571 | example_label = label_map[example.label] 572 | 573 | # Prepare other features 574 | if not use_demo: 575 | # No using demonstrations 576 | inputs = tokenize_multipart_input( 577 | input_text_list=input_example_to_tuple(example), 578 | max_length=max_length, 579 | tokenizer=self.tokenizer, 580 | task_name=self.args.task_name, 581 | prompt=prompt, 582 | template=template, 583 | label_word_list=label_word_list, 584 | first_sent_limit=self.args.first_sent_limit, 585 | other_sent_limit=self.args.other_sent_limit, 586 | ) 587 | features = OurInputFeatures(**inputs, label=example_label) 588 | 589 | else: 590 | # Using demonstrations 591 | 592 | # Max length 593 | if self.args.double_demo: 594 | # When using demonstrations, double the maximum length 595 | # Note that in this case, args.max_seq_length is the maximum length for a single sentence 596 | max_length = max_length * 2 597 | if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail: 598 | # When using GPT-3's in-context learning, take the maximum tokenization length of the model (512) 599 | max_length = 512 600 | 601 | # All input sentences, including the query and the demonstrations, are put into augmented_examples, 602 | # and are numbered based on the order (starting from 0). For single sentence tasks, the input (query) 603 | # is the sentence 0; for sentence-pair tasks, the input (query) is the sentence 0 and 1. Note that for GPT-3's 604 | # in-context learning, the input (query) might be at the end instead of the beginning (gpt3_in_context_head) 605 | augmented_example = [] 606 | query_text = input_example_to_tuple(example) # Input sentence list for query 607 | support_by_label = [[] for i in range(len(label_map))] 608 | 609 | if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail: 610 | support_labels = [] 611 | augmented_example = query_text 612 | for support_example in supports: 613 | augmented_example += input_example_to_tuple(support_example) 614 | current_label = support_example.label 615 | if len(label_list) == 1: 616 | current_label = '0' if float(current_label) <= median_mapping[self.args.task_name] else '1' # Regression 617 | support_labels.append(label_map[current_label]) 618 | else: 619 | # Group support examples by label 620 | for label_name, label_id in label_map.items(): 621 | if len(label_list) == 1: 622 | # Regression 623 | for support_example in filter(lambda s: ('0' if float(s.label) <= median_mapping[self.args.task_name] else '1') == label_name, supports): 624 | support_by_label[label_id] += input_example_to_tuple(support_example) 625 | else: 626 | for support_example in filter(lambda s: s.label == label_name, supports): 627 | support_by_label[label_id] += input_example_to_tuple(support_example) 628 | 629 | augmented_example = query_text 630 | for label_id in range(len(label_map)): 631 | augmented_example += support_by_label[label_id] 632 | 633 | # Tokenization (based on the template) 634 | inputs = tokenize_multipart_input( 635 | input_text_list=augmented_example, 636 | max_length=max_length, 637 | tokenizer=self.tokenizer, 638 | task_name=self.args.task_name, 639 | prompt=prompt, 640 | template=template, 641 | label_word_list=label_word_list, 642 | first_sent_limit=self.args.first_sent_limit, 643 | other_sent_limit=self.args.other_sent_limit, 644 | truncate_head=self.args.truncate_head, 645 | gpt3=self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail, 646 | support_labels=None if not (self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail) else support_labels 647 | ) 648 | features = OurInputFeatures(**inputs, label=example_label) 649 | 650 | if verbose: 651 | logger.info("*** Example ***") 652 | logger.info("guid: %s" % (example.guid)) 653 | logger.info("features: %s" % features) 654 | logger.info("text: %s" % self.tokenizer.decode(features.input_ids)) 655 | 656 | return features 657 | 658 | 659 | 660 | -------------------------------------------------------------------------------- /src/kernel_solvers.py: -------------------------------------------------------------------------------- 1 | from transformers.utils import logging 2 | import torch 3 | import torch.nn.functional as F 4 | from sklearn.svm import SVC, SVR 5 | from sklearn.linear_model import LogisticRegressionCV 6 | logger = logging.get_logger(__name__) 7 | 8 | class BaseKernelSolver: 9 | def __init__(self, args): 10 | self.args = args 11 | 12 | # initialized with fit() 13 | self.num_labels = None 14 | self.kernel_dtype = None 15 | 16 | def get_regularized_kernel(self, kernel): 17 | if self.args.kernel_regularization: 18 | _, S, _ = torch.svd(kernel) 19 | op_norm = S.max() 20 | 21 | reg = (self.args.kernel_regularization * op_norm).to(kernel.dtype) 22 | 23 | identity = torch.eye(kernel.shape[1], dtype=kernel.dtype).unsqueeze(0) 24 | return kernel + identity * reg 25 | else: 26 | return kernel 27 | 28 | def metrics(self): 29 | return {} 30 | 31 | def get_target_coords(self, targets): 32 | if self.num_labels == 1: # Regression 33 | return targets.flatten().to(self.kernel_dtype).unsqueeze(1) 34 | else: 35 | return torch.nn.functional.one_hot(targets.flatten(), self.num_labels).to(self.kernel_dtype) 36 | 37 | def loss(self, preds, targets): 38 | targets_coords = self.get_target_coords(targets) 39 | return 1/targets_coords.shape[0] * ((preds - targets_coords)**2).sum().item() 40 | 41 | def fit(self, train_kernel, train_targets, train_logits=None): 42 | raise NotImplementedError("BaseKernelSolver is just the abstract base class") 43 | 44 | def predict(self, eval_kernel, eval_targets, eval_logits=None): 45 | raise NotImplementedError("BaseKernelSolver is just the abstract base class") 46 | 47 | 48 | class LstsqKernelSolver(BaseKernelSolver): 49 | def __init__(self, args): 50 | super().__init__(args) 51 | 52 | # initialized with fit() 53 | self.kernel_solution = None 54 | self.residual = None 55 | self.rank = None 56 | 57 | def metrics(self): 58 | metrics_dict = super().metrics() 59 | if self.rank is not None: 60 | if self.rank.numel() > 1: 61 | for i,r in enumerate(self.rank.tolist()): 62 | metrics_dict["rank{}".format(i)] = r 63 | else: 64 | metrics_dict["rank0"] = self.rank.item() 65 | return metrics_dict 66 | 67 | def fit(self, train_kernel, train_targets, train_logits=None): 68 | self.num_labels = train_kernel.size(0) 69 | self.kernel_dtype = train_kernel.dtype 70 | 71 | kernel = self.get_regularized_kernel(train_kernel) 72 | train_targets_coords = self.get_target_coords(train_targets) 73 | 74 | if train_logits is not None and self.args.f0_scaling > 0: 75 | train_targets_coords -= train_logits / self.args.f0_scaling 76 | 77 | self.kernel_solution, self.residuals, self.rank, _ = torch.linalg.lstsq(kernel, train_targets_coords.t()) 78 | 79 | def predict(self, eval_kernel, eval_targets, eval_logits=None, **unused_kwargs): 80 | assert self.kernel_solution is not None, "Must call fit() before predict()" 81 | assert eval_kernel.size(0) == self.num_labels, "Number of labels in eval_kernel must match fit()" 82 | 83 | eval_preds = torch.bmm( 84 | eval_kernel.transpose(1, 2), 85 | self.kernel_solution.unsqueeze(2) 86 | ).squeeze(2).transpose(0, 1) # shape [#dataset_outer, #classes] 87 | 88 | if eval_logits is not None and self.args.f0_scaling > 0: 89 | eval_preds += eval_logits / self.args.f0_scaling 90 | 91 | eval_loss = self.loss(eval_preds, eval_targets) 92 | return eval_loss, eval_preds 93 | 94 | 95 | class AsymmetricLstsqKernelSolver(LstsqKernelSolver): 96 | def __init__(self, args): 97 | super().__init__(args) 98 | 99 | self.N = None 100 | self.train_targets = None 101 | 102 | def fit(self, train_kernel, train_targets, train_logits=None): 103 | self.num_labels = train_kernel.size(0) 104 | self.kernel_dtype = train_kernel.dtype 105 | assert self.num_labels == 1, "SVMKernelSolver only works for regression tasks or binary_classification" 106 | 107 | kernel = self.get_regularized_kernel(train_kernel) 108 | train_targets_coords = self.get_target_coords(train_targets) 109 | 110 | if train_logits is not None and self.args.f0_scaling > 0: 111 | train_targets_coords -= train_logits / self.args.f0_scaling 112 | 113 | kernel = kernel.squeeze() 114 | H = torch.zeros(kernel.shape) 115 | Y = train_targets_coords.squeeze() 116 | N = H.shape[0] 117 | 118 | for i in range(N): 119 | for j in range(N): 120 | H[i,j] = Y[i] * (kernel[i,j]* Y[j]) 121 | 122 | # # system with biases 123 | # A = torch.zeros(2*N + 2, 2*N + 2, dtype=self.kernel_dtype) 124 | # A[0, 2:2+N] = Y 125 | # A[1, 2+N:] = Y 126 | # A[2:2+N, 0] = Y 127 | # A[2+N:, 1] = Y 128 | # A[2:2+N, 2:2+N] = torch.eye(N) / self.args.kernel_gamma # scale by 1/gamma later 129 | # A[2:2+N, 2+N:] = H 130 | # A[2+N:, 2:2+N] = H.T 131 | # A[2+N:, 2+N:] = torch.eye(N) / self.args.kernel_gamma # scale by 1/gamma later 132 | 133 | # B = torch.zeros(2*N+2, dtype=self.kernel_dtype) 134 | # B[2:] = 1 135 | 136 | # system without biases 137 | A = torch.zeros(2*N, 2*N, dtype=self.kernel_dtype) 138 | A[:N, :N] = torch.eye(N) / self.args.kernel_gamma # scale by 1/gamma later 139 | A[:N, N:] = H 140 | A[N:, :N] = H.T 141 | A[N:, N:] = torch.eye(N) / self.args.kernel_gamma # scale by 1/gamma later 142 | B = torch.ones(2*N, dtype=self.kernel_dtype) 143 | 144 | self.N = N 145 | self.Y = Y 146 | self.kernel_solution, self.residuals, self.rank, _ = torch.linalg.lstsq(A, B) 147 | 148 | def predict(self, eval_kernel, eval_targets, eval_logits=None, eval_kernel_flipped=None, **unused_kwargs): 149 | assert self.kernel_solution is not None, "Must call fit() before predict()" 150 | assert eval_kernel.size(0) == self.num_labels, "Number of labels in eval_kernel must match fit()" 151 | 152 | N = self.N 153 | # beta_bias = self.kernel_solution[0] 154 | # alpha_bias = self.kernel_solution[1] 155 | # alpha = self.kernel_solution[2:2+N].unsqueeze(0) 156 | # beta = self.kernel_solution[2+N:].unsqueeze(0) 157 | alpha = self.kernel_solution[:N].unsqueeze(0) 158 | beta = self.kernel_solution[N:].unsqueeze(0) 159 | 160 | omega = torch.bmm( 161 | eval_kernel.transpose(1, 2), 162 | (alpha*self.Y).unsqueeze(2) 163 | ).squeeze(2).transpose(0, 1) #+ alpha_bias 164 | nu = torch.bmm( 165 | eval_kernel_flipped.transpose(1, 2), 166 | (beta*self.Y).unsqueeze(2) 167 | ).squeeze(2).transpose(0, 1) #+ beta_bias 168 | 169 | eval_preds = (self.args.kernel_lambda * omega + (1-self.args.kernel_lambda) * nu) 170 | 171 | if eval_logits is not None and self.args.f0_scaling > 0: 172 | eval_preds += eval_logits / self.args.f0_scaling 173 | 174 | eval_loss = self.loss(eval_preds, eval_targets) 175 | return eval_loss, eval_preds 176 | 177 | 178 | class SVRKernelSolver(BaseKernelSolver): 179 | def __init__(self, args): 180 | super().__init__(args) 181 | 182 | self.svms = None 183 | 184 | def fit(self, train_kernel, train_targets, train_logits=None): 185 | self.num_labels = train_kernel.size(0) 186 | self.kernel_dtype = train_kernel.dtype 187 | 188 | kernel = self.get_regularized_kernel(train_kernel) 189 | train_targets_coords = self.get_target_coords(train_targets) 190 | 191 | if train_logits is not None and self.args.f0_scaling > 0: 192 | train_targets_coords -= train_logits / self.args.f0_scaling 193 | 194 | self.svms = [] 195 | for k in range(self.num_labels): 196 | svm = SVR(kernel='precomputed') 197 | svm.fit(kernel[k].cpu().numpy(), train_targets_coords[:,k].cpu().t().numpy()) 198 | self.svms.append(svm) 199 | 200 | def predict(self, eval_kernel, eval_targets, eval_logits=None, **unused_kwargs): 201 | assert self.svms is not None, "Must call fit() before predict()" 202 | assert eval_kernel.size(0) == self.num_labels, "Number of labels in eval_kernel must match fit()" 203 | 204 | eval_preds = [] 205 | for k in range(self.num_labels): 206 | predict_k = self.svms[k].predict(eval_kernel[k].cpu().t().numpy()) 207 | eval_preds.append(torch.tensor(predict_k, dtype=self.kernel_dtype, device=eval_kernel.device)) 208 | eval_preds = torch.stack(eval_preds, dim=1) 209 | print(eval_preds, eval_targets) 210 | 211 | if eval_logits is not None and self.args.f0_scaling > 0: 212 | eval_preds += eval_logits / self.args.f0_scaling 213 | 214 | eval_loss = self.loss(eval_preds, eval_targets) 215 | return eval_loss, eval_preds 216 | 217 | 218 | class SVCKernelSolver(BaseKernelSolver): 219 | def __init__(self, args): 220 | super().__init__(args) 221 | 222 | self.svms = None 223 | 224 | def fit(self, train_kernel, train_targets, train_logits=None): 225 | self.num_labels = train_kernel.size(0) 226 | self.kernel_dtype = train_kernel.dtype 227 | assert self.num_labels == 1, "SVMKernelSolver only works for binary_classification" 228 | assert train_logits is None, "SVMKernelSolver does not support train_logits" 229 | 230 | kernel = self.get_regularized_kernel(train_kernel) 231 | train_targets = ((train_targets + 1) / 2).int() # convert back from {-1, 1} to {0, 1} 232 | 233 | self.svms = [] 234 | for k in range(self.num_labels): 235 | svm = SVC(kernel='precomputed') 236 | svm.fit(kernel[k].cpu().numpy(), train_targets.cpu().numpy()) 237 | self.svms.append(svm) 238 | 239 | def predict(self, eval_kernel, eval_targets, eval_logits=None, **unused_kwargs): 240 | assert self.svms is not None, "Must call fit() before predict()" 241 | assert eval_kernel.size(0) == self.num_labels, "Number of labels in eval_kernel must match fit()" 242 | assert eval_logits is None, "SVMKernelSolver does not support train_logits" 243 | 244 | eval_preds = [] 245 | for k in range(self.num_labels): 246 | predict_k = self.svms[k].predict(eval_kernel[k].cpu().t().numpy()) 247 | eval_preds.append(torch.tensor(predict_k, dtype=self.kernel_dtype, device=eval_kernel.device)) 248 | eval_preds = torch.stack(eval_preds, dim=1) 249 | 250 | eval_preds = (eval_preds * 2 - 1) # convert back from {0, 1} to {-1, 1} 251 | 252 | eval_loss = self.loss(eval_preds, eval_targets) 253 | return eval_loss, eval_preds 254 | 255 | 256 | class LogisticKernelSolver(BaseKernelSolver): 257 | def __init__(self, args): 258 | super().__init__(args) 259 | 260 | self.logistic_model = None 261 | 262 | def fit(self, train_kernel, train_targets, train_logits=None): 263 | self.num_labels = train_kernel.size(0) 264 | self.kernel_dtype = train_kernel.dtype 265 | assert self.num_labels == 1, "SVMKernelSolver only works for binary_classification" 266 | 267 | kernel = self.get_regularized_kernel(train_kernel).squeeze(0) 268 | train_targets = ((train_targets + 1) / 2).int() # convert back from {-1, 1} to {0, 1} 269 | 270 | self.logistic_model = LogisticRegressionCV(max_iter=10000, random_state=0) 271 | self.logistic_model.fit(kernel.cpu().numpy(), train_targets.cpu().numpy()) 272 | 273 | def predict(self, eval_kernel, eval_targets, eval_logits=None, **unused_kwargs): 274 | assert self.logistic_model is not None, "Must call fit() before predict()" 275 | assert eval_kernel.size(0) == self.num_labels, "Number of labels in eval_kernel must match fit()" 276 | 277 | log_proba = self.logistic_model.predict_log_proba(eval_kernel.cpu().squeeze().t().numpy()) 278 | log_proba = torch.tensor(log_proba, dtype=self.kernel_dtype, device=eval_kernel.device) 279 | 280 | eval_loss = self.loss(log_proba, eval_targets) 281 | 282 | eval_preds = (log_proba[:,1] - log_proba[:,0]).unsqueeze(1) 283 | 284 | return eval_loss, eval_preds 285 | 286 | def loss(self, preds, targets): 287 | targets = ((targets + 1) / 2).long() # convert back from {-1, 1} to {0, 1} 288 | return F.cross_entropy(preds, targets).item() 289 | 290 | 291 | 292 | SOLVERS = { 293 | "lstsq": LstsqKernelSolver, 294 | "svr": SVRKernelSolver, 295 | "svc": SVCKernelSolver, 296 | "asym": AsymmetricLstsqKernelSolver, 297 | "logistic": LogisticKernelSolver 298 | } 299 | -------------------------------------------------------------------------------- /src/kernel_trainer.py: -------------------------------------------------------------------------------- 1 | ########## The following part was originally copied from Transformers' trainer (3.4.0) and then changed heavily to compute eNTKs. ########## 2 | 3 | # coding=utf-8 4 | # Copyright 2020-present the HuggingFace Inc. team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | The trainer for computing eNTKs 19 | """ 20 | 21 | import os 22 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 23 | 24 | import torch 25 | import torch.nn as nn 26 | from torch.utils.data.dataloader import DataLoader 27 | from torch.utils.data.dataset import Dataset 28 | from torch.utils.data.sampler import SequentialSampler 29 | 30 | from functorch import vmap, jvp, jacrev, make_functional_with_buffers 31 | 32 | import transformers 33 | from transformers.data.data_collator import DataCollator 34 | from transformers.file_utils import is_torch_tpu_available 35 | 36 | from transformers.modeling_utils import PreTrainedModel 37 | from transformers.training_args import TrainingArguments 38 | from transformers.trainer import SequentialDistributedSampler 39 | from transformers.trainer_utils import PredictionOutput, EvalPrediction 40 | from transformers.utils import logging 41 | if is_torch_tpu_available(): 42 | import torch_xla.core.xla_model as xm 43 | import gc 44 | from transformers.trainer_utils import TrainOutput 45 | 46 | from src.linearhead_trainer import varsize_tensor_all_gather, LinearHeadTrainer 47 | 48 | import numpy as np 49 | from tqdm import tqdm 50 | from src.kernel_solvers import SOLVERS 51 | 52 | logger = logging.get_logger(__name__) 53 | 54 | 55 | class LogitModelWrapper(nn.Module): 56 | def __init__(self, model, binary_classification): 57 | super().__init__() 58 | self.model = model 59 | self.binary_classification = binary_classification 60 | 61 | def forward(self, input_ids, attention_mask, mask_pos): 62 | logits = self.model(input_ids, attention_mask, mask_pos=mask_pos)[0] # don't provide labels 63 | if self.binary_classification: 64 | assert logits.size(1) == 2, "--binary_classification should have 2 logits" 65 | logits = (logits[:,1] - logits[:,0]).unsqueeze(-1) 66 | return logits 67 | # label = (label * 2 - 1).float() # convert from {0, 1} to {-1, 1} 68 | 69 | 70 | def param_to_buffer(module, module_name, predicate): 71 | """Turns all parameters of a module into buffers.""" 72 | modules = module.named_modules(prefix=str(module_name)) 73 | next(modules) # Skip itself 74 | 75 | params = [] 76 | for name, param in module.named_parameters(recurse=False, prefix=str(module_name)): 77 | if predicate(name): 78 | params.append((name.split(".")[-1], param)) 79 | 80 | for name, param in params: 81 | delattr(module, name) # Unregister parameter 82 | module.register_buffer(name, param) 83 | for name, module in modules: 84 | param_to_buffer(module, name, predicate) 85 | 86 | 87 | class KernelTrainerFunc(LinearHeadTrainer): 88 | """ 89 | Adding some functions based on Transformers' Trainer class. 90 | """ 91 | 92 | def __init__( 93 | self, 94 | model: PreTrainedModel, 95 | args: TrainingArguments, 96 | data_collator: Optional[DataCollator] = None, 97 | train_dataset: Optional[Dataset] = None, 98 | eval_dataset: Optional[Dataset] = None, 99 | *posargs, 100 | **kwargs 101 | ): 102 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, *posargs, **kwargs) 103 | 104 | self.grad_dim = None 105 | self.train_train_kernel = None 106 | self.train_targets = None 107 | self.num_labels = None 108 | 109 | self.kernel_formula = args.kernel_formula 110 | 111 | def convert_to_buffer(name): 112 | if args.exclude_embeddings: 113 | if "embed" in name: 114 | logger.info("Excluding {}".format(name)) 115 | return True 116 | 117 | if args.exclude_head: 118 | if "head" in name: 119 | logger.info("Excluding {}".format(name)) 120 | return True 121 | 122 | if args.only_biases: 123 | if "bias" not in name: 124 | logger.info("Excluding {}".format(name)) 125 | return True 126 | 127 | if model.model_args.apply_lora: 128 | if name.startswith('roberta') and "lora" not in name: 129 | logger.info("Excluding {}".format(name)) 130 | return True 131 | return False 132 | 133 | param_to_buffer(self.model, "", convert_to_buffer) 134 | 135 | 136 | def get_unshuffled_dataloader(self, dataset: Optional[Dataset] = None, sharded: bool = False, batch_size: Optional[int] = -1): 137 | if sharded and is_torch_tpu_available(): 138 | sampler = SequentialDistributedSampler( 139 | dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() 140 | ) 141 | elif sharded and self.args.local_rank != -1: 142 | sampler = SequentialDistributedSampler(dataset) 143 | else: 144 | sampler = SequentialSampler(dataset) 145 | 146 | bs = self.args.per_device_eval_batch_size if batch_size == -1 else batch_size 147 | data_loader = DataLoader( 148 | dataset, 149 | sampler=sampler, 150 | batch_size=bs, 151 | collate_fn=self.data_collator, 152 | drop_last=self.args.dataloader_drop_last, 153 | ) 154 | 155 | return data_loader 156 | 157 | def profile_memory(self): 158 | import gc 159 | for obj in gc.get_objects(): 160 | try: 161 | if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): 162 | print(type(obj), obj.size()) 163 | except: 164 | pass 165 | 166 | def compute_kernel_inner(self, curried_fn, curried_jacobian_fn, grads_outer, dataset_inner): 167 | 168 | # Use per_device_eval_batch_size for outer loop (which is always the training dataset 169 | dataloader_inner = self.get_unshuffled_dataloader(dataset_inner, sharded=False, batch_size=self.args.per_device_eval_batch_size) 170 | kernel_blocks = [] 171 | targets_inner = [] 172 | 173 | for inputs_inner in tqdm(dataloader_inner, desc="Computing kernel inner"): 174 | 175 | for k, v in inputs_inner.items(): 176 | if isinstance(v, torch.Tensor): 177 | inputs_inner[k] = v.to(self.args.device) 178 | 179 | def get_ntk_slice(tangents): 180 | _, jvps = curried_fn(inputs_inner.get("input_ids"), inputs_inner.get("attention_mask"), inputs_inner.get("mask_pos"), tangents) 181 | return jvps 182 | 183 | if self.args.kernel_formula == "signgd": 184 | grads_inner = curried_jacobian_fn(inputs_inner.get("input_ids"), inputs_inner.get("attention_mask"), inputs_inner.get("mask_pos")) 185 | block = sum(torch.einsum('olw,ikw->olik', j1.sign().flatten(2).to(torch.float64), j2.sign().flatten(2).to(torch.float64)).cpu() for j1, j2 in zip(grads_outer, grads_inner)) 186 | else: 187 | block = vmap(vmap(get_ntk_slice))(grads_outer).to(torch.float64).cpu() # N_outer x C_outer x N_inner x C_inner 188 | 189 | kernel_blocks.append(block.detach()) 190 | label = inputs_inner.get("labels") 191 | if self.args.binary_classification: 192 | label = (label * 2 - 1).float() # convert from {0, 1} to {-1, 1} 193 | targets_inner.append(label) 194 | 195 | # del grads_inner 196 | del block 197 | del inputs_inner 198 | 199 | torch.cuda.empty_cache() 200 | gc.collect() 201 | 202 | return ( 203 | torch.cat(kernel_blocks, dim=2) if kernel_blocks else torch.tensor([]), 204 | torch.cat(targets_inner, dim=0) if targets_inner else torch.tensor([]) 205 | ) 206 | 207 | def compute_kernel_outer(self, dataset_outer, dataset_inner): 208 | # Use train_batch_size for outer loop (which is always the training dataset) 209 | dataloader_outer = self.get_unshuffled_dataloader(dataset_outer, sharded=True, batch_size=self.args.per_device_train_batch_size) 210 | 211 | model_wrapper = LogitModelWrapper(self.model, self.args.binary_classification) 212 | model_wrapper.eval() 213 | for param in model_wrapper.parameters(): 214 | param.requires_grad_(True) 215 | 216 | model_fn, params, buffers = make_functional_with_buffers(model_wrapper) 217 | 218 | jacobian_fn = jacrev(model_fn) 219 | 220 | def curried_jacobian_fn(input_ids, attention_mask, mask_pos): 221 | return jacobian_fn(params, buffers, input_ids, attention_mask, mask_pos) 222 | 223 | def curried_fn(input_ids, attention_mask, mask_pos, tangent): 224 | def curried_model_fn(params_): 225 | return model_fn(params_, buffers, input_ids, attention_mask, mask_pos) 226 | return jvp(curried_model_fn, (params,), (tangent,)) 227 | 228 | kernel_rows = [] 229 | 230 | inner_targets = None 231 | 232 | for inputs_outer in tqdm(dataloader_outer, desc="Computing kernel outer"): 233 | for k, v in inputs_outer.items(): 234 | if isinstance(v, torch.Tensor): 235 | inputs_outer[k] = v.to(self.args.device) 236 | 237 | grads_outer = curried_jacobian_fn(inputs_outer.get("input_ids"), inputs_outer.get("attention_mask"), inputs_outer.get("mask_pos")) 238 | if self.args.kernel_formula == 'asymmetric_signgd': 239 | grads_outer = tuple(g.sign() for g in grads_outer) 240 | 241 | # assert len(tuple(model_wrapper.model.named_parameters())) == len(grads_outer) 242 | 243 | # for (name,param), grad in zip(model_wrapper.named_parameters(), grads_outer): 244 | # print(name, grad[(grad != 0).all(-1)]) 245 | # assert param.shape == grad.shape[2:], f"{name} {param.shape} {grad[2:].shape}" 246 | # if (grad == 0).all(): 247 | # print(name, param.numel()) 248 | 249 | if self.grad_dim is None: 250 | self.grad_dim = sum(np.prod(x.shape[2:]) for x in grads_outer) 251 | # assert self.grad_dim == num_params, "gradient dim not constant: {} and {}".format(self.grad_dim, num_params) 252 | 253 | kernel_blocks, inner_targets = ( 254 | self.compute_kernel_inner(curried_fn, curried_jacobian_fn, grads_outer, dataset_inner)) 255 | 256 | kernel_rows.append(kernel_blocks) 257 | 258 | del grads_outer 259 | del inputs_outer 260 | del kernel_blocks 261 | 262 | torch.cuda.empty_cache() 263 | gc.collect() 264 | 265 | 266 | kernel = torch.cat(kernel_rows, dim=0) 267 | 268 | return ( 269 | kernel, 270 | inner_targets 271 | ) 272 | 273 | def compute_kernel_sharded(self, dataset_outer, dataset_inner): 274 | assert self.kernel_formula in ["sgd", "asymmetric_signgd", "signgd"], "only sgd and asymmetric_signgd are supported by torchfunc for now" 275 | 276 | with torch.no_grad(): 277 | kernel, inner_targets = self.compute_kernel_outer(dataset_outer, dataset_inner) 278 | 279 | if self.args.local_rank != -1: 280 | logger.info("Starting to gather kernel across GPUs") 281 | kernel = varsize_tensor_all_gather(kernel.to(self.args.device), torch.distributed.get_world_size()) 282 | logger.info("Finished gathering kernel across GPUs") 283 | 284 | return kernel, inner_targets 285 | 286 | def compute_model_logits_cached(self, eval_dataset): 287 | if self.args.load_kernels is not None: 288 | output_dir = self.args.load_kernels 289 | else: 290 | output_dir = self.args.output_dir 291 | logit_file_name = f"{eval_dataset.mode}_logits_{eval_dataset.task_name}.pt" 292 | logit_path = os.path.join(output_dir, logit_file_name) 293 | 294 | if os.path.exists(logit_path) and not self.args.overwrite_kernels: 295 | logger.info(f"Starting to load logits from {logit_path}.") 296 | logits, targets = torch.load(logit_path) 297 | logger.info(f"Finished loading logits from {logit_path}.") 298 | else: 299 | logger.info(f"Starting to compute the {eval_dataset.mode} logits.") 300 | dataloader = self.get_unshuffled_dataloader(eval_dataset) 301 | 302 | model_wrapper = LogitModelWrapper(self.model, self.args.binary_classification) 303 | model_wrapper.eval() 304 | 305 | logits = [] 306 | targets = [] 307 | with torch.no_grad(): 308 | for inputs in dataloader: 309 | for k, v in inputs.items(): 310 | if isinstance(v, torch.Tensor): 311 | inputs[k] = v.to(self.args.device) 312 | 313 | label = inputs.get("labels") 314 | if self.args.binary_classification: 315 | label = (label * 2 - 1).float() # convert from {0, 1} to {-1, 1} 316 | 317 | preds = model_wrapper(inputs.get("input_ids"), inputs.get("attention_mask"), inputs.get("mask_pos")) 318 | logits.append(preds.detach().cpu()) 319 | targets.append(label.cpu()) 320 | 321 | logits = torch.cat(logits, dim=0) 322 | targets = torch.cat(targets, dim=0) 323 | 324 | logger.info(f"Finished computing the {eval_dataset.mode} logits.") 325 | 326 | if self.is_world_process_zero(): 327 | torch.save((logits, targets), logit_path) 328 | return logits, targets 329 | 330 | def reshape_kernel_and_targets(self, kernel, targets): 331 | # reshape kernel to previous format 332 | if self.num_labels is None: 333 | self.num_labels = kernel.shape[1] 334 | assert self.num_labels == kernel.shape[1], "label dim not constant: {} and {}".format(self.num_labels, kernel.shape[1]) 335 | assert self.num_labels == kernel.shape[3], "label dim not constant: {} and {}".format(self.num_labels, kernel.shape[3]) 336 | 337 | if self.num_labels > 1: # multi logit 338 | targets = torch.nn.functional.one_hot(targets.squeeze(), self.num_labels) 339 | 340 | size1 = kernel.shape[0] * kernel.shape[1] 341 | size2 = kernel.shape[2] * kernel.shape[3] 342 | # kernel = kernel.transpose(0, 1).transpose(2, 3) 343 | return kernel.reshape(1, size1, size2), targets.reshape(-1) 344 | 345 | def compute_kernel_cached(self, eval_dataset): 346 | kernel_file_name = f"{eval_dataset.mode}_kernels_{eval_dataset.task_name}.pt" 347 | kernel_path = os.path.join(self.args.output_dir, kernel_file_name) 348 | 349 | if os.path.exists(kernel_path) and not self.args.overwrite_kernels: 350 | logger.info(f"Starting to load kernels from {kernel_path}.") 351 | (train_eval_kernel, eval_targets) = torch.load(kernel_path) 352 | logger.info(f"Finished loading kernels from {kernel_path}.") 353 | else: 354 | logger.info(f"Starting to compute the train-{eval_dataset.mode} kernel.") 355 | train_eval_kernel, eval_targets = self.compute_kernel_sharded( 356 | self.train_dataset, eval_dataset, 357 | ) 358 | logger.info(f"Finshed computing the train-{eval_dataset.mode} kernel.") 359 | 360 | train_eval_kernel = train_eval_kernel.cpu() 361 | eval_targets = eval_targets.cpu() 362 | 363 | if self.args.kernel_formula == 'asymmetric_signgd': 364 | logger.info(f"Starting to compute the flipped train-{eval_dataset.mode} kernel.") 365 | if eval_dataset == self.train_dataset: 366 | train_eval_kernel_flipped = train_eval_kernel 367 | else: 368 | train_eval_kernel_flipped, _ = self.compute_kernel_sharded( 369 | eval_dataset, self.train_dataset, 370 | ) 371 | logger.info(f"Finshed computing the flipped train-{eval_dataset.mode} kernel.") 372 | 373 | train_eval_kernel_flipped = train_eval_kernel_flipped.cpu() 374 | train_eval_kernel_flipped = train_eval_kernel_flipped.permute(2, 3, 0, 1) 375 | train_eval_kernel = torch.stack([train_eval_kernel, train_eval_kernel_flipped], dim=0) 376 | 377 | if self.is_world_process_zero(): 378 | torch.save((train_eval_kernel, eval_targets), kernel_path) 379 | return train_eval_kernel, eval_targets 380 | 381 | 382 | def train(self, model_path=None, dev_objective=None): 383 | if self.args.from_linearhead and model_path is None: 384 | super().train(model_path, dev_objective) # Train output layer using LinearHeadTrainer 385 | 386 | if self.args.load_kernels is None: 387 | eval_dataset = self.train_dataset 388 | self.train_train_kernel, self.train_targets = self.compute_kernel_cached(eval_dataset) 389 | 390 | return TrainOutput(0, 0.0, {}), None 391 | 392 | def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: 393 | if eval_dataset is None: 394 | eval_dataset = self.eval_dataset 395 | 396 | if self.args.load_kernels is not None: 397 | logger.info(f"Starting to load kernels from {self.args.load_kernels}.") 398 | kernel_file_name = f"{eval_dataset.mode}_kernels_{eval_dataset.task_name}.pt" 399 | load_kernel_path = os.path.join(self.args.load_kernels, kernel_file_name) 400 | (train_eval_kernel, eval_targets) = torch.load(load_kernel_path) 401 | 402 | kernel_file_name = f"{self.train_dataset.mode}_kernels_{self.train_dataset.task_name}.pt" 403 | load_kernel_path = os.path.join(self.args.load_kernels, kernel_file_name) 404 | (self.train_train_kernel, self.train_targets) = torch.load(load_kernel_path) 405 | logger.info(f"Finished loading kernels from {self.args.load_kernels}.") 406 | else: 407 | assert self.train_train_kernel is not None, "train_train_kernel is None, did you forget to call train()?" 408 | train_eval_kernel, eval_targets = self.compute_kernel_cached(eval_dataset) 409 | 410 | if self.args.kernel_formula == 'asymmetric_signgd': 411 | train_eval_kernel_flipped = train_eval_kernel[1] 412 | train_eval_kernel_flipped, _ = self.reshape_kernel_and_targets(train_eval_kernel_flipped, eval_targets) 413 | 414 | train_eval_kernel = train_eval_kernel[0] 415 | train_train_kernel = self.train_train_kernel[0] 416 | else: 417 | train_eval_kernel_flipped = None 418 | train_train_kernel = self.train_train_kernel 419 | 420 | train_train_kernel, train_targets = self.reshape_kernel_and_targets(train_train_kernel, self.train_targets) 421 | train_eval_kernel, eval_targets = self.reshape_kernel_and_targets(train_eval_kernel, eval_targets) 422 | 423 | # get train and test logits 424 | if self.args.adjust_for_init: 425 | train_logits, _ = self.compute_model_logits_cached(self.train_dataset) 426 | eval_logits, _ = self.compute_model_logits_cached(eval_dataset) 427 | train_logits = train_logits.reshape(-1, 1) 428 | eval_logits = eval_logits.reshape(-1, 1) 429 | else: 430 | train_logits, eval_logits = None, None 431 | 432 | metrics = {} 433 | 434 | solver = SOLVERS[self.args.kernel_solver](self.args) 435 | solver.fit(train_train_kernel, train_targets, train_logits) 436 | eval_error, eval_preds = solver.predict(train_eval_kernel, eval_targets, eval_logits, eval_kernel_flipped=train_eval_kernel_flipped) 437 | eval_preds = eval_preds.reshape(-1, self.num_labels) 438 | if self.num_labels > 1: 439 | eval_targets = eval_targets.reshape(-1, self.num_labels).argmax(-1) 440 | 441 | if self.args.binary_classification: # Make sure to compute loss before this transformation! 442 | eval_preds = torch.cat([-eval_preds, eval_preds], dim=-1) # convert back to two logits 443 | eval_targets = ((eval_targets + 1) / 2).long() # convert back from {-1, 1} to {0, 1} 444 | 445 | if self.compute_metrics is not None: 446 | metrics = self.compute_metrics(EvalPrediction(predictions=eval_preds.numpy(), label_ids=eval_targets.numpy())) 447 | 448 | # Prefix all keys with eval_ 449 | for key in list(metrics.keys()): 450 | if not key.startswith("eval_"): 451 | metrics[f"eval_{key}"] = metrics.pop(key) 452 | metrics["eval_loss"] = eval_error 453 | 454 | metrics.update(solver.metrics()) 455 | metrics["grad_dim"] = self.grad_dim 456 | 457 | output = PredictionOutput(predictions=eval_preds.numpy(), label_ids=eval_targets.numpy(), metrics=metrics) 458 | self.log(output.metrics) 459 | 460 | return output 461 | -------------------------------------------------------------------------------- /src/linearhead_trainer.py: -------------------------------------------------------------------------------- 1 | ########## The following part was originally copied from Transformers' trainer (3.4.0) and then changed heavily for linear head probing. ########## 2 | 3 | # coding=utf-8 4 | # Copyright 2020-present the HuggingFace Inc. team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | A trainer for finding linear probing solutions 19 | """ 20 | 21 | import collections 22 | from src.models import ModelForPromptFinetuning 23 | import torch 24 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 25 | from sklearn.linear_model import LinearRegression, LogisticRegression, LogisticRegressionCV 26 | 27 | import transformers 28 | from torch.utils.data.dataset import Dataset 29 | from transformers.trainer_utils import TrainOutput 30 | from transformers.utils import logging 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | 35 | def tensor_all_gather(tensor: torch.Tensor, distributed_world_size: int): 36 | tensor_list = [torch.zeros_like(tensor) for _ in range(distributed_world_size)] 37 | torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor) 38 | return torch.cat(tensor_list, dim=0) 39 | 40 | 41 | def varsize_tensor_all_gather(tensor: torch.Tensor, distributed_world_size: int): 42 | tensor = tensor.contiguous() 43 | 44 | dim_tensor = torch.tensor([tensor.size(0)], dtype=torch.int64, device=tensor.device) 45 | dim_tensor = tensor_all_gather(dim_tensor, distributed_world_size).cpu() 46 | max_size = dim_tensor.max() 47 | 48 | padded = torch.empty(max_size, *tensor.shape[1:], 49 | dtype=tensor.dtype, 50 | device=tensor.device) 51 | padded[:tensor.shape[0]] = tensor 52 | 53 | ag = tensor_all_gather(padded, distributed_world_size) 54 | slices = [] 55 | for i, sz in enumerate(dim_tensor): 56 | start_idx = i * max_size 57 | end_idx = start_idx + sz.item() 58 | 59 | if end_idx > start_idx: 60 | slices.append(ag[start_idx:end_idx]) 61 | 62 | return torch.cat(slices, dim=0) 63 | 64 | def get_token_prediction_layer(model): 65 | if isinstance(model, ModelForPromptFinetuning): 66 | if model.label_word_list is not None: 67 | lm_head = model.get_lm_head_fn() 68 | if model.model_type == "roberta": 69 | return lm_head.decoder 70 | elif model.model_type == "bert": 71 | return lm_head.predictions.decoder 72 | else: 73 | return model.classifier 74 | elif isinstance(model, transformers.RobertaForSequenceClassification): 75 | return model.classifier.out_proj 76 | elif isinstance(model, transformers.BertForSequenceClassification): 77 | return model.classifier 78 | else: 79 | raise NotImplementedError(model.__class__) 80 | 81 | def extract_features(model, *args, **kwargs): 82 | """some magic for getting features pre last layer""" 83 | features = {} 84 | def hook(model_, input_, output_): 85 | features["features"] = input_[0].detach() 86 | 87 | get_token_prediction_layer(model).register_forward_hook(hook) 88 | model.forward(*args, **kwargs) 89 | return features["features"] 90 | 91 | 92 | class LinearHeadTrainer(transformers.Trainer): 93 | """ 94 | Adding some functions based on Transformers' Trainer class. 95 | """ 96 | 97 | def train(self, model_path=None, dev_objective=None): 98 | """ 99 | Main training entry point. 100 | 101 | The training logic is directly borrowed from transformers.Trainer (version 3.0.2). 102 | Add early stopping. 103 | """ 104 | self.best_dir = None 105 | self.objective = -float("inf") 106 | 107 | model = self.model 108 | model.eval() 109 | 110 | train_dataloader = self.get_train_dataloader() 111 | targets = [] 112 | features = [] 113 | 114 | logger.info("Starting to get features for training dataset") 115 | with torch.no_grad(): 116 | for step, inputs in enumerate(train_dataloader): 117 | for k, v in inputs.items(): 118 | if isinstance(v, torch.Tensor): 119 | inputs[k] = v.to(self.args.device) 120 | features.append(extract_features(model, **inputs)) 121 | targets.append(inputs["labels"]) 122 | logger.info("Finished getting features for training dataset") 123 | 124 | features = torch.cat(features, dim=0) 125 | targets = torch.cat(targets, dim=0) 126 | 127 | if self.args.local_rank != -1: 128 | logger.info("Starting to gather features across workers") 129 | features = varsize_tensor_all_gather(features, torch.distributed.get_world_size()) 130 | targets = varsize_tensor_all_gather(targets, torch.distributed.get_world_size()) 131 | logger.info("Finished gathering features across workers") 132 | 133 | features = features.cpu() 134 | targets = targets.cpu() 135 | 136 | if model.num_labels == 1: # Regression 137 | targets_coords = targets.squeeze().unsqueeze(-1).float() 138 | reg = LinearRegression().fit(features.numpy(), targets_coords.numpy()) 139 | else: 140 | reg = LogisticRegressionCV(max_iter=5000, multi_class="multinomial", random_state=0).fit(features.numpy(), targets.numpy()) 141 | # targets_coords = torch.nn.functional.one_hot(targets.squeeze(), model.num_labels).float() 142 | 143 | logger.info("Fitting linear regression") 144 | 145 | logger.info("Assigning weights to model") 146 | # print(head.out_proj.weight.shape, head.out_proj.bias.shape) 147 | # print(reg.coef_.shape, reg.intercept_.shape) 148 | decoder = get_token_prediction_layer(model) 149 | coef_torch = torch.tensor(reg.coef_, device=decoder.weight.device, dtype=decoder.weight.dtype) 150 | bias_torch = torch.tensor(reg.intercept_, device=decoder.bias.device, dtype=decoder.bias.dtype) 151 | 152 | if model.num_labels == 2 and coef_torch.size(0) == 1: 153 | coef_torch = torch.cat([-coef_torch / 2, coef_torch / 2], dim=0) 154 | bias_torch = torch.cat([-bias_torch / 2, bias_torch / 2], dim=0) 155 | 156 | if decoder.weight.shape[0] == model.num_labels: 157 | decoder.weight.data = coef_torch 158 | decoder.bias.data = bias_torch 159 | else: 160 | decoder.weight.data[model.label_word_list,:] = coef_torch 161 | decoder.bias.data[model.label_word_list] = bias_torch 162 | 163 | if model.num_labels == 1: # Regression 164 | logits = torch.tensor(reg.predict(features.numpy())) 165 | train_loss = torch.nn.functional.mse_loss(logits, targets_coords, reduction="none") 166 | else: 167 | logits = torch.tensor(reg.predict_log_proba(features.numpy())) 168 | train_loss = torch.nn.functional.cross_entropy(logits, targets.squeeze(), reduction="none") 169 | 170 | return TrainOutput(0, train_loss, {}), self.objective 171 | 172 | 173 | 174 | """ 175 | Difference compared to original implementation: return output instead of output.metrics (so there is also the logits) 176 | """ 177 | def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: 178 | """ 179 | Run evaluation and returns metrics. 180 | 181 | The calling script will be responsible for providing a method to compute metrics, as they are 182 | task-dependent (pass it to the init :obj:`compute_metrics` argument). 183 | 184 | You can also subclass and override this method to inject custom behavior. 185 | 186 | Args: 187 | eval_dataset (:obj:`Dataset`, `optional`): 188 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 189 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement 190 | the :obj:`__len__` method. 191 | 192 | Returns: 193 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. 194 | """ 195 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 196 | raise ValueError("eval_dataset must implement __len__") 197 | 198 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 199 | 200 | output = self.prediction_loop(eval_dataloader, description="Evaluation") 201 | 202 | self.log(output.metrics) 203 | 204 | if self.args.tpu_metrics_debug or self.args.debug: 205 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 206 | xm.master_print(met.metrics_report()) 207 | 208 | return output 209 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | """Custom models for few-shot learning specific operations.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertOnlyMLMHead 6 | from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead 7 | 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | def resize_token_type_embeddings(model, new_num_types: int, random_segment: bool): 12 | """ 13 | Resize the segment (token type) embeddings for BERT 14 | """ 15 | if hasattr(model, 'bert'): 16 | old_token_type_embeddings = model.bert.embeddings.token_type_embeddings 17 | else: 18 | raise NotImplementedError 19 | new_token_type_embeddings = nn.Embedding(new_num_types, old_token_type_embeddings.weight.size(1)) 20 | if not random_segment: 21 | new_token_type_embeddings.weight.data[:old_token_type_embeddings.weight.size(0)] = old_token_type_embeddings.weight.data 22 | 23 | model.config.type_vocab_size = new_num_types 24 | if hasattr(model, 'bert'): 25 | model.bert.embeddings.token_type_embeddings = new_token_type_embeddings 26 | else: 27 | raise NotImplementedError 28 | 29 | class ModelForPromptFinetuning(BertPreTrainedModel): 30 | def __init__(self, config): 31 | super().__init__(config) 32 | self.num_labels = config.num_labels 33 | 34 | self.model_type = config.model_type 35 | if self.model_type == "roberta": 36 | self.roberta = RobertaModel(config) 37 | self.lm_head = RobertaLMHead(config) 38 | elif self.model_type == "bert": 39 | self.bert = BertModel(config) 40 | self.cls = BertOnlyMLMHead(config) 41 | else: 42 | raise NotImplementedError 43 | 44 | self.classifier = nn.Linear(config.hidden_size, self.num_labels) 45 | 46 | self.init_weights() 47 | 48 | # These attributes should be assigned once the model is initialized 49 | self.model_args = None 50 | self.data_args = None 51 | self.label_word_list = None 52 | 53 | # For regression 54 | self.lb = 0.0 55 | self.ub = 1.0 56 | 57 | # For auto label search. 58 | self.return_full_softmax = None 59 | 60 | def get_model_fn(self): 61 | return self.roberta if self.model_type == "roberta" else self.bert 62 | 63 | def get_lm_head_fn(self): 64 | return self.lm_head if self.model_type == "roberta" else self.cls 65 | 66 | def forward( 67 | self, 68 | input_ids=None, 69 | attention_mask=None, 70 | token_type_ids=None, 71 | mask_pos=None, 72 | labels=None, 73 | ): 74 | if mask_pos is not None: 75 | mask_pos = mask_pos.squeeze() 76 | 77 | model_fn = self.get_model_fn() 78 | # Encode everything 79 | outputs = model_fn( 80 | input_ids, 81 | attention_mask=attention_mask, 82 | token_type_ids=token_type_ids 83 | ) 84 | 85 | # Get token representation 86 | sequence_output, pooled_output = outputs[:2] 87 | if mask_pos is not None: 88 | sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] 89 | else: 90 | sequence_mask_output = sequence_output[:,0] # representation 91 | # sequence_mask_output = sequence_output.mean(dim=1) # average representation 92 | 93 | if self.label_word_list is not None: 94 | # Logits over vocabulary tokens 95 | head_fn = self.get_lm_head_fn() 96 | prediction_mask_scores = head_fn(sequence_mask_output) 97 | 98 | # Exit early and only return mask logits. 99 | if self.return_full_softmax: 100 | if labels is not None: 101 | return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores 102 | return prediction_mask_scores 103 | 104 | # Return logits for each label 105 | logits = [] 106 | # use MLM logit 107 | if self.model_args.use_task_word: 108 | vocab_logits = self.lm_head(sequence_mask_output) 109 | for _id in self.label_word_list: 110 | logits.append(vocab_logits[:, _id].unsqueeze(-1)) 111 | # use learned linear head logit on top of task word representation (standard LM-BFF) 112 | else: 113 | for label_id in range(len(self.label_word_list)): 114 | logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) 115 | logits = torch.cat(logits, -1) 116 | 117 | # Regression task 118 | if self.config.num_labels == 1: 119 | logsoftmax = nn.LogSoftmax(-1) 120 | logits = logsoftmax(logits) # Log prob of right polarity 121 | else: 122 | logits = self.classifier(sequence_mask_output) 123 | 124 | 125 | loss = None 126 | if labels is not None: 127 | if self.config.num_labels == 1: 128 | # Regression task 129 | if self.label_word_list is not None: 130 | labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) 131 | loss = nn.KLDivLoss(log_target=True)(logits.view(-1, 2), labels) 132 | else: 133 | labels = (labels.float().view(-1) - self.lb) / (self.ub - self.lb) 134 | loss = nn.MSELoss()(logits.view(-1), labels) 135 | else: 136 | if self.model_args.l2_loss: 137 | coords = torch.nn.functional.one_hot(labels.squeeze(), self.config.num_labels).float() 138 | loss = nn.MSELoss()(logits.view(-1, logits.size(-1)), coords) 139 | else: 140 | loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1)) 141 | 142 | output = (logits,) 143 | if self.model_args.use_task_word and self.num_labels == 1: 144 | # Regression output 145 | output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) 146 | return ((loss,) + output) if loss is not None else output 147 | -------------------------------------------------------------------------------- /src/processors.py: -------------------------------------------------------------------------------- 1 | """Dataset utils for different data settings for GLUE.""" 2 | 3 | import os 4 | import logging 5 | from transformers import DataProcessor, InputExample 6 | from transformers.data.processors.glue import * 7 | from transformers.data.metrics import glue_compute_metrics 8 | import pandas as pd 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class MrpcProcessor(DataProcessor): 14 | """Processor for the MRPC data set (GLUE version).""" 15 | 16 | def get_example_from_tensor_dict(self, tensor_dict): 17 | """See base class.""" 18 | return InputExample( 19 | tensor_dict["idx"].numpy(), 20 | tensor_dict["sentence1"].numpy().decode("utf-8"), 21 | tensor_dict["sentence2"].numpy().decode("utf-8"), 22 | str(tensor_dict["label"].numpy()), 23 | ) 24 | 25 | def get_train_examples(self, data_dir): 26 | """See base class.""" 27 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 28 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 29 | 30 | def get_dev_examples(self, data_dir): 31 | """See base class.""" 32 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 33 | 34 | def get_test_examples(self, data_dir): 35 | """See base class.""" 36 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 37 | 38 | def get_labels(self): 39 | """See base class.""" 40 | return ["0", "1"] 41 | 42 | def _create_examples(self, lines, set_type): 43 | """Creates examples for the training, dev and test sets.""" 44 | examples = [] 45 | for (i, line) in enumerate(lines): 46 | if i == 0: 47 | continue 48 | guid = "%s-%s" % (set_type, i) 49 | text_a = line[3] 50 | text_b = line[4] 51 | label = line[0] 52 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 53 | return examples 54 | 55 | 56 | class MnliProcessor(DataProcessor): 57 | """Processor for the MultiNLI data set (GLUE version).""" 58 | 59 | def get_example_from_tensor_dict(self, tensor_dict): 60 | """See base class.""" 61 | return InputExample( 62 | tensor_dict["idx"].numpy(), 63 | tensor_dict["premise"].numpy().decode("utf-8"), 64 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 65 | str(tensor_dict["label"].numpy()), 66 | ) 67 | 68 | def get_train_examples(self, data_dir): 69 | """See base class.""" 70 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 71 | 72 | def get_dev_examples(self, data_dir): 73 | """See base class.""" 74 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") 75 | 76 | def get_test_examples(self, data_dir): 77 | """See base class.""" 78 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched") 79 | 80 | def get_labels(self): 81 | """See base class.""" 82 | return ["contradiction", "entailment", "neutral"] 83 | 84 | def _create_examples(self, lines, set_type): 85 | """Creates examples for the training, dev and test sets.""" 86 | examples = [] 87 | for (i, line) in enumerate(lines): 88 | if i == 0: 89 | continue 90 | guid = "%s-%s" % (set_type, line[0]) 91 | text_a = line[8] 92 | text_b = line[9] 93 | label = line[-1] 94 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 95 | return examples 96 | 97 | 98 | class MnliMismatchedProcessor(MnliProcessor): 99 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 100 | 101 | def get_dev_examples(self, data_dir): 102 | """See base class.""" 103 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched") 104 | 105 | def get_test_examples(self, data_dir): 106 | """See base class.""" 107 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched") 108 | 109 | 110 | class SnliProcessor(DataProcessor): 111 | """Processor for the MultiNLI data set (GLUE version).""" 112 | 113 | def get_example_from_tensor_dict(self, tensor_dict): 114 | """See base class.""" 115 | return InputExample( 116 | tensor_dict["idx"].numpy(), 117 | tensor_dict["premise"].numpy().decode("utf-8"), 118 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 119 | str(tensor_dict["label"].numpy()), 120 | ) 121 | 122 | def get_train_examples(self, data_dir): 123 | """See base class.""" 124 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 125 | 126 | def get_dev_examples(self, data_dir): 127 | """See base class.""" 128 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 129 | 130 | def get_test_examples(self, data_dir): 131 | """See base class.""" 132 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 133 | 134 | def get_labels(self): 135 | """See base class.""" 136 | return ["contradiction", "entailment", "neutral"] 137 | 138 | def _create_examples(self, lines, set_type): 139 | """Creates examples for the training, dev and test sets.""" 140 | examples = [] 141 | for (i, line) in enumerate(lines): 142 | if i == 0: 143 | continue 144 | guid = "%s-%s" % (set_type, line[0]) 145 | text_a = line[7] 146 | text_b = line[8] 147 | label = line[-1] 148 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 149 | return examples 150 | 151 | 152 | class ColaProcessor(DataProcessor): 153 | """Processor for the CoLA data set (GLUE version).""" 154 | 155 | def get_example_from_tensor_dict(self, tensor_dict): 156 | """See base class.""" 157 | return InputExample( 158 | tensor_dict["idx"].numpy(), 159 | tensor_dict["sentence"].numpy().decode("utf-8"), 160 | None, 161 | str(tensor_dict["label"].numpy()), 162 | ) 163 | 164 | def get_train_examples(self, data_dir): 165 | """See base class.""" 166 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 167 | 168 | def get_dev_examples(self, data_dir): 169 | """See base class.""" 170 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 171 | 172 | def get_test_examples(self, data_dir): 173 | """See base class.""" 174 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 175 | 176 | def get_labels(self): 177 | """See base class.""" 178 | return ["0", "1"] 179 | 180 | def _create_examples(self, lines, set_type): 181 | """Creates examples for the training, dev and test sets.""" 182 | test_mode = set_type == "test" 183 | text_index = 3 184 | examples = [] 185 | for (i, line) in enumerate(lines): 186 | guid = "%s-%s" % (set_type, i) 187 | text_a = line[text_index] 188 | label = line[1] 189 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 190 | return examples 191 | 192 | 193 | class Sst2Processor(DataProcessor): 194 | """Processor for the SST-2 data set (GLUE version).""" 195 | 196 | def get_example_from_tensor_dict(self, tensor_dict): 197 | """See base class.""" 198 | return InputExample( 199 | tensor_dict["idx"].numpy(), 200 | tensor_dict["sentence"].numpy().decode("utf-8"), 201 | None, 202 | str(tensor_dict["label"].numpy()), 203 | ) 204 | 205 | def get_train_examples(self, data_dir): 206 | """See base class.""" 207 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 208 | 209 | def get_dev_examples(self, data_dir): 210 | """See base class.""" 211 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 212 | 213 | def get_test_examples(self, data_dir): 214 | """See base class.""" 215 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 216 | 217 | def get_labels(self): 218 | """See base class.""" 219 | return ["0", "1"] 220 | 221 | def _create_examples(self, lines, set_type): 222 | """Creates examples for the training, dev and test sets.""" 223 | examples = [] 224 | text_index = 0 225 | for (i, line) in enumerate(lines): 226 | if i == 0: 227 | continue 228 | guid = "%s-%s" % (set_type, i) 229 | text_a = line[text_index] 230 | label = line[1] 231 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 232 | return examples 233 | 234 | 235 | class StsbProcessor(DataProcessor): 236 | """Processor for the STS-B data set (GLUE version).""" 237 | 238 | def get_example_from_tensor_dict(self, tensor_dict): 239 | """See base class.""" 240 | return InputExample( 241 | tensor_dict["idx"].numpy(), 242 | tensor_dict["sentence1"].numpy().decode("utf-8"), 243 | tensor_dict["sentence2"].numpy().decode("utf-8"), 244 | str(tensor_dict["label"].numpy()), 245 | ) 246 | 247 | def get_train_examples(self, data_dir): 248 | """See base class.""" 249 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 250 | 251 | def get_dev_examples(self, data_dir): 252 | """See base class.""" 253 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 254 | 255 | def get_test_examples(self, data_dir): 256 | """See base class.""" 257 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 258 | 259 | def get_labels(self): 260 | """See base class.""" 261 | return [None] 262 | 263 | def _create_examples(self, lines, set_type): 264 | """Creates examples for the training, dev and test sets.""" 265 | examples = [] 266 | for (i, line) in enumerate(lines): 267 | if i == 0: 268 | continue 269 | guid = "%s-%s" % (set_type, line[0]) 270 | text_a = line[7] 271 | text_b = line[8] 272 | label = line[-1] 273 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 274 | return examples 275 | 276 | 277 | class QqpProcessor(DataProcessor): 278 | """Processor for the QQP data set (GLUE version).""" 279 | 280 | def get_example_from_tensor_dict(self, tensor_dict): 281 | """See base class.""" 282 | return InputExample( 283 | tensor_dict["idx"].numpy(), 284 | tensor_dict["question1"].numpy().decode("utf-8"), 285 | tensor_dict["question2"].numpy().decode("utf-8"), 286 | str(tensor_dict["label"].numpy()), 287 | ) 288 | 289 | def get_train_examples(self, data_dir): 290 | """See base class.""" 291 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 292 | 293 | def get_dev_examples(self, data_dir): 294 | """See base class.""" 295 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 296 | 297 | def get_test_examples(self, data_dir): 298 | """See base class.""" 299 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 300 | 301 | def get_labels(self): 302 | """See base class.""" 303 | return ["0", "1"] 304 | 305 | def _create_examples(self, lines, set_type): 306 | """Creates examples for the training, dev and test sets.""" 307 | test_mode = set_type == "test" 308 | q1_index = 3 309 | q2_index = 4 310 | examples = [] 311 | for (i, line) in enumerate(lines): 312 | if i == 0: 313 | continue 314 | guid = "%s-%s" % (set_type, line[0]) 315 | try: 316 | text_a = line[q1_index] 317 | text_b = line[q2_index] 318 | label = line[5] 319 | except IndexError: 320 | continue 321 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 322 | return examples 323 | 324 | 325 | class QnliProcessor(DataProcessor): 326 | """Processor for the QNLI data set (GLUE version).""" 327 | 328 | def get_example_from_tensor_dict(self, tensor_dict): 329 | """See base class.""" 330 | return InputExample( 331 | tensor_dict["idx"].numpy(), 332 | tensor_dict["question"].numpy().decode("utf-8"), 333 | tensor_dict["sentence"].numpy().decode("utf-8"), 334 | str(tensor_dict["label"].numpy()), 335 | ) 336 | 337 | def get_train_examples(self, data_dir): 338 | """See base class.""" 339 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 340 | 341 | def get_dev_examples(self, data_dir): 342 | """See base class.""" 343 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 344 | 345 | def get_test_examples(self, data_dir): 346 | """See base class.""" 347 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 348 | 349 | def get_labels(self): 350 | """See base class.""" 351 | return ["entailment", "not_entailment"] 352 | 353 | def _create_examples(self, lines, set_type): 354 | """Creates examples for the training, dev and test sets.""" 355 | examples = [] 356 | for (i, line) in enumerate(lines): 357 | if i == 0: 358 | continue 359 | guid = "%s-%s" % (set_type, line[0]) 360 | text_a = line[1] 361 | text_b = line[2] 362 | label = line[-1] 363 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 364 | return examples 365 | 366 | 367 | class RteProcessor(DataProcessor): 368 | """Processor for the RTE data set (GLUE version).""" 369 | 370 | def get_example_from_tensor_dict(self, tensor_dict): 371 | """See base class.""" 372 | return InputExample( 373 | tensor_dict["idx"].numpy(), 374 | tensor_dict["sentence1"].numpy().decode("utf-8"), 375 | tensor_dict["sentence2"].numpy().decode("utf-8"), 376 | str(tensor_dict["label"].numpy()), 377 | ) 378 | 379 | def get_train_examples(self, data_dir): 380 | """See base class.""" 381 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 382 | 383 | def get_dev_examples(self, data_dir): 384 | """See base class.""" 385 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 386 | 387 | def get_test_examples(self, data_dir): 388 | """See base class.""" 389 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 390 | 391 | def get_labels(self): 392 | """See base class.""" 393 | return ["entailment", "not_entailment"] 394 | 395 | def _create_examples(self, lines, set_type): 396 | """Creates examples for the training, dev and test sets.""" 397 | examples = [] 398 | for (i, line) in enumerate(lines): 399 | if i == 0: 400 | continue 401 | guid = "%s-%s" % (set_type, line[0]) 402 | text_a = line[1] 403 | text_b = line[2] 404 | label = line[-1] 405 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 406 | return examples 407 | 408 | 409 | class WnliProcessor(DataProcessor): 410 | """Processor for the WNLI data set (GLUE version).""" 411 | 412 | def get_example_from_tensor_dict(self, tensor_dict): 413 | """See base class.""" 414 | return InputExample( 415 | tensor_dict["idx"].numpy(), 416 | tensor_dict["sentence1"].numpy().decode("utf-8"), 417 | tensor_dict["sentence2"].numpy().decode("utf-8"), 418 | str(tensor_dict["label"].numpy()), 419 | ) 420 | 421 | def get_train_examples(self, data_dir): 422 | """See base class.""" 423 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 424 | 425 | def get_dev_examples(self, data_dir): 426 | """See base class.""" 427 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 428 | 429 | def get_test_examples(self, data_dir): 430 | """See base class.""" 431 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 432 | 433 | def get_labels(self): 434 | """See base class.""" 435 | return ["0", "1"] 436 | 437 | def _create_examples(self, lines, set_type): 438 | """Creates examples for the training, dev and test sets.""" 439 | examples = [] 440 | for (i, line) in enumerate(lines): 441 | if i == 0: 442 | continue 443 | guid = "%s-%s" % (set_type, line[0]) 444 | text_a = line[1] 445 | text_b = line[2] 446 | label = line[-1] 447 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 448 | return examples 449 | 450 | class TextClassificationProcessor(DataProcessor): 451 | """ 452 | Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa). 453 | """ 454 | 455 | def __init__(self, task_name): 456 | self.task_name = task_name 457 | 458 | def get_example_from_tensor_dict(self, tensor_dict): 459 | """See base class.""" 460 | return InputExample( 461 | tensor_dict["idx"].numpy(), 462 | tensor_dict["sentence"].numpy().decode("utf-8"), 463 | None, 464 | str(tensor_dict["label"].numpy()), 465 | ) 466 | 467 | def get_train_examples(self, data_dir): 468 | """See base class.""" 469 | return self._create_examples(pd.read_csv(os.path.join(data_dir, "train.csv"), header=None).values.tolist(), "train") 470 | 471 | def get_dev_examples(self, data_dir): 472 | """See base class.""" 473 | return self._create_examples(pd.read_csv(os.path.join(data_dir, "dev.csv"), header=None).values.tolist(), "dev") 474 | 475 | def get_test_examples(self, data_dir): 476 | """See base class.""" 477 | return self._create_examples(pd.read_csv(os.path.join(data_dir, "test.csv"), header=None).values.tolist(), "test") 478 | 479 | def get_labels(self): 480 | """See base class.""" 481 | if self.task_name == "ag_news": 482 | return list(range(1,5)) 483 | if self.task_name == "mr": 484 | return list(range(2)) 485 | elif self.task_name == "sst-5": 486 | return list(range(5)) 487 | elif self.task_name == "subj": 488 | return list(range(2)) 489 | elif self.task_name == "trec": 490 | return list(range(6)) 491 | elif self.task_name == "cr": 492 | return list(range(2)) 493 | elif self.task_name == "mpqa": 494 | return list(range(2)) 495 | else: 496 | raise Exception("task_name not supported.") 497 | 498 | def _create_examples(self, lines, set_type): 499 | """Creates examples for the training, dev and test sets.""" 500 | examples = [] 501 | for (i, line) in enumerate(lines): 502 | guid = "%s-%s" % (set_type, i) 503 | if self.task_name == "ag_news": 504 | examples.append(InputExample(guid=guid, text_a=line[2], label=line[0])) 505 | elif self.task_name == "yelp_review_full": 506 | examples.append(InputExample(guid=guid, text_a=line[1], short_text=line[1], label=line[0])) 507 | elif self.task_name == "yahoo_answers": 508 | text = line[1] 509 | if not pd.isna(line[2]): 510 | text += ' ' + line[2] 511 | if not pd.isna(line[3]): 512 | text += ' ' + line[3] 513 | examples.append(InputExample(guid=guid, text_a=text, short_text=line[1], label=line[0])) 514 | elif self.task_name in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']: 515 | examples.append(InputExample(guid=guid, text_a=line[1], label=line[0])) 516 | else: 517 | raise Exception("Task_name not supported.") 518 | 519 | return examples 520 | 521 | def text_classification_metrics(task_name, preds, labels): 522 | return {"acc": (preds == labels).mean()} 523 | 524 | # Add your task to the following mappings 525 | 526 | processors_mapping = { 527 | "cola": ColaProcessor(), 528 | "mnli": MnliProcessor(), 529 | "mnli-mm": MnliMismatchedProcessor(), 530 | "mrpc": MrpcProcessor(), 531 | "sst-2": Sst2Processor(), 532 | "sts-b": StsbProcessor(), 533 | "qqp": QqpProcessor(), 534 | "qnli": QnliProcessor(), 535 | "rte": RteProcessor(), 536 | "wnli": WnliProcessor(), 537 | "snli": SnliProcessor(), 538 | "mr": TextClassificationProcessor("mr"), 539 | "sst-5": TextClassificationProcessor("sst-5"), 540 | "subj": TextClassificationProcessor("subj"), 541 | "trec": TextClassificationProcessor("trec"), 542 | "cr": TextClassificationProcessor("cr"), 543 | "mpqa": TextClassificationProcessor("mpqa"), 544 | "ag_news": TextClassificationProcessor("ag_news") 545 | } 546 | 547 | num_labels_mapping = { 548 | "cola": 2, 549 | "mnli": 3, 550 | "mrpc": 2, 551 | "sst-2": 2, 552 | "sts-b": 1, 553 | "qqp": 2, 554 | "qnli": 2, 555 | "rte": 2, 556 | "wnli": 2, 557 | "snli": 3, 558 | "mr": 2, 559 | "sst-5": 5, 560 | "subj": 2, 561 | "trec": 6, 562 | "cr": 2, 563 | "mpqa": 2, 564 | "ag_news": 4 565 | } 566 | 567 | output_modes_mapping = { 568 | "cola": "classification", 569 | "mnli": "classification", 570 | "mnli-mm": "classification", 571 | "mrpc": "classification", 572 | "sst-2": "classification", 573 | "sts-b": "regression", 574 | "qqp": "classification", 575 | "qnli": "classification", 576 | "rte": "classification", 577 | "wnli": "classification", 578 | "snli": "classification", 579 | "mr": "classification", 580 | "sst-5": "classification", 581 | "subj": "classification", 582 | "trec": "classification", 583 | "cr": "classification", 584 | "mpqa": "classification", 585 | "ag_news": "classification" 586 | } 587 | 588 | # Return a function that takes (task_name, preds, labels) as inputs 589 | compute_metrics_mapping = { 590 | "cola": glue_compute_metrics, 591 | "mnli": glue_compute_metrics, 592 | "mnli-mm": glue_compute_metrics, 593 | "mrpc": glue_compute_metrics, 594 | "sst-2": glue_compute_metrics, 595 | "sts-b": glue_compute_metrics, 596 | "qqp": glue_compute_metrics, 597 | "qnli": glue_compute_metrics, 598 | "rte": glue_compute_metrics, 599 | "wnli": glue_compute_metrics, 600 | "snli": text_classification_metrics, 601 | "mr": text_classification_metrics, 602 | "sst-5": text_classification_metrics, 603 | "subj": text_classification_metrics, 604 | "trec": text_classification_metrics, 605 | "cr": text_classification_metrics, 606 | "mpqa": text_classification_metrics, 607 | "ag_news": text_classification_metrics, 608 | } 609 | 610 | # For regression task only: median 611 | median_mapping = { 612 | "sts-b": 2.5 613 | } 614 | 615 | bound_mapping = { 616 | "sts-b": (0, 5) 617 | } 618 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | ########## The following part is copied from Transformers' trainer (3.4.0) and later ported to be compatible with v4.4.2 and to support initialization from linear head probing. ########## 2 | 3 | # coding=utf-8 4 | # Copyright 2020-present the HuggingFace Inc. team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. 19 | """ 20 | 21 | import collections 22 | import inspect 23 | import math 24 | import os 25 | import re 26 | import shutil 27 | import warnings 28 | from pathlib import Path 29 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 30 | 31 | import numpy as np 32 | import torch 33 | from packaging import version 34 | from torch import nn 35 | from torch.utils.data.dataloader import DataLoader 36 | from torch.utils.data.dataset import Dataset 37 | from torch.utils.data.distributed import DistributedSampler 38 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 39 | 40 | import transformers 41 | from transformers.file_utils import is_datasets_available, is_in_notebook, is_torch_tpu_available 42 | from transformers.integrations import ( 43 | is_comet_available, 44 | is_optuna_available, 45 | is_ray_available, 46 | is_tensorboard_available, 47 | is_wandb_available, 48 | ) 49 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 50 | from transformers.trainer_callback import ( 51 | DefaultFlowCallback, 52 | ProgressCallback, 53 | ) 54 | from transformers.trainer_utils import ( 55 | default_compute_objective, 56 | ) 57 | from transformers.training_args import TrainingArguments 58 | from transformers.utils import logging 59 | from transformers.trainer_utils import TrainOutput 60 | 61 | from tqdm import tqdm, trange 62 | from torch.optim import SGD 63 | 64 | from src.linearhead_trainer import LinearHeadTrainer 65 | 66 | _use_native_amp = False 67 | _use_apex = False 68 | 69 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 70 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 71 | 72 | if is_in_notebook(): 73 | from transformers.utils.notebook import NotebookProgressCallback 74 | 75 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback 76 | 77 | # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex 78 | if version.parse(torch.__version__) < version.parse("1.6"): 79 | from transformers.file_utils import is_apex_available 80 | 81 | if is_apex_available(): 82 | from apex import amp 83 | _use_apex = True 84 | else: 85 | _use_native_amp = True 86 | from torch.cuda.amp import autocast 87 | 88 | if version.parse(torch.__version__) < version.parse("1.2"): 89 | _use_ddp_no_sync = False 90 | else: 91 | _use_ddp_no_sync = True 92 | 93 | if is_datasets_available(): 94 | import datasets 95 | 96 | if is_torch_tpu_available(): 97 | import torch_xla.core.xla_model as xm 98 | import torch_xla.debug.metrics as met 99 | import torch_xla.distributed.parallel_loader as pl 100 | 101 | if is_tensorboard_available(): 102 | from transformers.integrations import TensorBoardCallback 103 | 104 | DEFAULT_CALLBACKS.append(TensorBoardCallback) 105 | 106 | 107 | if is_wandb_available(): 108 | from transformers.integrations import WandbCallback 109 | 110 | DEFAULT_CALLBACKS.append(WandbCallback) 111 | 112 | if is_comet_available(): 113 | from transformers.integrations import CometCallback 114 | 115 | DEFAULT_CALLBACKS.append(CometCallback) 116 | 117 | if is_optuna_available(): 118 | import optuna 119 | 120 | if is_ray_available(): 121 | from ray import tune 122 | 123 | logger = logging.get_logger(__name__) 124 | 125 | ########## The above part is copied from Transformers' trainer (3.4.0) ########## 126 | 127 | def default_dev_objective(metrics): 128 | """ 129 | Objective used for picking the best model on development sets 130 | """ 131 | if "eval_mnli/acc" in metrics: 132 | return metrics["eval_mnli/acc"] 133 | elif "eval_mnli-mm/acc" in metrics: 134 | return metrics["eval_mnli-mm/acc"] 135 | elif "eval_f1" in metrics: 136 | return metrics["eval_f1"] 137 | elif "eval_mcc" in metrics: 138 | return metrics["eval_mcc"] 139 | elif "eval_pearson" in metrics: 140 | return metrics["eval_pearson"] 141 | elif "eval_acc" in metrics: 142 | return metrics["eval_acc"] 143 | 144 | raise Exception("No metric founded for {}".format(metrics)) 145 | 146 | class Trainer(LinearHeadTrainer): 147 | """ 148 | Adding some functions based on Transformers' Trainer class. 149 | """ 150 | 151 | def create_optimizer_and_scheduler(self, num_training_steps: int): 152 | """ 153 | Based on Transformers' default one, we add fixing layer option where the bottom n layers' parameters 154 | are fixed and only the top layers are further fine-tuned. 155 | """ 156 | if self.optimizer is None: 157 | params = {} 158 | for n, p in self.model.named_parameters(): 159 | if self.args.fix_layers > 0: 160 | if 'encoder.layer' in n: 161 | try: 162 | layer_num = int(n[n.find('encoder.layer') + 14:].split('.')[0]) 163 | except: 164 | print(n) 165 | raise Exception("") 166 | if layer_num >= self.args.fix_layers: 167 | print('yes', n) 168 | params[n] = p 169 | else: 170 | print('no ', n) 171 | elif 'embeddings' in n: 172 | print('no ', n) 173 | else: 174 | print('yes', n) 175 | params[n] = p 176 | else: 177 | params[n] = p 178 | no_decay = ["bias", "LayerNorm.weight"] 179 | optimizer_grouped_parameters = [ 180 | { 181 | "params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)], 182 | "weight_decay": self.args.weight_decay, 183 | }, 184 | { 185 | "params": [p for n, p in params.items() if any(nd in n for nd in no_decay)], 186 | "weight_decay": 0.0, 187 | }, 188 | ] 189 | if self.args.optimizer == 'adam': 190 | self.optimizer = AdamW( 191 | optimizer_grouped_parameters, 192 | lr=self.args.learning_rate, 193 | betas=(self.args.adam_beta1, self.args.adam_beta2), 194 | eps=self.args.adam_epsilon, 195 | ) 196 | elif self.args.optimizer == 'sgd': 197 | self.optimizer = SGD( 198 | optimizer_grouped_parameters, 199 | lr=self.args.learning_rate 200 | ) 201 | else: 202 | raise NotImplementedError 203 | if self.lr_scheduler is None: 204 | self.lr_scheduler = get_linear_schedule_with_warmup( 205 | self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps 206 | ) 207 | 208 | def train(self, model_path=None, dev_objective=None): 209 | """ 210 | Main training entry point. 211 | 212 | The training logic is directly borrowed from transformers.Trainer (version 3.0.2). 213 | Add early stopping. 214 | """ 215 | if self.args.from_linearhead and model_path is None: 216 | super().train(model_path, dev_objective) # Train output layer using LinearHeadTrainer 217 | 218 | self.best_dir = None 219 | self.objective = -float("inf") 220 | self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective 221 | 222 | # Data loading. 223 | train_dataloader = self.get_train_dataloader() 224 | num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps 225 | if num_update_steps_per_epoch == 0: 226 | num_update_steps_per_epoch = 1 227 | if self.args.max_steps > 0: 228 | t_total = self.args.max_steps 229 | num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( 230 | self.args.max_steps % num_update_steps_per_epoch > 0 231 | ) 232 | else: 233 | t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) 234 | num_train_epochs = self.args.num_train_epochs 235 | 236 | self.create_optimizer_and_scheduler(num_training_steps=t_total) 237 | optimizer = self.optimizer 238 | scheduler = self.lr_scheduler 239 | 240 | # Check if saved optimizer or scheduler states exist 241 | if ( 242 | model_path is not None 243 | and os.path.isfile(os.path.join(model_path, "optimizer.pt")) 244 | and os.path.isfile(os.path.join(model_path, "scheduler.pt")) 245 | ): 246 | # Load in optimizer and scheduler states 247 | optimizer.load_state_dict( 248 | torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) 249 | ) 250 | scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) 251 | 252 | model = self.model 253 | 254 | if self.args.fp16 and _use_apex: 255 | if not transformers.is_apex_available(): 256 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 257 | model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) 258 | 259 | # Multi-gpu training (should be after apex fp16 initialization) 260 | if self.args.n_gpu > 1: 261 | model = torch.nn.DataParallel(model) 262 | 263 | # Distributed training (should be after apex fp16 initialization) 264 | if self.args.local_rank != -1: 265 | model = torch.nn.parallel.DistributedDataParallel( 266 | model, 267 | device_ids=[self.args.local_rank], 268 | output_device=self.args.local_rank, 269 | find_unused_parameters=True, 270 | ) 271 | 272 | # Train 273 | if transformers.is_torch_tpu_available(): 274 | total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() 275 | else: 276 | total_train_batch_size = ( 277 | self.args.train_batch_size 278 | * self.args.gradient_accumulation_steps 279 | * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) 280 | ) 281 | logger.info("***** Running training *****") 282 | logger.info(" Num examples = %d", self.num_examples(train_dataloader)) 283 | logger.info(" Num Epochs = %d", num_train_epochs) 284 | logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) 285 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) 286 | logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) 287 | logger.info(" Total optimization steps = %d", t_total) 288 | 289 | self.global_step = 0 290 | self.epoch = 0 291 | epochs_trained = 0 292 | steps_trained_in_current_epoch = 0 293 | # Check if continuing training from a checkpoint 294 | if model_path is not None: 295 | # set global_step to global_step of last saved checkpoint from model path 296 | try: 297 | self.global_step = int(model_path.split("-")[-1].split("/")[0]) 298 | epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) 299 | steps_trained_in_current_epoch = self.global_step % ( 300 | len(train_dataloader) // self.args.gradient_accumulation_steps 301 | ) 302 | 303 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 304 | logger.info(" Continuing training from epoch %d", epochs_trained) 305 | logger.info(" Continuing training from global step %d", self.global_step) 306 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 307 | except ValueError: 308 | self.global_step = 0 309 | logger.info(" Starting fine-tuning.") 310 | 311 | tr_loss = torch.tensor(0.0).to(self.args.device) 312 | logging_loss_scalar = 0.0 313 | model.zero_grad() 314 | for epoch in range(epochs_trained, int(num_train_epochs)): 315 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 316 | train_dataloader.sampler.set_epoch(epoch) 317 | 318 | if transformers.is_torch_tpu_available(): 319 | parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( 320 | self.args.device 321 | ) 322 | epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero()) 323 | else: 324 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) 325 | 326 | # Reset the past mems state at the beginning of each epoch if necessary. 327 | if self.args.past_index >= 0: 328 | self._past = None 329 | 330 | for step, inputs in enumerate(tqdm(epoch_iterator, desc=f'Steps in epoch {epoch}', disable=not self.is_local_process_zero())): 331 | 332 | # Skip past any already trained steps if resuming training 333 | if steps_trained_in_current_epoch > 0: 334 | steps_trained_in_current_epoch -= 1 335 | continue 336 | 337 | tr_loss += self.training_step(model, inputs) 338 | 339 | if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( 340 | # last step in epoch but step is always smaller than gradient_accumulation_steps 341 | len(epoch_iterator) <= self.args.gradient_accumulation_steps 342 | and (step + 1) == len(epoch_iterator) 343 | ): 344 | if self.args.fp16 and _use_native_amp: 345 | self.scaler.unscale_(optimizer) 346 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) 347 | elif self.args.fp16: 348 | norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) 349 | else: 350 | norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) 351 | 352 | 353 | if self.args.optimizer_variant == 'signgd': 354 | for n,p in model.named_parameters(): 355 | if p.grad is not None: 356 | p.grad = torch.sign(p.grad) 357 | 358 | if transformers.is_torch_tpu_available(): 359 | xm.optimizer_step(optimizer) 360 | elif self.args.fp16 and _use_native_amp: 361 | self.scaler.step(optimizer) 362 | self.scaler.update() 363 | else: 364 | optimizer.step() 365 | 366 | scheduler.step() 367 | model.zero_grad() 368 | self.global_step += 1 369 | self.epoch = epoch + (step + 1) / len(epoch_iterator) 370 | 371 | if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( 372 | self.global_step == 1 and self.args.logging_first_step 373 | ): 374 | logs = {} 375 | tr_loss_scalar = tr_loss.item() 376 | logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps 377 | logs["norm"] = norm.item() 378 | # backward compatibility for pytorch schedulers 379 | logs["learning_rate"] = ( 380 | scheduler.get_last_lr()[0] 381 | if version.parse(torch.__version__) >= version.parse("1.4") 382 | else scheduler.get_lr()[0] 383 | ) 384 | logging_loss_scalar = tr_loss_scalar 385 | 386 | self.log(logs) 387 | 388 | if self.args.max_steps > 0 and self.global_step > self.args.max_steps: 389 | epoch_iterator.close() 390 | break 391 | if self.args.max_steps > 0 and self.global_step > self.args.max_steps: 392 | # train_iterator.close() 393 | break 394 | if self.args.tpu_metrics_debug or self.args.debug: 395 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 396 | xm.master_print(met.metrics_report()) 397 | 398 | 399 | # ---------------------------------------------------------------------- 400 | # BEGIN CHANGES. 401 | # ---------------------------------------------------------------------- 402 | 403 | metrics = None 404 | if self.args.evaluate_during_training: #and self.global_step % self.args.eval_steps == 0: 405 | output = self.evaluate() 406 | metrics = output.metrics 407 | objective = self.dev_objective(metrics) 408 | if objective > self.objective: 409 | logger.info("Best dev result: {}".format(objective)) 410 | self.objective = objective 411 | self.save_model(self.args.output_dir) 412 | 413 | # ---------------------------------------------------------------------- 414 | # END CHANGES. 415 | # ---------------------------------------------------------------------- 416 | 417 | 418 | if self.args.past_index and hasattr(self, "_past"): 419 | # Clean the state at the end of training 420 | delattr(self, "_past") 421 | 422 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") 423 | return TrainOutput(self.global_step, tr_loss / self.global_step, metrics), self.objective 424 | 425 | 426 | """ 427 | Difference compared to original implementation: return output instead of output.metrics (so there is also the logits) 428 | """ 429 | def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: 430 | """ 431 | Run evaluation and returns metrics. 432 | 433 | The calling script will be responsible for providing a method to compute metrics, as they are 434 | task-dependent (pass it to the init :obj:`compute_metrics` argument). 435 | 436 | You can also subclass and override this method to inject custom behavior. 437 | 438 | Args: 439 | eval_dataset (:obj:`Dataset`, `optional`): 440 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 441 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement 442 | the :obj:`__len__` method. 443 | 444 | Returns: 445 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. 446 | """ 447 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 448 | raise ValueError("eval_dataset must implement __len__") 449 | 450 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 451 | 452 | output = self.prediction_loop(eval_dataloader, description="Evaluation") 453 | 454 | self.log(output.metrics) 455 | 456 | if self.args.tpu_metrics_debug or self.args.debug: 457 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 458 | xm.master_print(met.metrics_report()) 459 | 460 | return output 461 | -------------------------------------------------------------------------------- /tools/gather_result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import torch 5 | from torch import device 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") 10 | 11 | # These options should be kept as their default values 12 | parser.add_argument("--log", type=str, default="log", help="Log path.") 13 | parser.add_argument("--key", type=str, default='', help="Validation metric name") 14 | parser.add_argument("--test_key", type=str, default="", help="Test metric name") 15 | parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") 16 | 17 | args = parser.parse_args() 18 | 19 | condition = eval(args.condition) 20 | 21 | if len(args.key) == 0: 22 | if condition['task_name'] == 'cola': 23 | args.key = 'cola_dev_eval_mcc' 24 | args.test_key = 'cola_test_eval_mcc' 25 | elif condition['task_name'] == 'mrpc/acc': 26 | args.key = 'mrpc_dev_eval_acc' 27 | args.test_key = 'mrpc_test_eval_acc' 28 | args.test_key2 = 'mrpc_test_eval_f1' 29 | condition['task_name'] = 'mrpc' 30 | elif condition['task_name'] == 'mrpc/f1': 31 | args.key = 'mrpc_dev_eval_f1' 32 | args.test_key2 = 'mrpc_test_eval_acc' 33 | args.test_key = 'mrpc_test_eval_f1' 34 | condition['task_name'] = 'mrpc' 35 | elif condition['task_name'] == 'qqp/acc': 36 | args.key = 'qqp_dev_eval_acc' 37 | args.test_key = 'qqp_test_eval_acc' 38 | args.test_key2 = 'qqp_test_eval_f1' 39 | condition['task_name'] = 'qqp' 40 | elif condition['task_name'] == 'qqp/f1': 41 | args.key = 'qqp_dev_eval_f1' 42 | args.test_key2 = 'qqp_test_eval_acc' 43 | args.test_key = 'qqp_test_eval_f1' 44 | condition['task_name'] = 'qqp' 45 | elif condition['task_name'] == 'sts-b/pearson': 46 | args.key = 'sts-b_dev_eval_pearson' 47 | args.test_key = 'sts-b_test_eval_pearson' 48 | args.test_key2 = 'sts-b_test_eval_spearmanr' 49 | condition['task_name'] = 'sts-b' 50 | elif condition['task_name'] == 'sts-b/spearmanr': 51 | args.key = 'sts-b_dev_eval_spearmanr' 52 | args.test_key2 = 'sts-b_test_eval_pearson' 53 | args.test_key = 'sts-b_test_eval_spearmanr' 54 | condition['task_name'] = 'sts-b' 55 | elif condition['task_name'] == 'qnli': 56 | args.key = 'qnli_dev_eval_acc' 57 | args.test_key = 'qnli_test_eval_acc' 58 | elif condition['task_name'] == 'sst-2': 59 | args.key = 'sst-2_dev_eval_acc' 60 | args.test_key = 'sst-2_test_eval_acc' 61 | elif condition['task_name'] == 'snli': 62 | args.key = 'snli_dev_eval_acc' 63 | args.test_key = 'snli_test_eval_acc' 64 | elif condition['task_name'] == 'mnli': 65 | args.key = 'mnli_dev_eval_mnli/acc' 66 | args.test_key = 'mnli_test_eval_mnli/acc' 67 | elif condition['task_name'] == 'mnli-mm': 68 | condition['task_name'] = 'mnli' 69 | args.key = 'mnli_dev_eval_mnli/acc' 70 | args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' 71 | elif condition['task_name'] == 'rte': 72 | args.key = 'rte_dev_eval_acc' 73 | args.test_key = 'rte_test_eval_acc' 74 | elif condition['task_name'] == 'ag_news': 75 | args.key = 'ag_news_dev_eval_acc' 76 | args.test_key = 'ag_news_test_eval_acc' 77 | elif condition['task_name'] == 'yahoo_answers': 78 | args.key = 'yahoo_answers_dev_eval_acc' 79 | args.test_key = 'yahoo_answers_test_eval_acc' 80 | elif condition['task_name'] == 'yelp_review_full': 81 | args.key = 'yelp_review_full_dev_eval_acc' 82 | args.test_key = 'yelp_review_full_test_eval_acc' 83 | elif condition['task_name'] == 'mr': 84 | args.key = 'mr_dev_eval_acc' 85 | args.test_key = 'mr_test_eval_acc' 86 | elif condition['task_name'] == 'sst-5': 87 | args.key = 'sst-5_dev_eval_acc' 88 | args.test_key = 'sst-5_test_eval_acc' 89 | elif condition['task_name'] == 'subj': 90 | args.key = 'subj_dev_eval_acc' 91 | args.test_key = 'subj_test_eval_acc' 92 | elif condition['task_name'] == 'trec': 93 | args.key = 'trec_dev_eval_acc' 94 | args.test_key = 'trec_test_eval_acc' 95 | elif condition['task_name'] == 'cr': 96 | args.key = 'cr_dev_eval_acc' 97 | args.test_key = 'cr_test_eval_acc' 98 | elif condition['task_name'] == 'mpqa': 99 | args.key = 'mpqa_dev_eval_acc' 100 | args.test_key = 'mpqa_test_eval_acc' 101 | else: 102 | raise NotImplementedError 103 | 104 | with open(args.log) as f: 105 | result_list = [] 106 | for line in f: 107 | line = line.replace("<", "\"") 108 | line = line.replace(">", "\"") 109 | line = line.replace(" inf,", "float('inf'),") 110 | result_list.append(eval(line)) 111 | 112 | seed_result = {} 113 | 114 | for item in result_list: 115 | ok = True 116 | for cond in condition: 117 | if isinstance(condition[cond], list): 118 | if cond not in item or (item[cond] not in condition[cond]): 119 | ok = False 120 | break 121 | else: 122 | if cond not in item or (item[cond] != condition[cond]): 123 | ok = False 124 | break 125 | if ok and args.test_key in item and args.key in item: 126 | seed = item['data_dir'].split('-')[-1] + '-' + str(item['seed']) 127 | if seed not in seed_result: 128 | seed_result[seed] = [item] 129 | else: 130 | seed_result[seed].append(item) 131 | 132 | all_seed_result = seed_result 133 | all_tags = sorted(set(x['tag'] for x in sum(all_seed_result.values(), []))) 134 | all_k = sorted(set(x['num_k'] for x in sum(all_seed_result.values(), []))) 135 | 136 | for tag in all_tags: 137 | for k in all_k: 138 | print("Tag: {}, K: {}".format(tag, k)) 139 | seed_result_with_duplicates = { 140 | s: list(x for x in v if x['tag'] == tag and x['num_k'] == k) 141 | for s, v in all_seed_result.items() 142 | } 143 | seed_result = { 144 | s: list({x['output_dir']: x for x in v}.values()) 145 | for s, v in seed_result_with_duplicates.items() 146 | } 147 | seed_best = { 148 | k: max(sorted(v, key=lambda x: x['output_dir']), key=lambda x: x[args.key]) 149 | for k, v in seed_result.items() if v 150 | } 151 | 152 | final_result_dev = np.zeros((len(seed_best))) 153 | final_result_test = np.zeros((len(seed_best))) 154 | final_result_test2 = np.zeros((len(seed_best))) 155 | num_results = np.zeros((len(seed_best))) 156 | for i, seed in enumerate(seed_best): 157 | # for res in seed_result[seed]: 158 | # print(res) 159 | 160 | final_result_dev[i] = seed_best[seed][args.key] 161 | final_result_test[i] = seed_best[seed][args.test_key] 162 | num_results[i] = len(seed_result[seed]) 163 | if len(args.test_key2) > 0: 164 | final_result_test2[i] = seed_best[seed][args.test_key2] 165 | print("%s: best dev (%.5f) test (%.5f) %s | total trials: %d (ignored %d)" % ( 166 | seed, 167 | seed_best[seed][args.key], 168 | seed_best[seed][args.test_key], 169 | "test2 (%.5f)" % (seed_best[seed][args.test_key2]) if len(args.test_key2) > 0 else "", 170 | len(seed_result[seed]), 171 | len(seed_result_with_duplicates[seed]) - len(seed_result[seed]) 172 | )) 173 | s = '' 174 | hp_to_care_about = [ 175 | 'per_device_train_batch_size', 176 | 'gradient_accumulation_steps', 177 | 'learning_rate', 178 | 'eval_steps', 179 | 'max_steps', 180 | 'f0_scaling', 181 | 'pre_projection_scale', 182 | 'kernel_regularization', 183 | 'kernel_gamma', 184 | 'kernel_lambda', 185 | 'output_dir', 186 | ] 187 | for k in hp_to_care_about: 188 | s += '| {}: {} '.format(k, seed_best[seed].get(k, "")) 189 | print(' ' + s) 190 | s = "mean +- std: " 191 | if len(final_result_test) > 0: 192 | s += "%.1f (%.1f) (#seeds %s) (#runs %s) (median %.1f)" % (final_result_test.mean() * 100, final_result_test.std() * 100, len(final_result_test), num_results.sum(), np.median(final_result_test) * 100,) 193 | if len(args.test_key2) > 0: 194 | s += "second metric: %.1f (%.1f) (median %.1f)" % (final_result_test2.mean() * 100, final_result_test2.std() * 100, np.median(final_result_test2) * 100) 195 | print(s) 196 | print("") 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /tools/generate_k_shot_data.py: -------------------------------------------------------------------------------- 1 | """This script samples K examples randomly without replacement from the original data.""" 2 | 3 | import argparse 4 | import os 5 | import numpy as np 6 | import pandas as pd 7 | from pandas import DataFrame 8 | 9 | def get_label(task, line): 10 | if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 11 | # GLUE style 12 | line = line.strip().split('\t') 13 | if task == 'CoLA': 14 | return line[1] 15 | elif task == 'MNLI': 16 | return line[-1] 17 | elif task == 'MRPC': 18 | return line[0] 19 | elif task == 'QNLI': 20 | return line[-1] 21 | elif task == 'QQP': 22 | return line[-1] 23 | elif task == 'RTE': 24 | return line[-1] 25 | elif task == 'SNLI': 26 | return line[-1] 27 | elif task == 'SST-2': 28 | return line[-1] 29 | elif task == 'STS-B': 30 | return 0 if float(line[-1]) < 2.5 else 1 31 | elif task == 'WNLI': 32 | return line[-1] 33 | else: 34 | raise NotImplementedError 35 | else: 36 | return line[0] 37 | 38 | def load_datasets(data_dir, tasks): 39 | datasets = {} 40 | for task in tasks: 41 | if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 42 | # GLUE style (tsv) 43 | dataset = {} 44 | dirname = os.path.join(data_dir, task) 45 | if task == "MNLI": 46 | splits = ["train", "dev_matched", "dev_mismatched"] 47 | else: 48 | splits = ["train", "dev"] 49 | for split in splits: 50 | filename = os.path.join(dirname, f"{split}.tsv") 51 | with open(filename, "r") as f: 52 | lines = f.readlines() 53 | dataset[split] = lines 54 | datasets[task] = dataset 55 | else: 56 | # Other datasets (csv) 57 | dataset = {} 58 | dirname = os.path.join(data_dir, task) 59 | splits = ["train", "test"] 60 | for split in splits: 61 | filename = os.path.join(dirname, f"{split}.csv") 62 | dataset[split] = pd.read_csv(filename, header=None) 63 | datasets[task] = dataset 64 | return datasets 65 | 66 | def split_header(task, lines): 67 | """ 68 | Returns if the task file has a header or not. Only for GLUE tasks. 69 | """ 70 | if task in ["CoLA"]: 71 | return [], lines 72 | elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]: 73 | return lines[0:1], lines[1:] 74 | else: 75 | raise ValueError("Unknown GLUE task.") 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--k", type=int, default=16, 80 | help="Training examples for each class.") 81 | parser.add_argument("--task", type=str, nargs="+", 82 | default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'], 83 | help="Task names") 84 | parser.add_argument("--seed", type=int, nargs="+", 85 | default=[100, 13, 21, 42, 87], 86 | help="Random seeds") 87 | 88 | parser.add_argument("--data_dir", type=str, default="data/original", help="Path to original data") 89 | parser.add_argument("--output_dir", type=str, default="data", help="Output path") 90 | parser.add_argument("--mode", type=str, default='k-shot', choices=['k-shot', 'k-shot-10x', 'k-shot-1k-test'], help="k-shot or k-shot-10x (10x dev set)") 91 | 92 | args = parser.parse_args() 93 | args.output_dir = os.path.join(args.output_dir, args.mode) 94 | 95 | k = args.k 96 | print("K =", k) 97 | datasets = load_datasets(args.data_dir, args.task) 98 | 99 | for seed in args.seed: 100 | print("Seed = %d" % (seed)) 101 | for task, dataset in datasets.items(): 102 | # Set random seed 103 | np.random.seed(seed) 104 | 105 | # Shuffle the training set 106 | print("| Task = %s" % (task)) 107 | if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 108 | # GLUE style 109 | train_header, train_lines = split_header(task, dataset["train"]) 110 | np.random.shuffle(train_lines) 111 | else: 112 | # Other datasets 113 | train_lines = dataset['train'].values.tolist() 114 | np.random.shuffle(train_lines) 115 | 116 | # Set up dir 117 | task_dir = os.path.join(args.output_dir, task) 118 | setting_dir = os.path.join(task_dir, f"{k}-{seed}") 119 | os.makedirs(setting_dir, exist_ok=True) 120 | 121 | # Write test splits 122 | if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 123 | # GLUE style 124 | # Use the original development set as the test set (the original test sets are not publicly available) 125 | for split, lines in dataset.items(): 126 | if split.startswith("train"): 127 | continue 128 | split = split.replace('dev', 'test') 129 | 130 | test_header, test_lines = split_header(task, lines) 131 | if '1k-test' in args.mode and len(test_lines) > 1000: 132 | np.random.seed(42) 133 | np.random.shuffle(test_lines) 134 | test_lines = test_lines[:1000] 135 | with open(os.path.join(setting_dir, f"{split}.tsv"), "w") as f: 136 | for line in test_header: 137 | f.write(line) 138 | for line in test_lines: 139 | f.write(line) 140 | else: 141 | # Other datasets 142 | # Use the original test sets 143 | test_dataset = dataset['test'] 144 | if '1k-test' in args.mode and len(test_dataset.index) > 1000: 145 | test_dataset = test_dataset.sample(n=1000, random_state=42) 146 | test_dataset.to_csv(os.path.join(setting_dir, 'test.csv'), header=False, index=False) 147 | 148 | # Get label list for balanced sampling 149 | label_list = {} 150 | for line in train_lines: 151 | label = get_label(task, line) 152 | if label not in label_list: 153 | label_list[label] = [line] 154 | else: 155 | label_list[label].append(line) 156 | 157 | if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 158 | with open(os.path.join(setting_dir, "train.tsv"), "w") as f: 159 | for line in train_header: 160 | f.write(line) 161 | for label in label_list: 162 | for line in label_list[label][:k]: 163 | f.write(line) 164 | name = "dev.tsv" 165 | if task == 'MNLI': 166 | name = "dev_matched.tsv" 167 | with open(os.path.join(setting_dir, name), "w") as f: 168 | for line in train_header: 169 | f.write(line) 170 | for label in label_list: 171 | dev_rate = 11 if '10x' in args.mode else 2 172 | for line in label_list[label][k:k*dev_rate]: 173 | f.write(line) 174 | else: 175 | new_train = [] 176 | for label in label_list: 177 | for line in label_list[label][:k]: 178 | new_train.append(line) 179 | new_train = DataFrame(new_train) 180 | new_train.to_csv(os.path.join(setting_dir, 'train.csv'), header=False, index=False) 181 | 182 | new_dev = [] 183 | for label in label_list: 184 | dev_rate = 11 if '10x' in args.mode else 2 185 | for line in label_list[label][k:k*dev_rate]: 186 | new_dev.append(line) 187 | new_dev = DataFrame(new_dev) 188 | new_dev.to_csv(os.path.join(setting_dir, 'dev.csv'), header=False, index=False) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | --------------------------------------------------------------------------------