├── requirements.txt ├── scripts ├── download_glue_dataset.sh └── download_clue_dataset.sh ├── core ├── gen_template │ ├── __init__.py │ └── LM_BFF.py └── prompt_bert.py ├── data └── config │ └── lm_bff.json ├── LICENSE ├── tools ├── trainer.py ├── args.py ├── data_loader.py ├── dataset.py ├── tools.py └── glue_data_processor.py ├── README.md ├── run_gen_template.py └── run_prompt_tuning.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.3.4 2 | numpy==1.19.5 3 | transformers==4.12.5 4 | torch==1.10.0 5 | tqdm==4.62.3 -------------------------------------------------------------------------------- /scripts/download_glue_dataset.sh: -------------------------------------------------------------------------------- 1 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar 2 | tar xvf datasets.tar -------------------------------------------------------------------------------- /core/gen_template/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import abc 11 | 12 | 13 | class TemplateGenerator(abc.ABC): 14 | @abc.abstractmethod 15 | def search_template(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | from core.gen_template.LM_BFF import LMBFFTemplateGenerator 20 | 21 | template_generator_map = { 22 | "lm_bff": LMBFFTemplateGenerator 23 | } 24 | -------------------------------------------------------------------------------- /data/config/lm_bff.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "./models/t5-3b", 3 | "end_token": "", 4 | "beam": 100, 5 | "inspired_templates": ["*cls**sentu_0****label****sep+*", "*cls*.***label****+sentu_0**sep+*"], 6 | "target_number": 2, 7 | "batch_size": 32, 8 | "gen_max_len": 20, 9 | "truncates": ["head", "tail"], 10 | "first_mask_token": "", 11 | "forbid_tokens": [3, 19794, 22354], 12 | "forbid_continuous_token": [5], 13 | "replace_token_map_list": [{ 14 | "": "*cls**sent_0*", 15 | "": "*mask*", 16 | "": "*sep+*", 17 | "": "*sep+*", 18 | "▁":"_" 19 | }, { 20 | "": "*cls*", 21 | "": "*mask*", 22 | "": "*+sent_0**sep+*", 23 | "": "*+sent_0**sep+*", 24 | "▁":"_" 25 | }] 26 | } -------------------------------------------------------------------------------- /scripts/download_clue_dataset.sh: -------------------------------------------------------------------------------- 1 | wget https://storage.googleapis.com/cluebenchmark/tasks/afqmc_public.zip 2 | wget https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip 3 | wget https://storage.googleapis.com/cluebenchmark/tasks/iflytek_public.zip 4 | wget https://storage.googleapis.com/cluebenchmark/tasks/ocnli_public.zip 5 | wget https://storage.googleapis.com/cluebenchmark/tasks/cmnli_public.zip 6 | wget https://storage.googleapis.com/cluebenchmark/tasks/cluewsc2020_public.zip 7 | wget https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip 8 | wget https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip 9 | wget https://storage.googleapis.com/cluebenchmark/tasks/drcd_public.zip 10 | wget https://storage.googleapis.com/cluebenchmark/tasks/chid_public.zip 11 | wget https://storage.googleapis.com/cluebenchmark/tasks/c3_public.zip 12 | wget https://storage.googleapis.com/cluebenchmark/tasks/clue_diagnostics_public.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 DengBoCong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/trainer.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch 11 | from transformers.optimization import AdamW 12 | from transformers.optimization import get_linear_schedule_with_warmup 13 | from transformers.trainer import Trainer as TransformerTrainer 14 | 15 | 16 | class Trainer(TransformerTrainer): 17 | def create_optimizer(self): 18 | """注意,这里写的是类bert结构的逻辑,其他不一样的直接覆写方法即可""" 19 | if self.optimizer is None: 20 | params = {} 21 | for n, p in self.model.named_parameters(): 22 | if self.args.fix_layers > 0: 23 | if "encoder.layer" in n: 24 | layer_num = int(n[n.find("encoder.layer") + 14:].split(".")[0]) 25 | if layer_num >= self.args.fix_layers: 26 | params[n] = p 27 | elif "embeddings" not in n: 28 | params[n] = p 29 | else: 30 | params[n] = p 31 | no_decay = ["bias", "LayerNorm.weight"] 32 | optimizer_grouped_parameters = [{ 33 | "params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)], 34 | "weight_decay": self.args.weight_decay 35 | }, { 36 | "params": [p for n, p in params.items() if any(nd in n for nd in no_decay)], 37 | "weight_decay": 0.0 38 | }] 39 | 40 | self.optimizer = AdamW( 41 | optimizer_grouped_parameters, 42 | lr=self.args.learning_rate, 43 | betas=(self.args.adam_beta1, self.args.adam_beta2), 44 | eps=self.args.adam_epsilon 45 | ) 46 | 47 | return self.optimizer 48 | 49 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 50 | if self.lr_scheduler is None: 51 | self.lr_scheduler = get_linear_schedule_with_warmup( 52 | optimizer=self.optimizer if optimizer is None else optimizer, 53 | num_warmup_steps=self.args.warmup_steps, 54 | num_training_steps=num_training_steps 55 | ) 56 | 57 | return self.lr_scheduler 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Prompt-Tuning

