├── LICENCE ├── README.md ├── cli.py ├── data └── FewGLUE_dev32 │ └── readme.txt ├── genaug ├── confidence_filter.py ├── gen_aug_T5.py ├── readme.txt ├── total_gen_aug.py └── utils.py ├── img ├── model.png └── readme.txt ├── log.py ├── modified_models ├── modeling_deberta_v2.py └── readme.txt ├── pet ├── __init__.py ├── evaluate_record.py ├── modeling.py ├── preprocessor.py ├── pvp.py ├── task_helpers.py ├── tasks.py ├── utils.py └── wrapper.py ├── petal.py ├── requirements.txt └── scripts ├── gen_augdata_commands.txt ├── run_deberta_pet.sh └── run_pet.sh /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 zhouj8553 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlipDA 2 | 3 | This repository contains the official code for FlipDA. 4 | 5 | We provide an automatic data augmentation method based on T5 and self-training by flipping the label. We evaluate it on FewGLUE, and improve its performance. 6 | 7 | ## Model 8 | 9 | ![image](https://github.com/zhouj8553/FlipDA/blob/main-v2/img/model.png) 10 | 11 | ## Setup 12 | 13 | Install and setup environment with requirements.txt. 14 | All our experiments are conducted on a single V100 with 32G VIDEO MEMORY. 15 | The FewGLUE dataset is default placed in the folder _data_. The data could be downloaded in https://drive.google.com/file/d/1ibM0YikAza0R7v1HGDYIFO1WO-78432A/view?usp=sharing. 16 | 17 | ## Step1: Run Baseline 18 | 19 | First, you should run the baseline. could be "boolq", "rte", "cb", "copa", "wsc", "wic", and "multirc". could be 0,1,2,..., according to the number of your gpu. 20 | 21 | ```bash 22 | bash scripts/run_pet.sh baseline 23 | ``` 24 | 25 | For example, to reproduce the baseline, you can run the commands as follows: 26 | 27 | ```bash 28 | bash scripts/run_pet.sh boolq 0 baseline 29 | bash scripts/run_pet.sh rte 1 baseline 30 | bash scripts/run_pet.sh cb 2 baseline 31 | bash scripts/run_pet.sh multirc 3 baseline 32 | bash scripts/run_pet.sh copa 4 baseline 33 | bash scripts/run_pet.sh wsc 5 baseline 34 | bash scripts/run_pet.sh wic 6 baseline 35 | bash scripts/run_pet.sh record 7 baseline 36 | ``` 37 | 38 | If you run the command and shell as default, the results will be in _results/baseline/pet/\_albert\_model/result_test.txt_. 39 | 40 | ## Step2: Produce augmented files 41 | 42 | The code to generate augmented examples by T5 model is in _genaug/total_gen_aug.py_. 43 | 44 | You could use the command as follows to generate augmented examples. could be "BoolQ", "RTE", "CB", "COPA", "WSC", "WiC", "MultiRC", and "ReCoRD". could be arbitrary floating point number between 0 and 1, and in our experiments, we only tried 0.3, 0.5, and 0.8. could be "default" or "rand_iter_%d", where %d could be any integers. could be "flip" or "keep". "do_sample" and "num_beams" controls the generation style (sample/greedy/beam search). denotes the number of augmented samples to be generated for each sample, in our experiments, we choose 10. 45 | 46 | ```bash 47 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name --mask_ratio --aug_type --label_type --do_sample --num_beams --aug_num 48 | ``` 49 | 50 | For example, to generate the augmented data we use in RTE, you could use the following command. 51 | 52 | ```bash 53 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name RTE --mask_ratio 0.5 --aug_type 'default' --label_type 'flip' --do_sample --num_beams 1 --aug_num 10 54 | CUDA_VISIBLE_DEVICES=1 python -m genaug.total_gen_aug --task_name RTE --mask_ratio 0.5 --aug_type 'default' --label_type 'keep' --do_sample --num_beams 1 --aug_num 10 55 | 56 | ``` 57 | 58 | ## Alternative: Run baselines with augmented files without classifier 59 | 60 | If you want to add all the augmented data in the augmented files into the model (you do not want to change the label according to the trained classifier or filter the augmented examples), you could run the command as follows. 61 | 62 | ```bash 63 | bash scripts/run_pet.sh boolq 0 64 | ``` 65 | 66 | ## Step3: Run FlipDA with augmented files 67 | 68 | If the has the corresponding version, we will load it. For example, if the filename is "t5_flip_0.5_default_sample0_beam1_augnum10" and we find "t5_keep_0.5_default_sample0_beam1_augnum10", we will load them both. 69 | 70 | If you allow the classifier to correct the label of the augmented data, run the command as follows, where is the augmented file name such as "t5_flip_0.5_default_sample0_beam1_augnum10". 71 | 72 | ```bash 73 | bash scripts/run_pet.sh boolq 0 genaug__filter_max_eachla 74 | ``` 75 | 76 | If you do not allow the classifier to correct the label of the augmented data, run the command as follows, where is the augmented file name such as "t5_flip_0.5_default_sample0_beam1_augnum10". 77 | 78 | ```bash 79 | bash scripts/run_pet.sh boolq 0 genaug__filter_max_eachla_sep 80 | ``` 81 | 82 | Note that which command to choose is based on the relative power of the augmentation model and the classification model. If the augmentation model is accurate enough, choosing the command with "sep" will be better. Otherwise, choose the first one. If you are not sure, just try them both. 83 | 84 | To reproduce our result, command for AlBERT-xxlarge-v2 of FlipDA is: 85 | 86 | ```bash 87 | bash scripts/run_pet.sh rte 0 genaug_t5_flip_0.5_default_sample1_beam1_augnum10_filter_max_eachla_sep 88 | bash scripts/run_pet.sh boolq 0 genaug_t5_flip_0.3_default_sample1_beam1_augnum10_filter_max_eachla 89 | bash scripts/run_pet.sh cb 0 genaug_t5_flip_0.5_default_sample1_beam1_augnum10_filter_max_eachla 90 | bash scripts/run_pet.sh copa 0 genaug_t5_flip_0.8_default_sample0_beam10_augnum10_filter_max_eachla_sep 91 | bash scripts/run_pet.sh wic 0 genaug_t5_flip_0.8_default_sample1_beam1_augnum10_filter_max_eachla_sep 92 | bash scripts/run_pet.sh wsc 0 genaug_t5_keep_0.3_default_sample0_beam1_augnum10wscaugtype_extra_filter_max_prevla 93 | bash scripts/run_pet.sh multirc 0 genaug_t5_flip_0.5_rand_iter_10_sample1_beam1_augnum10_filter_max_eachla_sep 94 | bash scripts/run_pet.sh record 0 genaug_t5_flip_0.3_rand_iter_10_sample0_beam10_augnum10_filter_max_eachla 95 | ``` 96 | 97 | Note that we do not try all the hyperparameter combinations to save time and avoid overfitting, and our main contribution is proving that label flipping is useful for few-shot data augmentation. We will not be surprised if you get a better result with more careful implementation and hyperparameter selection. 98 | 99 | ## Citation 100 | 101 | This paper will appear at ACL 2022 (Main Conference). Please cite us if FlipDA is useful in your work: 102 | 103 | ``` 104 | @inproceedings{DBLP:conf/acl/ZhouZTJY22, 105 | author = {Jing Zhou and 106 | Yanan Zheng and 107 | Jie Tang and 108 | Li Jian and 109 | Zhilin Yang}, 110 | title = {FlipDA: Effective and Robust Data Augmentation for Few-Shot Learning}, 111 | booktitle = {{ACL} {(1)}}, 112 | pages = {8646--8665}, 113 | publisher = {Association for Computational Linguistics}, 114 | year = {2022} 115 | } 116 | ``` 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | This script can be used to train and evaluate either a regular supervised model or a PET/iPET model on 15 | one of the supported tasks and datasets. 16 | """ 17 | 18 | import argparse 19 | import copy 20 | import os 21 | from typing import Tuple 22 | import shutil 23 | import torch 24 | import ast 25 | 26 | from pet.tasks import PROCESSORS, load_examples, UNLABELED_SET, TRAIN_SET, DEV_SET, TEST_SET, METRICS, DEFAULT_METRICS 27 | from pet.utils import eq_div, set_seed 28 | from pet.wrapper import WRAPPER_TYPES, MODEL_CLASSES, SEQUENCE_CLASSIFIER_WRAPPER, WrapperConfig 29 | import pet 30 | import log 31 | 32 | logger = log.get_logger('root') 33 | 34 | import numpy as np 35 | 36 | def load_pet_configs(args) -> Tuple[WrapperConfig, pet.TrainConfig, pet.EvalConfig]: 37 | """ 38 | Load the model, training and evaluation configs for PET from the given command line arguments. 39 | """ 40 | model_cfg = WrapperConfig(model_type=args.model_type, model_name_or_path=args.model_name_or_path, 41 | wrapper_type=args.wrapper_type, task_name=args.task_name, label_list=args.label_list, 42 | max_seq_length=args.pet_max_seq_length, verbalizer_file=args.verbalizer_file, 43 | cache_dir=args.cache_dir, 44 | use_noisy_student=args.use_noisy_student, drop_prob=args.drop_prob, 45 | fix_deberta=args.fix_deberta, 46 | mixup=args.mixup,mixup_alpha=args.mixup_alpha) 47 | 48 | train_cfg = pet.TrainConfig(device=args.device, per_gpu_train_batch_size=args.pet_per_gpu_train_batch_size, 49 | per_gpu_unlabeled_batch_size=args.pet_per_gpu_unlabeled_batch_size, n_gpu=args.n_gpu, 50 | num_train_epochs=args.pet_num_train_epochs, max_steps=args.pet_max_steps, 51 | gradient_accumulation_steps=args.pet_gradient_accumulation_steps, 52 | weight_decay=args.weight_decay, learning_rate=args.learning_rate, 53 | adam_epsilon=args.adam_epsilon, warmup_steps=args.warmup_steps, 54 | max_grad_norm=args.max_grad_norm, lm_training=args.lm_training, alpha=args.alpha) 55 | 56 | eval_cfg = pet.EvalConfig(device=args.device, n_gpu=args.n_gpu, metrics=args.metrics, 57 | per_gpu_eval_batch_size=args.pet_per_gpu_eval_batch_size, 58 | decoding_strategy=args.decoding_strategy, priming=args.priming) 59 | 60 | return model_cfg, train_cfg, eval_cfg 61 | 62 | 63 | def load_sequence_classifier_configs(args) -> Tuple[WrapperConfig, pet.TrainConfig, pet.EvalConfig]: 64 | """ 65 | Load the model, training and evaluation configs for a regular sequence classifier from the given command line 66 | arguments. This classifier can either be used as a standalone model or as the final classifier for PET/iPET. 67 | """ 68 | model_cfg = WrapperConfig(model_type=args.model_type, model_name_or_path=args.model_name_or_path, 69 | wrapper_type=SEQUENCE_CLASSIFIER_WRAPPER, task_name=args.task_name, 70 | label_list=args.label_list, max_seq_length=args.sc_max_seq_length, 71 | verbalizer_file=args.verbalizer_file, cache_dir=args.cache_dir, 72 | use_noisy_student=args.use_noisy_student, drop_prob=args.drop_prob, 73 | fix_deberta=args.fix_deberta) 74 | 75 | train_cfg = pet.TrainConfig(device=args.device, per_gpu_train_batch_size=args.sc_per_gpu_train_batch_size, 76 | per_gpu_unlabeled_batch_size=args.sc_per_gpu_unlabeled_batch_size, n_gpu=args.n_gpu, 77 | num_train_epochs=args.sc_num_train_epochs, max_steps=args.sc_max_steps, 78 | temperature=args.temperature, 79 | gradient_accumulation_steps=args.sc_gradient_accumulation_steps, 80 | weight_decay=args.weight_decay, learning_rate=args.learning_rate, 81 | adam_epsilon=args.adam_epsilon, warmup_steps=args.warmup_steps, 82 | max_grad_norm=args.max_grad_norm, use_logits=args.method != 'sequence_classifier') 83 | 84 | eval_cfg = pet.EvalConfig(device=args.device, n_gpu=args.n_gpu, metrics=args.metrics, 85 | per_gpu_eval_batch_size=args.sc_per_gpu_eval_batch_size) 86 | 87 | return model_cfg, train_cfg, eval_cfg 88 | 89 | 90 | def load_ipet_config(args) -> pet.IPetConfig: 91 | """ 92 | Load the iPET config from the given command line arguments. 93 | """ 94 | ipet_cfg = pet.IPetConfig(generations=args.ipet_generations, logits_percentage=args.ipet_logits_percentage, 95 | scale_factor=args.ipet_scale_factor, n_most_likely=args.ipet_n_most_likely) 96 | return ipet_cfg 97 | 98 | 99 | def main(): 100 | parser = argparse.ArgumentParser(description="Command line interface for PET/iPET") 101 | 102 | # Required parameters 103 | parser.add_argument("--method", required=True, choices=['pet', 'ipet', 'sequence_classifier', 'noisy_student'], 104 | help="The training method to use. Either regular sequence classification, PET or iPET.") 105 | parser.add_argument("--data_dir", default=None, type=str, required=True, 106 | help="The input data dir. Should contain the data files for the task.") 107 | parser.add_argument("--model_type", default=None, type=str, required=True, choices=MODEL_CLASSES.keys(), 108 | help="The type of the pretrained language model to use") 109 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 110 | help="Path to the pre-trained model or shortcut name") 111 | parser.add_argument("--task_name", default=None, type=str, required=True, choices=PROCESSORS.keys(), 112 | help="The name of the task to train/evaluate on") 113 | parser.add_argument("--output_dir", default=None, type=str, required=True, 114 | help="The output directory where the model predictions and checkpoints will be written") 115 | 116 | # PET-specific optional parameters 117 | parser.add_argument("--wrapper_type", default="mlm", choices=WRAPPER_TYPES, 118 | help="The wrapper type. Set this to 'mlm' for a masked language model like BERT or to 'plm' " 119 | "for a permuted language model like XLNet (only for PET)") 120 | parser.add_argument("--pattern_ids", default=[0], type=int, nargs='+', 121 | help="The ids of the PVPs to be used (only for PET)") 122 | parser.add_argument("--lm_training", action='store_true', 123 | help="Whether to use language modeling as auxiliary task (only for PET)") 124 | parser.add_argument("--alpha", default=0.9999, type=float, 125 | help="Weighting term for the auxiliary language modeling task (only for PET)") 126 | parser.add_argument("--temperature", default=2, type=float, 127 | help="Temperature used for combining PVPs (only for PET)") 128 | parser.add_argument("--verbalizer_file", default=None, 129 | help="The path to a file to override default verbalizers (only for PET)") 130 | parser.add_argument("--reduction", default='mean', choices=['wmean', 'mean'], 131 | help="Reduction strategy for merging predictions from multiple PET models. Select either " 132 | "uniform weighting (mean) or weighting based on train set accuracy (wmean)") 133 | parser.add_argument("--decoding_strategy", default='default', choices=['default', 'ltr', 'parallel'], 134 | help="The decoding strategy for PET with multiple masks (only for PET)") 135 | parser.add_argument("--no_distillation", action='store_true', 136 | help="If set to true, no distillation is performed (only for PET)") 137 | parser.add_argument("--pet_repetitions", default=3, type=int, 138 | help="The number of times to repeat PET training and testing with different seeds.") 139 | parser.add_argument("--pet_max_seq_length", default=256, type=int, 140 | help="The maximum total input sequence length after tokenization for PET. Sequences longer " 141 | "than this will be truncated, sequences shorter will be padded.") 142 | parser.add_argument("--pet_per_gpu_train_batch_size", default=4, type=int, 143 | help="Batch size per GPU/CPU for PET training.") 144 | parser.add_argument("--pet_per_gpu_eval_batch_size", default=8, type=int, 145 | help="Batch size per GPU/CPU for PET evaluation.") 146 | parser.add_argument("--pet_per_gpu_unlabeled_batch_size", default=4, type=int, 147 | help="Batch size per GPU/CPU for auxiliary language modeling examples in PET.") 148 | parser.add_argument('--pet_gradient_accumulation_steps', type=int, default=1, 149 | help="Number of updates steps to accumulate before performing a backward/update pass in PET.") 150 | parser.add_argument("--pet_num_train_epochs", default=3, type=float, 151 | help="Total number of training epochs to perform in PET.") 152 | parser.add_argument("--pet_max_steps", default=-1, type=int, 153 | help="If > 0: set total number of training steps to perform in PET. Override num_train_epochs.") 154 | 155 | # SequenceClassifier-specific optional parameters (also used for the final PET classifier) 156 | parser.add_argument("--sc_repetitions", default=1, type=int, 157 | help="The number of times to repeat seq. classifier training and testing with different seeds.") 158 | parser.add_argument("--sc_max_seq_length", default=256, type=int, 159 | help="The maximum total input sequence length after tokenization for sequence classification. " 160 | "Sequences longer than this will be truncated, sequences shorter will be padded.") 161 | parser.add_argument("--sc_per_gpu_train_batch_size", default=4, type=int, 162 | help="Batch size per GPU/CPU for sequence classifier training.") 163 | parser.add_argument("--sc_per_gpu_eval_batch_size", default=8, type=int, 164 | help="Batch size per GPU/CPU for sequence classifier evaluation.") 165 | parser.add_argument("--sc_per_gpu_unlabeled_batch_size", default=4, type=int, 166 | help="Batch size per GPU/CPU for unlabeled examples used for distillation.") 167 | parser.add_argument('--sc_gradient_accumulation_steps', type=int, default=1, 168 | help="Number of updates steps to accumulate before performing a backward/update pass for " 169 | "sequence classifier training.") 170 | parser.add_argument("--sc_num_train_epochs", default=3, type=float, 171 | help="Total number of training epochs to perform for sequence classifier training.") 172 | parser.add_argument("--sc_max_steps", default=-1, type=int, 173 | help="If > 0: set total number of training steps to perform for sequence classifier training. " 174 | "Override num_train_epochs.") 175 | 176 | # iPET-specific optional parameters 177 | parser.add_argument("--ipet_generations", default=3, type=int, 178 | help="The number of generations to train (only for iPET)") 179 | parser.add_argument("--ipet_logits_percentage", default=0.25, type=float, 180 | help="The percentage of models to choose for annotating new training sets (only for iPET)") 181 | parser.add_argument("--ipet_scale_factor", default=5, type=float, 182 | help="The factor by which to increase the training set size per generation (only for iPET)") 183 | parser.add_argument("--ipet_n_most_likely", default=-1, type=int, 184 | help="If >0, in the first generation the n_most_likely examples per label are chosen even " 185 | "if their predicted label is different (only for iPET)") 186 | 187 | # Other optional parameters 188 | parser.add_argument("--train_examples", default=-1, type=int, 189 | help="The total number of train examples to use, where -1 equals all examples.") 190 | parser.add_argument("--test_examples", default=-1, type=int, 191 | help="The total number of test examples to use, where -1 equals all examples.") 192 | parser.add_argument("--unlabeled_examples", default=-1, type=int, 193 | help="The total number of unlabeled examples to use, where -1 equals all examples") 194 | parser.add_argument("--split_examples_evenly", action='store_true', 195 | help="If true, train examples are not chosen randomly, but split evenly across all labels.") 196 | parser.add_argument("--cache_dir", default="", type=str, 197 | help="Where to store the pre-trained models downloaded from S3.") 198 | parser.add_argument("--learning_rate", default=1e-5, type=float, 199 | help="The initial learning rate for Adam.") 200 | parser.add_argument("--weight_decay", default=0.01, type=float, 201 | help="Weight decay if we apply some.") 202 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 203 | help="Epsilon for Adam optimizer.") 204 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 205 | help="Max gradient norm.") 206 | parser.add_argument("--warmup_steps", default=0, type=int, 207 | help="Linear warmup over warmup_steps.") 208 | parser.add_argument('--logging_steps', type=int, default=50, 209 | help="Log every X updates steps.") 210 | parser.add_argument("--no_cuda", action='store_true', 211 | help="Avoid using CUDA when available") 212 | parser.add_argument('--overwrite_output_dir', action='store_true', 213 | help="Overwrite the content of the output directory") 214 | parser.add_argument('--seed', type=int, default=42, 215 | help="random seed for initialization") 216 | parser.add_argument('--do_train', action='store_true', 217 | help="Whether to perform training") 218 | parser.add_argument('--do_eval', action='store_true', 219 | help="Whether to perform evaluation") 220 | parser.add_argument('--priming', action='store_true', 221 | help="Whether to use priming for evaluation") 222 | parser.add_argument("--eval_set", choices=['dev', 'test'], default='dev', 223 | help="Whether to perform evaluation on the dev set or the test set") 224 | 225 | parser.add_argument("--search_type",default='',type=str) 226 | parser.add_argument("--aug_ids", default=[0,2], type=int, nargs='+',) 227 | parser.add_argument("--filter_pattern",default=-1,type=int) 228 | parser.add_argument("--fixla_ratio",default='[[-1,-1],[-1,-1]]',type=str) 229 | parser.add_argument("--fixla_num",default='[[14,14],[18,18]]',type=str) 230 | parser.add_argument("--rmdup_num",default=1,type=int) 231 | 232 | 233 | # TODO: noisystudent 234 | parser.add_argument("--use_noisy_student", action="store_true", help="Whether to use noisy student.") 235 | parser.add_argument("--drop_prob", default=1.0, type=float, help="Dropout probability for noising input data.") 236 | parser.add_argument("--t5_augment_file_path", help="t5_augment data as unlabeled data for noisy student.") 237 | # t5_flip_0.5_rand_iter_10_sample1_beam1_augnum10_train.jsonl 238 | 239 | parser.add_argument("--sampler_seeds", type=list, default=[10, 20, 30]) 240 | parser.add_argument("--fix_deberta", action="store_true") 241 | 242 | 243 | args = parser.parse_args() 244 | logger.info("Parameters: {}".format(args)) 245 | 246 | assert not ('flip' in args.search_type and 'max_prevla' in args.search_type) and not ('keep' in args.search_type and 'max_otherla' in args.search_type) 247 | set_seed(args.seed) 248 | if 'topk' in args.search_type: 249 | args.output_dir=os.path.join(args.output_dir,args.fixla_num) 250 | if 'topp' in args.search_type: 251 | args.output_dir=os.path.join(args.output_dir,args.fixla_ratio) 252 | if 'rmdup' in args.search_type: 253 | args.output_dir=os.path.join(args.output_dir,'rmdup{}'.format(args.rmdup_num)) 254 | 255 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) \ 256 | and args.do_train and not args.overwrite_output_dir: 257 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 258 | # shutil.rmtree(args.output_dir) 259 | 260 | 261 | if args.search_type.startswith('mixup'): 262 | args.mixup=True 263 | if '_' in args.search_type: 264 | args.mixup_alpha=float(args.search_type.split('_')[1]) 265 | else: 266 | args.mixup_alpha=0.5 267 | else: 268 | args.mixup=False; args.mixup_alpha=-1 269 | 270 | # Setup CUDA, GPU & distributed training 271 | args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 272 | args.n_gpu = torch.cuda.device_count() 273 | 274 | # Prepare task 275 | args.task_name = args.task_name.lower() 276 | if args.task_name not in PROCESSORS: 277 | raise ValueError("Task '{}' not found".format(args.task_name)) 278 | processor = PROCESSORS[args.task_name]() 279 | args.label_list = processor.get_labels() 280 | 281 | train_ex_per_label, test_ex_per_label = None, None 282 | train_ex, test_ex = args.train_examples, args.test_examples 283 | if args.split_examples_evenly: 284 | train_ex_per_label = eq_div(args.train_examples, len(args.label_list)) if args.train_examples != -1 else -1 285 | test_ex_per_label = eq_div(args.test_examples, len(args.label_list)) if args.test_examples != -1 else -1 286 | train_ex, test_ex = None, None 287 | 288 | eval_set = TEST_SET if args.eval_set == 'test' else DEV_SET 289 | 290 | train_data = load_examples( 291 | args.task_name, args.data_dir, TRAIN_SET, num_examples=train_ex, num_examples_per_label=train_ex_per_label) 292 | eval_data = load_examples( 293 | args.task_name, args.data_dir, eval_set, num_examples=test_ex, num_examples_per_label=test_ex_per_label) 294 | 295 | # TODO: noisy student 296 | 297 | if args.use_noisy_student: 298 | taskname_map = {'copa': 'COPA', 'rte': 'RTE', 'boolq': "BoolQ", 'multirc': "MultiRC", 'wic': 'WiC', 'wsc': "WSC", 'cb': 'CB'} 299 | t5_augment_file_paths=eval(args.t5_augment_file_path) 300 | unlabeled_data=[] 301 | for t5_augment_file_path in t5_augment_file_paths: 302 | unlabeled_data += processor._create_examples( 303 | os.path.join('/'.join(args.data_dir.split('/')[:-1]), "augmented/{}/{}".format(taskname_map[args.task_name], t5_augment_file_path+'_train.jsonl')), "t5_augment_as_unlabeled") 304 | logger.info("t5_augment_as_unlabeled num_examples: {}".format(len(unlabeled_data))) 305 | 306 | else: 307 | unlabeled_data = load_examples( 308 | args.task_name, args.data_dir, UNLABELED_SET, num_examples=args.unlabeled_examples) 309 | 310 | args.metrics = METRICS.get(args.task_name, DEFAULT_METRICS) 311 | 312 | pet_model_cfg, pet_train_cfg, pet_eval_cfg = load_pet_configs(args) 313 | sc_model_cfg, sc_train_cfg, sc_eval_cfg = load_sequence_classifier_configs(args) 314 | ipet_cfg = load_ipet_config(args) 315 | 316 | 317 | if args.search_type=='baseline' or args.search_type.startswith('mixup'): 318 | pass 319 | elif args.search_type.startswith('genaug'): 320 | taskname_map={'copa':'COPA','rte':'RTE','boolq':"BoolQ",'multirc':"MultiRC",'wic':'WiC','wsc':"WSC",'cb':'CB'} 321 | new_examples=processor._create_examples(os.path.join('/'.join(args.data_dir.split('/')[:-1]), "augmented/{}/{}_train.jsonl".format(taskname_map[args.task_name],'_'.join(''.join(args.search_type.split('_filter_')[0]).split('_')[1:]))),"train") 322 | else: 323 | taskname_map={'rte':'RTE','boolq':'BoolQ','cb':'CB','multirc':'MultiRC','copa':'COPA','wic':'WiC','wsc':'WSC'} 324 | new_examples=processor._create_examples(os.path.join('/'.join(args.data_dir.split('/')[:-1]), "augmented/{}/{}_train.jsonl".format(taskname_map[args.task_name],args.search_type.split('_filter_')[0])), "train") 325 | 326 | if 'filter' in args.search_type: 327 | pattern_iter_output_dir='results/baseline/pet/{}_{}_model/'.format(args.task_name.lower(), args.model_type) 328 | from genaug import confidence_filter 329 | # if args.filter_pattern==-1: 330 | # # best_pattern_map={'boolq':0,'rte':3,'cb':3,'multirc':1,'wic':1,'wsc':2,'copa':0} 331 | # # best_iter_map={'boolq':0,'rte':1,'cb':1,'multirc':2,'wic':2,'wsc':2,'copa':0} 332 | # best_pattern_map={'boolq':1,'rte':3,'cb':2,'multirc':0,'wic':1,'wsc':2,'copa':1} 333 | # best_iter_map={'boolq':2,'rte':1,'cb':2,'multirc':0,'wic':2,'wsc':0,'copa':2} 334 | # args.filter_pattern=best_pattern_map[args.task_name.lower()] 335 | # args.best_iter_pattern=best_iter_map[args.task_name.lower()] 336 | if args.filter_pattern==-1: 337 | subdirs=next(os.walk(pattern_iter_output_dir))[1] 338 | best_score=-100 339 | for subdir in subdirs: 340 | results_file=os.path.join(pattern_iter_output_dir,subdir,'results.json') 341 | with open(results_file,'r') as fh: 342 | results=ast.literal_eval(fh.read().lower().replace('nan','100')) 343 | score=np.mean([y for (x,y) in results['test_set_after_training'].items()]) 344 | if score>best_score: 345 | best_score=score 346 | new_pattern_iter_output_dir=os.path.join(pattern_iter_output_dir,subdir) 347 | pattern_iter_output_dir=new_pattern_iter_output_dir 348 | else: 349 | pattern_iter_output_dir=os.path.join(pattern_iter_output_dir,'p{}-i{}'.format(args.filter_pattern,0)) 350 | myfilter=confidence_filter.Confidence_Filter(pattern_iter_output_dir=pattern_iter_output_dir) 351 | if 'flip' in args.search_type and 'max_otherla' not in args.search_type and args.task_name.lower()!='wsc': 352 | keep_path=os.path.join('/'.join(args.data_dir.split('/')[:-1]), "augmented/{}/{}_train.jsonl".format(taskname_map[args.task_name],'_'.join(''.join(args.search_type.replace('flip','keep').split('_filter_')[0]).split('_')[1:]))) 353 | if os.path.exists(keep_path)==True: 354 | keep_examples=processor._create_examples(keep_path,"train") 355 | examples=new_examples+keep_examples 356 | else: 357 | examples=new_examples 358 | elif 'keep' in args.search_type and 'max_prevla' not in args.search_type and args.task_name.lower()!='wsc': 359 | keep_path=os.path.join('/'.join(args.data_dir.split('/')[:-1]), "augmented/{}/{}_train.jsonl".format(taskname_map[args.task_name],'_'.join(''.join(args.search_type.replace('keep','flip').split('_filter_')[0]).split('_')[1:]))) 360 | if os.path.exists(keep_path)==True: 361 | keep_examples=processor._create_examples(keep_path,"train") 362 | examples=new_examples+keep_examples 363 | else: 364 | examples=new_examples 365 | else: 366 | examples=new_examples 367 | 368 | new_examples,filtered_num=myfilter.recover_labels(myfilter.wrapper,examples,pet_eval_cfg,recover_type=args.search_type.split('_filter_')[1],fixla_ratio=eval(args.fixla_ratio),fixla_num=eval(args.fixla_num),rmdup_num=args.rmdup_num) 369 | myfilter.del_finetuned_model() 370 | 371 | if args.search_type=='baseline' or args.search_type.startswith('mixup'): 372 | pass 373 | elif (args.search_type.startswith('eda') or args.search_type.startswith('bt')) and "filter" not in args.search_type: 374 | train_data = train_data + new_examples 375 | else: 376 | if 'max_eachla' in args.search_type: 377 | train_data=train_data+new_examples 378 | else: 379 | train_data=train_data*max(1,int(len(new_examples)//len(train_data)))+new_examples 380 | 381 | if args.drop_prob!=1: 382 | pet_model_cfg.use_noisy_student=True 383 | 384 | if args.method == 'pet': 385 | results=pet.train_pet(pet_model_cfg, pet_train_cfg, pet_eval_cfg, sc_model_cfg, sc_train_cfg, sc_eval_cfg, 386 | pattern_ids=args.pattern_ids, output_dir=args.output_dir, 387 | ensemble_repetitions=args.pet_repetitions, final_repetitions=args.sc_repetitions, 388 | reduction=args.reduction, train_data=train_data, unlabeled_data=unlabeled_data, 389 | eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval, 390 | no_distillation=args.no_distillation, seed=args.seed, sampler_seeds=args.sampler_seeds) 391 | 392 | mean_result={} 393 | for result in results: 394 | for metric,res in result.items(): 395 | if metric not in mean_result: mean_result[metric]=[] 396 | mean_result[metric].append(res) 397 | for (x,y) in mean_result.items(): 398 | mean_result[x]=np.mean(y) 399 | 400 | template_name=['task_name','search_type'] 401 | template_values=[args.task_name,args.search_type] 402 | if 'global_topk' in args.search_type: 403 | template_name.append('fixla_num') 404 | template_values.append(args.fixla_num) 405 | elif 'global_topp' in args.search_type: 406 | template_name.append('fixla_ratio') 407 | template_values.append(args.fixla_ratio) 408 | if 'rmdup' in args.search_type: 409 | template_name.append('rmdup_num') 410 | template_values.append(args.rmdup_num) 411 | if 'filter' in args.search_type: 412 | template_name.append('filtered_num') 413 | template_values.append(filtered_num) 414 | template_name.append('result') 415 | template_name.append('mean_result') 416 | if args.pet_repetitions!=1: 417 | if args.search_type.startswith('genaug'): 418 | writer=open(os.path.join('results/','pet_total_genaug_rep{}_{}.csv'.format(args.pet_repetitions,args.task_name)),'a+') 419 | else: 420 | writer=open(os.path.join('results/','pet_total_rep{}_{}.csv'.format(args.pet_repetitions,args.task_name)),'a+') 421 | else: 422 | if args.search_type.startswith('genaug'): 423 | writer=open(os.path.join('results/','pet_total_genaug_{}.csv'.format(args.task_name)),'a+') 424 | else: 425 | writer=open(os.path.join('results/','pet_total_{}.csv'.format(args.task_name)),'a+') 426 | writer.write((': {}, '.join(template_name)+': {}\n').format(*template_values+[results]+[mean_result])) 427 | writer.close() 428 | 429 | elif args.method.startswith('noisy_student'): 430 | results=pet.train_noisy_student(pet_model_cfg, pet_train_cfg, pet_eval_cfg, ipet_cfg, sc_model_cfg, sc_train_cfg, sc_eval_cfg, 431 | pattern_ids=args.pattern_ids, output_dir=args.output_dir, 432 | ensemble_repetitions=args.pet_repetitions, final_repetitions=args.sc_repetitions, 433 | reduction=args.reduction, train_data=train_data, unlabeled_data=unlabeled_data, 434 | eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval, seed=args.seed, sampler_seeds=args.sampler_seeds, 435 | fixla_ratio=args.fixla_ratio) 436 | 437 | mean_result={} 438 | for result in results: 439 | for metric,res in result.items(): 440 | if metric not in mean_result: mean_result[metric]=[] 441 | mean_result[metric].append(res) 442 | for (x,y) in mean_result.items(): 443 | mean_result[x]=np.mean(y) 444 | 445 | template_name=['task_name','search_type','augmented_file','drop_prob','fixla_ratio','filtered_num','result','mean_result'] 446 | template_values=[args.task_name,args.search_type,args.t5_augment_file_path,args.drop_prob,args.fixla_ratio,filtered_num] 447 | writer=open(os.path.join('results/','pet_total_noisy_{}.csv'.format(args.task_name)),'a+') 448 | writer.write((': {}, '.join(template_name)+': {}\n').format(*template_values+[results]+[mean_result])) 449 | writer.close() 450 | elif args.method == 'ipet': 451 | pet.train_ipet(pet_model_cfg, pet_train_cfg, pet_eval_cfg, ipet_cfg, sc_model_cfg, sc_train_cfg, sc_eval_cfg, 452 | pattern_ids=args.pattern_ids, output_dir=args.output_dir, 453 | ensemble_repetitions=args.pet_repetitions, final_repetitions=args.sc_repetitions, 454 | reduction=args.reduction, train_data=train_data, unlabeled_data=unlabeled_data, 455 | eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval, seed=args.seed, sampler_seeds=args.sampler_seeds) 456 | 457 | elif args.method == 'sequence_classifier': 458 | pet.train_classifier(sc_model_cfg, sc_train_cfg, sc_eval_cfg, output_dir=args.output_dir, 459 | repetitions=args.sc_repetitions, train_data=train_data, unlabeled_data=unlabeled_data, 460 | eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval, seed=args.seed, sampler_seeds=args.sampler_seeds) 461 | 462 | else: 463 | raise ValueError(f"Training method '{args.method}' not implemented") 464 | 465 | 466 | if __name__ == "__main__": 467 | main() 468 | -------------------------------------------------------------------------------- /data/FewGLUE_dev32/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /genaug/confidence_filter.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange,tqdm 2 | import torch 3 | import transformers 4 | 5 | from transformers import AlbertForMaskedLM 6 | from pet.wrapper import TransformerModelWrapper, SEQUENCE_CLASSIFIER_WRAPPER, WrapperConfig 7 | import itertools 8 | import numpy as np 9 | import copy 10 | import string 11 | import re 12 | 13 | from scipy.special import softmax 14 | from collections import defaultdict 15 | 16 | class Normal_Filter(object): 17 | ''' 18 | from genaug import confidence_filter 19 | normalfilter=confidence_filter.Normal_Filter() 20 | filter_funcs=normalfilter.set_sequential_funcs(remove_final_punc=True,remove_question=True,keep_first_sentence=True) 21 | normalfilter.apply_filter('I am a student. and you?') 22 | ''' 23 | def __init__(self): 24 | pass 25 | 26 | def remove_final_punc(self,text): 27 | if text is None: return text 28 | return text.rstrip(string.punctuation) 29 | 30 | def remove_question(self,text): 31 | if text is None: return text 32 | new_text=text.lower() 33 | if new_text.startswith('did') or new_text.startswith('are') or new_text.startswith('is'): 34 | return None 35 | else: 36 | return text 37 | 38 | def keep_first_sentence(self,text): 39 | if text is None: return text 40 | return text[:len(re.split('\?|\.',text)[0])+1] 41 | 42 | def set_sequential_funcs(self,remove_final_punc=False,remove_question=False,keep_first_sentence=False): 43 | funcs=[] 44 | if remove_final_punc==True: 45 | funcs.append(self.remove_final_punc) 46 | if remove_question==True: 47 | funcs.append(self.remove_question) 48 | if keep_first_sentence==True: 49 | funcs.append(self.keep_first_sentence) 50 | self.funcs=funcs 51 | return funcs 52 | 53 | def apply_filter(self,text,funcs=None): 54 | if funcs is None: 55 | funcs=self.funcs 56 | for func in funcs: 57 | text=func(text) 58 | return text 59 | 60 | class Confidence_Filter(object): 61 | def __init__(self,pattern_iter_output_dir=None,wrapper=None): 62 | assert pattern_iter_output_dir is None or wrapper is None 63 | self.wrappers=None 64 | self.wrapper=None 65 | if pattern_iter_output_dir is not None: 66 | self.wrapper=TransformerModelWrapper.from_pretrained(pattern_iter_output_dir) 67 | if wrapper is not None: 68 | self.wrapper=wrapper 69 | 70 | def reload_wrapper(self,wrapper=None,pattern_iter_output_dir=None): 71 | if wrapper is not None: 72 | self.wrapper=wrapper 73 | else: 74 | if isinstance(pattern_iter_output_dir,list): 75 | self.wrappers=[] 76 | for path in pattern_iter_output_dir: 77 | self.wrappers.append(TransformerModelWrapper.from_pretrained(path)) 78 | else: 79 | self.wrapper=TransformerModelWrapper.from_pretrained(pattern_iter_output_dir) 80 | 81 | def validate(self,wrapper,eval_data,eval_config): 82 | if isinstance(wrapper,list): 83 | def merge_outputs(outputs,output): 84 | if outputs is None: 85 | outputs=output 86 | else: 87 | outputs['logits']=outputs['logits']+output['logits'] 88 | return outputs 89 | total_data=list(itertools.chain.from_iterable(eval_data)) 90 | outputs=None 91 | for wrp in wrapper: 92 | wrp.model.to(eval_config.device) 93 | output=wrp.eval(total_data,eval_config.device,per_gpu_eval_batch_size=eval_config.per_gpu_eval_batch_size, 94 | n_gpu=eval_config.n_gpu,decoding_strategy=eval_config.decoding_strategy,priming=eval_config.priming) 95 | outputs=merge_outputs(outputs,output) 96 | wrp.model.to('cpu') 97 | torch.cuda.empty_cache() 98 | outputs['logits']=outputs['logits']/len(wrapper) 99 | return outputs 100 | else: 101 | wrapper.model.to(eval_config.device) 102 | if isinstance(eval_data[0],list): 103 | total_data=list(itertools.chain.from_iterable(eval_data)) 104 | else: 105 | total_data=eval_data 106 | output=wrapper.eval(total_data,eval_config.device,per_gpu_eval_batch_size=eval_config.per_gpu_eval_batch_size, 107 | n_gpu=eval_config.n_gpu,decoding_strategy=eval_config.decoding_strategy,priming=eval_config.priming) 108 | torch.cuda.empty_cache() 109 | return output 110 | 111 | def rearrange_examples(self,examples): 112 | guids=[] 113 | for e in examples: 114 | if e.guid not in guids: guids.append(e.guid) 115 | guid_map={y:x for (x,y) in enumerate(guids)} 116 | new_examples=[[] for _ in range(len(guids))] 117 | for e in examples: 118 | new_examples[guid_map[e.guid]].append(e) 119 | return guids,new_examples 120 | 121 | 122 | def recover_labels(self,wrapper,eval_data,eval_config,recover_type='max_prevla',fixla_num=[[14,18],[14,18]],fixla_ratio=[[0.9,0.9],[0.9,0.9]],rmdup_num=1): 123 | # eval_data: [all_aug_examples] 124 | # recover_type: 125 | # 'max_prevla': for each example, choose the most likely one whose label is preserved 126 | # 'max_eachla': for each example, choose the most likely one for each label if possible 127 | # 'max_otherla': for each example, choose the most likely one whose label is flipped 128 | # 'global_topk': choose examples who are among the topk confident 129 | # 'global_topp': chooce examples whose confidence > topp 130 | example_num=len(eval_data) 131 | label_map=wrapper.preprocessor.label_map 132 | inverse_label_map={x:y for (y,x) in label_map.items()} 133 | label_num=len(label_map) 134 | return_examples=[];filtered_num=dict() 135 | guids,rearranged_examples=self.rearrange_examples(eval_data) 136 | if recover_type==('max_prevla'): 137 | for aug_examples in rearranged_examples: 138 | examples=[e for e in aug_examples if e.label==e.orig_label] 139 | if len(examples)==0: continue 140 | orig_la=label_map[examples[0].orig_label] 141 | la=orig_la 142 | logits=self.validate(wrapper,examples,eval_config)['logits'] 143 | logits=softmax(logits/10,axis=1) 144 | # max_idx=np.argmax(logits[:,orig_la]) 145 | max_idx=-1 146 | for (idx,logit) in enumerate(logits): 147 | if np.argmax(logit)==la and (max_idx==-1 or logit[la]>logits[max_idx,la]): 148 | max_idx=idx 149 | if max_idx!=-1: 150 | return_examples.append(examples[max_idx]) 151 | label_trans='{} -> {}'.format(examples[max_idx].orig_label,examples[max_idx].label) 152 | filtered_num.setdefault(label_trans,0) 153 | filtered_num[label_trans]+=1 154 | elif recover_type==('max_prevla_comb'): 155 | for aug_examples in rearranged_examples: 156 | examples=aug_examples 157 | if len(examples)==0: continue 158 | orig_la=label_map[examples[0].orig_label] 159 | la=orig_la 160 | logits=self.validate(wrapper,examples,eval_config)['logits'] 161 | logits=softmax(logits/10,axis=1) 162 | # max_idx=np.argmax(logits[:,orig_la]) 163 | max_idx=-1 164 | for (idx,logit) in enumerate(logits): 165 | if np.argmax(logit)==la and (max_idx==-1 or logit[la]>logits[max_idx,la]): 166 | max_idx=idx 167 | if max_idx!=-1: 168 | new_example=copy.deepcopy(examples[max_idx]) 169 | new_example.label=inverse_label_map[la] 170 | return_examples.append(new_example) 171 | label_trans='{} -> {}'.format(examples[max_idx].orig_label,examples[max_idx].label) 172 | filtered_num.setdefault(label_trans,0) 173 | filtered_num[label_trans]+=1 174 | elif recover_type==('max_otherla'): 175 | for aug_examples in rearranged_examples: 176 | orig_la=label_map[aug_examples[0].orig_label] 177 | for la in range(label_num): 178 | if la==orig_la: continue 179 | examples=[e for e in aug_examples if label_map[e.label]==la] 180 | if len(examples)==0: continue 181 | logits=self.validate(wrapper,examples,eval_config)['logits'] 182 | logits=softmax(logits/10,axis=1) 183 | max_idx=-1 184 | for (idx,logit) in enumerate(logits): 185 | if np.argmax(logit)==la and (max_idx==-1 or logit[la]>logits[max_idx,la]): 186 | max_idx=idx 187 | if max_idx!=-1: 188 | return_examples.append(examples[max_idx]) 189 | label_trans='{} -> {}'.format(examples[0].orig_label,inverse_label_map[la]) 190 | filtered_num.setdefault(label_trans,0) 191 | filtered_num[label_trans]+=1 192 | elif recover_type==('max_otherla_comb'): 193 | for aug_examples in rearranged_examples: 194 | orig_la=label_map[aug_examples[0].orig_label] 195 | examples=aug_examples 196 | if len(examples)==0: continue 197 | logits=self.validate(wrapper,examples,eval_config)['logits'] 198 | logits=softmax(logits/10,axis=1) 199 | for la in range(label_num): 200 | if la==orig_la: continue 201 | max_idx=-1 202 | for (idx,logit) in enumerate(logits): 203 | if np.argmax(logit)==la and (max_idx==-1 or logit[la]>logits[max_idx,la]): 204 | max_idx=idx 205 | if max_idx!=-1: 206 | new_example=copy.deepcopy(examples[max_idx]) 207 | new_example.label=inverse_label_map[la] 208 | return_examples.append(new_example) 209 | label_trans='{} -> {}'.format(examples[0].orig_label,inverse_label_map[la]) 210 | filtered_num.setdefault(label_trans,0) 211 | filtered_num[label_trans]+=1 212 | elif recover_type==('max_eachla'): # We may flip the label according to the filter 213 | for examples in rearranged_examples: 214 | # import pdb 215 | # pdb.set_trace() 216 | logits=self.validate(wrapper,examples,eval_config)['logits'] 217 | logits=softmax(logits/10,axis=1) 218 | for la in range(label_num): 219 | if (wrapper.config.task_name=='record' or wrapper.config.task_name=='wsc') and la==0: continue 220 | max_idx=-1 221 | for (idx,logit) in enumerate(logits): 222 | if np.argmax(logit)==la and (max_idx==-1 or logit[la]>logits[max_idx,la]): 223 | max_idx=idx 224 | if max_idx!=-1: 225 | new_example=copy.deepcopy(examples[max_idx]) 226 | new_example.label=inverse_label_map[la] 227 | return_examples.append(new_example) 228 | label_trans='{} -> {}'.format(examples[0].orig_label,inverse_label_map[la]) 229 | filtered_num.setdefault(label_trans,0) 230 | filtered_num[label_trans]+=1 231 | elif recover_type==('max_eachla_sep'): 232 | for aug_examples in rearranged_examples: 233 | for la in range(label_num): 234 | if (wrapper.config.task_name=='record' or wrapper.config.task_name=='wsc') and la==0: continue 235 | examples=[e for e in aug_examples if label_map[e.label]==la] 236 | if len(examples)==0: continue 237 | logits=self.validate(wrapper,examples,eval_config)['logits'] 238 | logits=softmax(logits/10,axis=1) 239 | max_idx=-1 240 | for (idx,logit) in enumerate(logits): 241 | if np.argmax(logit)==la and (max_idx==-1 or logit[la]>logits[max_idx,la]): 242 | max_idx=idx 243 | if max_idx!=-1: 244 | return_examples.append(examples[max_idx]) 245 | label_trans='{} -> {}'.format(examples[0].orig_label,inverse_label_map[la]) 246 | filtered_num.setdefault(label_trans,0) 247 | filtered_num[label_trans]+=1 248 | elif recover_type.startswith('global_topk'): 249 | for orig_la in range(label_num): 250 | if 'sep' not in recover_type: 251 | examples=[e for e in eval_data if (label_map[e.orig_label]==orig_la)] 252 | if len(examples)==0: continue 253 | logits=self.validate(wrapper,examples,eval_config)['logits'] 254 | logits=softmax(logits/10,axis=1) 255 | for new_la in range(label_num): 256 | record_guids=defaultdict(int) 257 | if 'sep' in recover_type: 258 | examples=[e for e in eval_data if (label_map[e.orig_label]==orig_la and label_map[e.label]==new_la)] 259 | if len(examples)==0: continue 260 | logits=self.validate(wrapper,examples,eval_config)['logits'] 261 | logits=softmax(logits/10,axis=1) 262 | aug_num=fixla_num[orig_la][new_la] 263 | sortedindexs=np.argsort(logits[:,new_la])[::-1] 264 | for k in range(aug_num): 265 | if 'rmdup' in recover_type and record_guids[examples[sortedindexs[k]].guid]>=rmdup_num: 266 | continue 267 | new_example=copy.deepcopy(examples[sortedindexs[k]]) 268 | new_example.label=inverse_label_map[new_la] 269 | return_examples.append(new_example) 270 | label_trans='{} -> {}'.format(inverse_label_map[orig_la],inverse_label_map[new_la]) 271 | filtered_num.setdefault(label_trans,0) 272 | filtered_num[label_trans]+=1 273 | record_guids[new_example.guid]+=1 274 | elif recover_type.startswith('global_topp'): 275 | for orig_la in range(label_num): 276 | if 'sep' not in recover_type: 277 | examples=[e for e in eval_data if (label_map[e.orig_label]==orig_la)] 278 | if len(examples)==0: continue 279 | logits=self.validate(wrapper,examples,eval_config)['logits'] 280 | logits=softmax(logits,axis=1) 281 | for new_la in range(label_num): 282 | record_guids=defaultdict(int) 283 | if 'sep' in recover_type: 284 | examples=[e for e in eval_data if (label_map[e.orig_label]==orig_la and label_map[e.label]==new_la)] 285 | if len(examples)==0: continue 286 | logits=self.validate(wrapper,examples,eval_config)['logits'] 287 | logits=softmax(logits,axis=1) 288 | for (e,logit) in zip(examples,logits): 289 | if 'rmdup' in recover_type and record_guids[e.guid]>=rmdup_num: 290 | continue 291 | if logit[new_la]>=fixla_ratio[orig_la][new_la]: 292 | new_example=copy.deepcopy(e) 293 | new_example.label=inverse_label_map[new_la] 294 | return_examples.append(new_example) 295 | # return_examples.append(e) 296 | label_trans='{} -> {}'.format(inverse_label_map[orig_la],inverse_label_map[new_la]) 297 | filtered_num.setdefault(label_trans,0) 298 | filtered_num[label_trans]+=1 299 | record_guids[e.guid]+=1 300 | elif recover_type==('believe_cls'): 301 | logits=self.validate(wrapper,eval_data,eval_config)['logits'] 302 | for (e,logit) in zip(eval_data,logits): 303 | orig_la=label_map[e.orig_label] 304 | new_la=np.argmax(logit) 305 | new_example=copy.deepcopy(e) 306 | new_example.label=inverse_label_map[new_la] 307 | return_examples.append(new_example) 308 | # return_examples.append(e) 309 | label_trans='{} -> {}'.format(inverse_label_map[orig_la],inverse_label_map[new_la]) 310 | filtered_num.setdefault(label_trans,0) 311 | filtered_num[label_trans]+=1 312 | elif recover_type.startswith('deterministic_topk'): 313 | for orig_la in range(label_num): 314 | if 'sep' not in recover_type: 315 | examples=[e for e in eval_data if (label_map[e.orig_label]==orig_la)] 316 | if len(examples)==0: continue 317 | logits=self.validate(wrapper,examples,eval_config)['logits'] 318 | logits=softmax(logits/10,axis=1) 319 | for new_la in range(label_num): 320 | if 'sep' in recover_type: 321 | examples=[e for e in eval_data if (label_map[e.orig_label]==orig_la and label_map[e.label]==new_la)] 322 | if len(examples)==0: continue 323 | logits=self.validate(wrapper,examples,eval_config)['logits'] 324 | logits=softmax(logits/10,axis=1) 325 | aug_num=fixla_num[orig_la][new_la] 326 | # prepare sorted grouped list 327 | guids=[] 328 | for e in examples: 329 | if e.guid not in guids: guids.append(e.guid) 330 | guid_map={y:x for (x,y) in enumerate(guids)} 331 | new_examples=[[] for _ in range(len(guids))] 332 | for (e,score) in zip(examples,logits[:,new_la]): 333 | new_examples[guid_map[e.guid]].append((e,score)) 334 | for i in range(len(new_examples)): 335 | new_examples[i]=sorted(new_examples[i],key=lambda x:x[1])[::-1] 336 | # prepare sorted ungrouped list 337 | sorted_ungrouped_examples=[] 338 | for j in range(len(new_examples[0])): 339 | tmp_examples=[] 340 | for i in range(len(new_examples)): 341 | tmp_examples.append(new_examples[i][j]) 342 | tmp_examples=sorted(tmp_examples,key=lambda x:x[1])[::-1] 343 | sorted_ungrouped_examples+=tmp_examples 344 | for (e,score) in sorted_ungrouped_examples[:aug_num]: 345 | new_example=copy.deepcopy(e) 346 | new_example.label=inverse_label_map[new_la] 347 | return_examples.append(new_example) 348 | # return_examples.append(e) 349 | label_trans='{} -> {}'.format(inverse_label_map[orig_la],inverse_label_map[new_la]) 350 | filtered_num.setdefault(label_trans,0) 351 | filtered_num[label_trans]+=1 352 | return return_examples,filtered_num 353 | 354 | def del_finetuned_model(self): 355 | if self.wrappers is not None: 356 | for i in range(len(self.wrappers)): 357 | self.wrappers[i].model.cpu() 358 | self.wrappers[i].model=None 359 | self.wrappers[i]=None 360 | torch.cuda.empty_cache() 361 | else: 362 | self.wrapper.model.cpu() 363 | self.wrapper.model = None 364 | self.wrapper = None 365 | torch.cuda.empty_cache() 366 | 367 | 368 | -------------------------------------------------------------------------------- /genaug/gen_aug_T5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 4 | from transformers import T5ForConditionalGeneration, T5Tokenizer 5 | 6 | from transformers.generation_logits_process import ( 7 | EncoderNoRepeatNGramLogitsProcessor, 8 | ForcedBOSTokenLogitsProcessor, 9 | ForcedEOSTokenLogitsProcessor, 10 | HammingDiversityLogitsProcessor, 11 | InfNanRemoveLogitsProcessor, 12 | LogitsProcessor, 13 | LogitsProcessorList, 14 | MinLengthLogitsProcessor, 15 | NoBadWordsLogitsProcessor, 16 | NoRepeatNGramLogitsProcessor, 17 | PrefixConstrainedLogitsProcessor, 18 | RepetitionPenaltyLogitsProcessor, 19 | # TemperatureLogitsWarper, 20 | # TopKLogitsWarper, 21 | # TopPLogitsWarper, 22 | ) 23 | 24 | from transformers.generation_stopping_criteria import ( 25 | MaxLengthCriteria, 26 | MaxTimeCriteria, 27 | StoppingCriteriaList, 28 | StoppingCriteria, 29 | # validate_stopping_criteria, 30 | ) 31 | 32 | T5_start_mask_token=32000 33 | T5_end_mask_token=32099 34 | 35 | class ForcedNoEOSTokenLogitsProcessor(LogitsProcessor): 36 | def __init__(self,encoder_input_ids: torch.LongTensor, eos_token_id: int): # (batch_size,...) 37 | self.target_blank_num=((encoder_input_ids>=T5_start_mask_token)&(encoder_input_ids<=T5_end_mask_token)).sum(dim=1) 38 | self.starts_with_extraid=((encoder_input_ids[:,0]>=T5_start_mask_token)&(encoder_input_ids[:,0]<=T5_end_mask_token)).int() 39 | self.batch_size = encoder_input_ids.shape[0] 40 | self.eos_token_id=eos_token_id 41 | self.pad_token_id=0 42 | 43 | def __call__(self,input_ids:torch.LongTensor,scores=torch.FloatTensor): #(batch_size*num_beams,...) 44 | num_hypos=scores.shape[0] 45 | num_beams=num_hypos//self.batch_size 46 | if input_ids.shape[1]<=1: return scores 47 | already_blank_num=(((input_ids>=T5_start_mask_token)&(input_ids<=T5_end_mask_token))).sum(dim=1) 48 | generated_extraid_first=((input_ids[:,1]>=T5_start_mask_token)&(input_ids[:,1]<=T5_end_mask_token)).int() 49 | for hypo_idx in range(num_hypos): 50 | batch_idx=hypo_idx//num_beams 51 | beam_idx=hypo_idx%num_beams 52 | if already_blank_num[hypo_idx]-generated_extraid_first[hypo_idx]+1=T5_start_mask_token and input_ids[hypo_idx,-1]<=T5_end_mask_token): 72 | scores[hypo_idx,T5_start_mask_token:T5_end_mask_token+1]=-float("inf") 73 | scores[hypo_idx,self.eos_token_id]=-float("inf") 74 | scores[hypo_idx,self.pad_token_id]=-float("inf") 75 | return scores 76 | 77 | class ForcedNoExtraTokenLogitsProcessor(LogitsProcessor): 78 | def __init__(self,encoder_input_ids: torch.LongTensor, eos_token_id: int): 79 | self.target_blank_num=((encoder_input_ids>=T5_start_mask_token)&(encoder_input_ids<=T5_end_mask_token)).sum(dim=1) 80 | self.starts_with_extraid=((encoder_input_ids[:,0]>=T5_start_mask_token)&(encoder_input_ids[:,0]<=T5_end_mask_token)).int() 81 | self.batch_size = encoder_input_ids.shape[0] 82 | self.eos_token_id=eos_token_id 83 | 84 | def __call__(self,input_ids:torch.LongTensor,scores=torch.FloatTensor): 85 | num_hypos=scores.shape[0] 86 | num_beams=num_hypos//self.batch_size 87 | if input_ids.shape[1]<=1: return scores 88 | already_blank_num=(((input_ids>=T5_start_mask_token)&(input_ids<=T5_end_mask_token))).sum(dim=1) 89 | generated_extraid_first=((input_ids[:,1]>=T5_start_mask_token)&(input_ids[:,1]<=T5_end_mask_token)).int() 90 | for hypo_idx in range(num_hypos): 91 | batch_idx=hypo_idx//num_beams 92 | beam_idx=hypo_idx%num_beams 93 | if already_blank_num[hypo_idx]-generated_extraid_first[hypo_idx]+1>=self.target_blank_num[batch_idx]: 94 | scores[hypo_idx,self.eos_token_id]=max(scores[hypo_idx,self.eos_token_id],scores[hypo_idx,T5_start_mask_token:T5_end_mask_token+1].max()) 95 | scores[hypo_idx,T5_start_mask_token:T5_end_mask_token+1]=-float("inf") 96 | return scores 97 | 98 | class ForcedStartTokenLogitsProcessor(LogitsProcessor): 99 | def __init__(self,encoder_input_ids: torch.LongTensor, eos_token_id: int): 100 | self.starts_with_extraid=((encoder_input_ids[:,0]>=T5_start_mask_token)&(encoder_input_ids[:,0]<=T5_end_mask_token)) 101 | self.batch_size = encoder_input_ids.shape[0] 102 | self.eos_token_id=eos_token_id 103 | 104 | def __call__(self,input_ids:torch.LongTensor,scores=torch.FloatTensor): 105 | if input_ids.shape[1]!=1: return scores 106 | num_hypos=scores.shape[0] 107 | num_beams=num_hypos//self.batch_size 108 | num_tokens = scores.shape[1] 109 | for hypo_idx in range(num_hypos): 110 | batch_idx=hypo_idx//num_beams 111 | beam_idx=hypo_idx%num_beams 112 | # if True: 113 | if self.starts_with_extraid[batch_idx]==False: 114 | scores[hypo_idx,:]=-float("inf") 115 | scores[hypo_idx,T5_start_mask_token:T5_end_mask_token]=0 116 | else: 117 | scores[hypo_idx,T5_start_mask_token:T5_end_mask_token+1]=-float("inf") 118 | return scores 119 | 120 | class EosCriteria(StoppingCriteria): 121 | def __init__(self,): 122 | self.eos_token_id=1 123 | 124 | def __call__(self,input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 125 | return (input_ids==self.eos_token_id).sum()==len(input_ids) 126 | 127 | class T5_Blank(T5ForConditionalGeneration): 128 | def _get_stopping_criteria( 129 | self, 130 | max_length: Optional[int], 131 | max_time: Optional[float], 132 | ) -> StoppingCriteriaList: 133 | stopping_criteria = StoppingCriteriaList() 134 | if max_length is not None: 135 | stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) 136 | if max_time is not None: 137 | stopping_criteria.append(MaxTimeCriteria(max_time=max_time)) 138 | stopping_criteria.append(EosCriteria()) 139 | return stopping_criteria 140 | 141 | def _get_logits_processor( 142 | self, 143 | repetition_penalty: float, 144 | no_repeat_ngram_size: int, 145 | encoder_no_repeat_ngram_size: int, 146 | encoder_input_ids: torch.LongTensor, 147 | bad_words_ids: List[List[int]], 148 | min_length: int, 149 | max_length: int, 150 | eos_token_id: int, 151 | forced_bos_token_id: int, 152 | forced_eos_token_id: int, 153 | prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], 154 | num_beams: int, 155 | num_beam_groups: int, 156 | diversity_penalty: float, 157 | remove_invalid_values: bool, 158 | ) -> LogitsProcessorList: 159 | """ 160 | This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant 161 | :obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. 162 | """ 163 | 164 | # init warp parameters 165 | repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty 166 | no_repeat_ngram_size = ( 167 | no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size 168 | ) 169 | encoder_no_repeat_ngram_size = ( 170 | encoder_no_repeat_ngram_size 171 | if encoder_no_repeat_ngram_size is not None 172 | else self.config.encoder_no_repeat_ngram_size 173 | ) 174 | bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids 175 | min_length = min_length if min_length is not None else self.config.min_length 176 | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id 177 | diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty 178 | forced_bos_token_id = ( 179 | forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id 180 | ) 181 | forced_eos_token_id = ( 182 | forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id 183 | ) 184 | remove_invalid_values = ( 185 | remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values 186 | ) 187 | # instantiate processors list 188 | processors = LogitsProcessorList() 189 | 190 | # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files 191 | # all samplers can be found in `generation_utils_samplers.py` 192 | if diversity_penalty is not None and diversity_penalty > 0.0: 193 | processors.append( 194 | HammingDiversityLogitsProcessor( 195 | diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups 196 | ) 197 | ) 198 | if repetition_penalty is not None and repetition_penalty != 1.0: 199 | processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) 200 | if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: 201 | processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) 202 | if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: 203 | if self.config.is_encoder_decoder: 204 | processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) 205 | else: 206 | raise ValueError( 207 | "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" 208 | ) 209 | if bad_words_ids is not None: 210 | processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) 211 | processors.append(ForcedNoEOSTokenLogitsProcessor(encoder_input_ids,eos_token_id)) 212 | processors.append(ForcedNoExtraTokenLogitsProcessor(encoder_input_ids,eos_token_id)) 213 | if min_length is not None and eos_token_id is not None and min_length > -1: 214 | processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) 215 | if prefix_allowed_tokens_fn is not None: 216 | processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) 217 | if forced_bos_token_id is not None: 218 | processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) 219 | if forced_eos_token_id is not None: 220 | processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) 221 | if remove_invalid_values is True: 222 | processors.append(InfNanRemoveLogitsProcessor()) 223 | return processors 224 | 225 | def init_model(model_name_or_path='t5-large'): 226 | tokenizer = T5Tokenizer.from_pretrained(model_name_or_path) 227 | tokenizer.sep_token = '' 228 | model = T5_Blank.from_pretrained(model_name_or_path) 229 | model=model.cuda() 230 | model.eval() 231 | return tokenizer,model 232 | 233 | class T5Aug(): 234 | def __init__(self,model_path='t5-large',tokenizer=None,model=None): 235 | if tokenizer is not None and model is not None: 236 | self.tokenizer=tokenizer;self.model=model 237 | elif model_path is not None: 238 | self.tokenizer,self.model=init_model(model_path) 239 | 240 | def generate_blanks(self,strings_to_be_generated, 241 | max_length: Optional[int] = 512, 242 | min_length: Optional[int] = None, 243 | do_sample: Optional[bool] = None, 244 | early_stopping: Optional[bool] = None, 245 | num_beams: Optional[int] = None, 246 | temperature: Optional[float] = 1.0, 247 | top_k: Optional[int] = 15, 248 | top_p: Optional[float] = 0.5, 249 | repetition_penalty: Optional[float] = 2.5, 250 | bad_words_ids: Optional[Iterable[int]] = [[3], [19794], [22354]], 251 | bos_token_id: Optional[int] = None, 252 | pad_token_id: Optional[int] = 0, 253 | eos_token_id: Optional[int] = 1, 254 | length_penalty: Optional[float] = 0.0, 255 | no_repeat_ngram_size: Optional[int] = 3, 256 | encoder_no_repeat_ngram_size: Optional[int] = None, 257 | num_return_sequences: Optional[int] = 1, 258 | max_time: Optional[float] = None, 259 | decoder_start_token_id: Optional[int] = None, 260 | use_cache: Optional[bool] = False, 261 | num_beam_groups: Optional[int] = None, 262 | diversity_penalty: Optional[float] = None, 263 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 264 | output_attentions: Optional[bool] = None, 265 | output_hidden_states: Optional[bool] = None, 266 | output_scores: Optional[bool] = None, 267 | return_dict_in_generate: Optional[bool] = None, 268 | forced_bos_token_id: Optional[int] = None, 269 | forced_eos_token_id: Optional[int] = None, 270 | remove_invalid_values: Optional[bool] = None, 271 | **model_kwargs, 272 | ): 273 | pred_blanks=[];pred_texts=[] 274 | tokenizer=self.tokenizer 275 | eos_token_id = tokenizer._convert_token_to_id('') 276 | pad_token_id = tokenizer._convert_token_to_id('') 277 | start_mask_token=tokenizer._convert_token_to_id('') 278 | end_mask_token=tokenizer._convert_token_to_id('') 279 | batch_size=10 280 | for batch_idx in range(int(np.ceil(len(strings_to_be_generated)/batch_size))): 281 | sentences=strings_to_be_generated[batch_idx*batch_size:(batch_idx+1)*batch_size] 282 | input_ids=tokenizer(sentences,return_tensors='pt',padding=True).input_ids.cuda() 283 | outputs=self.model.generate(input_ids,max_length=max_length,min_length=min_length,do_sample=do_sample,early_stopping=early_stopping,\ 284 | num_beams=num_beams,temperature=temperature,top_k=top_k,top_p=top_p,repetition_penalty=repetition_penalty,bad_words_ids=bad_words_ids,\ 285 | bos_token_id=bos_token_id,pad_token_id=pad_token_id,eos_token_id=eos_token_id,length_penalty=length_penalty,no_repeat_ngram_size=no_repeat_ngram_size, 286 | encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,num_return_sequences=num_return_sequences,decoder_start_token_id=decoder_start_token_id, 287 | use_cache=use_cache,num_beam_groups=num_beam_groups,diversity_penalty=diversity_penalty) 288 | for (b_id,input_id) in enumerate(input_ids): 289 | pred_text=[];result = [] 290 | for item in outputs[b_id*num_return_sequences:(b_id+1)*num_return_sequences]: 291 | result.append([]);blanks=[] 292 | for token_id in item[1:]: 293 | token_id=token_id.item() 294 | if (token_id>=start_mask_token and token_id<=end_mask_token) or token_id==eos_token_id or token_id==pad_token_id: 295 | blanks.append([]) 296 | else: 297 | if len(blanks)==0: blanks.append([]) 298 | blanks[-1].append(token_id) 299 | for blank in blanks: 300 | result[-1].append(tokenizer.decode(blank)) 301 | 302 | current_blank=0;output_tokens=[] 303 | for token in input_id: 304 | token=token.item() 305 | if token>=start_mask_token and token<=end_mask_token: 306 | if current_blank<~·!@#¥%……&*()——+-=“:’;、。,?》《{}' 12 | reg = "[^0-9A-Za-z\u4e00-\u9fa5]" 13 | 14 | def removePunctuation(text): 15 | # text = re.sub(r'[{}]+'.format(punctuation),'',text) 16 | # text = (re.sub(punc, "",text)).replace('[','').replace(']','').replace(' ','') 17 | text=re.sub(reg, '', text) 18 | return text.strip() 19 | 20 | import string 21 | stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 22 | 'ours', 'ourselves', 'you', 'your', 'yours', 23 | 'yourself', 'yourselves', 'he', 'him', 'his', 24 | 'himself', 'she', 'her', 'hers', 'herself', 25 | 'it', 'its', 'itself', 'they', 'them', 'their', 26 | 'theirs', 'themselves', 'what', 'which', 'who', 27 | 'whom', 'this', 'that', 'these', 'those', 'am', 28 | 'is', 'are', 'was', 'were', 'be', 'been', 'being', 29 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 30 | 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 31 | 'because', 'as', 'until', 'while', 'of', 'at', 32 | 'by', 'for', 'with', 'about', 'against', 'between', 33 | 'into', 'through', 'during', 'before', 'after', 34 | 'above', 'below', 'to', 'from', 'up', 'down', 'in', 35 | 'out', 'on', 'off', 'over', 'under', 'again', 36 | 'further', 'then', 'once', 'here', 'there', 'when', 37 | 'where', 'why', 'how', 'all', 'any', 'both', 'each', 38 | 'few', 'more', 'most', 'other', 'some', 'such', 'no', 39 | 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 40 | 'very', 's', 't', 'can', 'will', 'just', 'don', 41 | 'should', 'now']+[x for x in string.punctuation] 42 | 43 | def find_all_nouns(lines,parallel=40): 44 | import spacy 45 | nlp = spacy.load("en") 46 | if isinstance(lines,list): 47 | ans=[] 48 | for line in lines: 49 | doc = nlp(line) 50 | for np in doc.noun_chunks: 51 | if len(np.text.split())==1 and np.text in stop_words: 52 | continue 53 | ans.append(np.text) 54 | else: 55 | doc = nlp(lines) 56 | ans=[] 57 | for np in doc.noun_chunks: 58 | ans.append(np.text) 59 | return ans 60 | 61 | 62 | def get_pps(doc): 63 | # import spacy 64 | # nlp=spacy.load('en') 65 | # doc=nlp('A short man in blue jeans is working in the kitchen.') 66 | # print(get_pps(doc)) 67 | pps = [];pos=[] 68 | for (idx,token) in enumerate(doc): 69 | # Try this with other parts of speech for different subtrees. 70 | if token.pos_ == 'ADP': 71 | # import pdb 72 | # pdb.set_trace() 73 | pp = ' '.join([tok.orth_ for tok in token.subtree]) 74 | for (i,tok) in enumerate(token.subtree): 75 | if tok.orth_==str(token): start_pos=idx-i;break 76 | pps.append(pp) 77 | pos.append((start_pos,start_pos+len(pp.split()))) 78 | return pps,pos 79 | 80 | 81 | from sklearn.metrics.pairwise import cosine_similarity 82 | import numpy as np 83 | # class SBERT(): 84 | # def __init__(self,): 85 | # from sentence_transformers import SentenceTransformer 86 | # self.model = SentenceTransformer('paraphrase-distilroberta-base-v1') 87 | 88 | # def get_embeddings(self,sentences): 89 | # return self.model.encode(sentences) 90 | 91 | # def get_cos_similarity(self,sent_embs=[[0,1,2],[2,3,4]],tgt_emb=[5,6,7]): 92 | # scores=[] 93 | # for sent_emb in sent_embs: 94 | # scores.append(cosine_similarity(np.array([sent_emb,tgt_emb]))[0,1]) 95 | # return scores 96 | 97 | def softmax(x,temp=1): 98 | x=np.array(x) 99 | x_max=np.max(x) 100 | x=x-x_max 101 | x_exp=np.exp(x/temp) 102 | x_exp_sum=x_exp.sum() 103 | softmax=x_exp/x_exp_sum 104 | return softmax 105 | 106 | 107 | 108 | SPECIAL_TOKENS={ 109 | 'bert':{ 110 | 'unk_token': '[UNK]', 111 | 'sep_token': '[SEP]', 112 | 'pad_token': '[PAD]', 113 | 'cls_token': '[CLS]', 114 | 'mask_token': '[MASK]' 115 | }, 116 | 'albert':{ 117 | 'bos_token': '[CLS]', 118 | 'eos_token': '[SEP]', 119 | 'unk_token': '', 120 | 'sep_token': '[SEP]', 121 | 'pad_token': '', 122 | 'cls_token': '[CLS]', 123 | 'mask_token': '[MASK]' 124 | } 125 | } 126 | 127 | import json 128 | 129 | def load_dict(filename): 130 | '''load dict from json file''' 131 | with open(filename,"r") as f: 132 | dic = f.read() 133 | dic=eval(dic) 134 | return dic 135 | 136 | def save_dict(dict,filename): 137 | with open(filename,'w') as f: 138 | f.write(str(dict)) 139 | return 140 | 141 | ''' 142 | from baseline_aug import utils 143 | imp.reload(utils) 144 | file_path='tmp/candidate_nouns.jsonl' 145 | utils.save_dict(candidate_nouns,file_path) 146 | nouns=utils.load_dict(file_path) 147 | ''' 148 | -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhouj8553/FlipDA/f220cef78cc8d79b6707128b7b81afa7c561f8a8/img/model.png -------------------------------------------------------------------------------- /img/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | This file contains basic logging logic. 15 | """ 16 | import logging 17 | 18 | names = set() 19 | 20 | 21 | def __setup_custom_logger(name: str) -> logging.Logger: 22 | root_logger = logging.getLogger() 23 | root_logger.handlers.clear() 24 | 25 | formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s') 26 | 27 | names.add(name) 28 | 29 | handler = logging.StreamHandler() 30 | handler.setFormatter(formatter) 31 | 32 | logger = logging.getLogger(name) 33 | logger.setLevel(logging.INFO) 34 | logger.addHandler(handler) 35 | return logger 36 | 37 | 38 | def get_logger(name: str) -> logging.Logger: 39 | if name in names: 40 | return logging.getLogger(name) 41 | else: 42 | return __setup_custom_logger(name) 43 | -------------------------------------------------------------------------------- /modified_models/readme.txt: -------------------------------------------------------------------------------- 1 | Auxiliary files for deberta. In order to run it in V100, we fix the bottom 1/3 parameters to save Device Memory. 2 | -------------------------------------------------------------------------------- /pet/__init__.py: -------------------------------------------------------------------------------- 1 | from pet.modeling import * 2 | -------------------------------------------------------------------------------- /pet/evaluate_record.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import string 3 | import re 4 | from pet import tasks 5 | 6 | def normalize_answer(s): 7 | """Lower text and remove punctuation, articles and extra whitespace.""" 8 | def remove_articles(text): 9 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 10 | return re.sub(regex, ' ', text) 11 | def white_space_fix(text): 12 | return ' '.join(text.split()) 13 | def remove_punc(text): 14 | exclude = set(string.punctuation) 15 | return ''.join(ch for ch in text if ch not in exclude) 16 | def lower(text): 17 | return text.lower() 18 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 19 | 20 | def get_tokens(s): 21 | if not s: return [] 22 | return normalize_answer(s).split() 23 | 24 | def compute_f1(a_gold, a_pred): 25 | gold_toks = get_tokens(a_gold) 26 | pred_toks = get_tokens(a_pred) 27 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 28 | num_same = sum(common.values()) 29 | if len(gold_toks) == 0 or len(pred_toks) == 0: 30 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 31 | return int(gold_toks == pred_toks) 32 | if num_same == 0: 33 | return 0 34 | precision = 1.0 * num_same / len(pred_toks) 35 | recall = 1.0 * num_same / len(gold_toks) 36 | f1 = (2 * precision * recall) / (precision + recall) 37 | return f1 38 | 39 | def read_file(filename): 40 | f= open(filename,"r") 41 | lines=f.readlines() 42 | ans=[] 43 | for line in lines: 44 | ans.append(eval(line)) 45 | f.close() 46 | return ans 47 | 48 | def get_max_ans(ans): 49 | aans={} 50 | for line in ans: 51 | max_val=-1000000 52 | text='' 53 | # print(line) 54 | for (choice,v) in line['choices'].items(): 55 | if v>max_val: 56 | max_val=v 57 | text=choice 58 | aans[line['idx']]=text 59 | return aans 60 | 61 | def get_evaluate_examples(filename='../../../data/FewGLUE/ReCoRD'): 62 | myprocessor=tasks.RecordProcessor() 63 | eval_examples=myprocessor.get_dev_examples(filename) 64 | return eval_examples 65 | 66 | 67 | def cal_f1(ans,eval_examples): 68 | aans=get_max_ans(ans) 69 | f1s=[] 70 | for example in eval_examples: 71 | qid=example.meta['question_idx'] 72 | targets=example.meta['answers'] 73 | f_tmp=max([compute_f1(target,aans[qid]) for target in targets]) 74 | f1s.append(f_tmp) 75 | return sum(f1s)/len(f1s) 76 | 77 | def get_f1_from_file(target_filename,pred_filename): 78 | ans=read_file(pred_filename) 79 | eval_examples=get_evaluate_examples(target_filename) 80 | return cal_f1(ans,eval_examples) 81 | 82 | # if __name__ == '__main__': 83 | # f11=get_f1_from_file('../../../data/FewGLUE/ReCoRD','results/pet/record_32_model/p0-i0/predictions.jsonl') 84 | # print('f1',f11) 85 | # f12=get_f1_from_file('../../../data/FewGLUE/ReCoRD','results/pet/record_32_model/p0-i1/predictions.jsonl') 86 | # print('f1',f12) 87 | # f13=get_f1_from_file('../../../data/FewGLUE/ReCoRD','results/pet/record_32_model/p0-i2/predictions.jsonl') 88 | # print('f1',f13) 89 | # print((f11+f12+f13)/3) -------------------------------------------------------------------------------- /pet/preprocessor.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from abc import ABC, abstractmethod 14 | from typing import List 15 | 16 | import numpy as np 17 | 18 | from pet.utils import InputFeatures, InputExample, PLMInputFeatures 19 | from pet.pvp import PVP, PVPS 20 | 21 | 22 | class Preprocessor(ABC): 23 | """ 24 | A preprocessor that transforms an :class:`InputExample` into a :class:`InputFeatures` object so that it can be 25 | processed by the model being used. 26 | """ 27 | 28 | def __init__(self, wrapper, task_name, pattern_id: int = 0, verbalizer_file: str = None): 29 | """ 30 | Create a new preprocessor. 31 | 32 | :param wrapper: the wrapper for the language model to use 33 | :param task_name: the name of the task 34 | :param pattern_id: the id of the PVP to be used 35 | :param verbalizer_file: path to a file containing a verbalizer that overrides the default verbalizer 36 | """ 37 | self.wrapper = wrapper 38 | self.pvp = PVPS[task_name](self.wrapper, pattern_id, verbalizer_file) # type: PVP 39 | self.label_map = {label: i for i, label in enumerate(self.wrapper.config.label_list)} 40 | 41 | @abstractmethod 42 | def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False, 43 | **kwargs) -> InputFeatures: 44 | """Convert the given example into a set of input features""" 45 | pass 46 | 47 | 48 | class MLMPreprocessor(Preprocessor): 49 | """Preprocessor for models pretrained using a masked language modeling objective (e.g., BERT).""" 50 | 51 | def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False, 52 | **kwargs) -> InputFeatures: 53 | 54 | if priming: 55 | input_ids, token_type_ids = self.pvp.encode(example, priming=True) 56 | priming_data = example.meta['priming_data'] # type: List[InputExample] 57 | 58 | priming_input_ids = [] 59 | for priming_example in priming_data: 60 | pe_input_ids, _ = self.pvp.encode(priming_example, priming=True, labeled=True) 61 | priming_input_ids += pe_input_ids 62 | 63 | input_ids = priming_input_ids + input_ids 64 | token_type_ids = self.wrapper.tokenizer.create_token_type_ids_from_sequences(input_ids) 65 | input_ids = self.wrapper.tokenizer.build_inputs_with_special_tokens(input_ids) 66 | else: 67 | input_ids, token_type_ids = self.pvp.encode(example) 68 | 69 | attention_mask = [1] * len(input_ids) 70 | padding_length = self.wrapper.config.max_seq_length - len(input_ids) 71 | 72 | if padding_length < 0: 73 | raise ValueError(f"Maximum sequence length is too small, got {len(input_ids)} input ids") 74 | 75 | input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length) 76 | attention_mask = attention_mask + ([0] * padding_length) 77 | token_type_ids = token_type_ids + ([0] * padding_length) 78 | 79 | assert len(input_ids) == self.wrapper.config.max_seq_length 80 | assert len(attention_mask) == self.wrapper.config.max_seq_length 81 | assert len(token_type_ids) == self.wrapper.config.max_seq_length 82 | 83 | label = self.label_map[example.label] if example.label is not None else -100 84 | logits = example.logits if example.logits else [-1] 85 | 86 | if labelled: 87 | mlm_labels = self.pvp.get_mask_positions(input_ids) 88 | if self.wrapper.config.model_type == 'gpt2': 89 | # shift labels to the left by one 90 | mlm_labels.append(mlm_labels.pop(0)) 91 | else: 92 | mlm_labels = [-1] * self.wrapper.config.max_seq_length 93 | 94 | return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 95 | label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx) 96 | 97 | 98 | class PLMPreprocessor(MLMPreprocessor): 99 | """Preprocessor for models pretrained using a permuted language modeling objective (e.g., XLNet).""" 100 | 101 | def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False, 102 | **kwargs) -> PLMInputFeatures: 103 | input_features = super().get_input_features(example, labelled, priming, **kwargs) 104 | input_ids = input_features.input_ids 105 | 106 | num_masks = 1 # currently, PLMPreprocessor supports only replacements that require exactly one mask 107 | 108 | perm_mask = np.zeros((len(input_ids), len(input_ids)), dtype=np.float) 109 | label_idx = input_ids.index(self.pvp.mask_id) 110 | perm_mask[:, label_idx] = 1 # the masked token is not seen by any other token 111 | 112 | target_mapping = np.zeros((num_masks, len(input_ids)), dtype=np.float) 113 | target_mapping[0, label_idx] = 1.0 114 | 115 | return PLMInputFeatures(perm_mask=perm_mask, target_mapping=target_mapping, **input_features.__dict__) 116 | 117 | 118 | class SequenceClassifierPreprocessor(Preprocessor): 119 | """Preprocessor for a regular sequence classification model.""" 120 | 121 | def get_input_features(self, example: InputExample, **kwargs) -> InputFeatures: 122 | inputs = self.wrapper.task_helper.get_sequence_classifier_inputs(example) if self.wrapper.task_helper else None 123 | if inputs is None: 124 | inputs = self.wrapper.tokenizer.encode_plus( 125 | example.text_a if example.text_a else None, 126 | example.text_b if example.text_b else None, 127 | add_special_tokens=True, 128 | max_length=self.wrapper.config.max_seq_length, 129 | ) 130 | input_ids, token_type_ids = inputs["input_ids"], inputs.get("token_type_ids") 131 | 132 | attention_mask = [1] * len(input_ids) 133 | padding_length = self.wrapper.config.max_seq_length - len(input_ids) 134 | 135 | input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length) 136 | attention_mask = attention_mask + ([0] * padding_length) 137 | if not token_type_ids: 138 | token_type_ids = [0] * self.wrapper.config.max_seq_length 139 | else: 140 | token_type_ids = token_type_ids + ([0] * padding_length) 141 | mlm_labels = [-1] * len(input_ids) 142 | 143 | assert len(input_ids) == self.wrapper.config.max_seq_length 144 | assert len(attention_mask) == self.wrapper.config.max_seq_length 145 | assert len(token_type_ids) == self.wrapper.config.max_seq_length 146 | 147 | label = self.label_map[example.label] if example.label is not None else -100 148 | logits = example.logits if example.logits else [-1] 149 | 150 | return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 151 | label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx) 152 | -------------------------------------------------------------------------------- /pet/pvp.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | This file contains the pattern-verbalizer pairs (PVPs) for all tasks. 15 | """ 16 | import random 17 | import string 18 | from abc import ABC, abstractmethod 19 | from collections import defaultdict 20 | from typing import Tuple, List, Union, Dict 21 | 22 | import torch 23 | from transformers import PreTrainedTokenizer, GPT2Tokenizer 24 | 25 | from pet.task_helpers import MultiMaskTaskHelper 26 | from pet.tasks import TASK_HELPERS 27 | from pet.utils import InputExample, get_verbalization_ids 28 | 29 | import log 30 | from pet import wrapper as wrp 31 | 32 | logger = log.get_logger('root') 33 | 34 | FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], List[Union[str, Tuple[str, bool]]]] 35 | 36 | 37 | class PVP(ABC): 38 | """ 39 | This class contains functions to apply patterns and verbalizers as required by PET. Each task requires its own 40 | custom implementation of a PVP. 41 | """ 42 | 43 | def __init__(self, wrapper, pattern_id: int = 0, verbalizer_file: str = None, seed: int = 42): 44 | """ 45 | Create a new PVP. 46 | 47 | :param wrapper: the wrapper for the underlying language model 48 | :param pattern_id: the pattern id to use 49 | :param verbalizer_file: an optional file that contains the verbalizer to be used 50 | :param seed: a seed to be used for generating random numbers if necessary 51 | """ 52 | self.wrapper = wrapper 53 | self.pattern_id = pattern_id 54 | self.rng = random.Random(seed) 55 | 56 | if verbalizer_file: 57 | self.verbalize = PVP._load_verbalizer_from_file(verbalizer_file, self.pattern_id) 58 | 59 | use_multimask = (self.wrapper.config.task_name in TASK_HELPERS) and ( 60 | issubclass(TASK_HELPERS[self.wrapper.config.task_name], MultiMaskTaskHelper) 61 | ) 62 | if not use_multimask and self.wrapper.config.wrapper_type in [wrp.MLM_WRAPPER, wrp.PLM_WRAPPER]: 63 | self.mlm_logits_to_cls_logits_tensor = self._build_mlm_logits_to_cls_logits_tensor() 64 | 65 | def _build_mlm_logits_to_cls_logits_tensor(self): 66 | label_list = self.wrapper.config.label_list 67 | m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1 68 | 69 | for label_idx, label in enumerate(label_list): 70 | verbalizers = self.verbalize(label) 71 | for verbalizer_idx, verbalizer in enumerate(verbalizers): 72 | verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) 73 | assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as " 74 | m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id 75 | return m2c_tensor 76 | 77 | @property 78 | def mask(self) -> str: 79 | """Return the underlying LM's mask token""" 80 | return self.wrapper.tokenizer.mask_token 81 | 82 | @property 83 | def mask_id(self) -> int: 84 | """Return the underlying LM's mask id""" 85 | return self.wrapper.tokenizer.mask_token_id 86 | 87 | @property 88 | def max_num_verbalizers(self) -> int: 89 | """Return the maximum number of verbalizers across all labels""" 90 | return max(len(self.verbalize(label)) for label in self.wrapper.config.label_list) 91 | 92 | @staticmethod 93 | def shortenable(s): 94 | """Return an instance of this string that is marked as shortenable""" 95 | return s, True 96 | 97 | @staticmethod 98 | def remove_final_punc(s: Union[str, Tuple[str, bool]]): 99 | """Remove the final punctuation mark""" 100 | if isinstance(s, tuple): 101 | return PVP.remove_final_punc(s[0]), s[1] 102 | return s.rstrip(string.punctuation) 103 | 104 | @staticmethod 105 | def lowercase_first(s: Union[str, Tuple[str, bool]]): 106 | """Lowercase the first character""" 107 | if isinstance(s, tuple): 108 | return PVP.lowercase_first(s[0]), s[1] 109 | return s[0].lower() + s[1:] 110 | 111 | def encode(self, example: InputExample, priming: bool = False, labeled: bool = False) \ 112 | -> Tuple[List[int], List[int]]: 113 | """ 114 | Encode an input example using this pattern-verbalizer pair. 115 | 116 | :param example: the input example to encode 117 | :param priming: whether to use this example for priming 118 | :param labeled: if ``priming=True``, whether the label should be appended to this example 119 | :return: A tuple, consisting of a list of input ids and a list of token type ids 120 | """ 121 | 122 | if not priming: 123 | assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true" 124 | 125 | tokenizer = self.wrapper.tokenizer # type: PreTrainedTokenizer 126 | parts_a, parts_b = self.get_parts(example) 127 | 128 | kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {} 129 | 130 | parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a] 131 | parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x] 132 | 133 | if parts_b: 134 | parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b] 135 | parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x] 136 | 137 | self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length) 138 | 139 | tokens_a = [token_id for part, _ in parts_a for token_id in part] 140 | tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None 141 | 142 | if priming: 143 | input_ids = tokens_a 144 | if tokens_b: 145 | input_ids += tokens_b 146 | if labeled: 147 | mask_idx = input_ids.index(self.mask_id) 148 | assert mask_idx >= 0, 'sequence of input_ids must contain a mask token' 149 | assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label' 150 | verbalizer = self.verbalize(example.label)[0] 151 | verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) 152 | input_ids[mask_idx] = verbalizer_id 153 | return input_ids, [] 154 | 155 | input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) 156 | token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) 157 | 158 | return input_ids, token_type_ids 159 | 160 | @staticmethod 161 | def _seq_length(parts: List[Tuple[str, bool]], only_shortenable: bool = False): 162 | return sum([len(x) for x, shortenable in parts if not only_shortenable or shortenable]) if parts else 0 163 | 164 | @staticmethod 165 | def _remove_last(parts: List[Tuple[str, bool]]): 166 | last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) if shortenable and seq) 167 | parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1]) 168 | 169 | def truncate(self, parts_a: List[Tuple[str, bool]], parts_b: List[Tuple[str, bool]], max_length: int): 170 | """Truncate two sequences of text to a predefined total maximum length""" 171 | total_len = self._seq_length(parts_a) + self._seq_length(parts_b) 172 | total_len += self.wrapper.tokenizer.num_special_tokens_to_add(bool(parts_b)) 173 | num_tokens_to_remove = total_len - max_length 174 | 175 | if num_tokens_to_remove <= 0: 176 | return parts_a, parts_b 177 | 178 | for _ in range(num_tokens_to_remove): 179 | if self._seq_length(parts_a, only_shortenable=True) > self._seq_length(parts_b, only_shortenable=True): 180 | self._remove_last(parts_a) 181 | else: 182 | self._remove_last(parts_b) 183 | 184 | @abstractmethod 185 | def get_parts(self, example: InputExample) -> FilledPattern: 186 | """ 187 | Given an input example, apply a pattern to obtain two text sequences (text_a and text_b) containing exactly one 188 | mask token (or one consecutive sequence of mask tokens for PET with multiple masks). If a task requires only a 189 | single sequence of text, the second sequence should be an empty list. 190 | 191 | :param example: the input example to process 192 | :return: Two sequences of text. All text segments can optionally be marked as being shortenable. 193 | """ 194 | pass 195 | 196 | @abstractmethod 197 | def verbalize(self, label) -> List[str]: 198 | """ 199 | Return all verbalizations for a given label. 200 | 201 | :param label: the label 202 | :return: the list of verbalizations 203 | """ 204 | pass 205 | 206 | def get_mask_positions(self, input_ids: List[int]) -> List[int]: 207 | label_idx = input_ids.index(self.mask_id) 208 | labels = [-1] * len(input_ids) 209 | labels[label_idx] = 1 210 | return labels 211 | 212 | def convert_mlm_logits_to_cls_logits(self, mlm_labels: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: 213 | masked_logits = logits[mlm_labels >= 0] 214 | cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(ml) for ml in masked_logits]) 215 | return cls_logits 216 | 217 | def _convert_single_mlm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor: 218 | m2c = self.mlm_logits_to_cls_logits_tensor.to(logits.device) 219 | # filler_len.shape() == max_fillers 220 | filler_len = torch.tensor([len(self.verbalize(label)) for label in self.wrapper.config.label_list], 221 | dtype=torch.float) 222 | filler_len = filler_len.to(logits.device) 223 | 224 | # cls_logits.shape() == num_labels x max_fillers (and 0 when there are not as many fillers). 225 | cls_logits = logits[torch.max(torch.zeros_like(m2c), m2c)] 226 | cls_logits = cls_logits * (m2c > 0).float() 227 | 228 | # cls_logits.shape() == num_labels 229 | cls_logits = cls_logits.sum(axis=1) / filler_len 230 | return cls_logits 231 | 232 | def convert_plm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor: 233 | assert logits.shape[1] == 1 234 | logits = torch.squeeze(logits, 1) # remove second dimension as we always have exactly one per example 235 | cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(lgt) for lgt in logits]) 236 | return cls_logits 237 | 238 | @staticmethod 239 | def _load_verbalizer_from_file(path: str, pattern_id: int): 240 | 241 | verbalizers = defaultdict(dict) # type: Dict[int, Dict[str, List[str]]] 242 | current_pattern_id = None 243 | 244 | with open(path, 'r') as fh: 245 | for line in fh.read().splitlines(): 246 | if line.isdigit(): 247 | current_pattern_id = int(line) 248 | elif line: 249 | label, *realizations = line.split() 250 | verbalizers[current_pattern_id][label] = realizations 251 | 252 | logger.info("Automatically loaded the following verbalizer: \n {}".format(verbalizers[pattern_id])) 253 | 254 | def verbalize(label) -> List[str]: 255 | return verbalizers[pattern_id][label] 256 | 257 | return verbalize 258 | 259 | 260 | class AgnewsPVP(PVP): 261 | VERBALIZER = { 262 | "1": ["World"], 263 | "2": ["Sports"], 264 | "3": ["Business"], 265 | "4": ["Tech"] 266 | } 267 | 268 | def get_parts(self, example: InputExample) -> FilledPattern: 269 | 270 | text_a = self.shortenable(example.text_a) 271 | text_b = self.shortenable(example.text_b) 272 | 273 | if self.pattern_id == 0: 274 | return [self.mask, ':', text_a, text_b], [] 275 | elif self.pattern_id == 1: 276 | return [self.mask, 'News:', text_a, text_b], [] 277 | elif self.pattern_id == 2: 278 | return [text_a, '(', self.mask, ')', text_b], [] 279 | elif self.pattern_id == 3: 280 | return [text_a, text_b, '(', self.mask, ')'], [] 281 | elif self.pattern_id == 4: 282 | return ['[ Category:', self.mask, ']', text_a, text_b], [] 283 | elif self.pattern_id == 5: 284 | return [self.mask, '-', text_a, text_b], [] 285 | else: 286 | raise ValueError("No pattern implemented for id {}".format(self.pattern_id)) 287 | 288 | def verbalize(self, label) -> List[str]: 289 | return AgnewsPVP.VERBALIZER[label] 290 | 291 | 292 | class YahooPVP(PVP): 293 | VERBALIZER = { 294 | "1": ["Society"], 295 | "2": ["Science"], 296 | "3": ["Health"], 297 | "4": ["Education"], 298 | "5": ["Computer"], 299 | "6": ["Sports"], 300 | "7": ["Business"], 301 | "8": ["Entertainment"], 302 | "9": ["Relationship"], 303 | "10": ["Politics"], 304 | } 305 | 306 | def get_parts(self, example: InputExample) -> FilledPattern: 307 | 308 | text_a = self.shortenable(example.text_a) 309 | text_b = self.shortenable(example.text_b) 310 | 311 | if self.pattern_id == 0: 312 | return [self.mask, ':', text_a, text_b], [] 313 | elif self.pattern_id == 1: 314 | return [self.mask, 'Question:', text_a, text_b], [] 315 | elif self.pattern_id == 2: 316 | return [text_a, '(', self.mask, ')', text_b], [] 317 | elif self.pattern_id == 3: 318 | return [text_a, text_b, '(', self.mask, ')'], [] 319 | elif self.pattern_id == 4: 320 | return ['[ Category:', self.mask, ']', text_a, text_b], [] 321 | elif self.pattern_id == 5: 322 | return [self.mask, '-', text_a, text_b], [] 323 | else: 324 | raise ValueError("No pattern implemented for id {}".format(self.pattern_id)) 325 | 326 | def verbalize(self, label) -> List[str]: 327 | return YahooPVP.VERBALIZER[label] 328 | 329 | 330 | class MnliPVP(PVP): 331 | VERBALIZER_A = { 332 | "contradiction": ["Wrong"], 333 | "entailment": ["Right"], 334 | "neutral": ["Maybe"] 335 | } 336 | VERBALIZER_B = { 337 | "contradiction": ["No"], 338 | "entailment": ["Yes"], 339 | "neutral": ["Maybe"] 340 | } 341 | 342 | def get_parts(self, example: InputExample) -> FilledPattern: 343 | text_a = self.shortenable(self.remove_final_punc(example.text_a)) 344 | text_b = self.shortenable(example.text_b) 345 | 346 | if self.pattern_id == 0 or self.pattern_id == 2: 347 | return ['"', text_a, '" ?'], [self.mask, ', "', text_b, '"'] 348 | elif self.pattern_id == 1 or self.pattern_id == 3: 349 | return [text_a, '?'], [self.mask, ',', text_b] 350 | 351 | def verbalize(self, label) -> List[str]: 352 | if self.pattern_id == 0 or self.pattern_id == 1: 353 | return MnliPVP.VERBALIZER_A[label] 354 | return MnliPVP.VERBALIZER_B[label] 355 | 356 | 357 | class YelpPolarityPVP(PVP): 358 | VERBALIZER = { 359 | "1": ["bad"], 360 | "2": ["good"] 361 | } 362 | 363 | def get_parts(self, example: InputExample) -> FilledPattern: 364 | text = self.shortenable(example.text_a) 365 | 366 | if self.pattern_id == 0: 367 | return ['It was', self.mask, '.', text], [] 368 | elif self.pattern_id == 1: 369 | return [text, '. All in all, it was', self.mask, '.'], [] 370 | elif self.pattern_id == 2: 371 | return ['Just', self.mask, "!"], [text] 372 | elif self.pattern_id == 3: 373 | return [text], ['In summary, the restaurant is', self.mask, '.'] 374 | else: 375 | raise ValueError("No pattern implemented for id {}".format(self.pattern_id)) 376 | 377 | def verbalize(self, label) -> List[str]: 378 | return YelpPolarityPVP.VERBALIZER[label] 379 | 380 | 381 | class YelpFullPVP(YelpPolarityPVP): 382 | VERBALIZER = { 383 | "1": ["terrible"], 384 | "2": ["bad"], 385 | "3": ["okay"], 386 | "4": ["good"], 387 | "5": ["great"] 388 | } 389 | 390 | def verbalize(self, label) -> List[str]: 391 | return YelpFullPVP.VERBALIZER[label] 392 | 393 | 394 | class XStancePVP(PVP): 395 | VERBALIZERS = { 396 | 'en': {"FAVOR": ["Yes"], "AGAINST": ["No"]}, 397 | 'de': {"FAVOR": ["Ja"], "AGAINST": ["Nein"]}, 398 | 'fr': {"FAVOR": ["Oui"], "AGAINST": ["Non"]} 399 | } 400 | 401 | def get_parts(self, example: InputExample) -> FilledPattern: 402 | 403 | text_a = self.shortenable(example.text_a) 404 | text_b = self.shortenable(example.text_b) 405 | 406 | if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4: 407 | return ['"', text_a, '"'], [self.mask, '. "', text_b, '"'] 408 | elif self.pattern_id == 1 or self.pattern_id == 3 or self.pattern_id == 5: 409 | return [text_a], [self.mask, '.', text_b] 410 | 411 | def verbalize(self, label) -> List[str]: 412 | lang = 'de' if self.pattern_id < 2 else 'en' if self.pattern_id < 4 else 'fr' 413 | return XStancePVP.VERBALIZERS[lang][label] 414 | 415 | 416 | class RtePVP(PVP): 417 | VERBALIZER = { 418 | "not_entailment": ["No"], 419 | "entailment": ["Yes"] 420 | } 421 | 422 | def get_parts(self, example: InputExample) -> FilledPattern: 423 | # switch text_a and text_b to get the correct order 424 | text_a = self.shortenable(example.text_a) 425 | text_b = self.shortenable(example.text_b.rstrip(string.punctuation)) 426 | 427 | if self.pattern_id == -1: 428 | return [text_a, 'Question:', text_b, "?", "Answer:", self.mask, "."], [] 429 | 430 | if self.pattern_id == 0: 431 | return ['"', text_b, '" ?'], [self.mask, ', "', text_a, '"'] 432 | elif self.pattern_id == 1: 433 | return [text_b, '?'], [self.mask, ',', text_a] 434 | if self.pattern_id == 2: 435 | return ['"', text_b, '" ?'], [self.mask, '. "', text_a, '"'] 436 | elif self.pattern_id == 3: 437 | return [text_b, '?'], [self.mask, '.', text_a] 438 | elif self.pattern_id == 4: 439 | return [text_a, ' question: ', self.shortenable(example.text_b), ' True or False? answer:', self.mask], [] 440 | 441 | def verbalize(self, label) -> List[str]: 442 | if self.pattern_id == 4: 443 | return ['true'] if label == 'entailment' else ['false'] 444 | return RtePVP.VERBALIZER[label] 445 | 446 | 447 | class CbPVP(RtePVP): 448 | VERBALIZER = { 449 | "contradiction": ["No"], 450 | "entailment": ["Yes"], 451 | "neutral": ["Maybe"] 452 | } 453 | 454 | def get_parts(self, example: InputExample) -> FilledPattern: 455 | if self.pattern_id == 4: 456 | text_a = self.shortenable(example.text_a) 457 | text_b = self.shortenable(example.text_b) 458 | return [text_a, ' question: ', text_b, ' true, false or neither? answer:', self.mask], [] 459 | return super().get_parts(example) 460 | 461 | def verbalize(self, label) -> List[str]: 462 | if self.pattern_id == 4: 463 | return ['true'] if label == 'entailment' else ['false'] if label == 'contradiction' else ['neither'] 464 | return CbPVP.VERBALIZER[label] 465 | 466 | 467 | class CopaPVP(PVP): 468 | 469 | def get_parts(self, example: InputExample) -> FilledPattern: 470 | 471 | premise = self.remove_final_punc(self.shortenable(example.text_a)) 472 | choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1'])) 473 | choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2'])) 474 | 475 | question = example.meta['question'] 476 | assert question in ['cause', 'effect'] 477 | 478 | example.meta['choice1'], example.meta['choice2'] = choice1, choice2 479 | num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2]) 480 | 481 | if question == 'cause': 482 | if self.pattern_id == 0: 483 | return ['"', choice1, '" or "', choice2, '"?', premise, 'because', self.mask * num_masks, '.'], [] 484 | elif self.pattern_id == 1: 485 | return [choice1, 'or', choice2, '?', premise, 'because', self.mask * num_masks, '.'], [] 486 | else: 487 | if self.pattern_id == 0: 488 | return ['"', choice1, '" or "', choice2, '"?', premise, ', so', self.mask * num_masks, '.'], [] 489 | elif self.pattern_id == 1: 490 | return [choice1, 'or', choice2, '?', premise, ', so', self.mask * num_masks, '.'], [] 491 | 492 | def verbalize(self, label) -> List[str]: 493 | return [] 494 | 495 | 496 | class WscPVP(PVP): 497 | 498 | def get_parts(self, example: InputExample) -> FilledPattern: 499 | pronoun = example.meta['span2_text'] 500 | target = example.meta['span1_text'] 501 | pronoun_idx = example.meta['span2_index'] 502 | 503 | words_a = example.text_a.split() 504 | words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' 505 | text_a = ' '.join(words_a) 506 | text_a = self.shortenable(text_a) 507 | 508 | num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1 509 | num_masks = len(get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad 510 | masks = self.mask * num_masks 511 | 512 | if self.pattern_id == 0: 513 | return [text_a, "The pronoun '*" + pronoun + "*' refers to", masks + '.'], [] 514 | elif self.pattern_id == 1: 515 | return [text_a, "In the previous sentence, the pronoun '*" + pronoun + "*' refers to", masks + '.'], [] 516 | elif self.pattern_id == 2: 517 | return [text_a, 518 | "Question: In the passage above, what does the pronoun '*" + pronoun + "*' refer to? Answer: ", 519 | masks + '.'], [] 520 | 521 | def verbalize(self, label) -> List[str]: 522 | return [] 523 | 524 | 525 | class BoolQPVP(PVP): 526 | VERBALIZER_A = { 527 | "False": ["No"], 528 | "True": ["Yes"] 529 | } 530 | 531 | VERBALIZER_B = { 532 | "False": ["false"], 533 | "True": ["true"] 534 | } 535 | 536 | def get_parts(self, example: InputExample) -> FilledPattern: 537 | passage = self.shortenable(example.text_a) 538 | question = self.shortenable(example.text_b) 539 | 540 | if self.pattern_id < 2: 541 | return [passage, '. Question: ', question, '? Answer: ', self.mask, '.'], [] 542 | elif self.pattern_id < 4: 543 | return [passage, '. Based on the previous passage, ', question, '?', self.mask, '.'], [] 544 | else: 545 | return ['Based on the following passage, ', question, '?', self.mask, '.', passage], [] 546 | 547 | def verbalize(self, label) -> List[str]: 548 | if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4: 549 | return BoolQPVP.VERBALIZER_A[label] 550 | else: 551 | return BoolQPVP.VERBALIZER_B[label] 552 | 553 | 554 | class MultiRcPVP(PVP): 555 | VERBALIZER = { 556 | "0": ["No"], 557 | "1": ["Yes"] 558 | } 559 | 560 | def get_parts(self, example: InputExample) -> FilledPattern: 561 | passage = self.shortenable(example.text_a) 562 | question = example.text_b 563 | answer = example.meta['answer'] 564 | 565 | if self.pattern_id == 0: 566 | return [passage, '. Question: ', question, '? Is it ', answer, '?', self.mask, '.'], [] 567 | if self.pattern_id == 1: 568 | return [passage, '. Question: ', question, '? Is the correct answer "', answer, '"?', self.mask, '.'], [] 569 | if self.pattern_id == 2: 570 | return [passage, '. Based on the previous passage, ', question, '? Is "', answer, '" a correct answer?', 571 | self.mask, '.'], [] 572 | if self.pattern_id == 3: 573 | return [passage, question, '- [', self.mask, ']', answer], [] 574 | 575 | def verbalize(self, label) -> List[str]: 576 | if self.pattern_id == 3: 577 | return ['False'] if label == "0" else ['True'] 578 | return MultiRcPVP.VERBALIZER[label] 579 | 580 | 581 | class WicPVP(PVP): 582 | VERBALIZER_A = { 583 | "F": ["No"], 584 | "T": ["Yes"] 585 | } 586 | VERBALIZER_B = { 587 | "F": ["2"], 588 | "T": ["b"] 589 | } 590 | 591 | def get_parts(self, example: InputExample) -> FilledPattern: 592 | text_a = self.shortenable(example.text_a) 593 | text_b = self.shortenable(example.text_b) 594 | word = example.meta['word'] 595 | 596 | if self.pattern_id == 0: 597 | return ['"', text_a, '" / "', text_b, '" Similar sense of "' + word + '"?', self.mask, '.'], [] 598 | if self.pattern_id == 1: 599 | return [text_a, text_b, 'Does ' + word + ' have the same meaning in both sentences?', self.mask], [] 600 | if self.pattern_id == 2: 601 | return [word, ' . Sense (1) (a) "', text_a, '" (', self.mask, ') "', text_b, '"'], [] 602 | 603 | def verbalize(self, label) -> List[str]: 604 | if self.pattern_id == 2: 605 | return WicPVP.VERBALIZER_B[label] 606 | return WicPVP.VERBALIZER_A[label] 607 | 608 | 609 | class RecordPVP(PVP): 610 | 611 | def get_parts(self, example: InputExample) -> FilledPattern: 612 | premise = self.shortenable(example.text_a) 613 | choices = example.meta['candidates'] 614 | 615 | assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' 616 | num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in choices) 617 | question = example.text_b.replace('@placeholder', self.mask * num_masks) 618 | return [premise, question], [] 619 | 620 | def verbalize(self, label) -> List[str]: 621 | return [] 622 | 623 | 624 | PVPS = { 625 | 'agnews': AgnewsPVP, 626 | 'mnli': MnliPVP, 627 | 'yelp-polarity': YelpPolarityPVP, 628 | 'yelp-full': YelpFullPVP, 629 | 'yahoo': YahooPVP, 630 | 'xstance': XStancePVP, 631 | 'xstance-de': XStancePVP, 632 | 'xstance-fr': XStancePVP, 633 | 'rte': RtePVP, 634 | 'wic': WicPVP, 635 | 'cb': CbPVP, 636 | 'wsc': WscPVP, 637 | 'boolq': BoolQPVP, 638 | 'copa': CopaPVP, 639 | 'multirc': MultiRcPVP, 640 | 'record': RecordPVP, 641 | 'ax-b': RtePVP, 642 | 'ax-g': RtePVP, 643 | } 644 | -------------------------------------------------------------------------------- /pet/tasks.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | This file contains the logic for loading training and test data for all tasks. 15 | """ 16 | 17 | import csv 18 | import json 19 | import os 20 | import random 21 | from abc import ABC, abstractmethod 22 | from collections import defaultdict, Counter 23 | from typing import List, Dict, Callable 24 | 25 | import log 26 | from pet import task_helpers 27 | from pet.utils import InputExample 28 | 29 | logger = log.get_logger('root') 30 | 31 | 32 | def _shuffle_and_restrict(examples: List[InputExample], num_examples: int, seed: int = 42) -> List[InputExample]: 33 | """ 34 | Shuffle a list of examples and restrict it to a given maximum size. 35 | 36 | :param examples: the examples to shuffle and restrict 37 | :param num_examples: the maximum number of examples 38 | :param seed: the random seed for shuffling 39 | :return: the first ``num_examples`` elements of the shuffled list 40 | """ 41 | if 0 < num_examples < len(examples): 42 | random.Random(seed).shuffle(examples) 43 | examples = examples[:num_examples] 44 | return examples 45 | 46 | 47 | class LimitedExampleList: 48 | def __init__(self, labels: List[str], max_examples=-1): 49 | """ 50 | Implementation of a list that stores only a limited amount of examples per label. 51 | 52 | :param labels: the set of all possible labels 53 | :param max_examples: the maximum number of examples per label. This can either be a fixed number, 54 | in which case `max_examples` examples are loaded for every label, or a list with the same size as 55 | `labels`, in which case at most `max_examples[i]` examples are loaded for label `labels[i]`. 56 | """ 57 | self._labels = labels 58 | self._examples = [] 59 | self._examples_per_label = defaultdict(int) 60 | 61 | if isinstance(max_examples, list): 62 | self._max_examples = dict(zip(self._labels, max_examples)) 63 | else: 64 | self._max_examples = {label: max_examples for label in self._labels} 65 | 66 | def is_full(self): 67 | """Return `true` iff no more examples can be added to this list""" 68 | for label in self._labels: 69 | if self._examples_per_label[label] < self._max_examples[label] or self._max_examples[label] < 0: 70 | return False 71 | return True 72 | 73 | def add(self, example: InputExample) -> bool: 74 | """ 75 | Add a new input example to this list. 76 | 77 | :param example: the example to add 78 | :returns: `true` iff the example was actually added to the list 79 | """ 80 | label = example.label 81 | if self._examples_per_label[label] < self._max_examples[label] or self._max_examples[label] < 0: 82 | self._examples_per_label[label] += 1 83 | self._examples.append(example) 84 | return True 85 | return False 86 | 87 | def to_list(self): 88 | return self._examples 89 | 90 | 91 | class DataProcessor(ABC): 92 | """ 93 | Abstract class that provides methods for loading training, testing, development and unlabeled examples for a given 94 | task 95 | """ 96 | 97 | @abstractmethod 98 | def get_train_examples(self, data_dir) -> List[InputExample]: 99 | """Get a collection of `InputExample`s for the train set.""" 100 | pass 101 | 102 | @abstractmethod 103 | def get_dev_examples(self, data_dir) -> List[InputExample]: 104 | """Get a collection of `InputExample`s for the dev set.""" 105 | pass 106 | 107 | @abstractmethod 108 | def get_test_examples(self, data_dir) -> List[InputExample]: 109 | """Get a collection of `InputExample`s for the test set.""" 110 | pass 111 | 112 | @abstractmethod 113 | def get_unlabeled_examples(self, data_dir) -> List[InputExample]: 114 | """Get a collection of `InputExample`s for the unlabeled set.""" 115 | pass 116 | 117 | @abstractmethod 118 | def get_labels(self) -> List[str]: 119 | """Get the list of labels for this data set.""" 120 | pass 121 | 122 | 123 | class MnliProcessor(DataProcessor): 124 | """Processor for the MultiNLI data set (GLUE version).""" 125 | 126 | def get_train_examples(self, data_dir): 127 | return self._create_examples(MnliProcessor._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 128 | 129 | def get_dev_examples(self, data_dir): 130 | return self._create_examples(MnliProcessor._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") 131 | 132 | def get_test_examples(self, data_dir) -> List[InputExample]: 133 | raise NotImplementedError() 134 | 135 | def get_unlabeled_examples(self, data_dir) -> List[InputExample]: 136 | return self.get_train_examples(data_dir) 137 | 138 | def get_labels(self): 139 | return ["contradiction", "entailment", "neutral"] 140 | 141 | @staticmethod 142 | def _create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 143 | examples = [] 144 | 145 | for (i, line) in enumerate(lines): 146 | if i == 0: 147 | continue 148 | guid = "%s-%s" % (set_type, line[0]) 149 | text_a = line[8] 150 | text_b = line[9] 151 | label = line[-1] 152 | 153 | example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) 154 | examples.append(example) 155 | 156 | return examples 157 | 158 | @staticmethod 159 | def _read_tsv(input_file, quotechar=None): 160 | with open(input_file, "r", encoding="utf-8-sig") as f: 161 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 162 | lines = [] 163 | for line in reader: 164 | lines.append(line) 165 | return lines 166 | 167 | 168 | class MnliMismatchedProcessor(MnliProcessor): 169 | """Processor for the MultiNLI mismatched data set (GLUE version).""" 170 | 171 | def get_dev_examples(self, data_dir): 172 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched") 173 | 174 | def get_test_examples(self, data_dir) -> List[InputExample]: 175 | raise NotImplementedError() 176 | 177 | 178 | class AgnewsProcessor(DataProcessor): 179 | """Processor for the AG news data set.""" 180 | 181 | def get_train_examples(self, data_dir): 182 | return self._create_examples(os.path.join(data_dir, "train.csv"), "train") 183 | 184 | def get_dev_examples(self, data_dir): 185 | return self._create_examples(os.path.join(data_dir, "test.csv"), "dev") 186 | 187 | def get_test_examples(self, data_dir) -> List[InputExample]: 188 | raise NotImplementedError() 189 | 190 | def get_unlabeled_examples(self, data_dir) -> List[InputExample]: 191 | return self.get_train_examples(data_dir) 192 | 193 | def get_labels(self): 194 | return ["1", "2", "3", "4"] 195 | 196 | @staticmethod 197 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 198 | examples = [] 199 | 200 | with open(path) as f: 201 | reader = csv.reader(f, delimiter=',') 202 | for idx, row in enumerate(reader): 203 | label, headline, body = row 204 | guid = "%s-%s" % (set_type, idx) 205 | text_a = headline.replace('\\', ' ') 206 | text_b = body.replace('\\', ' ') 207 | 208 | example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) 209 | examples.append(example) 210 | 211 | return examples 212 | 213 | 214 | class YahooAnswersProcessor(DataProcessor): 215 | """Processor for the Yahoo Answers data set.""" 216 | 217 | def get_train_examples(self, data_dir): 218 | return self._create_examples(os.path.join(data_dir, "train.csv"), "train") 219 | 220 | def get_dev_examples(self, data_dir): 221 | return self._create_examples(os.path.join(data_dir, "test.csv"), "dev") 222 | 223 | def get_test_examples(self, data_dir) -> List[InputExample]: 224 | raise NotImplementedError() 225 | 226 | def get_unlabeled_examples(self, data_dir) -> List[InputExample]: 227 | return self.get_train_examples(data_dir) 228 | 229 | def get_labels(self): 230 | return ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] 231 | 232 | @staticmethod 233 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 234 | examples = [] 235 | 236 | with open(path, encoding='utf8') as f: 237 | reader = csv.reader(f, delimiter=',') 238 | for idx, row in enumerate(reader): 239 | label, question_title, question_body, answer = row 240 | guid = "%s-%s" % (set_type, idx) 241 | text_a = ' '.join([question_title.replace('\\n', ' ').replace('\\', ' '), 242 | question_body.replace('\\n', ' ').replace('\\', ' ')]) 243 | text_b = answer.replace('\\n', ' ').replace('\\', ' ') 244 | 245 | example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) 246 | examples.append(example) 247 | 248 | return examples 249 | 250 | 251 | class YelpPolarityProcessor(DataProcessor): 252 | """Processor for the YELP binary classification set.""" 253 | 254 | def get_train_examples(self, data_dir): 255 | return self._create_examples(os.path.join(data_dir, "train.csv"), "train") 256 | 257 | def get_dev_examples(self, data_dir): 258 | return self._create_examples(os.path.join(data_dir, "test.csv"), "dev") 259 | 260 | def get_test_examples(self, data_dir) -> List[InputExample]: 261 | raise NotImplementedError() 262 | 263 | def get_unlabeled_examples(self, data_dir) -> List[InputExample]: 264 | return self.get_train_examples(data_dir) 265 | 266 | def get_labels(self): 267 | return ["1", "2"] 268 | 269 | @staticmethod 270 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 271 | examples = [] 272 | 273 | with open(path) as f: 274 | reader = csv.reader(f, delimiter=',') 275 | for idx, row in enumerate(reader): 276 | label, body = row 277 | guid = "%s-%s" % (set_type, idx) 278 | text_a = body.replace('\\n', ' ').replace('\\', ' ') 279 | 280 | example = InputExample(guid=guid, text_a=text_a, label=label) 281 | examples.append(example) 282 | 283 | return examples 284 | 285 | 286 | class YelpFullProcessor(YelpPolarityProcessor): 287 | """Processor for the YELP full classification set.""" 288 | 289 | def get_test_examples(self, data_dir) -> List[InputExample]: 290 | raise NotImplementedError() 291 | 292 | def get_labels(self): 293 | return ["1", "2", "3", "4", "5"] 294 | 295 | 296 | class XStanceProcessor(DataProcessor): 297 | """Processor for the X-Stance data set.""" 298 | 299 | def __init__(self, language: str = None): 300 | if language is not None: 301 | assert language in ['de', 'fr'] 302 | self.language = language 303 | 304 | def get_train_examples(self, data_dir): 305 | return self._create_examples(os.path.join(data_dir, "train.jsonl")) 306 | 307 | def get_dev_examples(self, data_dir): 308 | return self._create_examples(os.path.join(data_dir, "test.jsonl")) 309 | 310 | def get_test_examples(self, data_dir) -> List[InputExample]: 311 | raise NotImplementedError() 312 | 313 | def get_unlabeled_examples(self, data_dir) -> List[InputExample]: 314 | return self.get_train_examples(data_dir) 315 | 316 | def get_labels(self): 317 | return ["FAVOR", "AGAINST"] 318 | 319 | def _create_examples(self, path: str) -> List[InputExample]: 320 | examples = [] 321 | 322 | with open(path, encoding='utf8') as f: 323 | for line in f: 324 | example_json = json.loads(line) 325 | label = example_json['label'] 326 | id_ = example_json['id'] 327 | text_a = example_json['question'] 328 | text_b = example_json['comment'] 329 | language = example_json['language'] 330 | 331 | if self.language is not None and language != self.language: 332 | continue 333 | 334 | example = InputExample(guid=id_, text_a=text_a, text_b=text_b, label=label) 335 | examples.append(example) 336 | 337 | return examples 338 | 339 | 340 | class RteProcessor(DataProcessor): 341 | """Processor for the RTE data set.""" 342 | 343 | def __init__(self): 344 | self.mnli_processor = MnliProcessor() 345 | 346 | def get_train_examples(self, data_dir): 347 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 348 | 349 | def get_dev_examples(self, data_dir): 350 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 351 | 352 | def get_test_examples(self, data_dir): 353 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 354 | 355 | def get_unlabeled_examples(self, data_dir): 356 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 357 | 358 | def get_labels(self): 359 | return ["entailment", "not_entailment"] 360 | 361 | def _create_examples(self, path: str, set_type: str, hypothesis_name: str = "hypothesis", 362 | premise_name: str = "premise") -> List[InputExample]: 363 | examples = [] 364 | 365 | with open(path, encoding='utf8') as f: 366 | for line_idx, line in enumerate(f): 367 | example_json = json.loads(line) 368 | idx = example_json['idx'] 369 | if isinstance(idx, str): 370 | try: 371 | idx = int(idx) 372 | except ValueError: 373 | idx = line_idx 374 | label = example_json.get('label') 375 | orig_label = example_json.get('orig_label',label) 376 | guid = "%s-%s" % (set_type, idx) 377 | text_a = example_json[premise_name] 378 | text_b = example_json[hypothesis_name] 379 | 380 | example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx, orig_label=orig_label) 381 | examples.append(example) 382 | 383 | return examples 384 | 385 | 386 | class AxGProcessor(RteProcessor): 387 | """Processor for the AX-G diagnostic data set.""" 388 | 389 | def get_train_examples(self, data_dir): 390 | return self._create_examples(os.path.join(data_dir, "AX-g.jsonl"), "train") 391 | 392 | def get_test_examples(self, data_dir): 393 | return self._create_examples(os.path.join(data_dir, "AX-g.jsonl"), "test") 394 | 395 | 396 | class AxBProcessor(RteProcessor): 397 | """Processor for the AX-B diagnostic data set.""" 398 | 399 | def get_train_examples(self, data_dir): 400 | return self._create_examples(os.path.join(data_dir, "AX-b.jsonl"), "train") 401 | 402 | def get_test_examples(self, data_dir): 403 | return self._create_examples(os.path.join(data_dir, "AX-b.jsonl"), "test") 404 | 405 | def _create_examples(self, path, set_type, hypothesis_name="sentence2", premise_name="sentence1"): 406 | return super()._create_examples(path, set_type, hypothesis_name, premise_name) 407 | 408 | 409 | class CbProcessor(RteProcessor): 410 | """Processor for the CB data set.""" 411 | 412 | def get_labels(self): 413 | return ["entailment", "contradiction", "neutral"] 414 | 415 | 416 | class WicProcessor(DataProcessor): 417 | """Processor for the WiC data set.""" 418 | 419 | def get_train_examples(self, data_dir): 420 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 421 | 422 | def get_dev_examples(self, data_dir): 423 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 424 | 425 | def get_test_examples(self, data_dir): 426 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 427 | 428 | def get_unlabeled_examples(self, data_dir): 429 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 430 | 431 | def get_labels(self): 432 | return ["F", "T"] 433 | 434 | @staticmethod 435 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 436 | examples = [] 437 | with open(path, encoding='utf8') as f: 438 | for line in f: 439 | example_json = json.loads(line) 440 | idx = example_json['idx'] 441 | if isinstance(idx, str): 442 | idx = int(idx) 443 | label = "T" if example_json.get('label') else "F" 444 | if example_json.get('orig_label') is None: 445 | orig_label=label 446 | else: 447 | orig_label = "T" if example_json.get('orig_label') else "F" 448 | guid = "%s-%s" % (set_type, idx) 449 | text_a = example_json['sentence1'] 450 | text_b = example_json['sentence2'] 451 | meta = {'word': example_json['word']} 452 | example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx, meta=meta, orig_label=orig_label) 453 | examples.append(example) 454 | return examples 455 | 456 | 457 | class WscProcessor(DataProcessor): 458 | """Processor for the WSC data set.""" 459 | 460 | def get_train_examples(self, data_dir): 461 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 462 | 463 | def get_dev_examples(self, data_dir): 464 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 465 | 466 | def get_test_examples(self, data_dir): 467 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 468 | 469 | def get_unlabeled_examples(self, data_dir): 470 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 471 | 472 | def get_labels(self): 473 | return ["False", "True"] 474 | 475 | @staticmethod 476 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 477 | examples = [] 478 | 479 | with open(path, encoding='utf8') as f: 480 | for line in f: 481 | example_json = json.loads(line) 482 | idx = example_json['idx'] 483 | label = str(example_json['label']) if 'label' in example_json else None 484 | orig_label = str(example_json['orig_label']) if 'orig_label' in example_json else label 485 | guid = "%s-%s" % (set_type, idx) 486 | text_a = example_json['text'] 487 | meta = { 488 | 'span1_text': example_json['target']['span1_text'], 489 | 'span2_text': example_json['target']['span2_text'], 490 | 'span1_index': example_json['target']['span1_index'], 491 | 'span2_index': example_json['target']['span2_index'] 492 | } 493 | 494 | # the indices in the dataset are wrong for some examples, so we manually fix them 495 | span1_index, span1_text = meta['span1_index'], meta['span1_text'] 496 | span2_index, span2_text = meta['span2_index'], meta['span2_text'] 497 | words_a = text_a.split() 498 | words_a_lower = text_a.lower().split() 499 | words_span1_text = span1_text.lower().split() 500 | span1_len = len(words_span1_text) 501 | 502 | if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text: 503 | for offset in [-1, +1]: 504 | if words_a_lower[span1_index + offset:span1_index + span1_len + offset] == words_span1_text: 505 | span1_index += offset 506 | 507 | if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text: 508 | logger.warning(f"Got '{words_a_lower[span1_index:span1_index + span1_len]}' but expected " 509 | f"'{words_span1_text}' at index {span1_index} for '{words_a}'") 510 | 511 | if words_a[span2_index] != span2_text: 512 | for offset in [-1, +1]: 513 | if words_a[span2_index + offset] == span2_text: 514 | span2_index += offset 515 | 516 | if words_a[span2_index] != span2_text and words_a[span2_index].startswith(span2_text): 517 | words_a = words_a[:span2_index] \ 518 | + [words_a[span2_index][:len(span2_text)], words_a[span2_index][len(span2_text):]] \ 519 | + words_a[span2_index + 1:] 520 | 521 | assert words_a[span2_index] == span2_text, \ 522 | f"Got '{words_a[span2_index]}' but expected '{span2_text}' at index {span2_index} for '{words_a}'" 523 | 524 | text_a = ' '.join(words_a) 525 | meta['span1_index'], meta['span2_index'] = span1_index, span2_index 526 | 527 | example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx, orig_label=orig_label) 528 | if set_type == 'train' and label != 'True': 529 | continue 530 | examples.append(example) 531 | 532 | return examples 533 | 534 | 535 | class BoolQProcessor(DataProcessor): 536 | """Processor for the BoolQ data set.""" 537 | 538 | def get_train_examples(self, data_dir): 539 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 540 | 541 | def get_dev_examples(self, data_dir): 542 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 543 | 544 | def get_test_examples(self, data_dir): 545 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 546 | 547 | def get_unlabeled_examples(self, data_dir): 548 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 549 | 550 | def get_labels(self): 551 | return ["False", "True"] 552 | 553 | @staticmethod 554 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 555 | examples = [] 556 | 557 | with open(path, encoding='utf8') as f: 558 | for line in f: 559 | example_json = json.loads(line) 560 | idx = example_json['idx'] 561 | label = str(example_json['label']) if 'label' in example_json else None 562 | orig_label = str(example_json['orig_label']) if 'orig_label' in example_json else label 563 | guid = "%s-%s" % (set_type, idx) 564 | text_a = example_json['passage'] 565 | text_b = example_json['question'] 566 | example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx, orig_label=orig_label) 567 | examples.append(example) 568 | 569 | return examples 570 | 571 | 572 | class CopaProcessor(DataProcessor): 573 | """Processor for the COPA data set.""" 574 | 575 | def get_train_examples(self, data_dir): 576 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 577 | 578 | def get_dev_examples(self, data_dir): 579 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 580 | 581 | def get_test_examples(self, data_dir): 582 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 583 | 584 | def get_unlabeled_examples(self, data_dir): 585 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 586 | 587 | def get_labels(self): 588 | return ["0", "1"] 589 | 590 | @staticmethod 591 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 592 | examples = [] 593 | 594 | with open(path, encoding='utf8') as f: 595 | for line in f: 596 | example_json = json.loads(line) 597 | label = str(example_json['label']) if 'label' in example_json else None 598 | orig_label = str(example_json['orig_label']) if 'orig_label' in example_json else label 599 | idx = example_json['idx'] 600 | guid = "%s-%s" % (set_type, idx) 601 | text_a = example_json['premise'] 602 | meta = { 603 | 'choice1': example_json['choice1'], 604 | 'choice2': example_json['choice2'], 605 | 'question': example_json['question'] 606 | } 607 | example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx, orig_label=orig_label) 608 | examples.append(example) 609 | 610 | if set_type == 'train' or set_type == 'unlabeled': 611 | mirror_examples = [] 612 | for ex in examples: 613 | label = "1" if ex.label == "0" else "0" 614 | orig_label = "1" if ex.orig_label == "0" else "0" 615 | meta = { 616 | 'choice1': ex.meta['choice2'], 617 | 'choice2': ex.meta['choice1'], 618 | 'question': ex.meta['question'] 619 | } 620 | mirror_example = InputExample(guid=ex.guid + 'm', text_a=ex.text_a, label=label, meta=meta, orig_label=orig_label) 621 | mirror_examples.append(mirror_example) 622 | examples += mirror_examples 623 | logger.info(f"Added {len(mirror_examples)} mirror examples, total size is {len(examples)}...") 624 | return examples 625 | 626 | 627 | class MultiRcProcessor(DataProcessor): 628 | """Processor for the MultiRC data set.""" 629 | 630 | def get_train_examples(self, data_dir): 631 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 632 | 633 | def get_dev_examples(self, data_dir): 634 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 635 | 636 | def get_test_examples(self, data_dir): 637 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 638 | 639 | def get_unlabeled_examples(self, data_dir): 640 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 641 | 642 | def get_labels(self): 643 | return ["0", "1"] 644 | 645 | @staticmethod 646 | def _create_examples(path: str, set_type: str) -> List[InputExample]: 647 | examples = [] 648 | 649 | with open(path, encoding='utf8') as f: 650 | for line in f: 651 | example_json = json.loads(line) 652 | 653 | passage_idx = example_json['idx'] 654 | text = example_json['passage']['text'] 655 | questions = example_json['passage']['questions'] 656 | for question_json in questions: 657 | question = question_json["question"] 658 | question_idx = question_json['idx'] 659 | answers = question_json["answers"] 660 | for answer_json in answers: 661 | label = str(answer_json["label"]) if 'label' in answer_json else None 662 | orig_label =str(answer_json["orig_label"]) if "orig_label" in answer_json else label 663 | answer_idx = answer_json["idx"] 664 | guid = f'{set_type}-p{passage_idx}-q{question_idx}-a{answer_idx}' 665 | meta = { 666 | 'passage_idx': passage_idx, 667 | 'question_idx': question_idx, 668 | 'answer_idx': answer_idx, 669 | 'answer': answer_json["text"] 670 | } 671 | idx = [passage_idx, question_idx, answer_idx] 672 | example = InputExample(guid=guid, text_a=text, text_b=question, label=label, meta=meta, idx=idx, orig_label=orig_label) 673 | examples.append(example) 674 | 675 | question_indices = list(set(example.meta['question_idx'] for example in examples)) 676 | label_distribution = Counter(example.label for example in examples) 677 | logger.info(f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label " 678 | f"distribution {list(label_distribution.items())}") 679 | return examples 680 | 681 | 682 | class RecordProcessor(DataProcessor): 683 | """Processor for the ReCoRD data set.""" 684 | 685 | def get_train_examples(self, data_dir): 686 | return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") 687 | 688 | def get_dev_examples(self, data_dir): 689 | return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") 690 | 691 | def get_test_examples(self, data_dir): 692 | return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") 693 | 694 | def get_unlabeled_examples(self, data_dir): 695 | return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") 696 | 697 | def get_labels(self): 698 | return ["0", "1"] 699 | 700 | @staticmethod 701 | def _create_examples(path, set_type, seed=42, max_train_candidates_per_question: int = 10) -> List[InputExample]: 702 | examples = [] 703 | 704 | entity_shuffler = random.Random(seed) 705 | 706 | with open(path, encoding='utf8') as f: 707 | for idx, line in enumerate(f): 708 | example_json = json.loads(line) 709 | 710 | idx = example_json['idx'] 711 | text = example_json['passage']['text'] 712 | entities = set() 713 | if 'entity_names' in example_json['passage']: 714 | for entity_name in example_json['passage']['entity_names']: 715 | entities.add(entity_name) 716 | else: 717 | for entity_json in example_json['passage']['entities']: 718 | start = entity_json['start'] 719 | end = entity_json['end'] 720 | entity = text[start:end + 1] 721 | entities.add(entity) 722 | 723 | entities = list(entities) 724 | 725 | text = text.replace("@highlight\n", "- ") # we follow the GPT-3 paper wrt @highlight annotations 726 | questions = example_json['qas'] 727 | 728 | for question_json in questions: 729 | question = question_json['query'] 730 | question_idx = question_json['idx'] 731 | answers = set() 732 | 733 | for answer_json in question_json.get('answers', []): 734 | answer = answer_json['text'] 735 | answers.add(answer) 736 | 737 | answers = list(answers) 738 | 739 | if set_type == 'train': 740 | # create a single example per *correct* answer 741 | for answer_idx, answer in enumerate(answers): 742 | candidates = [ent for ent in entities if ent not in answers] 743 | if len(candidates) > max_train_candidates_per_question - 1: 744 | entity_shuffler.shuffle(candidates) 745 | candidates = candidates[:max_train_candidates_per_question - 1] 746 | 747 | guid = f'{set_type}-p{idx}-q{question_idx}-a{answer_idx}' 748 | meta = { 749 | 'passage_idx': idx, 750 | 'question_idx': question_idx, 751 | 'candidates': [answer] + candidates, 752 | 'answers': [answer] 753 | } 754 | ex_idx = [idx, question_idx, answer_idx] 755 | example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta, 756 | idx=ex_idx) 757 | examples.append(example) 758 | 759 | else: 760 | # create just one example with *all* correct answers and *all* answer candidates 761 | guid = f'{set_type}-p{idx}-q{question_idx}' 762 | meta = { 763 | 'passage_idx': idx, 764 | 'question_idx': question_idx, 765 | 'candidates': entities, 766 | 'answers': answers 767 | } 768 | example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta) 769 | examples.append(example) 770 | 771 | question_indices = list(set(example.meta['question_idx'] for example in examples)) 772 | label_distribution = Counter(example.label for example in examples) 773 | logger.info(f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label " 774 | f"distribution {list(label_distribution.items())}") 775 | return examples 776 | 777 | 778 | PROCESSORS = { 779 | "mnli": MnliProcessor, 780 | "mnli-mm": MnliMismatchedProcessor, 781 | "agnews": AgnewsProcessor, 782 | "yahoo": YahooAnswersProcessor, 783 | "yelp-polarity": YelpPolarityProcessor, 784 | "yelp-full": YelpFullProcessor, 785 | "xstance-de": lambda: XStanceProcessor("de"), 786 | "xstance-fr": lambda: XStanceProcessor("fr"), 787 | "xstance": XStanceProcessor, 788 | "wic": WicProcessor, 789 | "rte": RteProcessor, 790 | "cb": CbProcessor, 791 | "wsc": WscProcessor, 792 | "boolq": BoolQProcessor, 793 | "copa": CopaProcessor, 794 | "multirc": MultiRcProcessor, 795 | "record": RecordProcessor, 796 | "ax-g": AxGProcessor, 797 | "ax-b": AxBProcessor, 798 | } # type: Dict[str,Callable[[],DataProcessor]] 799 | 800 | TASK_HELPERS = { 801 | "wsc": task_helpers.WscTaskHelper, 802 | "multirc": task_helpers.MultiRcTaskHelper, 803 | "copa": task_helpers.CopaTaskHelper, 804 | "record": task_helpers.RecordTaskHelper, 805 | } 806 | 807 | METRICS = { 808 | "cb": ["acc", "f1-macro"], 809 | "multirc": ["acc", "f1", "em"] 810 | } 811 | 812 | DEFAULT_METRICS = ["acc"] 813 | 814 | TRAIN_SET = "train" 815 | DEV_SET = "dev" 816 | TEST_SET = "test" 817 | UNLABELED_SET = "unlabeled" 818 | 819 | SET_TYPES = [TRAIN_SET, DEV_SET, TEST_SET, UNLABELED_SET] 820 | 821 | 822 | def load_examples(task, data_dir: str, set_type: str, *_, num_examples: int = None, 823 | num_examples_per_label: int = None, seed: int = 42) -> List[InputExample]: 824 | """Load examples for a given task.""" 825 | assert (num_examples is not None) ^ (num_examples_per_label is not None), \ 826 | "Exactly one of 'num_examples' and 'num_examples_per_label' must be set." 827 | assert (not set_type == UNLABELED_SET) or (num_examples is not None), \ 828 | "For unlabeled data, 'num_examples_per_label' is not allowed" 829 | 830 | processor = PROCESSORS[task]() 831 | 832 | ex_str = f"num_examples={num_examples}" if num_examples is not None \ 833 | else f"num_examples_per_label={num_examples_per_label}" 834 | logger.info( 835 | f"Creating features from dataset file at {data_dir} ({ex_str}, set_type={set_type})" 836 | ) 837 | 838 | if set_type == DEV_SET: 839 | examples = processor.get_dev_examples(data_dir) 840 | elif set_type == TEST_SET: 841 | examples = processor.get_test_examples(data_dir) 842 | elif set_type == TRAIN_SET: 843 | examples = processor.get_train_examples(data_dir) 844 | elif set_type == UNLABELED_SET: 845 | examples = processor.get_unlabeled_examples(data_dir) 846 | for example in examples: 847 | example.label = processor.get_labels()[0] 848 | else: 849 | raise ValueError(f"'set_type' must be one of {SET_TYPES}, got '{set_type}' instead") 850 | 851 | if num_examples is not None: 852 | examples = _shuffle_and_restrict(examples, num_examples, seed) 853 | 854 | elif num_examples_per_label is not None: 855 | limited_examples = LimitedExampleList(processor.get_labels(), num_examples_per_label) 856 | for example in examples: 857 | limited_examples.add(example) 858 | examples = limited_examples.to_list() 859 | 860 | label_distribution = Counter(example.label for example in examples) 861 | logger.info(f"Returning {len(examples)} {set_type} examples with label dist.: {list(label_distribution.items())}") 862 | 863 | return examples 864 | -------------------------------------------------------------------------------- /pet/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | import copy 14 | import json 15 | import pickle 16 | import random 17 | import string 18 | from collections import defaultdict 19 | from typing import Dict, List, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from torch.nn import functional as F 24 | from torch.utils.data import Dataset 25 | from transformers import PreTrainedTokenizer, GPT2Tokenizer 26 | 27 | 28 | class LogitsList: 29 | """A list of logits obtained from a finetuned PET model""" 30 | 31 | def __init__(self, score: float, logits: List[List[float]]): 32 | """ 33 | Create a new LogitsList. 34 | 35 | :param score: the corresponding PET model's score on the training set 36 | :param logits: the list of logits, where ``logits[i][j]`` is the score for label ``j`` at example ``i`` 37 | """ 38 | self.score = score 39 | self.logits = logits 40 | 41 | def __repr__(self): 42 | return 'LogitsList(score={}, logits[:2]={})'.format(self.score, self.logits[:2]) 43 | 44 | def save(self, path: str) -> None: 45 | """Save this list to a file.""" 46 | with open(path, 'w') as fh: 47 | fh.write(str(self.score) + '\n') 48 | for example_logits in self.logits: 49 | fh.write(' '.join(str(logit) for logit in example_logits) + '\n') 50 | 51 | @staticmethod 52 | def load(path: str, with_score: bool = True) -> 'LogitsList': 53 | """Load a list from a file""" 54 | score = -1 55 | logits = [] 56 | with open(path, 'r') as fh: 57 | for line_idx, line in enumerate(fh.readlines()): 58 | line = line.rstrip('\n') 59 | if line_idx == 0 and with_score: 60 | score = float(line) 61 | else: 62 | logits.append([float(x) for x in line.split()]) 63 | return LogitsList(score=score, logits=logits) 64 | 65 | 66 | class InputExample(object): 67 | """A raw input example consisting of one or two segments of text and a label""" 68 | 69 | def __init__(self, guid, text_a, text_b=None, label=None, logits=None, meta: Optional[Dict] = None, idx=-1, orig_label=None): 70 | """ 71 | Create a new InputExample. 72 | 73 | :param guid: a unique textual identifier 74 | :param text_a: the sequence of text 75 | :param text_b: an optional, second sequence of text 76 | :param label: an optional label 77 | :param logits: an optional list of per-class logits 78 | :param meta: an optional dictionary to store arbitrary meta information 79 | :param idx: an optional numeric index 80 | """ 81 | self.guid = guid 82 | self.text_a = text_a 83 | self.text_b = text_b 84 | self.label = label 85 | self.logits = logits 86 | self.idx = idx 87 | self.meta = meta if meta else {} 88 | self.orig_label=orig_label 89 | 90 | def __repr__(self): 91 | return str(self.to_json_string()) 92 | 93 | def to_dict(self): 94 | """Serialize this instance to a Python dictionary.""" 95 | output = copy.deepcopy(self.__dict__) 96 | return output 97 | 98 | def to_json_string(self): 99 | """Serialize this instance to a JSON string.""" 100 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 101 | 102 | @staticmethod 103 | def load_examples(path: str) -> List['InputExample']: 104 | """Load a set of input examples from a file""" 105 | with open(path, 'rb') as fh: 106 | return pickle.load(fh) 107 | 108 | @staticmethod 109 | def save_examples(examples: List['InputExample'], path: str) -> None: 110 | """Save a set of input examples to a file""" 111 | with open(path, 'wb') as fh: 112 | pickle.dump(examples, fh) 113 | 114 | 115 | class InputFeatures(object): 116 | """A set of numeric features obtained from an :class:`InputExample`""" 117 | 118 | def __init__(self, input_ids, attention_mask, token_type_ids, label, mlm_labels=None, logits=None, 119 | meta: Optional[Dict] = None, idx=-1): 120 | """ 121 | Create new InputFeatures. 122 | 123 | :param input_ids: the input ids corresponding to the original text or text sequence 124 | :param attention_mask: an attention mask, with 0 = no attention, 1 = attention 125 | :param token_type_ids: segment ids as used by BERT 126 | :param label: the label 127 | :param mlm_labels: an optional sequence of labels used for auxiliary language modeling 128 | :param logits: an optional sequence of per-class logits 129 | :param meta: an optional dictionary to store arbitrary meta information 130 | :param idx: an optional numeric index 131 | """ 132 | self.input_ids = input_ids 133 | self.attention_mask = attention_mask 134 | self.token_type_ids = token_type_ids 135 | self.label = label 136 | self.mlm_labels = mlm_labels 137 | self.logits = logits 138 | self.idx = idx 139 | self.meta = meta if meta else {} 140 | 141 | def __repr__(self): 142 | return str(self.to_json_string()) 143 | 144 | def pretty_print(self, tokenizer): 145 | return f'input_ids = {tokenizer.convert_ids_to_tokens(self.input_ids)}\n' + \ 146 | f'attention_mask = {self.attention_mask}\n' + \ 147 | f'token_type_ids = {self.token_type_ids}\n' + \ 148 | f'mlm_labels = {self.mlm_labels}\n' + \ 149 | f'logits = {self.logits}\n' + \ 150 | f'label = {self.label}' 151 | 152 | def to_dict(self): 153 | """Serialize this instance to a Python dictionary.""" 154 | output = copy.deepcopy(self.__dict__) 155 | return output 156 | 157 | def to_json_string(self): 158 | """Serialize this instance to a JSON string.""" 159 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 160 | 161 | 162 | class PLMInputFeatures(InputFeatures): 163 | """A set of numeric input features for a model pretrained with a permuted language modeling objective.""" 164 | 165 | def __init__(self, *_, perm_mask, target_mapping, **kwargs): 166 | super().__init__(**kwargs) 167 | self.perm_mask = perm_mask 168 | self.target_mapping = target_mapping 169 | 170 | def pretty_print(self, tokenizer): 171 | return super().pretty_print(tokenizer) + '\n' + \ 172 | f'perm_mask = {self.perm_mask}\n' + \ 173 | f'target_mapping = {self.target_mapping}' 174 | 175 | 176 | class DictDataset(Dataset): 177 | """A dataset of tensors that uses a dictionary for key-value mappings""" 178 | 179 | def __init__(self, **tensors): 180 | tensors.values() 181 | 182 | assert all(next(iter(tensors.values())).size(0) == tensor.size(0) for tensor in tensors.values()) 183 | self.tensors = tensors 184 | 185 | def __getitem__(self, index): 186 | return {key: tensor[index] for key, tensor in self.tensors.items()} 187 | 188 | def __len__(self): 189 | return next(iter(self.tensors.values())).size(0) 190 | 191 | 192 | def set_seed(seed: int): 193 | """ Set RNG seeds for python's `random` module, numpy and torch""" 194 | random.seed(seed) 195 | np.random.seed(seed) 196 | torch.manual_seed(seed) 197 | if torch.cuda.is_available(): 198 | torch.cuda.manual_seed_all(seed) 199 | 200 | 201 | def eq_div(N, i): 202 | """ Equally divide N examples among i buckets. For example, `eq_div(12,3) = [4,4,4]`. """ 203 | return [] if i <= 0 else [N // i + 1] * (N % i) + [N // i] * (i - N % i) 204 | 205 | 206 | def chunks(lst, n): 207 | """Yield successive n-sized chunks from lst.""" 208 | for i in range(0, len(lst), n): 209 | yield lst[i:i + n] 210 | 211 | 212 | def remove_final_punc(s: str): 213 | """Remove the last character from a string if it is some form of punctuation""" 214 | return s.rstrip(string.punctuation) 215 | 216 | 217 | def lowercase_first(s: str): 218 | """Lowercase the first letter of a string""" 219 | return s[0].lower() + s[1:] 220 | 221 | 222 | def save_logits(path: str, logits: np.ndarray): 223 | """Save an array of logits to a file""" 224 | with open(path, 'w') as fh: 225 | for example_logits in logits: 226 | fh.write(' '.join(str(logit) for logit in example_logits) + '\n') 227 | pass 228 | 229 | 230 | def save_predictions(path: str, wrapper, results: Dict): 231 | """Save a sequence of predictions to a file""" 232 | predictions_with_idx = [] 233 | 234 | if wrapper.task_helper and wrapper.task_helper.output: 235 | predictions_with_idx = wrapper.task_helper.output 236 | else: 237 | inv_label_map = {idx: label for label, idx in wrapper.preprocessor.label_map.items()} 238 | for idx, prediction_idx in zip(results['indices'], results['predictions']): 239 | prediction = inv_label_map[prediction_idx] 240 | idx = idx.tolist() if isinstance(idx, np.ndarray) else int(idx) 241 | predictions_with_idx.append({'idx': idx, 'label': prediction}) 242 | 243 | with open(path, 'w', encoding='utf8') as fh: 244 | for line in predictions_with_idx: 245 | fh.write(json.dumps(line) + '\n') 246 | 247 | 248 | def softmax(x, temperature=1.0, axis=None): 249 | """Custom softmax implementation""" 250 | y = np.atleast_2d(x) 251 | 252 | if axis is None: 253 | axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1) 254 | 255 | y = y * float(temperature) 256 | y = y - np.expand_dims(np.max(y, axis=axis), axis) 257 | y = np.exp(y) 258 | 259 | ax_sum = np.expand_dims(np.sum(y, axis=axis), axis) 260 | p = y / ax_sum 261 | 262 | if len(x.shape) == 1: 263 | p = p.flatten() 264 | return p 265 | 266 | 267 | def get_verbalization_ids(word: str, tokenizer: PreTrainedTokenizer, force_single_token: bool) -> Union[int, List[int]]: 268 | """ 269 | Get the token ids corresponding to a verbalization 270 | 271 | :param word: the verbalization 272 | :param tokenizer: the tokenizer to use 273 | :param force_single_token: whether it should be enforced that the verbalization corresponds to a single token. 274 | If set to true, this method returns a single int instead of a list and throws an error if the word 275 | corresponds to multiple tokens. 276 | :return: either the list of token ids or the single token id corresponding to this word 277 | """ 278 | kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {} 279 | ids = tokenizer.encode(word, add_special_tokens=False, **kwargs) 280 | if not force_single_token: 281 | return ids 282 | assert len(ids) == 1, \ 283 | f'Verbalization "{word}" does not correspond to a single token, got {tokenizer.convert_ids_to_tokens(ids)}' 284 | verbalization_id = ids[0] 285 | assert verbalization_id not in tokenizer.all_special_ids, \ 286 | f'Verbalization {word} is mapped to a special token {tokenizer.convert_ids_to_tokens(verbalization_id)}' 287 | return verbalization_id 288 | 289 | 290 | def trim_input_ids(input_ids: torch.tensor, pad_token_id, mask_token_id, num_masks: int): 291 | """ 292 | Trim a sequence of input ids by removing all padding tokens and keeping at most a specific number of mask tokens. 293 | 294 | :param input_ids: the sequence of input token ids 295 | :param pad_token_id: the id of the pad token 296 | :param mask_token_id: the id of the mask tokens 297 | :param num_masks: the number of masks to keeps 298 | :return: the trimmed sequence of input ids 299 | """ 300 | assert input_ids.shape[0] == 1 301 | input_ids_without_pad = [x for x in input_ids[0] if x != pad_token_id] 302 | 303 | trimmed_input_ids = [] 304 | mask_count = 0 305 | for input_id in input_ids_without_pad: 306 | if input_id == mask_token_id: 307 | if mask_count >= num_masks: 308 | continue 309 | mask_count += 1 310 | trimmed_input_ids.append(input_id) 311 | 312 | return torch.tensor([trimmed_input_ids], dtype=torch.long, device=input_ids.device) 313 | 314 | 315 | def exact_match(predictions: np.ndarray, actuals: np.ndarray, question_ids: np.ndarray): 316 | """Compute the exact match (EM) for a sequence of predictions and actual labels""" 317 | unique_questions = set(question_ids) 318 | 319 | q_actuals = list(zip(question_ids, actuals)) 320 | q_predictions = list(zip(question_ids, predictions)) 321 | 322 | actuals_per_question = defaultdict(list) 323 | predictions_per_question = defaultdict(list) 324 | 325 | for qid, val in q_actuals: 326 | actuals_per_question[qid].append(val) 327 | for qid, val in q_predictions: 328 | predictions_per_question[qid].append(val) 329 | 330 | em = 0 331 | for qid in unique_questions: 332 | if actuals_per_question[qid] == predictions_per_question[qid]: 333 | em += 1 334 | em /= len(unique_questions) 335 | 336 | return em 337 | 338 | 339 | def distillation_loss(predictions, targets, temperature): 340 | """Compute the distillation loss (KL divergence between predictions and targets) as described in the PET paper""" 341 | p = F.log_softmax(predictions / temperature, dim=1) 342 | q = F.softmax(targets / temperature, dim=1) 343 | return F.kl_div(p, q, reduction='sum') * (temperature ** 2) / predictions.shape[0] 344 | -------------------------------------------------------------------------------- /petal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from collections import Counter 5 | from typing import Dict, List 6 | 7 | import numpy as np 8 | import random 9 | import torch 10 | from transformers import PreTrainedTokenizer, RobertaTokenizer 11 | 12 | from pet.tasks import PROCESSORS, load_examples, TRAIN_SET 13 | from pet.utils import InputExample, eq_div 14 | from pet.wrapper import TransformerModelWrapper, MODEL_CLASSES, WrapperConfig 15 | import log 16 | 17 | logger = log.get_logger('root') 18 | 19 | 20 | def filter_words(tokens: List[str], word_counts=None, max_words: int = -1): 21 | """ 22 | Given a list of tokens, return a reduced list that contains only tokens from the list that correspond 23 | to actual words and occur a given number of times. 24 | :param tokens: the list of tokens to filter 25 | :param word_counts: a dictionary mapping words to their number of occurrences 26 | :param max_words: if set to a value >0, only the `max_words` most frequent words according to `word_counts` are kept 27 | :return: the filtered list of tokens 28 | """ 29 | tokens = (word for word in tokens if word[0] == 'Ġ' and len([char for char in word[1:] if char.isalpha()]) >= 2) 30 | if word_counts and max_words > 0: 31 | tokens = sorted(tokens, key=lambda word: word_counts[word[1:]], reverse=True)[:max_words] 32 | return tokens 33 | 34 | 35 | def get_word_to_id_map(tokenizer: PreTrainedTokenizer, word_counts=None, max_words: int = -1): 36 | """ 37 | Return a mapping from all tokens to their internal ids for a given tokenizer 38 | :param tokenizer: the tokenizer 39 | :param word_counts: a dictionary mapping words to their number of occurrences 40 | :param max_words: if set to a value >0, only the `max_words` most frequent words according to `word_counts` are kept 41 | :return: 42 | """ 43 | if not isinstance(tokenizer, RobertaTokenizer): 44 | raise ValueError("this function currently only supports instances of 'RobertaTokenizer'") 45 | 46 | words = filter_words(tokenizer.encoder.keys(), word_counts, max_words) 47 | word2id = {word[1:]: tokenizer.convert_tokens_to_ids(word) for word in words} 48 | logger.info(f"There are {len(word2id)} words left after filtering non-word tokens") 49 | return word2id 50 | 51 | 52 | class AutomaticVerbalizerSearch: 53 | 54 | def __init__(self, word2idx: Dict[str, int], labels: List[str], logits_list: List[np.ndarray], 55 | expected: Dict[str, np.ndarray]): 56 | self.word2idx = word2idx 57 | self.labels = labels 58 | self.expected = expected 59 | 60 | logits_list = [np.exp(logits) for logits in logits_list] 61 | self.probs_list = [logits / np.expand_dims(np.sum(logits, axis=1), axis=1) for logits in logits_list] 62 | 63 | def _get_candidates(self, num_candidates: int) -> Dict[str, List[str]]: 64 | if num_candidates <= 0: 65 | return {label: self.word2idx.keys() for label in self.labels} 66 | 67 | scores = {label: Counter() for label in self.labels} 68 | 69 | for label in self.labels: 70 | for probs in self.probs_list: 71 | for word, idx in self.word2idx.items(): 72 | score = np.sum(np.log(probs[:, idx]) * self.expected[label]) 73 | scores[label][word] += score 74 | 75 | return {label: [w for w, _ in scores[label].most_common(num_candidates)] for label in self.labels} 76 | 77 | def _get_top_words(self, candidates: Dict[str, List[str]], normalize: bool = True, words_per_label: int = 10, 78 | score_fct: str = 'llr') -> Dict[str, List[str]]: 79 | 80 | scores = {label: Counter() for label in self.labels} 81 | 82 | for label in self.labels: 83 | for probs in self.probs_list: 84 | for word in candidates[label]: 85 | idx = self.word2idx[word] 86 | if score_fct == 'llr': 87 | scores[label][word] += self.log_likelihood_ratio(probs[:, idx], self.expected[label], normalize) 88 | elif score_fct == 'ce': 89 | scores[label][word] += self.cross_entropy(probs[:, idx], self.expected[label], normalize) 90 | else: 91 | raise ValueError(f"Score function '{score_fct}' not implemented") 92 | 93 | return {label: [w for w, _ in scores[label].most_common(words_per_label)] for label in self.labels} 94 | 95 | @staticmethod 96 | def log_likelihood_ratio(predictions: np.ndarray, expected: np.ndarray, normalize: bool) -> float: 97 | scale_factor = sum(1 - expected) / sum(expected) if normalize else 1 98 | pos_score = scale_factor * (np.sum(np.log(predictions) * expected) - np.sum(np.log(1 - predictions) * expected)) 99 | neg_score = np.sum(np.log(1 - predictions) * (1 - expected)) - np.sum(np.log(predictions) * (1 - expected)) 100 | return pos_score + neg_score 101 | 102 | @staticmethod 103 | def cross_entropy(predictions: np.ndarray, expected: np.ndarray, normalize: bool) -> float: 104 | scale_factor = sum(1 - expected) / sum(expected) if normalize else 1 105 | pos_score = scale_factor * np.sum(np.log(predictions) * expected) 106 | neg_score = np.sum(np.log(1 - predictions) * (1 - expected)) 107 | return pos_score + neg_score 108 | 109 | def find_verbalizer(self, words_per_label: int = 10, num_candidates: int = 1000, normalize: bool = True, 110 | score_fct: str = 'llr'): 111 | if score_fct == 'random': 112 | return {label: random.sample(self.word2idx.keys(), words_per_label) for label in self.labels} 113 | 114 | candidates = self._get_candidates(num_candidates=num_candidates) 115 | return self._get_top_words(candidates=candidates, normalize=normalize, words_per_label=words_per_label, 116 | score_fct=score_fct) 117 | 118 | 119 | def main(): 120 | parser = argparse.ArgumentParser() 121 | 122 | # required parameters 123 | parser.add_argument("--output_dir", default=None, type=str, required=True, 124 | help="The output directory. The verbalizers are written to a file 'verbalizer.json' in this directory.") 125 | parser.add_argument("--data_dir", default=None, type=str, required=True, 126 | help="The input data dir. Should contain the data files for the task.") 127 | parser.add_argument("--model_type", default=None, type=str, required=True, 128 | help="The model type") 129 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 130 | help="Path to pre-trained model or shortcut name") 131 | parser.add_argument("--task_name", default=None, type=str, required=True, 132 | help="The name of the task to train selected in the list: " + ", ".join(PROCESSORS.keys())) 133 | 134 | # verbalizer search hyperparameters 135 | parser.add_argument("--normalize", action='store_true', 136 | help="Whether to normalize the loss as proposed in the paper. It is recommended to set this to 'true'.") 137 | parser.add_argument("--combine_patterns", action='store_true', 138 | help="If set to true, a single joint verbalizer is searched for all patterns") 139 | parser.add_argument("--num_candidates", default=1000, type=int, 140 | help="The number of candidate tokens to consider as verbalizers (see Section 4.1 of the paper)") 141 | parser.add_argument("--words_per_label", default=10, type=int, 142 | help="The number of verbalizer tokens to assign to each label") 143 | parser.add_argument("--score_fct", default='llr', choices=['llr', 'ce', 'random'], 144 | help="The function used to score verbalizers. Choices are: the log-likelihood ratio loss proposed in the paper " 145 | "('llr'), cross-entropy loss ('ce') and 'random', which assigns random tokens to each label.") 146 | 147 | # other optional parameters 148 | parser.add_argument("--train_examples", default=50, type=int, 149 | help="The total number of train examples to use, where -1 equals all examples.") 150 | parser.add_argument("--pattern_ids", default=[0], type=int, nargs='+', 151 | help="The ids of the PVPs to be used") 152 | parser.add_argument("--max_seq_length", default=256, type=int, 153 | help="The maximum total input sequence length after tokenization. Sequences longer " 154 | "than this will be truncated, sequences shorter will be padded.") 155 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 156 | help="Batch size per GPU/CPU for evaluation.") 157 | parser.add_argument("--words_file", default=None, type=str, 158 | help="Path to a file containing (unlabeled) texts from the task's domain. This text is used to compute " 159 | "verbalization candidates by selecting the most frequent words.") 160 | parser.add_argument("--max_words", default=10000, type=int, 161 | help="Only the 10,000 tokens that occur most frequently in the task’s unlabeled data (see --words_file) are " 162 | "considered as verbalization candidates") 163 | parser.add_argument("--additional_input_examples", type=str, 164 | help="An optional path to an additional set of input examples (e.g., obtained using iPET)") 165 | parser.add_argument("--seed", default=42, type=int, 166 | help="random seed for initialization") 167 | 168 | args = parser.parse_args() 169 | random.seed(args.seed) 170 | 171 | if not os.path.exists(args.output_dir): 172 | os.makedirs(args.output_dir) 173 | 174 | with open(os.path.join(args.output_dir, 'config.txt'), 'w', encoding='utf8') as fh: 175 | json.dump(args.__dict__, fh, indent=2) 176 | 177 | # setup gpu/cpu 178 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 179 | args.n_gpu = torch.cuda.device_count() 180 | 181 | # prepare task 182 | args.task_name = args.task_name.lower() 183 | if args.task_name not in PROCESSORS: 184 | raise ValueError("Task not found: {}".format(args.task_name)) 185 | processor = PROCESSORS[args.task_name]() 186 | args.label_list = processor.get_labels() 187 | args.cache_dir = "" 188 | args.do_lower_case = False 189 | args.verbalizer_file = None 190 | args.wrapper_type = 'mlm' 191 | 192 | # get training data 193 | train_examples_per_label = eq_div(args.train_examples, len(args.label_list)) if args.train_examples != -1 else -1 194 | train_data = load_examples(args.task_name, args.data_dir, set_type=TRAIN_SET, num_examples_per_label=train_examples_per_label) 195 | if args.additional_input_examples: 196 | additional_data = InputExample.load_examples(args.additional_input_examples) 197 | train_data += additional_data 198 | logger.info(f"Loaded {len(additional_data)} additional examples from {args.additional_input_examples}, total" 199 | f"training set size is now {len(train_data)}") 200 | 201 | expected = {label: np.array([1 if x.label == label else 0 for x in train_data]) for label in args.label_list} 202 | 203 | if args.words_file: 204 | with open(args.words_file, 'r', encoding='utf8') as fh: 205 | word_counts = Counter(fh.read().split()) 206 | else: 207 | word_counts = None 208 | 209 | tokenizer_class = MODEL_CLASSES[args.model_type]['tokenizer'] 210 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) 211 | word2idx = get_word_to_id_map(tokenizer, word_counts=word_counts, max_words=args.max_words) 212 | 213 | logits = [] 214 | 215 | for pattern_id in args.pattern_ids: 216 | logger.info(f"Processing examples with pattern id {pattern_id}...") 217 | args.pattern_id = pattern_id 218 | 219 | config = WrapperConfig(model_type=args.model_type, model_name_or_path=args.model_name_or_path, wrapper_type='mlm', 220 | task_name=args.task_name, max_seq_length=args.max_seq_length, label_list=args.label_list, 221 | pattern_id=args.pattern_id) 222 | 223 | wrapper = TransformerModelWrapper(config) 224 | wrapper.model.to(device) 225 | # modify all patterns so that they return a single text segment instead of two segments 226 | get_parts = wrapper.preprocessor.pvp.get_parts 227 | wrapper.preprocessor.pvp.get_parts = lambda example: (get_parts(example)[0] + get_parts(example)[1], []) 228 | wrapper.preprocessor.pvp.convert_mlm_logits_to_cls_logits = lambda mask, x, _=None: x[mask >= 0] 229 | 230 | pattern_logits = wrapper.eval(train_data, device, per_gpu_eval_batch_size=args.per_gpu_eval_batch_size, n_gpu=args.n_gpu)['logits'] 231 | pattern_logits = pattern_logits - np.expand_dims(np.max(pattern_logits, axis=1), axis=1) 232 | logits.append(pattern_logits) 233 | 234 | logger.info("Starting verbalizer search...") 235 | 236 | if args.combine_patterns: 237 | avs = AutomaticVerbalizerSearch(word2idx, args.label_list, logits, expected) 238 | verbalizer = avs.find_verbalizer( 239 | num_candidates=args.num_candidates, 240 | words_per_label=args.words_per_label, 241 | normalize=args.normalize, 242 | score_fct=args.score_fct 243 | ) 244 | verbalizers = {pattern_id: verbalizer for pattern_id in args.pattern_ids} 245 | 246 | else: 247 | verbalizers = {} 248 | for idx, pattern_id in enumerate(args.pattern_ids): 249 | avs = AutomaticVerbalizerSearch(word2idx, args.label_list, [logits[idx]], expected) 250 | verbalizers[pattern_id] = avs.find_verbalizer( 251 | num_candidates=args.num_candidates, 252 | words_per_label=args.words_per_label, 253 | normalize=args.normalize, 254 | score_fct=args.score_fct 255 | ) 256 | 257 | print(json.dumps(verbalizers, indent=2)) 258 | logger.info("Verbalizer search complete, writing output...") 259 | 260 | with open(os.path.join(args.output_dir, 'verbalizers.json'), 'w', encoding='utf8') as fh: 261 | json.dump(verbalizers, fh, indent=2) 262 | 263 | logger.info("Done") 264 | 265 | 266 | if __name__ == "__main__": 267 | main() 268 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.2 2 | nltk==3.6.2 3 | transformers==4.5.1 4 | torch==1.5.0 5 | tqdm==4.49.0 6 | scipy==1.6.3 7 | jsonpickle==2.0.0 8 | scikit_learn==0.24.2 9 | spacy==3.1.1 10 | -------------------------------------------------------------------------------- /scripts/gen_augdata_commands.txt: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name BoolQ --mask_ratio 0.3 --aug_type 'default' --label_type 'flip' --do_sample --num_beams 1 --aug_num 10 2 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name BoolQ --mask_ratio 0.3 --aug_type 'default' --label_type 'keep' --do_sample --num_beams 1 --aug_num 10 3 | 4 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name RTE --mask_ratio 0.5 --aug_type 'default' --label_type 'flip' --do_sample --num_beams 1 --aug_num 10 5 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name RTE --mask_ratio 0.5 --aug_type 'default' --label_type 'keep' --do_sample --num_beams 1 --aug_num 10 6 | 7 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name CB --mask_ratio 0.5 --aug_type 'default' --label_type 'flip' --do_sample --num_beams 1 --aug_num 10 8 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name CB --mask_ratio 0.5 --aug_type 'default' --label_type 'keep' --do_sample --num_beams 1 --aug_num 10 9 | 10 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name COPA --mask_ratio 0.8 --aug_type 'default' --label_type 'flip' --num_beams 10 --aug_num 10 11 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name COPA --mask_ratio 0.8 --aug_type 'default' --label_type 'keep' --num_beams 10 --aug_num 10 12 | 13 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name MultiRC --mask_ratio 0.5 --aug_type 'rand_iter_10' --label_type 'flip' --num_beams 1 --do_sample --aug_num 10 14 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name MultiRC --mask_ratio 0.5 --aug_type 'rand_iter_10' --label_type 'keep' --num_beams 1 --do_sample --aug_num 10 15 | 16 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name WiC --mask_ratio 0.8 --aug_type 'default' --label_type 'flip' --num_beams 1 --do_sample --aug_num 10 17 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name WiC --mask_ratio 0.8 --aug_type 'default' --label_type 'keep' --num_beams 1 --do_sample --aug_num 10 18 | 19 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name ReCoRD --mask_ratio 0.3 --aug_type 'rand_iter_10' --label_type 'flip' -num_beams 10 --aug_num 10 20 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name ReCoRD --mask_ratio 0.3 --aug_type 'rand_iter_10' --label_type 'keep' --num_beams 10 --aug_num 10 21 | 22 | CUDA_VISIBLE_DEVICES=0 python -m genaug.total_gen_aug --task_name WSC --mask_ratio 0.3 --aug_type 'default' --label_type 'keep' --wsc_aug_type 'extra' --num_beams 1 --aug_num 10 23 | -------------------------------------------------------------------------------- /scripts/run_deberta_pet.sh: -------------------------------------------------------------------------------- 1 | TASK=$1 2 | device=$2 3 | search_type=$3 4 | # fixla_ratio=$4 5 | 6 | 7 | METHOD='pet' 8 | MODEL_TYPE='deberta' 9 | 10 | DATA_ROOT='data/FewGLUE_dev32/' 11 | MODEL_NAME_OR_PATH="microsoft/deberta-v2-xxlarge" 12 | 13 | 14 | echo Running PET with the following parameters: 15 | echo ------------------------------ 16 | echo TASK = "$TASK" 17 | echo METHOD = "$METHOD" 18 | echo MODEL_TYPE = "$MODEL_TYPE" 19 | echo device = "$device" 20 | # echo fixla_ratio = "$fixla_ratio" 21 | echo DATA_ROOT = "$DATA_ROOT" 22 | echo MODEL_NAME_OR_PATH = "$MODEL_NAME_OR_PATH" 23 | echo ------------------------------ 24 | 25 | 26 | 27 | 28 | OUTPUT_DIR=results/${search_type}/${METHOD}/${TASK}_${MODEL_TYPE}_model 29 | 30 | TRAIN_BATCH_SIZE=2 31 | ACCU=8 32 | SEQ_LENGTH=256 33 | LR=1e-5 34 | 35 | 36 | if [ $TASK = "wic" ]; then 37 | PATTERN_IDS="0 1 2" 38 | DATA_DIR=${DATA_ROOT}WiC 39 | LR=5e-6 40 | 41 | elif [ $TASK = "rte" ]; then 42 | # PATTERN_IDS="0 1 2 3 4" 43 | PATTERN_IDS="-1" 44 | DATA_DIR=${DATA_ROOT}RTE 45 | LR=5e-6 46 | 47 | elif [ $TASK = "cb" ]; then 48 | # PATTERN_IDS="0 1 2 3 4" 49 | PATTERN_IDS="0 1 2 3" 50 | DATA_DIR=${DATA_ROOT}CB 51 | 52 | elif [ $TASK = "wsc" ]; then 53 | PATTERN_IDS="0 1 2" 54 | DATA_DIR=${DATA_ROOT}WSC 55 | SEQ_LENGTH=128 56 | 57 | elif [ $TASK = "boolq" ]; then 58 | PATTERN_IDS="0 1 2 3 4 5" 59 | DATA_DIR=${DATA_ROOT}BoolQ 60 | 61 | elif [ $TASK = "copa" ]; then 62 | PATTERN_IDS="0 1" 63 | DATA_DIR=${DATA_ROOT}COPA 64 | SEQ_LENGTH=96 65 | 66 | elif [ $TASK = "multirc" ]; then 67 | # PATTERN_IDS="0 1 2 3" 68 | PATTERN_IDS="0 1 2" 69 | DATA_DIR=${DATA_ROOT}MultiRC 70 | SEQ_LENGTH=512 71 | TRAIN_BATCH_SIZE=1 72 | ACCU=16 73 | 74 | elif [ $TASK = "record" ]; then 75 | PATTERN_IDS="0" 76 | DATA_DIR=${DATA_ROOT}ReCoRD 77 | TRAIN_BATCH_SIZE=1 78 | ACCU=16 79 | SEQ_LENGTH=512 80 | 81 | else 82 | echo "Task " $TASK " is not supported by this script" 1>&2 83 | exit 1 84 | fi 85 | 86 | 87 | if [[ $TASK = "record" || $TASK = "wsc" || $TASK = "copa" ]]; then 88 | echo "type1" $TASK 89 | CUDA_VISIBLE_DEVICES=$device nohup python3 cli.py \ 90 | --method $METHOD \ 91 | --pattern_ids $PATTERN_IDS \ 92 | --data_dir $DATA_DIR \ 93 | --model_type $MODEL_TYPE \ 94 | --model_name_or_path $MODEL_NAME_OR_PATH \ 95 | --task_name $TASK \ 96 | --output_dir $OUTPUT_DIR \ 97 | --do_train \ 98 | --do_eval \ 99 | --pet_per_gpu_eval_batch_size 1 \ 100 | --pet_per_gpu_train_batch_size $TRAIN_BATCH_SIZE \ 101 | --pet_gradient_accumulation_steps $ACCU \ 102 | --pet_max_steps 250 \ 103 | --pet_max_seq_length $SEQ_LENGTH \ 104 | --pet_repetitions 3 \ 105 | --no_distillation \ 106 | --search_type $search_type \ 107 | --fix_deberta >myout_${METHOD}_${MODEL_TYPE}_${TASK}_${search_type}.file 2>&1 & 108 | 109 | elif [[ $TASK = "wic" || $TASK = "rte" || $TASK = "cb" || $TASK = 'boolq' || $TASK = 'multirc' ]]; then 110 | echo "type2" $TASK 111 | CUDA_VISIBLE_DEVICES=$device nohup python3 cli.py \ 112 | --method $METHOD \ 113 | --pattern_ids $PATTERN_IDS \ 114 | --data_dir $DATA_DIR \ 115 | --model_type $MODEL_TYPE \ 116 | --model_name_or_path $MODEL_NAME_OR_PATH \ 117 | --task_name $TASK \ 118 | --output_dir $OUTPUT_DIR \ 119 | --do_train \ 120 | --do_eval \ 121 | --learning_rate $LR \ 122 | --pet_per_gpu_eval_batch_size 16 \ 123 | --pet_per_gpu_train_batch_size $TRAIN_BATCH_SIZE \ 124 | --pet_gradient_accumulation_steps $ACCU \ 125 | --pet_max_steps 250 \ 126 | --pet_max_seq_length $SEQ_LENGTH \ 127 | --pet_repetitions 3 \ 128 | --no_distillation \ 129 | --search_type $search_type \ 130 | --fix_deberta >myout_${METHOD}_${MODEL_TYPE}_${TASK}_${search_type}.file 2>&1 & 131 | fi 132 | 133 | -------------------------------------------------------------------------------- /scripts/run_pet.sh: -------------------------------------------------------------------------------- 1 | TASK=$1 2 | device=$2 3 | search_type=$3 4 | 5 | 6 | METHOD='pet' 7 | MODEL_TYPE='albert' 8 | 9 | DATA_ROOT='data/FewGLUE_dev32/' 10 | MODEL_NAME_OR_PATH="albert-xxlarge-v2" 11 | 12 | 13 | echo Running iPET with the following parameters: 14 | echo ------------------------------ 15 | echo TASK = "$TASK" 16 | echo METHOD = "$METHOD" 17 | echo MODEL_TYPE = "$MODEL_TYPE" 18 | echo device = "$device" 19 | echo DATA_ROOT = "$DATA_ROOT" 20 | echo MODEL_NAME_OR_PATH = "$MODEL_NAME_OR_PATH" 21 | echo ------------------------------ 22 | 23 | 24 | 25 | 26 | OUTPUT_DIR=results/${search_type}/${METHOD}/${TASK}_${MODEL_TYPE}_model 27 | 28 | TRAIN_BATCH_SIZE=8 29 | ACCU=2 30 | SEQ_LENGTH=256 31 | 32 | 33 | if [ $TASK = "wic" ]; then 34 | PATTERN_IDS="0 1 2" 35 | DATA_DIR=${DATA_ROOT}WiC 36 | elif [ $TASK = "rte" ]; then 37 | # PATTERN_IDS="0 1 2 3 4" 38 | PATTERN_IDS="0 1 2 3" 39 | DATA_DIR=${DATA_ROOT}RTE 40 | elif [ $TASK = "cb" ]; then 41 | # PATTERN_IDS="0 1 2 3 4" 42 | PATTERN_IDS="0 1 2 3" 43 | DATA_DIR=${DATA_ROOT}CB 44 | elif [ $TASK = "wsc" ]; then 45 | PATTERN_IDS="0 1 2" 46 | DATA_DIR=${DATA_ROOT}WSC 47 | TRAIN_BATCH_SIZE=4 48 | ACCU=4 49 | SEQ_LENGTH=128 50 | elif [ $TASK = "boolq" ]; then 51 | PATTERN_IDS="0 1 2 3 4 5" 52 | DATA_DIR=${DATA_ROOT}BoolQ 53 | elif [ $TASK = "copa" ]; then 54 | PATTERN_IDS="0 1" 55 | DATA_DIR=${DATA_ROOT}COPA 56 | TRAIN_BATCH_SIZE=4 57 | ACCU=4 58 | SEQ_LENGTH=96 59 | elif [ $TASK = "multirc" ]; then 60 | # PATTERN_IDS="0 1 2 3" 61 | PATTERN_IDS="0 1 2" 62 | DATA_DIR=${DATA_ROOT}MultiRC 63 | TRAIN_BATCH_SIZE=4 64 | ACCU=4 65 | SEQ_LENGTH=512 66 | elif [ $TASK = "record" ]; then 67 | PATTERN_IDS="0" 68 | DATA_DIR=${DATA_ROOT}ReCoRD 69 | TRAIN_BATCH_SIZE=1 70 | ACCU=16 71 | SEQ_LENGTH=512 72 | else 73 | echo "Task " $TASK " is not supported by this script" 1>&2 74 | exit 1 75 | fi 76 | 77 | 78 | if [[ $TASK = "record" || $TASK = "wsc" || $TASK = "copa" ]]; then 79 | echo "type1" $TASK 80 | CUDA_VISIBLE_DEVICES=$device nohup python3 cli.py \ 81 | --method $METHOD \ 82 | --pattern_ids $PATTERN_IDS \ 83 | --data_dir $DATA_DIR \ 84 | --model_type $MODEL_TYPE \ 85 | --model_name_or_path $MODEL_NAME_OR_PATH \ 86 | --task_name $TASK \ 87 | --output_dir $OUTPUT_DIR \ 88 | --do_train \ 89 | --do_eval \ 90 | --pet_per_gpu_eval_batch_size 1 \ 91 | --pet_per_gpu_train_batch_size $TRAIN_BATCH_SIZE \ 92 | --pet_gradient_accumulation_steps $ACCU \ 93 | --pet_max_steps 250 \ 94 | --pet_max_seq_length $SEQ_LENGTH \ 95 | --pet_repetitions 3 \ 96 | --no_distillation \ 97 | --search_type $search_type >myout_${METHOD}_${MODEL_TYPE}_${TASK}_${search_type}.file 2>&1 & 98 | elif [[ $TASK = "rte" || $TASK = "cb" || $TASK = 'boolq' || $TASK = 'wic' || $TASK = 'multirc' ]]; then 99 | echo "type2" $TASK 100 | CUDA_VISIBLE_DEVICES=$device nohup python3 cli.py \ 101 | --method $METHOD \ 102 | --pattern_ids $PATTERN_IDS \ 103 | --data_dir $DATA_DIR \ 104 | --model_type $MODEL_TYPE \ 105 | --model_name_or_path $MODEL_NAME_OR_PATH \ 106 | --task_name $TASK \ 107 | --output_dir $OUTPUT_DIR \ 108 | --do_train \ 109 | --do_eval \ 110 | --pet_per_gpu_eval_batch_size 32 \ 111 | --pet_per_gpu_train_batch_size $TRAIN_BATCH_SIZE \ 112 | --pet_gradient_accumulation_steps $ACCU \ 113 | --pet_max_steps 250 \ 114 | --pet_max_seq_length $SEQ_LENGTH \ 115 | --pet_repetitions 3 \ 116 | --no_distillation \ 117 | --search_type $search_type >myout_${METHOD}_${MODEL_TYPE}_${TASK}_${search_type}.file 2>&1 & 118 | fi 119 | 120 | --------------------------------------------------------------------------------