2 | 3 | + A pipeline for Prompt-tuning 4 | + 集成主流的Prompt-tuning相关方法,以及search template策略 5 | + 提供Prompt-tuning完整的执行pipeline 6 | 7 | # Requirements 8 | 本项目相关的依赖包参考requirements.txt,也可直接使用如下指令安装: 9 | ```shell 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | # Usage 14 | + core下放置相关prompt-tuning模型 15 | + core/gen_template下是相关template生成方法,执行入口参考run_gen_template.py,执行示例如下: 16 | ```python 17 | python3 run_gen_template.py \ 18 | --task_name CoLA \ 19 | --k 16 \ 20 | --dev_rate 1 \ 21 | --data_loader glue \ 22 | --template_generator lm_bff \ 23 | --data_dir data/original/CoLA \ 24 | --output_dir data/output \ 25 | --generator_config_path data/config/lm_bff.json 26 | ``` 27 | + 模型实现放在core目录下,执行入口参考run_prompt_tuning.py,执行示例如下: 28 | ```python 29 | python3 run_prompt_tuning.py \ 30 | --data_dir data/CoLA/ \ 31 | --do_train \ 32 | --do_eval \ 33 | --do_predict \ 34 | --model_name_or_path bert \ 35 | --num_k 16 \ 36 | --max_steps 1000 \ 37 | --eval_steps 100 \ 38 | --learning_rate 1e-5 \ 39 | --output_dir result/ \ 40 | --seed 16 41 | --template "*cls**sent_0*_It_was*mask*.*sep+*" \ 42 | --mapping "{'0':'terrible','1':'great'}" \ 43 | --num_sample 16 \ 44 | ``` 45 | + data放置相关config及datasets,由于数据集比较庞大,可使用scripts下的下载脚本自行下载,如下: 46 | ```shell 47 | cd data 48 | sh download_clue_dataset.sh 49 | sh download_glue_dataset.sh 50 | ``` 51 | + tools放置相关工具方法及数据集处理方法等 52 | 53 | # Paper 54 | 更详细的论文解读和阅读笔记 ☞ [点这里](https://github.com/DengBoCong/nlp-paper) 55 | 56 | + [Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inference](https://arxiv.org/pdf/2001.07676.pdf) 57 | + [AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts](https://arxiv.org/pdf/2010.15980.pdf) 58 | + [Making Pre-trained Language Models Better Few-shot Learners](https://arxiv.org/pdf/2012.15723.pdf) 59 | + [Prefix-Tuning: Optimizing Continuous Prompts for Generation](https://arxiv.org/pdf/2101.00190.pdf) 60 | + [GPT Understands, Too](https://arxiv.org/pdf/2103.10385.pdf) 61 | + [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf) 62 | + [Noisy Channel Language Model Prompting for Few-Shot Text Classification](https://arxiv.org/pdf/2108.04106.pdf) 63 | + [PPT: Pre-trained Prompt Tuning for Few-shot Learning](https://arxiv.org/pdf/2109.04332.pdf) 64 | + [SPoT: Better Frozen Model Adaptation through Soft Prompt Transfer](https://arxiv.org/pdf/2110.07904.pdf) 65 | 66 | # Reference 67 | + https://github.com/princeton-nlp/LM-BFF 68 | + https://github.com/shmsw25/Channel-LM-Prompting 69 | 70 | # Dataset 71 | + GLUE:https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar 72 | ```shell 73 | sh scripts/download_glue_dataset.sh 74 | ``` 75 | + CLUE:https://github.com/CLUEbenchmark/CLUE 76 | ```shell 77 | sh scripts/download_clue_dataset.sh 78 | ``` 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /run_gen_template.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | import json 12 | import torch 13 | import torch.nn as nn 14 | import argparse 15 | import numpy as np 16 | from tools.data_loader import loader_map 17 | from tools.glue_data_processor import label_of_mapping 18 | from core.gen_template import template_generator_map 19 | from transformers import T5ForConditionalGeneration 20 | from transformers import T5Tokenizer 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--seed", type=int, default=100, help="Random seeds") 26 | parser.add_argument("--task_name", type=str, default="", help="Task names") 27 | parser.add_argument("--k", type=int, default=16, help="Training examples for each class.") 28 | parser.add_argument("--data_dir", type=str, default="data/original", help="Path to original data") 29 | parser.add_argument("--output_dir", type=str, default="data", help="Output path") 30 | parser.add_argument("--dev_rate", type=int, default=1, help="dev:train scale") 31 | parser.add_argument("--data_loader", type=str, default="glue", choices=["glue"], help="Data loader") 32 | parser.add_argument("--template_generator", type=str, default="lm_bff", 33 | choices=["lm_bff"], help="Template generator") 34 | parser.add_argument("--generator_config_path", type=str, default="data/config/lm_bff.json", help="Data loader") 35 | 36 | args = parser.parse_args() 37 | args.output_dir = os.path.join(args.output_dir, f"{args.dev_rate}.txt") 38 | 39 | # random seed 40 | np.random.seed(args.seed) 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | if args.data_loader == "glue": 44 | train_data, dev_data = loader_map[args.data_loader]().generate_k_shot( 45 | k=args.k, data_dir=args.data_dir, task_name=args.task_name, dev_rate=args.dev_rate 46 | ) 47 | datasets = loader_map[args.data_loader].gen_samples(task_name=args.task_name, sources=train_data) 48 | else: 49 | raise ValueError(f"DataLoader `{args.data_loader}` not found") 50 | 51 | with open(args.generator_config_path, "r", encoding="utf-8") as file: 52 | generator_config = json.load(file) 53 | 54 | if args.template_generator == "lm_bff": 55 | model = T5ForConditionalGeneration.from_pretrained(generator_config["model_dir"]) 56 | tokenizer = T5Tokenizer.from_pretrained(generator_config["model_dir"]) 57 | tokenizer.sep_token = generator_config["end_token"] 58 | 59 | # if torch.cuda.device_count() > 1: 60 | # model = nn.DataParallel(model, device_ids=[index for index in range(torch.cuda.device_count())]) 61 | model.to(device) 62 | model.eval() 63 | 64 | template_generator = template_generator_map[args.template_generator](device=device) 65 | res_templates = template_generator.search_template( 66 | model, tokenizer, datasets, generator_config["beam"], label_of_mapping[args.task_name], 67 | generator_config["inspired_templates"], generator_config["target_number"], 68 | generator_config["batch_size"], generator_config["gen_max_len"], 69 | truncates=generator_config["truncates"], end_token=generator_config["end_token"], 70 | forbid_tokens=generator_config["forbid_tokens"], 71 | forbid_continuous_token=generator_config["forbid_continuous_token"], 72 | replace_token_map_list=generator_config["replace_token_map_list"] 73 | ) 74 | 75 | with open(args.output_dir, "w", encoding="utf-8") as save_file: 76 | for text, score, _ in res_templates: 77 | save_file.write(f"{score}\t{text}\n") 78 | else: 79 | raise ValueError(f"TemplateGenerator `{args.template_generator}` not found") 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /tools/args.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from dataclasses import dataclass 11 | from dataclasses import field 12 | from transformers import TrainingArguments 13 | from typing import Optional 14 | 15 | 16 | @dataclass 17 | class ModelArguments: 18 | """ model/config/tokenizer 相关参数 """ 19 | model_name_or_path: str = field( 20 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 21 | ) 22 | config_name: Optional[str] = field( 23 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 24 | ) 25 | tokenizer_name: Optional[str] = field( 26 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 27 | ) 28 | random_segment: bool = field( 29 | default=False, 30 | metadata={"help": "Whether to reinitialize the token type embeddings (only for BERT)."} 31 | ) 32 | cache_dir: Optional[str] = field( 33 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 34 | ) 35 | 36 | 37 | @dataclass 38 | class DynamicDataTrainingArguments: 39 | """ 控制prompt相关参数 """ 40 | num_k: Optional[int] = field( 41 | default=16, 42 | metadata={"help": "Number of training instances per class"} 43 | ) 44 | 45 | num_sample: Optional[int] = field( 46 | default=16, 47 | metadata={"help": "Number of samples (for inference) in fine-tuning with demonstrations"} 48 | ) 49 | 50 | # For prompting 51 | template: str = field( 52 | default=None, 53 | metadata={"help": "Template"} 54 | ) 55 | 56 | mapping: str = field( 57 | default=None, 58 | metadata={"help": "Label word mapping"} 59 | ) 60 | 61 | label_to_word: str = field( 62 | default=None, 63 | metadata={"help": "Label to word mapping"} 64 | ) 65 | 66 | processor: str = field( 67 | default=None, 68 | metadata={"help": "processor name"} 69 | ) 70 | 71 | max_seq_length: int = field( 72 | default=None, 73 | metadata={"help": "full length (512)"} 74 | ) 75 | 76 | first_sent_limit: int = field( 77 | default=None, 78 | metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"} 79 | ) 80 | 81 | other_sent_limit: int = field( 82 | default=None, 83 | metadata={"help": "Limit the length of sentences other than the first sentence"} 84 | ) 85 | 86 | use_full_length: bool = field( 87 | default=None, 88 | metadata={"help": "Use the full length (512)"} 89 | ) 90 | 91 | truncate_head: bool = field( 92 | default=False, 93 | metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."} 94 | ) 95 | 96 | multipart_type: str = field( 97 | default=None, 98 | metadata={"help": "tokenize multipart input"} 99 | ) 100 | 101 | 102 | @dataclass 103 | class DynamicTrainingArguments(TrainingArguments): 104 | # Regularization 105 | fix_layers: int = field( 106 | default=0, 107 | metadata={"help": "Fix bottom-n layers when optimizing"} 108 | ) 109 | 110 | # Turn off train/test 111 | no_train: bool = field( 112 | default=False, 113 | metadata={"help": "No training"} 114 | ) 115 | no_predict: bool = field( 116 | default=False, 117 | metadata={"help": "No test"} 118 | ) 119 | 120 | num_labels: int = field( 121 | default=None, 122 | metadata={"help": "task labels num"} 123 | ) 124 | 125 | output_mode: str = field( 126 | default=None, 127 | metadata={"help": "task mode"} 128 | ) 129 | 130 | data_dir: str = field( 131 | default=None, 132 | metadata={"help": "data dir, include train, dev, test"} 133 | ) 134 | -------------------------------------------------------------------------------- /run_prompt_tuning.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | import sys 12 | import torch 13 | from core.prompt_bert import BertForPromptTuning 14 | from tools.args import DynamicDataTrainingArguments 15 | from tools.args import DynamicTrainingArguments 16 | from tools.args import ModelArguments 17 | from tools.dataset import PromptDataset 18 | from tools.glue_data_processor import processors_mapping 19 | from tools.tools import multipart_map 20 | from tools.tools import resize_token_type_embeddings 21 | from tools.trainer import Trainer 22 | from transformers import AutoConfig 23 | from transformers import AutoTokenizer 24 | from transformers import HfArgumentParser 25 | from transformers import set_seed 26 | 27 | 28 | def main(): 29 | parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments)) 30 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 31 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 32 | else: 33 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 34 | 35 | if training_args.no_train: 36 | training_args.do_train = False 37 | if training_args.no_predict: 38 | training_args.do_predict = False 39 | 40 | if ( 41 | os.path.exists(training_args.output_dir) 42 | and os.listdir(training_args.output_dir) 43 | and training_args.do_train 44 | and not training_args.overwrite_output_dir 45 | ): 46 | raise ValueError(f"Output directory ({training_args.output_dir}) already exists.") 47 | 48 | set_seed(training_args.seed) 49 | 50 | config = AutoConfig.from_pretrained( 51 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 52 | num_labels=training_args.num_labels, 53 | cache_dir=model_args.cache_dir 54 | ) 55 | 56 | if config.model_type == "bert": 57 | model_fn = BertForPromptTuning 58 | else: 59 | raise NotImplementedError(f"`{config.model_type}` not impl") 60 | 61 | special_tokens = [] 62 | 63 | tokenizer = AutoTokenizer.from_pretrained( 64 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 65 | additional_special_tokens=special_tokens, 66 | cache_dir=model_args.cache_dir, 67 | ) 68 | 69 | # 这里可以融合多个processors_mapping 70 | processor = processors_mapping[data_args.processor] 71 | 72 | tokenize_multipart_input = multipart_map[data_args.multipart_type] 73 | 74 | train_dataset = PromptDataset(model_args.data_dir, data_args.label_to_word, tokenizer, 75 | processor, data_args.template, data_args.max_seq_length, 76 | tokenize_multipart_input, "train", data_args.num_sample) 77 | 78 | if training_args.do_eval: 79 | eval_dataset = PromptDataset(model_args.data_dir, data_args.label_to_word, tokenizer, 80 | processor, data_args.template, data_args.max_seq_length, 81 | tokenize_multipart_input, "dev", data_args.num_sample) 82 | 83 | if training_args.do_predict: 84 | test_dataset = PromptDataset(model_args.data_dir, data_args.label_to_word, tokenizer, 85 | processor, data_args.template, data_args.max_seq_length, 86 | tokenize_multipart_input, "test", data_args.num_sample) 87 | 88 | model = model_fn.from_pretrained( 89 | model_args.model_name_or_path, 90 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 91 | config=config, 92 | cache_dir=model_args.cache_dir 93 | ) 94 | 95 | # For BERT, increase the size of the segment (token type) embeddings 96 | if config.model_type == "bert": 97 | model.resize_token_embeddings(len(tokenizer)) 98 | resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment) 99 | 100 | model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda() 101 | model.model_args = model_args 102 | model.data_args = data_args 103 | model.tokenizer = tokenizer 104 | 105 | trainer = Trainer( 106 | model=model, 107 | args=training_args, 108 | train_dataset=train_dataset, 109 | eval_dataset=eval_dataset, 110 | compute_metrics=build_compute_metrics_fn(data_args.task_name) 111 | ) 112 | 113 | trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /core/prompt_bert.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch 11 | import torch.nn as nn 12 | from transformers.models.bert.modeling_bert import BertModel 13 | from transformers.models.bert.modeling_bert import BertOnlyMLMHead 14 | from transformers.models.bert.modeling_bert import BertPreTrainedModel 15 | from transformers import PretrainedConfig 16 | from typing import Optional 17 | 18 | 19 | class BertForPromptTuning(BertPreTrainedModel): 20 | def __init__(self, config: PretrainedConfig, **kwarg): 21 | super(BertForPromptTuning, self).__init__(config) 22 | self.num_labels = config.num_labels 23 | self.bert = BertModel(config) 24 | self.cls = BertOnlyMLMHead(config) 25 | self.init_weights() 26 | 27 | # Exit early and only return mask logits. 28 | # For label search. 29 | self.return_full_softmax = kwarg.get("return_full_softmax", None) 30 | 31 | self.label_word_list = kwarg.get("label_word_list", None) 32 | 33 | # if labels be passed and num_labels == 1 34 | self.lower_bounds = kwarg.get("label_word_list", None) 35 | self.upper_bounds = kwarg.get("label_word_list", None) 36 | 37 | def forward(self, 38 | input_ids: Optional[torch.Tensor] = None, 39 | attention_mask: Optional[torch.Tensor] = None, 40 | token_type_ids: Optional[torch.Tensor] = None, 41 | position_ids: Optional[torch.Tensor] = None, 42 | head_mask: Optional[torch.Tensor] = None, 43 | inputs_embeds: Optional[torch.Tensor] = None, 44 | encoder_hidden_states: Optional[torch.Tensor] = None, 45 | encoder_attention_mask: Optional[torch.Tensor] = None, 46 | ab_pos_mask: Optional[torch.Tensor] = None, 47 | labels: Optional[torch.Tensor] = None, 48 | output_attentions: Optional[bool] = None, 49 | output_hidden_states: Optional[bool] = None, 50 | return_dict: Optional[bool] = None): 51 | if ab_pos_mask is not None: 52 | ab_pos_mask = ab_pos_mask.squeeze() 53 | 54 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 55 | 56 | bert_outputs = self.bert( 57 | input_ids, 58 | attention_mask=attention_mask, 59 | token_type_ids=token_type_ids, 60 | position_ids=position_ids, 61 | head_mask=head_mask, 62 | inputs_embeds=inputs_embeds, 63 | encoder_hidden_states=encoder_hidden_states, 64 | encoder_attention_mask=encoder_attention_mask, 65 | output_attentions=output_attentions, 66 | output_hidden_states=output_hidden_states, 67 | return_dict=return_dict, 68 | ) 69 | 70 | sequence_output = bert_outputs[0] 71 | sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), ab_pos_mask] 72 | prediction_mask_scores = self.cls(sequence_mask_output) 73 | 74 | if self.return_full_softmax: 75 | if labels is not None: 76 | return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores 77 | return prediction_mask_scores 78 | 79 | # Return logits for each label 80 | logits = None 81 | if self.label_word_list: 82 | logits = [] 83 | for label_id in range(len(self.label_word_list)): 84 | logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) 85 | logits = torch.cat(logits, -1) 86 | 87 | # Regression task 88 | if self.config.num_labels == 1: 89 | log_softmax = nn.LogSoftmax(dim=-1)(logits) 90 | 91 | loss = None 92 | if labels is not None: 93 | if self.num_labels == 1: 94 | loss_fct = nn.KLDivLoss(log_target=True) 95 | labels = torch.stack( 96 | [1 - (labels.view(-1) - self.lower_bounds) / (self.upper_bounds - self.lower_bounds), 97 | (labels.view(-1) - self.lower_bounds) / (self.upper_bounds - self.lower_bounds)], -1 98 | ) 99 | loss = loss_fct(logits.view(-1, 2), labels) 100 | else: 101 | loss_fct = nn.CrossEntropyLoss() 102 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 103 | 104 | outputs = (logits,) 105 | if self.num_labels == 1: 106 | outputs = ( 107 | torch.exp(logits[..., 1].unsqueeze(-1)) * (self.upper_bounds - self.lower_bounds) + self.lower_bounds,) 108 | 109 | return ((loss,) + outputs) if loss is not None else outputs 110 | -------------------------------------------------------------------------------- /tools/data_loader.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import abc 11 | import os 12 | import numpy as np 13 | import pandas as pd 14 | from typing import List, Dict, Tuple, Any 15 | 16 | 17 | class DataLoader(abc.ABC): 18 | """Dataset Loader""" 19 | 20 | @abc.abstractmethod 21 | def generate_k_shot(self, **kwargs): 22 | raise NotImplementedError 23 | 24 | 25 | class GLUEDataLoader(DataLoader): 26 | """Support GLUE Dataset""" 27 | 28 | def __init__(self, **kwargs) -> None: 29 | super(GLUEDataLoader, self).__init__() 30 | 31 | def generate_k_shot(self, 32 | k: int, 33 | data_dir: str, 34 | task_name: str, 35 | dev_rate: int = 1, 36 | **kwargs) -> Tuple[List[str], List[str]]: 37 | """ 38 | :param k: 类内sampling num 39 | :param data_dir: 数据集路径 40 | :param task_name: 任务数据集名 41 | :param dev_rate: dev:train 比例 42 | """ 43 | dataset = self.load_std_dataset(data_dir, task_name) 44 | 45 | if task_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 46 | # GLUE style 47 | train_header, train_lines = self.split_header(task_name, dataset["train"]) 48 | np.random.shuffle(train_lines) 49 | else: 50 | # Other datasets, default DataFrame 51 | train_lines = dataset["train"].values.tolist() 52 | np.random.shuffle(train_lines) 53 | 54 | # Get label list for balanced sampling 55 | label_list = {} 56 | for line in train_lines: 57 | line = line.strip().strip("\n") 58 | label = self.get_label(task_name, line) 59 | if label not in label_list: 60 | label_list[label] = [line] 61 | else: 62 | label_list[label].append(line) 63 | 64 | train_data, dev_data = [], [] 65 | for label in label_list: 66 | train_data.extend(label_list[label][:k]) 67 | for label in label_list: 68 | dev_data.extend(label_list[label][k:k * (dev_rate + 1)]) 69 | 70 | return train_data, dev_data 71 | 72 | @staticmethod 73 | def split_header(task_name: str, lines: List[str]) -> Tuple[List[str], List[str]]: 74 | """ 返回文件头 75 | :param task_name: 任务数据集名 76 | :param lines: 已读取数据文件的所有行 77 | """ 78 | if task_name in ["CoLA"]: 79 | return [], lines 80 | elif task_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]: 81 | return lines[0:1], lines[1:] 82 | else: 83 | raise ValueError("Unknown GLUE task.") 84 | 85 | @staticmethod 86 | def load_std_dataset(data_dir: str, task_name: str, splits: List[str] = None) -> Dict[str, List[str]]: 87 | """ 加载预设标准数据集 88 | :param data_dir: 数据集路径 89 | :param task_name: 任务数据集名 90 | :param splits: 数据集分类 91 | """ 92 | dataset = {} 93 | if task_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 94 | # GLUE style (tsv) 95 | if not splits: 96 | if task_name == "MNLI": 97 | splits = ["train", "dev_matched", "dev_mismatched"] 98 | else: 99 | splits = ["train", "dev"] 100 | for split in splits: 101 | filename = os.path.join(data_dir, f"{split}.tsv") 102 | with open(filename, "r") as file: 103 | lines = file.readlines() 104 | dataset[split] = lines 105 | else: 106 | # Other datasets (csv) 107 | splits = splits if splits else ["train", "test"] 108 | for split in splits: 109 | filename = os.path.join(data_dir, f"{split}.csv") 110 | dataset[split] = pd.read_csv(filename, header=None) 111 | 112 | return dataset 113 | 114 | @staticmethod 115 | def get_label(task_name: str, line: str) -> Any: 116 | """ 117 | :param task_name: 任务数据集名 118 | :param line: 已读取数据文件的行 119 | """ 120 | if task_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: 121 | # GLUE style 122 | line = line.strip().split("\t") 123 | if task_name == "CoLA": 124 | return line[1] 125 | elif task_name == "MNLI": 126 | return line[-1] 127 | elif task_name == "MRPC": 128 | return line[0] 129 | elif task_name == "QNLI": 130 | return line[-1] 131 | elif task_name == "QQP": 132 | return line[-1] 133 | elif task_name == "RTE": 134 | return line[-1] 135 | elif task_name == "SNLI": 136 | return line[-1] 137 | elif task_name == "SST-2": 138 | return line[-1] 139 | elif task_name == "STS-B": 140 | return 0 if float(line[-1]) < 2.5 else 1 141 | elif task_name == "WNLI": 142 | return line[-1] 143 | else: 144 | raise NotImplementedError 145 | else: 146 | return line[0] 147 | 148 | @staticmethod 149 | def gen_samples(task_name: str, sources: Any) -> List[Dict[str, Any]]: 150 | """ 151 | :param task_name: 任务数据集名 152 | :param sources: 已读取数据文件的行 或 文件地址 153 | """ 154 | if isinstance(sources, str): 155 | samples = [] 156 | with open(sources, "r", encoding="utf-8") as file: 157 | for line in file: 158 | line = line.strip().strip("\n") 159 | samples.append(line) 160 | else: 161 | samples = sources 162 | 163 | if task_name != "CoLA" and isinstance(sources, str): 164 | samples = samples[1:] 165 | 166 | dataset = [] 167 | for sample in samples: 168 | sample = sample.strip().split("\t") 169 | if task_name == "CoLA": 170 | dataset.append({"label": sample[1], "text": [sample[-1]]}) 171 | elif task_name == "MNLI": 172 | dataset.append({"label": sample[-1], "text": [sample[8], sample[9]]}) 173 | elif task_name == "MRPC": 174 | dataset.append({"label": sample[0], "text": [sample[-2], sample[-1]]}) 175 | elif task_name == "QNLI": 176 | dataset.append({"label": sample[-1], "text": [sample[1], sample[2]]}) 177 | elif task_name == "QQP": 178 | dataset.append({"label": sample[-1], "text": [sample[3], sample[4]]}) 179 | elif task_name == "RTE": 180 | dataset.append({"label": sample[-1], "text": [sample[1], sample[2]]}) 181 | elif task_name == "SNLI": 182 | dataset.append({"label": sample[-1], "text": [sample[7], sample[8]]}) 183 | elif task_name == "SST-2": 184 | dataset.append({"label": sample[-1], "text": [sample[0]]}) 185 | elif task_name == "STS-B": 186 | dataset.append({"label": "0" if float(sample[-1]) < 2.5 else "1", "text": [sample[-3], sample[-2]]}) 187 | elif task_name == "WNLI": 188 | dataset.append({"label": sample[-1], "text": [sample[1], sample[2]]}) 189 | else: 190 | dataset.append({"label": sample[0], "text": [sample[1]]}) 191 | 192 | return dataset 193 | 194 | 195 | loader_map = { 196 | "glue": GLUEDataLoader 197 | } 198 | -------------------------------------------------------------------------------- /tools/dataset.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import json 11 | import dataclasses 12 | import pandas as pd 13 | from dataclasses import dataclass 14 | from torch.utils.data import Dataset 15 | from transformers.data.processors import DataProcessor 16 | from transformers.data.processors import InputExample 17 | from typing import Any, Dict, List, Optional, Union, Callable 18 | 19 | 20 | @dataclass(frozen=True) 21 | class InputFeatures: 22 | """see transformer InputFeatures""" 23 | input_ids: List[int] = None 24 | attention_mask: Optional[List[int]] = None 25 | token_type_ids: Optional[List[int]] = None 26 | label: Optional[Union[int, float]] = None 27 | mask_pos: Optional[List[int]] = None # Position of the mask token 28 | label_word_list: Optional[List[int]] = None # Label word mapping (dynamic) 29 | 30 | def to_json_string(self): 31 | """Serializes this instance to a JSON string.""" 32 | return json.dumps(dataclasses.asdict(self)) + "\n" 33 | 34 | 35 | class PromptDataset(Dataset): 36 | """Dataset for Prompt""" 37 | 38 | def __init__(self, 39 | data_dir: str, 40 | label_to_word: Dict[str, Any], 41 | tokenizer: Any, 42 | processor: DataProcessor, 43 | template: Any, 44 | max_seq_length: int, 45 | tokenize_multipart_input: Callable, 46 | mode: str = "train", 47 | num_sample: int = 16, 48 | special_tokens: List[str] = None, 49 | **kwargs) -> None: 50 | """ 51 | :param data_dir: 数据集路径目录 52 | :param label_to_word: Label-Word mapping 53 | :param tokenizer: 编码器 54 | :param processor: 数据处理器 55 | :param template: template, str/list 56 | :param max_seq_length: max seq len 57 | :param mode: 当前执行模式 58 | :param num_sample: 采样template的数量 59 | :param special_tokens: label中特殊token 60 | """ 61 | assert mode in ["train", "dev", "test"] 62 | 63 | self.data_dir = data_dir 64 | self.processor = processor 65 | self.label_to_word = label_to_word 66 | self.tokenizer = tokenizer 67 | self.template = template 68 | self.max_seq_length = max_seq_length 69 | self.mode = mode 70 | self.num_sample = num_sample 71 | self.special_tokens = special_tokens 72 | # 这里留个参数选项给InputFeature输入处理方法 73 | self.kwargs = kwargs 74 | self.tokenize_multipart_input = tokenize_multipart_input 75 | 76 | if self.special_tokens is None: 77 | self.special_tokens = ["<", "[", ".", ","] 78 | 79 | self.label_list = self.processor.get_labels() 80 | self.num_labels = len(self.label_list) 81 | self.label_word_list = [] 82 | 83 | for key in self.label_to_word: 84 | # For RoBERTa/BART/T5, tokenization also considers space, so we use space+word as label words. 85 | if self.label_to_word[key][0] not in self.special_tokens: 86 | # Make sure space+word is in the vocabulary 87 | assert len(self.tokenizer.tokenize(f" {self.label_to_word[key]}")) == 1 88 | self.label_to_word[key] = self.tokenizer._convert_token_to_id( 89 | self.tokenizer.tokenize(f" {self.label_to_word[key]}")[0]) 90 | else: 91 | self.label_to_word[key] = self.tokenizer._convert_token_to_id(self.label_to_word[key]) 92 | 93 | if len(self.label_list) > 1: 94 | self.label_word_list = [self.label_to_word[label] for label in self.label_list] 95 | else: 96 | self.label_word_list = [self.label_to_word[label] for label in ["0", "1"]] 97 | 98 | if self.mode == "train": 99 | # We do not do multiple sampling when it's the training mode 100 | self.num_sample = 1 101 | else: 102 | self.num_sample = self.num_sample 103 | 104 | # 在inference阶段需要被多次采样 105 | if isinstance(self.template, list): 106 | self.num_sample *= len(self.template) 107 | 108 | # The support examples are sourced from the training set. 109 | self.support_examples = self.processor.get_train_examples(self.data_dir) 110 | 111 | if mode == "dev": 112 | self.query_examples = self.processor.get_dev_examples(self.data_dir) 113 | elif mode == "test": 114 | self.query_examples = self.processor.get_test_examples(self.data_dir) 115 | else: 116 | self.query_examples = self.support_examples 117 | 118 | self.size = len(self.query_examples) * self.num_sample 119 | support_indices = list(range(len(self.support_examples))) 120 | self.example_idx = [] 121 | for sample_idx in range(self.num_sample): 122 | for query_idx in range(len(self.query_examples)): 123 | context_indices = [support_idx for support_idx in support_indices 124 | if support_idx != query_idx or mode != "train"] 125 | self.example_idx.append((query_idx, context_indices, sample_idx)) 126 | 127 | # If it is not training, we pre-process the data; otherwise, we process the data online. 128 | if mode != "train": 129 | self.features, count = [], 0 130 | for query_idx, context_indices, sample_idx in self.example_idx: 131 | example = self.query_examples[query_idx] 132 | if isinstance(self.template, list): 133 | template = self.template[sample_idx % len(self.template)] # Use template in order 134 | else: 135 | template = self.template 136 | 137 | self.features.append(self.convert_fn(example=example, template=template)) 138 | 139 | count += 1 140 | else: 141 | self.features = None 142 | 143 | def __len__(self): 144 | return self.size 145 | 146 | def __getitem__(self, i): 147 | if self.features is None: 148 | query_idx, context_indices, sample_idx = self.example_idx[i] 149 | example = self.query_examples[query_idx] 150 | if isinstance(self.template, list): 151 | template = self.template[sample_idx % len(self.template)] # Use template in order 152 | else: 153 | template = self.template 154 | 155 | features = self.convert_fn(example=example, template=template) 156 | else: 157 | features = self.features[i] 158 | 159 | return features 160 | 161 | def get_labels(self) -> List[str]: 162 | return self.label_list 163 | 164 | def convert_fn(self, example: InputExample, template: str) -> InputFeatures: 165 | """ 166 | :param example: input example 167 | :param template: template 168 | """ 169 | label_map = {label: i for i, label in enumerate(self.label_list)} # Mapping the label names to label ids 170 | if len(self.label_list) == 1: 171 | label_map = {"0": 0, "1": 1} 172 | 173 | if example.label is None: 174 | example_label = None 175 | elif len(self.label_list) == 1: 176 | example_label = float(example.label) 177 | else: 178 | example_label = label_map[example.label] 179 | 180 | assert not pd.isna(example.text_a) and example.text_a is not None 181 | 182 | inputs = self.tokenize_multipart_input( 183 | input_text_list=[example.text_a] if example.text_b is None else [example.text_a, example.text_b], 184 | max_length=self.max_seq_length, 185 | tokenizer=self.tokenizer, 186 | template=template, 187 | label_word_list=self.label_word_list, 188 | **self.kwargs 189 | ) 190 | 191 | return InputFeatures(**inputs, label=example_label) 192 | -------------------------------------------------------------------------------- /tools/tools.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch.nn as nn 11 | from typing import List, Any, Dict 12 | 13 | 14 | def resize_token_type_embeddings(model: Any, new_num_types: int, random_segment: bool): 15 | """ Resize the segment (token type) embeddings for BERT """ 16 | if hasattr(model, "bert"): 17 | old_token_type_embeddings = model.bert.embeddings.token_type_embeddings 18 | else: 19 | raise NotImplementedError 20 | new_token_type_embeddings = nn.Embedding(new_num_types, old_token_type_embeddings.weight.size(1)) 21 | if not random_segment: 22 | new_token_type_embeddings.weight.data[:old_token_type_embeddings.weight.size(0)] = old_token_type_embeddings.weight.data 23 | 24 | model.config.type_vocab_size = new_num_types 25 | if hasattr(model, "bert"): 26 | model.bert.embeddings.token_type_embeddings = new_token_type_embeddings 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def tokenize_multipart_input_for_gen_en(input_text_list: List[str], 32 | max_length: int, 33 | tokenizer: Any, 34 | template: str, 35 | label_word_list: List[str], 36 | **kwargs) -> Dict[str, Any]: 37 | """Concatenate all sentences and prompts based on the provided template. 38 | Template example: '*cls*It was*mask*.*sent_0***label_0:*sent_1****label_1*:*sent_2***' 39 | *xx* represent variables: 40 | *cls*: cls_token 41 | *mask*: mask_token 42 | *sep*: sep_token 43 | *sep+*: sep_token, also means +1 for segment id 44 | *sent_i*: sentence i (input_text_list[i]) 45 | *sent-_i*: same as above, but delete the last token 46 | *sentl_i*: same as above, but use lower case for the first word 47 | *sentl-_i*: same as above, but use lower case for the first word and delete the last token 48 | *+sent_i*: same as above, but add a space before the sentence 49 | *+sentl_i*: same as above, but add a space before the sentence and use lower case for the first word 50 | *label_i*: label_word_list[i] 51 | 52 | Use "_" to replace space. 53 | PAY ATTENTION TO SPACE!! DO NOT leave space before variables, for this will lead to extra space token. 54 | 55 | kwargs: first_sent_limit、other_sent_limit、truncate_head 56 | :param input_text_list: 输入文本list 57 | :param max_length: 最大长度 58 | :param tokenizer: 编码器 59 | :param template: template 60 | :param label_word_list: label word list 61 | """ 62 | assert template is not None 63 | 64 | input_ids = [] 65 | attention_mask = [] 66 | token_type_ids = [] 67 | mask_pos = None 68 | 69 | special_token_mapping = { 70 | "cls": tokenizer.cls_token_id, "mask": tokenizer.mask_token_id, 71 | "sep": tokenizer.sep_token_id, "sep+": tokenizer.sep_token_id, 72 | } 73 | template_list = template.split("*") # Get variable list in the template 74 | segment_id = 0 # Current segment id. Segment id +1 if encountering sep+. 75 | 76 | for part_id, part in enumerate(template_list): 77 | new_tokens, segment_plus_1_flag = [], False 78 | if part in special_token_mapping: 79 | if part == "cls" and "T5" in type(tokenizer).__name__: 80 | # T5 does not have cls token 81 | continue 82 | new_tokens.append(special_token_mapping[part]) 83 | if part == "sep+": 84 | segment_plus_1_flag = True 85 | elif part[:6] == "label_": 86 | # Note that label_word_list already has extra space, so do not add more space ahead of it. 87 | label_id = int(part.split("_")[1]) 88 | label_word = label_word_list[label_id] 89 | new_tokens.append(label_word) 90 | elif part[:5] == "sent_": 91 | sent_id = int(part.split("_")[1]) 92 | new_tokens += tokenizer.encode(input_text_list[sent_id], add_special_tokens=False) 93 | elif part[:6] == "+sent_": 94 | # Add space 95 | sent_id = int(part.split("_")[1]) 96 | new_tokens += tokenizer.encode(" " + input_text_list[sent_id], add_special_tokens=False) 97 | elif part[:6] == "sent-_": 98 | # Delete the last token 99 | sent_id = int(part.split("_")[1]) 100 | new_tokens += tokenizer.encode(input_text_list[sent_id][:-1], add_special_tokens=False) 101 | elif part[:6] == "sentl_": 102 | # Lower case the first token 103 | sent_id = int(part.split("_")[1]) 104 | text = input_text_list[sent_id] 105 | text = text[:1].lower() + text[1:] 106 | new_tokens += tokenizer.encode(text, add_special_tokens=False) 107 | elif part[:7] == "+sentl_": 108 | # Lower case the first token and add space 109 | sent_id = int(part.split("_")[1]) 110 | text = input_text_list[sent_id] 111 | text = text[:1].lower() + text[1:] 112 | new_tokens += tokenizer.encode(" " + text, add_special_tokens=False) 113 | elif part[:7] == "sentl-_": 114 | # Lower case the first token and discard the last token 115 | sent_id = int(part.split("_")[1]) 116 | text = input_text_list[sent_id] 117 | text = text[:1].lower() + text[1:] 118 | new_tokens += tokenizer.encode(text[:-1], add_special_tokens=False) 119 | elif part[:6] == "sentu_": 120 | # Upper case the first token 121 | sent_id = int(part.split("_")[1]) 122 | text = input_text_list[sent_id] 123 | text = text[:1].upper() + text[1:] 124 | new_tokens += tokenizer.encode(text, add_special_tokens=False) 125 | elif part[:7] == "+sentu_": 126 | # Upper case the first token and add space 127 | sent_id = int(part.split("_")[1]) 128 | text = input_text_list[sent_id] 129 | text = text[:1].upper() + text[1:] 130 | new_tokens += tokenizer.encode(" " + text, add_special_tokens=False) 131 | else: 132 | # Just natural language prompt 133 | part = part.replace("_", " ") 134 | # handle special case when T5 tokenizer might add an extra space 135 | if len(part) == 1: 136 | new_tokens.append(tokenizer._convert_token_to_id(part)) 137 | else: 138 | new_tokens += tokenizer.encode(part, add_special_tokens=False) 139 | 140 | if part[:4] == "sent" or part[1:5] == "sent": 141 | # If this part is the sentence, limit the sentence length 142 | sent_id = int(part.split("_")[1]) 143 | if sent_id == 0: 144 | if kwargs.get("first_sent_limit", None) is not None: 145 | new_tokens = new_tokens[:kwargs["first_sent_limit"]] 146 | else: 147 | if kwargs.get("other_sent_limit", None) is not None: 148 | new_tokens = new_tokens[:kwargs["other_sent_limit"]] 149 | 150 | input_ids += new_tokens 151 | attention_mask += [1 for i in range(len(new_tokens))] 152 | token_type_ids += [segment_id for i in range(len(new_tokens))] 153 | 154 | if segment_plus_1_flag: 155 | segment_id += 1 156 | 157 | # Padding 158 | if kwargs.get("first_sent_limit", None) is not None and len(input_ids) > max_length: 159 | print(f"Input exceeds max_length limit: {tokenizer.decode(input_ids)}") 160 | 161 | while len(input_ids) < max_length: 162 | input_ids.append(tokenizer.pad_token_id) 163 | attention_mask.append(0) 164 | token_type_ids.append(0) 165 | 166 | # Truncate 167 | if len(input_ids) > max_length: 168 | if kwargs.get("truncate_head", None): 169 | input_ids = input_ids[-max_length:] 170 | attention_mask = attention_mask[-max_length:] 171 | token_type_ids = token_type_ids[-max_length:] 172 | else: 173 | # Default is to truncate the tail 174 | input_ids = input_ids[:max_length] 175 | attention_mask = attention_mask[:max_length] 176 | token_type_ids = token_type_ids[:max_length] 177 | 178 | mask_pos = [input_ids.index(tokenizer.mask_token_id)] 179 | # Make sure that the masked position is inside the max_length 180 | assert mask_pos[0] < max_length 181 | 182 | result = {"input_ids": input_ids, "attention_mask": attention_mask} 183 | if "BERT" in type(tokenizer).__name__: 184 | # Only provide token type ids for BERT 185 | result["token_type_ids"] = token_type_ids 186 | 187 | result["mask_pos"] = mask_pos 188 | 189 | return result 190 | 191 | 192 | # 调用工具融合 193 | multipart_map = { 194 | "glue": tokenize_multipart_input_for_gen_en 195 | } 196 | -------------------------------------------------------------------------------- /core/gen_template/LM_BFF.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch 11 | from core.gen_template import TemplateGenerator 12 | from transformers import AutoModel 13 | from transformers import AutoTokenizer 14 | from tqdm import tqdm 15 | from typing import List, Dict, Any, Callable, Tuple 16 | 17 | 18 | class LMBFFTemplateGenerator(TemplateGenerator): 19 | def __init__(self, device: torch.device) -> None: 20 | super(LMBFFTemplateGenerator, self).__init__() 21 | self.device = device 22 | 23 | def search_template(self, 24 | model: Any, 25 | tokenizer: Any, 26 | dataset: List[Dict[str, Any]], 27 | beam: int, 28 | label_mapping: Dict[Any, Any], 29 | inspired_templates: List[str], 30 | target_number: int, 31 | batch_size: int = 32, 32 | gen_max_len: int = 20, 33 | label: Any = None, 34 | truncates: List[str] = None, 35 | first_mask_token: str = "", 36 | end_token: str = "", 37 | template_encoder: Callable[[str, List[str], Any, Any, Dict[Any, Any]], List[int]] = None, 38 | forbid_tokens: List[int] = None, 39 | forbid_continuous_token: List[int] = None, 40 | replace_token_map_list: List[Dict[str, str]] = None, 41 | *args, **kwargs) -> List[Tuple[str, float, List[str]]]: 42 | """ 43 | :param model: 用来生成prompt的模型 44 | :param tokenizer: 编码器 45 | :param dataset: 载入的数据,一般形式: {"label": line[-1], "text": [line[8], line[9]]} 46 | :param beam: beam search size 47 | :param label_mapping: 标签描述映射,形如将数字标签映射成文字标签,一般形式: {0: 'terrible', 1: 'great'} 48 | :param inspired_templates: 启发输入模板 49 | :param target_number: 生成目标词的范围 50 | :param batch_size: T5推理的batch size 51 | :param gen_max_len: 生成内容的最大长度 52 | :param label: 指定某个label的文本才进行template生成 53 | :param truncates: 截断,配合inspired_templates定制化,数量需要一致 54 | :param first_mask_token: 首个用于生成位置mask的token 55 | :param end_token: 结束token 56 | :param template_encoder: 和template匹配的文本编码方法 57 | :param forbid_tokens: 跳过一些特定的token,如"..." 58 | :param forbid_continuous_token: 跳过一些不可连续生成的token,如标点符号 59 | :param replace_token_map_list: 用于替换生成的文本中部分的token,配合inspired_templates定制化,数量需要一致 60 | """ 61 | if isinstance(model, str): 62 | model = AutoModel.from_pretrained(model) 63 | if isinstance(tokenizer, str): 64 | tokenizer = AutoTokenizer.from_pretrained(tokenizer) 65 | 66 | res_templates = [] 67 | 68 | assert len(truncates) == len(inspired_templates) 69 | for inspired_template, truncate, replace_token_map in zip(inspired_templates, truncates, replace_token_map_list): 70 | generate_text = self.generate(dataset, inspired_template, model, tokenizer, target_number, label_mapping, 71 | beam, batch_size, gen_max_len, label, truncate, first_mask_token, end_token, 72 | template_encoder, forbid_tokens, forbid_continuous_token)[:beam // 2] 73 | 74 | if replace_token_map: 75 | for text, score, text_id in generate_text: 76 | for ori_token, repl_token in replace_token_map.items(): 77 | text.replace(ori_token, repl_token) 78 | res_templates.append((text, score, text_id)) 79 | else: 80 | res_templates.extend(generate_text) 81 | 82 | return res_templates 83 | 84 | def generate(self, 85 | dataset: List[Dict[str, Any]], 86 | inspired_template: str, 87 | model: Any, 88 | tokenizer: Any, 89 | target_number: int, 90 | label_mapping: Dict[Any, Any], 91 | beam: int, 92 | batch_size: int = 32, 93 | gen_max_len: int = 20, 94 | label: Any = None, 95 | truncate: str = None, 96 | first_mask_token: str = "", 97 | end_token: str = "", 98 | template_encoder: Callable[[str, List[str], Any, Any, Dict[Any, Any]], List[int]] = None, 99 | forbid_tokens: List[int] = None, 100 | forbid_continuous_token: List[int] = None) -> List[Tuple[str, float, List[str]]]: 101 | """ 102 | :param dataset: 载入的数据,一般形式: {"label": line[-1], "text": [line[8], line[9]]} 103 | :param inspired_template: 启发输入模板 104 | :param model: 用来生成prompt的模型 105 | :param tokenizer: 编码器 106 | :param target_number: 生成目标词的范围 107 | :param label_mapping: 标签描述映射,形如将数字标签映射成文字标签,一般形式: {0: 'terrible', 1: 'great'} 108 | :param beam: beam search size 109 | :param batch_size: T5推理的batch size 110 | :param gen_max_len: 生成内容的最大长度 111 | :param label: 指定某个label的文本才进行template生成 112 | :param truncate: 截断 113 | :param first_mask_token: 首个用于生成位置mask的token 114 | :param end_token: 结束token 115 | :param template_encoder: 和template匹配的文本编码方法 116 | :param forbid_tokens: 跳过一些特定的token,如"..." 117 | :param forbid_continuous_token: 跳过一些不可连续生成的token,如标点符号 118 | """ 119 | if template_encoder is None: 120 | template_encoder = self.encode_text_by_template 121 | 122 | input_tensors, max_length = [], 0 123 | 124 | for item in dataset: 125 | if label is None or item["label"] == label: 126 | input_text = template_encoder(inspired_template, item["text"], item["label"], tokenizer, label_mapping) 127 | if truncate is not None: 128 | if truncate == "head": 129 | input_text = input_text[-256:] 130 | elif truncate == "tail": 131 | input_text = input_text[:256] 132 | else: 133 | raise NotImplementedError 134 | input_ids = torch.tensor(input_text).long() 135 | max_length = max(max_length, input_ids.size(-1)) 136 | input_tensors.append(input_ids) 137 | 138 | # Concatenate inputs as a batch 139 | input_ids = torch.zeros((len(input_tensors), max_length)).long() 140 | attention_mask = torch.zeros((len(input_tensors), max_length)).long() 141 | for i in range(len(input_tensors)): 142 | input_ids[i, :input_tensors[i].size(-1)] = input_tensors[i] 143 | attention_mask[i, :input_tensors[i].size(-1)] = 1 144 | 145 | input_ids = input_ids.to(self.device) 146 | attention_mask = attention_mask.to(self.device) 147 | assert len(input_tensors) > 0 148 | 149 | start_mask = tokenizer._convert_token_to_id(first_mask_token) 150 | ori_decoder_input_ids = torch.zeros((input_ids.size(0), gen_max_len)).long() 151 | ori_decoder_input_ids[..., 0] = model.config.decoder_start_token_id 152 | 153 | current_output = [{"decoder_input_ids": ori_decoder_input_ids, "ll": 0, "output_id": 1, "output": []}] 154 | for i in tqdm(range(gen_max_len - 2)): 155 | new_current_output = [] 156 | for item in current_output: 157 | if item["output_id"] > target_number: 158 | # Enough contents 159 | new_current_output.append(item) 160 | continue 161 | decoder_input_ids = item["decoder_input_ids"] 162 | decoder_input_ids = decoder_input_ids.to(self.device) 163 | 164 | # Forward 165 | turn = input_ids.size(0) // batch_size 166 | if input_ids.size(0) % batch_size != 0: 167 | turn += 1 168 | aggr_output = [] 169 | for t in range(turn): 170 | start = t * batch_size 171 | end = min((t + 1) * batch_size, input_ids.size(0)) 172 | 173 | with torch.no_grad(): 174 | aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], 175 | decoder_input_ids=decoder_input_ids[start:end])[0]) 176 | aggr_output = torch.cat(aggr_output, 0) 177 | 178 | # Gather results across all input sentences, and sort generated tokens by log likelihood 179 | aggr_output = aggr_output.mean(0) 180 | log_denominator = torch.logsumexp(aggr_output[i], -1).item() 181 | ids = list(range(model.config.vocab_size)) 182 | ids.sort(key=lambda x: aggr_output[i][x].item(), reverse=True) 183 | ids = ids[:beam + 3] 184 | 185 | for word_id in ids: 186 | output_id = item["output_id"] 187 | 188 | check = True 189 | # random stop and finish one part 190 | if word_id == start_mask - output_id or word_id == tokenizer._convert_token_to_id(end_token): 191 | output_id += 1 192 | 193 | output_text = item["output"] + [word_id] 194 | ll = item["ll"] + aggr_output[i][word_id] - log_denominator 195 | new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.size()) 196 | new_decoder_input_ids[:] = decoder_input_ids 197 | new_decoder_input_ids[..., i + 1] = word_id 198 | 199 | forbid_tokens = [3, 19794, 22354] if forbid_tokens is None else forbid_tokens 200 | if word_id in forbid_tokens: 201 | check = False 202 | 203 | # Forbid continuous 204 | forbid_continuous_token = [5] if forbid_continuous_token is None else forbid_continuous_token 205 | if len(output_text) > 1 and output_text[-2] == output_text[-1] and \ 206 | output_text[-1] in forbid_continuous_token: 207 | check = False 208 | 209 | if check: 210 | # Add new results to beam search pool 211 | new_item = {"decoder_input_ids": new_decoder_input_ids, 212 | "ll": ll, "output_id": output_id, "output": output_text} 213 | new_current_output.append(new_item) 214 | 215 | if len(new_current_output) == 0: 216 | break 217 | 218 | new_current_output.sort(key=lambda x: x["ll"], reverse=True) 219 | new_current_output = new_current_output[:beam] 220 | current_output = new_current_output 221 | 222 | result = [] 223 | for item in current_output: 224 | generate_text = "" 225 | for token in item["output"]: 226 | generate_text += tokenizer._convert_id_to_token(token) 227 | 228 | result.append((generate_text, item["ll"].item(), item["output"])) 229 | 230 | return result 231 | 232 | @staticmethod 233 | def encode_text_by_template(inspired_template: str, 234 | input_text_tuple: List[str], 235 | label: Any, tokenizer: Any, 236 | label_mapping: Dict[Any, Any]) -> List[int]: 237 | """ 给英文T5用的编码规则,其他特别要求的模型自行定义 238 | :param inspired_template: 启发输入模板 239 | :param input_text_tuple: 数据文本 240 | :param label: 标签 241 | :param tokenizer: 编码器 242 | :param label_mapping: 标签描述映射,形如将数字标签映射成文字标签,一般形式: {0: 'terrible', 1: 'great'} 243 | :return: 244 | """ 245 | 246 | def enc(token: str): 247 | return tokenizer.encode(token, add_special_tokens=False) 248 | 249 | special_token_mapping = {"cls": tokenizer.cls_token_id, "mask": tokenizer.mask_token_id, 250 | "sep": tokenizer.sep_token_id, "sep+": tokenizer.sep_token_id} 251 | for index in range(10): 252 | special_token_mapping[f""] = tokenizer._convert_token_to_id(f"") 253 | inspired_template_list = inspired_template.split('*') 254 | input_ids = [] 255 | for part in inspired_template_list: 256 | new_tokens = [] 257 | if part in special_token_mapping: 258 | if part == "cls" and "T5" in type(tokenizer).__name__: 259 | # T5 does not have cls token 260 | continue 261 | new_tokens.append(special_token_mapping[part]) 262 | elif part[:5] == "label": 263 | new_tokens += enc(" " + label_mapping[label]) 264 | elif part[:5] == "sent_": 265 | sent_id = int(part.split("_")[1]) 266 | new_tokens += enc(input_text_tuple[sent_id]) 267 | elif part[:6] == "+sent_": 268 | sent_id = int(part.split("_")[1]) 269 | new_tokens += enc(" " + input_text_tuple[sent_id]) # add space 270 | elif part[:6] == "sent-_": 271 | # Delete the last token 272 | sent_id = int(part.split("_")[1]) 273 | new_tokens += enc(input_text_tuple[sent_id][:-1]) 274 | elif part[:7] == "+sentl_": 275 | # Lower case the first token 276 | sent_id = int(part.split("_")[1]) 277 | text = input_text_tuple[sent_id] 278 | text = text[:1].lower() + text[1:] 279 | new_tokens += enc(" " + text) 280 | elif part[:7] == "+sentu_": 281 | # Upper case the first token 282 | sent_id = int(part.split("_")[1]) 283 | text = input_text_tuple[sent_id] 284 | text = text[:1].upper() + text[1:] 285 | new_tokens += enc(' ' + text) 286 | elif part[:6] == "sentl_": 287 | # Lower case the first token 288 | sent_id = int(part.split("_")[1]) 289 | text = input_text_tuple[sent_id] 290 | text = text[:1].lower() + text[1:] 291 | new_tokens += enc(text) 292 | elif part[:6] == "sentu_": 293 | # Lower case the first token 294 | sent_id = int(part.split("_")[1]) 295 | text = input_text_tuple[sent_id] 296 | text = text[:1].upper() + text[1:] 297 | new_tokens += enc(text) 298 | elif part[:7] == "sentl-_": 299 | # Lower case the first token 300 | sent_id = int(part.split("_")[1]) 301 | text = input_text_tuple[sent_id] 302 | text = text[:1].lower() + text[1:] 303 | new_tokens += enc(text[:-1]) 304 | else: 305 | part = part.replace("_", " ") # there cannot be space in command, so use "_" to replace space 306 | # handle special case when t5 tokenizer might add an extra space 307 | if len(part) == 1: 308 | new_tokens.append(tokenizer._convert_token_to_id(part)) 309 | else: 310 | new_tokens += enc(part) 311 | 312 | input_ids += new_tokens 313 | return input_ids 314 | -------------------------------------------------------------------------------- /tools/glue_data_processor.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # Author: DengBoCong 3 | # 4 | # License: MIT License 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | import pandas as pd 12 | from transformers.data.processors.glue import DataProcessor 13 | from transformers.data.processors.glue import InputExample 14 | from transformers.data.metrics import glue_compute_metrics 15 | from typing import Dict, Any, List 16 | 17 | 18 | class MrpcProcessor(DataProcessor): 19 | """Processor for the MRPC data set (GLUE version).""" 20 | 21 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 22 | """See base class.""" 23 | return InputExample( 24 | tensor_dict["idx"].numpy(), 25 | tensor_dict["sentence1"].numpy().decode("utf-8"), 26 | tensor_dict["sentence2"].numpy().decode("utf-8"), 27 | str(tensor_dict["label"].numpy()), 28 | ) 29 | 30 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 31 | """See base class.""" 32 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 33 | 34 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 35 | """See base class.""" 36 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 37 | 38 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 39 | """See base class.""" 40 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 41 | 42 | def get_labels(self) -> List[str]: 43 | """See base class.""" 44 | return ["0", "1"] 45 | 46 | @staticmethod 47 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 48 | """Creates examples for the training, dev and test sets.""" 49 | examples = [] 50 | for (i, line) in enumerate(lines): 51 | if i == 0: 52 | continue 53 | guid = "%s-%s" % (set_type, i) 54 | text_a = line[3] 55 | text_b = line[4] 56 | label = line[0] 57 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 58 | return examples 59 | 60 | 61 | class MnliProcessor(DataProcessor): 62 | """Processor for the MultiNLI data set (GLUE version).""" 63 | 64 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 65 | """See base class.""" 66 | return InputExample( 67 | tensor_dict["idx"].numpy(), 68 | tensor_dict["premise"].numpy().decode("utf-8"), 69 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 70 | str(tensor_dict["label"].numpy()), 71 | ) 72 | 73 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 74 | """See base class.""" 75 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 76 | 77 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 78 | """See base class.""" 79 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") 80 | 81 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 82 | """See base class.""" 83 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched") 84 | 85 | def get_labels(self) -> List[str]: 86 | """See base class.""" 87 | return ["contradiction", "entailment", "neutral"] 88 | 89 | @staticmethod 90 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 91 | """Creates examples for the training, dev and test sets.""" 92 | examples = [] 93 | for (i, line) in enumerate(lines): 94 | if i == 0: 95 | continue 96 | guid = "%s-%s" % (set_type, line[0]) 97 | text_a = line[8] 98 | text_b = line[9] 99 | label = line[-1] 100 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 101 | return examples 102 | 103 | 104 | class MnliMismatchedProcessor(MnliProcessor): 105 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 106 | 107 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 108 | """See base class.""" 109 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched") 110 | 111 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 112 | """See base class.""" 113 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched") 114 | 115 | 116 | class SnliProcessor(DataProcessor): 117 | """Processor for the MultiNLI data set (GLUE version).""" 118 | 119 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 120 | """See base class.""" 121 | return InputExample( 122 | tensor_dict["idx"].numpy(), 123 | tensor_dict["premise"].numpy().decode("utf-8"), 124 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 125 | str(tensor_dict["label"].numpy()), 126 | ) 127 | 128 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 129 | """See base class.""" 130 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 131 | 132 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 133 | """See base class.""" 134 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 135 | 136 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 137 | """See base class.""" 138 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 139 | 140 | def get_labels(self) -> List[str]: 141 | """See base class.""" 142 | return ["contradiction", "entailment", "neutral"] 143 | 144 | @staticmethod 145 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 146 | """Creates examples for the training, dev and test sets.""" 147 | examples = [] 148 | for (i, line) in enumerate(lines): 149 | if i == 0: 150 | continue 151 | guid = "%s-%s" % (set_type, line[0]) 152 | text_a = line[7] 153 | text_b = line[8] 154 | label = line[-1] 155 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 156 | return examples 157 | 158 | 159 | class ColaProcessor(DataProcessor): 160 | """Processor for the CoLA data set (GLUE version).""" 161 | 162 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 163 | """See base class.""" 164 | return InputExample( 165 | tensor_dict["idx"].numpy(), 166 | tensor_dict["sentence"].numpy().decode("utf-8"), 167 | None, 168 | str(tensor_dict["label"].numpy()), 169 | ) 170 | 171 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 172 | """See base class.""" 173 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 174 | 175 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 176 | """See base class.""" 177 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 178 | 179 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 180 | """See base class.""" 181 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 182 | 183 | def get_labels(self) -> List[str]: 184 | """See base class.""" 185 | return ["0", "1"] 186 | 187 | @staticmethod 188 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 189 | """Creates examples for the training, dev and test sets.""" 190 | test_mode = set_type == "test" 191 | text_index = 3 192 | examples = [] 193 | for (i, line) in enumerate(lines): 194 | guid = "%s-%s" % (set_type, i) 195 | text_a = line[text_index] 196 | label = line[1] 197 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 198 | return examples 199 | 200 | 201 | class Sst2Processor(DataProcessor): 202 | """Processor for the SST-2 data set (GLUE version).""" 203 | 204 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 205 | """See base class.""" 206 | return InputExample( 207 | tensor_dict["idx"].numpy(), 208 | tensor_dict["sentence"].numpy().decode("utf-8"), 209 | None, 210 | str(tensor_dict["label"].numpy()), 211 | ) 212 | 213 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 214 | """See base class.""" 215 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 216 | 217 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 218 | """See base class.""" 219 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 220 | 221 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 222 | """See base class.""" 223 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 224 | 225 | def get_labels(self) -> List[str]: 226 | """See base class.""" 227 | return ["0", "1"] 228 | 229 | @staticmethod 230 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 231 | """Creates examples for the training, dev and test sets.""" 232 | examples = [] 233 | text_index = 0 234 | for (i, line) in enumerate(lines): 235 | if i == 0: 236 | continue 237 | guid = "%s-%s" % (set_type, i) 238 | text_a = line[text_index] 239 | label = line[1] 240 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 241 | return examples 242 | 243 | 244 | class StsbProcessor(DataProcessor): 245 | """Processor for the STS-B data set (GLUE version).""" 246 | 247 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 248 | """See base class.""" 249 | return InputExample( 250 | tensor_dict["idx"].numpy(), 251 | tensor_dict["sentence1"].numpy().decode("utf-8"), 252 | tensor_dict["sentence2"].numpy().decode("utf-8"), 253 | str(tensor_dict["label"].numpy()), 254 | ) 255 | 256 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 257 | """See base class.""" 258 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 259 | 260 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 261 | """See base class.""" 262 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 263 | 264 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 265 | """See base class.""" 266 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 267 | 268 | def get_labels(self) -> List[Any]: 269 | """See base class.""" 270 | return [None] 271 | 272 | @staticmethod 273 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 274 | """Creates examples for the training, dev and test sets.""" 275 | examples = [] 276 | for (i, line) in enumerate(lines): 277 | if i == 0: 278 | continue 279 | guid = "%s-%s" % (set_type, line[0]) 280 | text_a = line[7] 281 | text_b = line[8] 282 | label = line[-1] 283 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 284 | return examples 285 | 286 | 287 | class QqpProcessor(DataProcessor): 288 | """Processor for the QQP data set (GLUE version).""" 289 | 290 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 291 | """See base class.""" 292 | return InputExample( 293 | tensor_dict["idx"].numpy(), 294 | tensor_dict["question1"].numpy().decode("utf-8"), 295 | tensor_dict["question2"].numpy().decode("utf-8"), 296 | str(tensor_dict["label"].numpy()), 297 | ) 298 | 299 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 300 | """See base class.""" 301 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 302 | 303 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 304 | """See base class.""" 305 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 306 | 307 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 308 | """See base class.""" 309 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 310 | 311 | def get_labels(self) -> List[str]: 312 | """See base class.""" 313 | return ["0", "1"] 314 | 315 | @staticmethod 316 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 317 | """Creates examples for the training, dev and test sets.""" 318 | test_mode = set_type == "test" 319 | q1_index = 3 320 | q2_index = 4 321 | examples = [] 322 | for (i, line) in enumerate(lines): 323 | if i == 0: 324 | continue 325 | guid = "%s-%s" % (set_type, line[0]) 326 | try: 327 | text_a = line[q1_index] 328 | text_b = line[q2_index] 329 | label = line[5] 330 | except IndexError: 331 | continue 332 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 333 | return examples 334 | 335 | 336 | class QnliProcessor(DataProcessor): 337 | """Processor for the QNLI data set (GLUE version).""" 338 | 339 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 340 | """See base class.""" 341 | return InputExample( 342 | tensor_dict["idx"].numpy(), 343 | tensor_dict["question"].numpy().decode("utf-8"), 344 | tensor_dict["sentence"].numpy().decode("utf-8"), 345 | str(tensor_dict["label"].numpy()), 346 | ) 347 | 348 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 349 | """See base class.""" 350 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 351 | 352 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 353 | """See base class.""" 354 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 355 | 356 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 357 | """See base class.""" 358 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 359 | 360 | def get_labels(self) -> List[str]: 361 | """See base class.""" 362 | return ["entailment", "not_entailment"] 363 | 364 | @staticmethod 365 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 366 | """Creates examples for the training, dev and test sets.""" 367 | examples = [] 368 | for (i, line) in enumerate(lines): 369 | if i == 0: 370 | continue 371 | guid = "%s-%s" % (set_type, line[0]) 372 | text_a = line[1] 373 | text_b = line[2] 374 | label = line[-1] 375 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 376 | return examples 377 | 378 | 379 | class RteProcessor(DataProcessor): 380 | """Processor for the RTE data set (GLUE version).""" 381 | 382 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 383 | """See base class.""" 384 | return InputExample( 385 | tensor_dict["idx"].numpy(), 386 | tensor_dict["sentence1"].numpy().decode("utf-8"), 387 | tensor_dict["sentence2"].numpy().decode("utf-8"), 388 | str(tensor_dict["label"].numpy()), 389 | ) 390 | 391 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 392 | """See base class.""" 393 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 394 | 395 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 396 | """See base class.""" 397 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 398 | 399 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 400 | """See base class.""" 401 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 402 | 403 | def get_labels(self) -> List[str]: 404 | """See base class.""" 405 | return ["entailment", "not_entailment"] 406 | 407 | @staticmethod 408 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 409 | """Creates examples for the training, dev and test sets.""" 410 | examples = [] 411 | for (i, line) in enumerate(lines): 412 | if i == 0: 413 | continue 414 | guid = "%s-%s" % (set_type, line[0]) 415 | text_a = line[1] 416 | text_b = line[2] 417 | label = line[-1] 418 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 419 | return examples 420 | 421 | 422 | class WnliProcessor(DataProcessor): 423 | """Processor for the WNLI data set (GLUE version).""" 424 | 425 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 426 | """See base class.""" 427 | return InputExample( 428 | tensor_dict["idx"].numpy(), 429 | tensor_dict["sentence1"].numpy().decode("utf-8"), 430 | tensor_dict["sentence2"].numpy().decode("utf-8"), 431 | str(tensor_dict["label"].numpy()), 432 | ) 433 | 434 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 435 | """See base class.""" 436 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 437 | 438 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 439 | """See base class.""" 440 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 441 | 442 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 443 | """See base class.""" 444 | return self.create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 445 | 446 | def get_labels(self) -> List[str]: 447 | """See base class.""" 448 | return ["0", "1"] 449 | 450 | @staticmethod 451 | def create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]: 452 | """Creates examples for the training, dev and test sets.""" 453 | examples = [] 454 | for (i, line) in enumerate(lines): 455 | if i == 0: 456 | continue 457 | guid = "%s-%s" % (set_type, line[0]) 458 | text_a = line[1] 459 | text_b = line[2] 460 | label = line[-1] 461 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 462 | return examples 463 | 464 | 465 | class TextClassificationProcessor(DataProcessor): 466 | """ Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa). """ 467 | 468 | def __init__(self, task_name: str) -> None: 469 | self.task_name = task_name 470 | 471 | def get_example_from_tensor_dict(self, tensor_dict: Dict[str, Any]) -> InputExample: 472 | """See base class.""" 473 | return InputExample( 474 | tensor_dict["idx"].numpy(), 475 | tensor_dict["sentence"].numpy().decode("utf-8"), 476 | None, 477 | str(tensor_dict["label"].numpy()), 478 | ) 479 | 480 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 481 | """See base class.""" 482 | return self._create_examples(pd.read_csv(os.path.join(data_dir, "train.csv"), 483 | header=None).values.tolist(), "train") 484 | 485 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 486 | """See base class.""" 487 | return self._create_examples(pd.read_csv(os.path.join(data_dir, "dev.csv"), 488 | header=None).values.tolist(), "dev") 489 | 490 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 491 | """See base class.""" 492 | return self._create_examples(pd.read_csv(os.path.join(data_dir, "test.csv"), 493 | header=None).values.tolist(), "test") 494 | 495 | def get_labels(self) -> List[int]: 496 | """See base class.""" 497 | if self.task_name == "mr": 498 | return list(range(2)) 499 | elif self.task_name == "sst-5": 500 | return list(range(5)) 501 | elif self.task_name == "subj": 502 | return list(range(2)) 503 | elif self.task_name == "trec": 504 | return list(range(6)) 505 | elif self.task_name == "cr": 506 | return list(range(2)) 507 | elif self.task_name == "mpqa": 508 | return list(range(2)) 509 | else: 510 | raise Exception("task_name not supported.") 511 | 512 | def _create_examples(self, lines: List[List[str]], set_type: str) -> List[InputExample]: 513 | """Creates examples for the training, dev and test sets.""" 514 | examples = [] 515 | for (i, line) in enumerate(lines): 516 | guid = "%s-%s" % (set_type, i) 517 | if self.task_name == "ag_news": 518 | examples.append(InputExample(guid=guid, text_a=line[1] + ". " + line[2], 519 | short_text=line[1] + ".", label=line[0])) 520 | elif self.task_name == "yelp_review_full": 521 | examples.append(InputExample(guid=guid, text_a=line[1], short_text=line[1], label=line[0])) 522 | elif self.task_name == "yahoo_answers": 523 | text = line[1] 524 | if not pd.isna(line[2]): 525 | text += " " + line[2] 526 | if not pd.isna(line[3]): 527 | text += " " + line[3] 528 | examples.append(InputExample(guid=guid, text_a=text, short_text=line[1], label=line[0])) 529 | elif self.task_name in ["mr", "sst-5", "subj", "trec", "cr", "mpqa"]: 530 | examples.append(InputExample(guid=guid, text_a=line[1], label=line[0])) 531 | else: 532 | raise Exception("Task_name not supported.") 533 | 534 | return examples 535 | 536 | 537 | def text_classification_metrics(task_name: str, preds: Any, labels: Any): 538 | return {"acc": (preds == labels).mean()} 539 | 540 | 541 | processors_mapping = { 542 | "cola": ColaProcessor(), 543 | "mnli": MnliProcessor(), 544 | "mnli-mm": MnliMismatchedProcessor(), 545 | "mrpc": MrpcProcessor(), 546 | "sst-2": Sst2Processor(), 547 | "sts-b": StsbProcessor(), 548 | "qqp": QqpProcessor(), 549 | "qnli": QnliProcessor(), 550 | "rte": RteProcessor(), 551 | "wnli": WnliProcessor(), 552 | "snli": SnliProcessor(), 553 | "mr": TextClassificationProcessor("mr"), 554 | "sst-5": TextClassificationProcessor("sst-5"), 555 | "subj": TextClassificationProcessor("subj"), 556 | "trec": TextClassificationProcessor("trec"), 557 | "cr": TextClassificationProcessor("cr"), 558 | "mpqa": TextClassificationProcessor("mpqa") 559 | } 560 | 561 | num_labels_mapping = { 562 | "cola": 2, 563 | "mnli": 3, 564 | "mrpc": 2, 565 | "sst-2": 2, 566 | "sts-b": 1, 567 | "qqp": 2, 568 | "qnli": 2, 569 | "rte": 2, 570 | "wnli": 2, 571 | "snli": 3, 572 | "mr": 2, 573 | "sst-5": 5, 574 | "subj": 2, 575 | "trec": 6, 576 | "cr": 2, 577 | "mpqa": 2 578 | } 579 | 580 | output_modes_mapping = { 581 | "cola": "classification", 582 | "mnli": "classification", 583 | "mnli-mm": "classification", 584 | "mrpc": "classification", 585 | "sst-2": "classification", 586 | "sts-b": "regression", 587 | "qqp": "classification", 588 | "qnli": "classification", 589 | "rte": "classification", 590 | "wnli": "classification", 591 | "snli": "classification", 592 | "mr": "classification", 593 | "sst-5": "classification", 594 | "subj": "classification", 595 | "trec": "classification", 596 | "cr": "classification", 597 | "mpqa": "classification" 598 | } 599 | 600 | # Return a function that takes (task_name, preds, labels) as inputs 601 | compute_metrics_mapping = { 602 | "cola": glue_compute_metrics, 603 | "mnli": glue_compute_metrics, 604 | "mnli-mm": glue_compute_metrics, 605 | "mrpc": glue_compute_metrics, 606 | "sst-2": glue_compute_metrics, 607 | "sts-b": glue_compute_metrics, 608 | "qqp": glue_compute_metrics, 609 | "qnli": glue_compute_metrics, 610 | "rte": glue_compute_metrics, 611 | "wnli": glue_compute_metrics, 612 | "snli": text_classification_metrics, 613 | "mr": text_classification_metrics, 614 | "sst-5": text_classification_metrics, 615 | "subj": text_classification_metrics, 616 | "trec": text_classification_metrics, 617 | "cr": text_classification_metrics, 618 | "mpqa": text_classification_metrics, 619 | } 620 | 621 | # For regression task only: median 622 | median_mapping = { 623 | "sts-b": 2.5 624 | } 625 | 626 | bound_mapping = { 627 | "sts-b": (0, 5) 628 | } 629 | 630 | label_of_mapping = { 631 | "SST-2": {"0": "terrible", "1": "great"}, 632 | "sst-5": {0: "terrible", 1: "bad", 2: "okay", 3: "good", 4: "great"}, 633 | "mr": {0: "terrible", 1: "great"}, 634 | "cr": {0: "terrible", 1: "great"}, 635 | "subj": {0: "subjective", 1: "objective"}, 636 | "trec": {0: "Description", 1: "Entity", 2: "Expression", 3: "Human", 4: "Location", 5: "Number"}, 637 | "mpqa": {0: "terrible", 1: "great"}, 638 | "CoLA": {"0": "incorrect", "1": "correct"}, 639 | "MRPC": {"0": "No", "1": "Yes"}, 640 | "QQP": {"0": "No", "1": "Yes"}, 641 | "STS-B": {"0": "No", "1": "Yes"}, 642 | "MNLI": {"contradiction": "No", "entailment": "Yes", "neutral": "Maybe"}, 643 | "SNLI": {"contradiction": "No", "entailment": "Yes", "neutral": "Maybe"}, 644 | "QNLI": {"not_entailment": "No", "entailment": "Yes"}, 645 | "RTE": {"not_entailment": "No", "entailment": "Yes"} 646 | } 647 | --------------------------------------------------------------------------------