├── .gitignore ├── README.md ├── distillation.py ├── download_glue_data.py ├── run_glue.py ├── run_glue_distillation.py └── utils_glue.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-PKD-for-BERT-Compression 2 | 3 | Pytorch implementation of the distillation method described in the following paper: [**Patient Knowledge Distillation for BERT Model Compression**](https://arxiv.org/abs/1908.09355). This repository heavily refers to [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers) by huggingface. 4 | 5 | ## Steps to run the code 6 | ### 1. download glue_data 7 | ``` 8 | $ python download_glue_data.py 9 | ``` 10 | 11 | ### 2. Fine-tune teacher BERT model 12 | By running following code, save fine-tuned model. 13 | ``` 14 | python run_glue.py \ 15 | --model_type bert \ 16 | --model_name_or_path bert-base-uncased \ 17 | --task_name $TASK_NAME \ 18 | --do_train \ 19 | --do_eval \ 20 | --do_lower_case \ 21 | --data_dir $GLUE_DIR/$TASK_NAME \ 22 | --max_seq_length 128 \ 23 | --per_gpu_eval_batch_size=8 \ 24 | --per_gpu_train_batch_size=8 \ 25 | --learning_rate 2e-5 \ 26 | --num_train_epochs 3.0 \ 27 | --output_dir /tmp/$TASK_NAME/ 28 | ``` 29 | 30 | ### 3. distill student model with teacher BERT 31 | $TEACHER_MODEL is your fine-tuned model folder. 32 | ``` 33 | python run_glue_distillation.py \ 34 | --model_type bert \ 35 | --teacher_model $TEACHER_MODEL \ 36 | --student_model bert-base-uncased \ 37 | --task_name $TASK_NAME \ 38 | --num_hidden_layers 6 \ 39 | --alpha 0.5 \ 40 | --beta 100.0 \ 41 | --do_train \ 42 | --do_eval \ 43 | --do_lower_case \ 44 | --data_dir $GLUE_DIR/$TASK_NAME \ 45 | --max_seq_length 128 \ 46 | --per_gpu_eval_batch_size=8 \ 47 | --per_gpu_train_batch_size=8 \ 48 | --learning_rate 2e-5 \ 49 | --num_train_epochs 4.0 \ 50 | --output_dir /tmp/$TASK_NAME/ 51 | ``` 52 | 53 | ## Experimental Results on dev set 54 | model | num_layers | SST-2 | MRPC-f1/acc | QQP-f1/acc | MNLI-m/mm | QNLI | RTE 55 | -- | -- | -- | -- | -- | -- | -- | -- 56 | base | 12 | 0.9232 | 0.89/0.8358 | 0.8818/0.9121 | 0.8432/0.8479 | 0.916 | 0.6751 57 | finetuned | 6 | 0.9002 | 0.8741/0.8186 | 0.8672/0.901 | 0.8051/0.8033 | 0.8662 | 0.6101 58 | distill | 6 | 0.9071 | 0.8885/0.8382 | 0.8704/0.9016 | 0.8153/0.821 | 0.8642 | 0.6318 59 | -------------------------------------------------------------------------------- /distillation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from operator import itemgetter 5 | 6 | 7 | class PatientDistillation(nn.Module): 8 | def __init__(self, t_config, s_config): 9 | super(PatientDistillation, self).__init__() 10 | self.t_config = t_config 11 | self.s_config = s_config 12 | 13 | def forward(self, t_model, s_model, order, input_ids, token_type_ids, attention_mask, labels, args): 14 | with torch.no_grad(): 15 | t_outputs = t_model(input_ids=input_ids, 16 | token_type_ids=token_type_ids, 17 | attention_mask=attention_mask) 18 | 19 | s_outputs = s_model(input_ids=input_ids, 20 | token_type_ids=token_type_ids, 21 | attention_mask=attention_mask, 22 | labels=labels) 23 | 24 | t_logits, t_features = t_outputs[0], t_outputs[-1] 25 | train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1] 26 | T = args.temperature 27 | soft_targets = F.softmax(t_logits / T, dim=-1) 28 | log_probs = F.log_softmax(s_logits / T, dim=-1) 29 | soft_loss = F.kl_div(log_probs, soft_targets.detach(), reduction='batchmean') * T * T 30 | 31 | t_features = torch.cat(t_features[1:-1], dim=0).view(self.t_config.num_hidden_layers - 1, 32 | -1, 33 | args.max_seq_length, 34 | self.t_config.hidden_size)[:, :, 0] 35 | 36 | s_features = torch.cat(s_features[1:-1], dim=0).view(self.s_config.num_hidden_layers - 1, 37 | -1, 38 | args.max_seq_length, 39 | self.s_config.hidden_size)[:, :, 0] 40 | 41 | t_features = itemgetter(order)(t_features) 42 | t_features = t_features / t_features.norm(dim=-1).unsqueeze(-1) 43 | s_features = s_features / s_features.norm(dim=-1).unsqueeze(-1) 44 | distill_loss = F.mse_loss(s_features, t_features.detach(), reduction="mean") 45 | return train_loss, soft_loss, distill_loss 46 | -------------------------------------------------------------------------------- /download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | Note: for legal reasons, we are unable to host MRPC. 3 | You can either use the version hosted by the SentEval team, which is already tokenized, 4 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 5 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 6 | You should then rename and place specific files in a folder (see below for an example). 7 | mkdir MRPC 8 | cabextract MSRParaphraseCorpus.msi -d MRPC 9 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 10 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 11 | rm MRPC/_* 12 | rm MSRParaphraseCorpus.msi 13 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 14 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | 25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 26 | TASK2PATH = { 27 | "CoLA": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 28 | "SST": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 29 | "MRPC": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 30 | "QQP": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 31 | "STS": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 32 | "MNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 33 | "SNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 34 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 35 | "RTE": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 36 | "WNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 37 | "diagnostic": 'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 38 | 39 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 40 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 41 | 42 | 43 | def download_and_extract(task, data_dir): 44 | print("Downloading and extracting %s..." % task) 45 | data_file = "%s.zip" % task 46 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 47 | with zipfile.ZipFile(data_file) as zip_ref: 48 | zip_ref.extractall(data_dir) 49 | os.remove(data_file) 50 | print("\tCompleted!") 51 | 52 | 53 | def format_mrpc(data_dir, path_to_data): 54 | print("Processing MRPC...") 55 | mrpc_dir = os.path.join(data_dir, "MRPC") 56 | if not os.path.isdir(mrpc_dir): 57 | os.mkdir(mrpc_dir) 58 | if path_to_data: 59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 61 | else: 62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 70 | 71 | dev_ids = [] 72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 73 | for row in ids_fh: 74 | dev_ids.append(row.strip().split('\t')) 75 | 76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 79 | header = data_fh.readline() 80 | train_fh.write(header) 81 | dev_fh.write(header) 82 | for row in data_fh: 83 | label, id1, id2, s1, s2 = row.strip().split('\t') 84 | if [id1, id2] in dev_ids: 85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 86 | else: 87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 88 | 89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 91 | header = data_fh.readline() 92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 93 | for idx, row in enumerate(data_fh): 94 | label, id1, id2, s1, s2 = row.strip().split('\t') 95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 96 | print("\tCompleted!") 97 | 98 | 99 | def download_diagnostic(data_dir): 100 | print("Downloading and extracting diagnostic...") 101 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 102 | os.mkdir(os.path.join(data_dir, "diagnostic")) 103 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 104 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 105 | print("\tCompleted!") 106 | return 107 | 108 | 109 | def get_tasks(task_names): 110 | task_names = task_names.split(',') 111 | if "all" in task_names: 112 | tasks = TASKS 113 | else: 114 | tasks = [] 115 | for task_name in task_names: 116 | assert task_name in TASKS, "Task %s not found!" % task_name 117 | tasks.append(task_name) 118 | return tasks 119 | 120 | 121 | def main(arguments): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 124 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 125 | type=str, default='all') 126 | parser.add_argument('--path_to_mrpc', 127 | help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 128 | type=str, default='') 129 | args = parser.parse_args(arguments) 130 | 131 | if not os.path.isdir(args.data_dir): 132 | os.mkdir(args.data_dir) 133 | tasks = get_tasks(args.tasks) 134 | 135 | for task in tasks: 136 | if task == 'MRPC': 137 | format_mrpc(args.data_dir, args.path_to_mrpc) 138 | elif task == 'diagnostic': 139 | download_diagnostic(args.data_dir) 140 | else: 141 | download_and_extract(task, args.data_dir) 142 | 143 | 144 | if __name__ == '__main__': 145 | sys.exit(main(sys.argv[1:])) 146 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 29 | TensorDataset) 30 | from torch.utils.data.distributed import DistributedSampler 31 | from tensorboardX import SummaryWriter 32 | from tqdm import tqdm, trange 33 | 34 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 35 | BertForSequenceClassification, BertTokenizer, 36 | RobertaConfig, 37 | RobertaForSequenceClassification, 38 | RobertaTokenizer, 39 | XLMConfig, XLMForSequenceClassification, 40 | XLMTokenizer, XLNetConfig, 41 | XLNetForSequenceClassification, 42 | XLNetTokenizer) 43 | 44 | from pytorch_transformers import AdamW, WarmupLinearSchedule 45 | 46 | from utils_glue import (compute_metrics, convert_examples_to_features, 47 | output_modes, processors) 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | ALL_MODELS = sum( 52 | (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), 53 | ()) 54 | 55 | MODEL_CLASSES = { 56 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 57 | 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 58 | 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 59 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 60 | } 61 | 62 | 63 | def set_seed(args): 64 | random.seed(args.seed) 65 | np.random.seed(args.seed) 66 | torch.manual_seed(args.seed) 67 | if args.n_gpu > 0: 68 | torch.cuda.manual_seed_all(args.seed) 69 | 70 | 71 | def train(args, train_dataset, model, tokenizer): 72 | """ Train the model """ 73 | if args.local_rank in [-1, 0]: 74 | tb_writer = SummaryWriter() 75 | 76 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 77 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 78 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 79 | 80 | if args.max_steps > 0: 81 | t_total = args.max_steps 82 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 83 | else: 84 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 85 | 86 | # Prepare optimizer and schedule (linear warmup and decay) 87 | no_decay = ['bias', 'LayerNorm.weight'] 88 | optimizer_grouped_parameters = [ 89 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 90 | 'weight_decay': args.weight_decay}, 91 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 92 | ] 93 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 94 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 95 | if args.fp16: 96 | try: 97 | from apex import amp 98 | except ImportError: 99 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 100 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 101 | 102 | # multi-gpu training (should be after apex fp16 initialization) 103 | if args.n_gpu > 1: 104 | model = torch.nn.DataParallel(model) 105 | 106 | # Distributed training (should be after apex fp16 initialization) 107 | if args.local_rank != -1: 108 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 109 | output_device=args.local_rank, 110 | find_unused_parameters=True) 111 | 112 | # Train! 113 | logger.info("***** Running training *****") 114 | logger.info(" Num examples = %d", len(train_dataset)) 115 | logger.info(" Num Epochs = %d", args.num_train_epochs) 116 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 117 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 118 | args.train_batch_size * args.gradient_accumulation_steps * ( 119 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 120 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 121 | logger.info(" Total optimization steps = %d", t_total) 122 | 123 | global_step = 0 124 | tr_loss, logging_loss = 0.0, 0.0 125 | model.zero_grad() 126 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 127 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 128 | for _ in train_iterator: 129 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 130 | for step, batch in enumerate(epoch_iterator): 131 | model.train() 132 | batch = tuple(t.to(args.device) for t in batch) 133 | inputs = {'input_ids': batch[0], 134 | 'attention_mask': batch[1], 135 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, 136 | # XLM don't use segment_ids 137 | 'labels': batch[3]} 138 | outputs = model(**inputs) 139 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) 140 | 141 | if args.n_gpu > 1: 142 | loss = loss.mean() # mean() to average on multi-gpu parallel training 143 | if args.gradient_accumulation_steps > 1: 144 | loss = loss / args.gradient_accumulation_steps 145 | 146 | if args.fp16: 147 | with amp.scale_loss(loss, optimizer) as scaled_loss: 148 | scaled_loss.backward() 149 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 150 | else: 151 | loss.backward() 152 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 153 | 154 | tr_loss += loss.item() 155 | if (step + 1) % args.gradient_accumulation_steps == 0: 156 | scheduler.step() # Update learning rate schedule 157 | optimizer.step() 158 | model.zero_grad() 159 | global_step += 1 160 | 161 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 162 | # Log metrics 163 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 164 | results = evaluate(args, model, tokenizer) 165 | for key, value in results.items(): 166 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 167 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 168 | tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) 169 | logging_loss = tr_loss 170 | 171 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 172 | # Save model checkpoint 173 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 174 | if not os.path.exists(output_dir): 175 | os.makedirs(output_dir) 176 | model_to_save = model.module if hasattr(model, 177 | 'module') else model # Take care of distributed/parallel training 178 | model_to_save.save_pretrained(output_dir) 179 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 180 | logger.info("Saving model checkpoint to %s", output_dir) 181 | 182 | if args.max_steps > 0 and global_step > args.max_steps: 183 | epoch_iterator.close() 184 | break 185 | if args.max_steps > 0 and global_step > args.max_steps: 186 | train_iterator.close() 187 | break 188 | 189 | if args.local_rank in [-1, 0]: 190 | tb_writer.close() 191 | 192 | return global_step, tr_loss / global_step 193 | 194 | 195 | def evaluate(args, model, tokenizer, prefix=""): 196 | # Loop to handle MNLI double evaluation (matched, mis-matched) 197 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 198 | eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,) 199 | 200 | results = {} 201 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 202 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 203 | 204 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 205 | os.makedirs(eval_output_dir) 206 | 207 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 208 | # Note that DistributedSampler samples randomly 209 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 210 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 211 | 212 | # Eval! 213 | logger.info("***** Running evaluation {} *****".format(prefix)) 214 | logger.info(" Num examples = %d", len(eval_dataset)) 215 | logger.info(" Batch size = %d", args.eval_batch_size) 216 | eval_loss = 0.0 217 | nb_eval_steps = 0 218 | preds = None 219 | out_label_ids = None 220 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 221 | model.eval() 222 | batch = tuple(t.to(args.device) for t in batch) 223 | 224 | with torch.no_grad(): 225 | inputs = {'input_ids': batch[0], 226 | 'attention_mask': batch[1], 227 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, 228 | # XLM and RoBERTa don't use segment_ids 229 | 'labels': batch[3]} 230 | outputs = model(**inputs) 231 | tmp_eval_loss, logits = outputs[:2] 232 | 233 | eval_loss += tmp_eval_loss.mean().item() 234 | nb_eval_steps += 1 235 | if preds is None: 236 | preds = logits.detach().cpu().numpy() 237 | out_label_ids = inputs['labels'].detach().cpu().numpy() 238 | else: 239 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 240 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 241 | 242 | eval_loss = eval_loss / nb_eval_steps 243 | if args.output_mode == "classification": 244 | preds = np.argmax(preds, axis=1) 245 | elif args.output_mode == "regression": 246 | preds = np.squeeze(preds) 247 | result = compute_metrics(eval_task, preds, out_label_ids) 248 | results.update(result) 249 | 250 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 251 | with open(output_eval_file, "w") as writer: 252 | logger.info("***** Eval results {} *****".format(prefix)) 253 | for key in sorted(result.keys()): 254 | logger.info(" %s = %s", key, str(result[key])) 255 | writer.write("%s = %s\n" % (key, str(result[key]))) 256 | 257 | return results 258 | 259 | 260 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 261 | if args.local_rank not in [-1, 0] and not evaluate: 262 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 263 | 264 | processor = processors[task]() 265 | output_mode = output_modes[task] 266 | # Load data features from cache or dataset file 267 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 268 | 'dev' if evaluate else 'train', 269 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 270 | str(args.max_seq_length), 271 | str(task))) 272 | if os.path.exists(cached_features_file): 273 | logger.info("Loading features from cached file %s", cached_features_file) 274 | features = torch.load(cached_features_file) 275 | else: 276 | logger.info("Creating features from dataset file at %s", args.data_dir) 277 | label_list = processor.get_labels() 278 | if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: 279 | # HACK(label indices are swapped in RoBERTa pretrained model) 280 | label_list[1], label_list[2] = label_list[2], label_list[1] 281 | examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples( 282 | args.data_dir) 283 | features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, 284 | cls_token_at_end=bool(args.model_type in ['xlnet']), 285 | # xlnet has a cls token at the end 286 | cls_token=tokenizer.cls_token, 287 | cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, 288 | sep_token=tokenizer.sep_token, 289 | sep_token_extra=bool(args.model_type in ['roberta']), 290 | # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 291 | pad_on_left=bool(args.model_type in ['xlnet']), 292 | # pad on the left for xlnet 293 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 294 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 295 | ) 296 | if args.local_rank in [-1, 0]: 297 | logger.info("Saving features into cached file %s", cached_features_file) 298 | torch.save(features, cached_features_file) 299 | 300 | if args.local_rank == 0 and not evaluate: 301 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 302 | 303 | # Convert to Tensors and build dataset 304 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 305 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 306 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 307 | if output_mode == "classification": 308 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 309 | elif output_mode == "regression": 310 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) 311 | 312 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 313 | return dataset 314 | 315 | 316 | def main(): 317 | parser = argparse.ArgumentParser() 318 | 319 | ## Required parameters 320 | parser.add_argument("--data_dir", default=None, type=str, required=True, 321 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 322 | parser.add_argument("--model_type", default=None, type=str, required=True, 323 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 324 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 325 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( 326 | ALL_MODELS)) 327 | parser.add_argument("--task_name", default=None, type=str, required=True, 328 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 329 | parser.add_argument("--output_dir", default=None, type=str, required=True, 330 | help="The output directory where the model predictions and checkpoints will be written.") 331 | 332 | ## Other parameters 333 | parser.add_argument("--config_name", default="", type=str, 334 | help="Pretrained config name or path if not the same as model_name") 335 | parser.add_argument("--tokenizer_name", default="", type=str, 336 | help="Pretrained tokenizer name or path if not the same as model_name") 337 | parser.add_argument("--cache_dir", default="", type=str, 338 | help="Where do you want to store the pre-trained models downloaded from s3") 339 | parser.add_argument("--max_seq_length", default=128, type=int, 340 | help="The maximum total input sequence length after tokenization. Sequences longer " 341 | "than this will be truncated, sequences shorter will be padded.") 342 | parser.add_argument("--do_train", action='store_true', 343 | help="Whether to run training.") 344 | parser.add_argument("--do_eval", action='store_true', 345 | help="Whether to run eval on the dev set.") 346 | parser.add_argument("--evaluate_during_training", action='store_true', 347 | help="Rul evaluation during training at each logging step.") 348 | parser.add_argument("--do_lower_case", action='store_true', 349 | help="Set this flag if you are using an uncased model.") 350 | 351 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 352 | help="Batch size per GPU/CPU for training.") 353 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 354 | help="Batch size per GPU/CPU for evaluation.") 355 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 356 | help="Number of updates steps to accumulate before performing a backward/update pass.") 357 | parser.add_argument("--learning_rate", default=5e-5, type=float, 358 | help="The initial learning rate for Adam.") 359 | parser.add_argument("--weight_decay", default=0.0, type=float, 360 | help="Weight deay if we apply some.") 361 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 362 | help="Epsilon for Adam optimizer.") 363 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 364 | help="Max gradient norm.") 365 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 366 | help="Total number of training epochs to perform.") 367 | parser.add_argument("--max_steps", default=-1, type=int, 368 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 369 | parser.add_argument("--warmup_steps", default=0, type=int, 370 | help="Linear warmup over warmup_steps.") 371 | 372 | parser.add_argument('--logging_steps', type=int, default=50, 373 | help="Log every X updates steps.") 374 | parser.add_argument('--save_steps', type=int, default=50, 375 | help="Save checkpoint every X updates steps.") 376 | parser.add_argument("--eval_all_checkpoints", action='store_true', 377 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 378 | parser.add_argument("--no_cuda", action='store_true', 379 | help="Avoid using CUDA when available") 380 | parser.add_argument('--overwrite_output_dir', action='store_true', 381 | help="Overwrite the content of the output directory") 382 | parser.add_argument('--overwrite_cache', action='store_true', 383 | help="Overwrite the cached training and evaluation sets") 384 | parser.add_argument('--seed', type=int, default=42, 385 | help="random seed for initialization") 386 | 387 | parser.add_argument('--fp16', action='store_true', 388 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 389 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 390 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 391 | "See details at https://nvidia.github.io/apex/amp.html") 392 | parser.add_argument("--local_rank", type=int, default=-1, 393 | help="For distributed training: local_rank") 394 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 395 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 396 | args = parser.parse_args() 397 | 398 | if os.path.exists(args.output_dir) and os.listdir( 399 | args.output_dir) and args.do_train and not args.overwrite_output_dir: 400 | raise ValueError( 401 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 402 | args.output_dir)) 403 | 404 | # Setup distant debugging if needed 405 | if args.server_ip and args.server_port: 406 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 407 | import ptvsd 408 | print("Waiting for debugger attach") 409 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 410 | ptvsd.wait_for_attach() 411 | 412 | # Setup CUDA, GPU & distributed training 413 | if args.local_rank == -1 or args.no_cuda: 414 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 415 | args.n_gpu = torch.cuda.device_count() 416 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 417 | torch.cuda.set_device(args.local_rank) 418 | device = torch.device("cuda", args.local_rank) 419 | torch.distributed.init_process_group(backend='nccl') 420 | args.n_gpu = 1 421 | args.device = device 422 | 423 | # Setup logging 424 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 425 | datefmt='%m/%d/%Y %H:%M:%S', 426 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 427 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 428 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 429 | 430 | # Set seed 431 | set_seed(args) 432 | 433 | # Prepare GLUE task 434 | args.task_name = args.task_name.lower() 435 | if args.task_name not in processors: 436 | raise ValueError("Task not found: %s" % (args.task_name)) 437 | processor = processors[args.task_name]() 438 | args.output_mode = output_modes[args.task_name] 439 | label_list = processor.get_labels() 440 | num_labels = len(label_list) 441 | 442 | # Load pretrained model and tokenizer 443 | if args.local_rank not in [-1, 0]: 444 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 445 | 446 | args.model_type = args.model_type.lower() 447 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 448 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 449 | num_labels=num_labels, finetuning_task=args.task_name) 450 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 451 | do_lower_case=args.do_lower_case) 452 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), 453 | config=config) 454 | 455 | if args.local_rank == 0: 456 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 457 | 458 | model.to(args.device) 459 | 460 | logger.info("Training/evaluation parameters %s", args) 461 | 462 | # Training 463 | if args.do_train: 464 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 465 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 466 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 467 | 468 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 469 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 470 | # Create output directory if needed 471 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 472 | os.makedirs(args.output_dir) 473 | 474 | logger.info("Saving model checkpoint to %s", args.output_dir) 475 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 476 | # They can then be reloaded using `from_pretrained()` 477 | model_to_save = model.module if hasattr(model, 478 | 'module') else model # Take care of distributed/parallel training 479 | model_to_save.save_pretrained(args.output_dir) 480 | tokenizer.save_pretrained(args.output_dir) 481 | 482 | # Good practice: save your training arguments together with the trained model 483 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 484 | 485 | # Load a trained model and vocabulary that you have fine-tuned 486 | model = model_class.from_pretrained(args.output_dir) 487 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 488 | model.to(args.device) 489 | 490 | # Evaluation 491 | results = {} 492 | if args.do_eval and args.local_rank in [-1, 0]: 493 | checkpoints = [args.output_dir] 494 | if args.eval_all_checkpoints: 495 | checkpoints = list( 496 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 497 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 498 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 499 | for checkpoint in checkpoints: 500 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 501 | model = model_class.from_pretrained(checkpoint) 502 | model.to(args.device) 503 | result = evaluate(args, model, tokenizer, prefix=global_step) 504 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 505 | results.update(result) 506 | 507 | return results 508 | 509 | 510 | if __name__ == "__main__": 511 | main() 512 | -------------------------------------------------------------------------------- /run_glue_distillation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | import numpy as np 26 | import torch 27 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 28 | TensorDataset) 29 | from torch.utils.data.distributed import DistributedSampler 30 | from tensorboardX import SummaryWriter 31 | 32 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, 33 | BertForSequenceClassification, BertTokenizer, 34 | RobertaConfig, 35 | RobertaForSequenceClassification, 36 | RobertaTokenizer, 37 | XLMConfig, XLMForSequenceClassification, 38 | XLMTokenizer, XLNetConfig, 39 | XLNetForSequenceClassification, 40 | XLNetTokenizer) 41 | 42 | from pytorch_transformers import AdamW, WarmupLinearSchedule 43 | 44 | from utils_glue import (compute_metrics, convert_examples_to_features, 45 | output_modes, processors) 46 | from distillation import PatientDistillation 47 | 48 | MODEL_CLASSES = { 49 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 50 | 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 51 | 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 52 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 53 | } 54 | 55 | 56 | def clone_weights(first_module, second_module): 57 | for first_param, second_param in zip(first_module.parameters(), second_module.parameters()): 58 | first_param.data = torch.clone(second_param.data) 59 | 60 | 61 | logger = logging.getLogger(__name__) 62 | 63 | 64 | def set_seed(args): 65 | random.seed(args.seed) 66 | np.random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | if args.n_gpu > 0: 69 | torch.cuda.manual_seed_all(args.seed) 70 | 71 | 72 | def train(args, train_dataset, t_model, s_model, order, d_criterion, tokenizer): 73 | """ Train the model """ 74 | if args.local_rank in [-1, 0]: 75 | tb_writer = SummaryWriter() 76 | 77 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 78 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 79 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 80 | 81 | if args.max_steps > 0: 82 | t_total = args.max_steps 83 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 84 | else: 85 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 86 | 87 | param_optimizer = list(s_model.named_parameters()) + list(d_criterion.named_parameters()) 88 | 89 | # Prepare optimizer and schedule (linear warmup and decay) 90 | no_decay = ['bias', 'LayerNorm.weight'] 91 | optimizer_grouped_parameters = [ 92 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 93 | 'weight_decay': args.weight_decay}, 94 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 95 | ] 96 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 97 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 98 | if args.fp16: 99 | try: 100 | from apex import amp 101 | except ImportError: 102 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 103 | model, optimizer = amp.initialize(s_model, optimizer, opt_level=args.fp16_opt_level) 104 | 105 | # Train! 106 | logger.info("***** Running training *****") 107 | logger.info(" Num examples = %d", len(train_dataset)) 108 | logger.info(" Num Epochs = %d", args.num_train_epochs) 109 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 110 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 111 | args.train_batch_size * args.gradient_accumulation_steps * ( 112 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 113 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 114 | logger.info(" Total optimization steps = %d", t_total) 115 | 116 | global_step = 0 117 | tr_loss = 0.0 118 | average_loss = 0.0 119 | train_avg_loss = 0.0 120 | soft_avg_loss = 0.0 121 | distill_avg_loss = 0.0 122 | 123 | s_model.zero_grad() 124 | train_iterator = range(int(args.num_train_epochs)) 125 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 126 | for epoch in train_iterator: 127 | for step, batch in enumerate(train_dataloader): 128 | s_model.train() 129 | t_model.eval() 130 | batch = tuple(t.to(args.device) for t in batch) 131 | input_ids, attention_mask, token_type_ids, labels = batch[0], batch[1], batch[2], batch[3] 132 | 133 | train_loss, soft_loss, distill_loss = d_criterion(t_model=t_model, 134 | s_model=s_model, 135 | order=order, 136 | input_ids=input_ids, 137 | token_type_ids=token_type_ids, 138 | attention_mask=attention_mask, 139 | labels=labels, 140 | args=args) 141 | 142 | loss = args.alpha * train_loss + (1 - args.alpha) * soft_loss + args.beta * distill_loss 143 | 144 | if args.n_gpu > 1: 145 | loss = loss.mean() # mean() to average on multi-gpu parallel training 146 | train_loss = train_loss.mean() 147 | soft_loss = soft_loss.mean() 148 | distill_loss = distill_loss.mean() 149 | 150 | if args.gradient_accumulation_steps > 1: 151 | loss = loss / args.gradient_accumulation_steps 152 | train_loss = train_loss / args.gradient_accumulation_steps 153 | soft_loss = soft_loss / args.gradient_accumulation_steps 154 | distill_loss = distill_loss / args.gradient_accumulation_steps 155 | 156 | if args.fp16: 157 | with amp.scale_loss(loss, optimizer) as scaled_loss: 158 | scaled_loss.backward() 159 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 160 | else: 161 | loss.backward() 162 | torch.nn.utils.clip_grad_norm_(s_model.parameters(), args.max_grad_norm) 163 | 164 | tr_loss += loss.item() 165 | average_loss += loss.item() 166 | train_avg_loss += train_loss.item() 167 | soft_avg_loss += soft_loss.item() 168 | distill_avg_loss += distill_loss.item() 169 | 170 | if (step + 1) % args.gradient_accumulation_steps == 0: 171 | optimizer.step() 172 | if args.schedule: 173 | scheduler.step() # Update learning rate schedule 174 | optimizer.zero_grad() 175 | global_step += 1 176 | 177 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 178 | # Log metrics 179 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 180 | results = evaluate(args, s_model, tokenizer) 181 | for key, value in results.items(): 182 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 183 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 184 | tb_writer.add_scalar('total loss', average_loss / args.logging_steps, global_step) 185 | tb_writer.add_scalar('train loss', train_avg_loss / args.logging_steps, global_step) 186 | tb_writer.add_scalar('soft loss', soft_avg_loss / args.logging_steps, global_step) 187 | tb_writer.add_scalar('distill loss', distill_avg_loss / args.logging_steps, global_step) 188 | 189 | average_loss = 0.0 190 | train_avg_loss = 0.0 191 | soft_avg_loss = 0.0 192 | distill_avg_loss = 0.0 193 | 194 | # Save model checkpoint 195 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(epoch + 1)) 196 | if not os.path.exists(output_dir): 197 | os.makedirs(output_dir) 198 | model_to_save = s_model.module if hasattr(s_model, 199 | 'module') else s_model # Take care of distributed/parallel training 200 | model_to_save.save_pretrained(output_dir) 201 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 202 | logger.info("Saving model checkpoint to %s", output_dir) 203 | 204 | return global_step, tr_loss / global_step 205 | 206 | 207 | def evaluate(args, model, tokenizer, prefix=""): 208 | # Loop to handle MNLI double evaluation (matched, mis-matched) 209 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 210 | eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,) 211 | 212 | results = {} 213 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 214 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 215 | 216 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 217 | os.makedirs(eval_output_dir) 218 | 219 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 220 | # Note that DistributedSampler samples randomly 221 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 222 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 223 | 224 | # Eval! 225 | logger.info("***** Running evaluation {} *****".format(prefix)) 226 | logger.info(" Num examples = %d", len(eval_dataset)) 227 | logger.info(" Batch size = %d", args.eval_batch_size) 228 | eval_loss = 0.0 229 | nb_eval_steps = 0 230 | preds = None 231 | out_label_ids = None 232 | for batch in eval_dataloader: 233 | model.eval() 234 | batch = tuple(t.to(args.device) for t in batch) 235 | 236 | with torch.no_grad(): 237 | inputs = {'input_ids': batch[0], 238 | 'attention_mask': batch[1], 239 | 'token_type_ids': batch[2], 240 | 'labels': batch[3]} 241 | input_ids, attention_mask, token_type_ids, labels = batch[0], batch[1], batch[2], batch[3] 242 | outputs = model(**inputs) 243 | tmp_eval_loss, logits = outputs[:2] 244 | 245 | eval_loss += tmp_eval_loss.mean().item() 246 | nb_eval_steps += 1 247 | if preds is None: 248 | preds = logits.detach().cpu().numpy() 249 | out_label_ids = inputs['labels'].detach().cpu().numpy() 250 | else: 251 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 252 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 253 | 254 | eval_loss = eval_loss / nb_eval_steps 255 | logging.info("eval_loss: %s", str(eval_loss)) 256 | if args.output_mode == "classification": 257 | preds = np.argmax(preds, axis=1) 258 | elif args.output_mode == "regression": 259 | preds = np.squeeze(preds) 260 | result = compute_metrics(eval_task, preds, out_label_ids) 261 | results.update(result) 262 | 263 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 264 | with open(output_eval_file, "w") as writer: 265 | logger.info("***** Eval results {} *****".format(prefix)) 266 | for key in sorted(result.keys()): 267 | logger.info(" %s = %s", key, str(result[key])) 268 | writer.write("%s = %s\n" % (key, str(result[key]))) 269 | 270 | return results 271 | 272 | 273 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 274 | processor = processors[task]() 275 | output_mode = output_modes[task] 276 | data_dir = args.data_dir 277 | 278 | logger.info("Creating features from dataset file at %s", data_dir) 279 | label_list = processor.get_labels() 280 | examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir) 281 | features = convert_examples_to_features(examples, label_list, args.max_seq_length, 282 | tokenizer, output_mode, 283 | cls_token_at_end=False, 284 | cls_token=tokenizer.cls_token, 285 | sep_token=tokenizer.sep_token, 286 | cls_token_segment_id=1, 287 | pad_on_left=False, 288 | pad_token_segment_id=0) 289 | 290 | # Convert to Tensors and build dataset 291 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 292 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 293 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 294 | if output_mode == "classification": 295 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 296 | elif output_mode == "regression": 297 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) 298 | 299 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 300 | return dataset 301 | 302 | 303 | def main(): 304 | parser = argparse.ArgumentParser() 305 | 306 | ## Required parameters 307 | parser.add_argument("--model_type", default='bert', type=str, 308 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 309 | parser.add_argument("--teacher_model", default="bert-base-uncased", type=str, 310 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 311 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 312 | parser.add_argument("--student_model", default="bert-base-uncased", type=str, 313 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 314 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 315 | parser.add_argument("--data_dir", default=None, type=str, required=True, 316 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 317 | parser.add_argument("--log_dir", default='logs', type=str, help="The log data dir.") 318 | parser.add_argument("--task_name", default=None, type=str, required=True, 319 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 320 | parser.add_argument("--output_dir", default='tmp/', type=str, 321 | help="The output directory where the model predictions and checkpoints will be written.") 322 | parser.add_argument('--num_hidden_layers', default=6, type=int) 323 | parser.add_argument("--alpha", default=0.5, type=float, 324 | help="Train loss ratio.") 325 | parser.add_argument("--beta", default=100.0, type=float, 326 | help="Distillation loss ratio.") 327 | parser.add_argument("--temperature", default=5.0, type=float, 328 | help="Distillation temperature for soft target.") 329 | parser.add_argument("--select", default="skip", type=str) 330 | parser.add_argument("--schedule", default=False, action='store_true') 331 | 332 | ## Other parameters 333 | parser.add_argument("--tokenizer_name", default="", type=str, 334 | help="Pretrained tokenizer name or path if not the same as model_name") 335 | parser.add_argument("--cache_dir", default="", type=str, 336 | help="Where do you want to store the pre-trained models downloaded from s3") 337 | parser.add_argument("--max_seq_length", default=128, type=int, 338 | help="The maximum total input sequence length after tokenization. Sequences longer " 339 | "than this will be truncated, sequences shorter will be padded.") 340 | parser.add_argument("--do_train", action='store_true', 341 | help="Whether to run training.") 342 | parser.add_argument("--do_eval", action='store_true', 343 | help="Whether to run eval on the dev set.") 344 | parser.add_argument("--evaluate_during_training", action='store_true', 345 | help="Rul evaluation during training at each logging step.") 346 | parser.add_argument("--do_lower_case", action='store_true', 347 | help="Set this flag if you are using an uncased model.") 348 | 349 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 350 | help="Batch size per GPU/CPU for training.") 351 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 352 | help="Batch size per GPU/CPU for evaluation.") 353 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 354 | help="Number of updates steps to accumulate before performing a backward/update pass.") 355 | parser.add_argument("--learning_rate", default=5e-5, type=float, 356 | help="The initial learning rate for Adam.") 357 | parser.add_argument("--weight_decay", default=0.0, type=float, 358 | help="Weight deay if we apply some.") 359 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 360 | help="Epsilon for Adam optimizer.") 361 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 362 | help="Max gradient norm.") 363 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 364 | help="Total number of training epochs to perform.") 365 | parser.add_argument("--max_steps", default=-1, type=int, 366 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 367 | parser.add_argument("--warmup_steps", default=0, type=int, 368 | help="Linear warmup over warmup_steps.") 369 | 370 | parser.add_argument('--logging_steps', type=int, default=50, 371 | help="Log every X updates steps.") 372 | parser.add_argument('--save_steps', type=int, default=1000, 373 | help="Save checkpoint every X updates steps.") 374 | parser.add_argument("--eval_all_checkpoints", action='store_true', 375 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 376 | parser.add_argument("--no_cuda", action='store_true', 377 | help="Avoid using CUDA when available") 378 | parser.add_argument('--overwrite_output_dir', action='store_true', 379 | help="Overwrite the content of the output directory") 380 | parser.add_argument('--overwrite_cache', action='store_true', 381 | help="Overwrite the cached training and evaluation sets") 382 | parser.add_argument('--seed', type=int, default=42, 383 | help="random seed for initialization") 384 | 385 | parser.add_argument('--fp16', action='store_true', 386 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 387 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 388 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 389 | "See details at https://nvidia.github.io/apex/amp.html") 390 | parser.add_argument("--local_rank", type=int, default=-1, 391 | help="For distributed training: local_rank") 392 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 393 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 394 | args = parser.parse_args() 395 | 396 | if os.path.exists(args.output_dir) and os.listdir( 397 | args.output_dir) and args.do_train and not args.overwrite_output_dir: 398 | raise ValueError( 399 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 400 | args.output_dir)) 401 | 402 | # Setup distant debugging if needed 403 | if args.server_ip and args.server_port: 404 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 405 | import ptvsd 406 | print("Waiting for debugger attach") 407 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 408 | ptvsd.wait_for_attach() 409 | 410 | # Setup CUDA, GPU & distributed training 411 | if args.local_rank == -1 or args.no_cuda: 412 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 413 | args.n_gpu = torch.cuda.device_count() 414 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 415 | torch.cuda.set_device(args.local_rank) 416 | device = torch.device("cuda", args.local_rank) 417 | torch.distributed.init_process_group(backend='nccl') 418 | args.n_gpu = 1 419 | args.device = device 420 | 421 | # Setup logging 422 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 423 | datefmt='%m/%d/%Y %H:%M:%S', 424 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 425 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 426 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 427 | 428 | # Set seed 429 | set_seed(args) 430 | 431 | # Prepare GLUE task 432 | args.task_name = args.task_name.lower() 433 | if args.task_name not in processors: 434 | raise ValueError("Task not found: %s" % (args.task_name)) 435 | processor = processors[args.task_name]() 436 | args.output_mode = output_modes[args.task_name] 437 | label_list = processor.get_labels() 438 | num_labels = len(label_list) 439 | 440 | # Load pretrained model and tokenizer 441 | if args.local_rank not in [-1, 0]: 442 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 443 | 444 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 445 | tokenizer = tokenizer_class.from_pretrained(args.teacher_model, do_lower_case=args.do_lower_case) 446 | 447 | t_config = BertConfig.from_pretrained(args.teacher_model) 448 | t_config.num_labels = num_labels 449 | t_config.finetuning_task = args.task_name 450 | t_config.output_hidden_states = True 451 | t_model = model_class.from_pretrained(args.teacher_model, config=t_config) 452 | 453 | s_config = BertConfig.from_pretrained(args.student_model) 454 | s_config.num_hidden_layers = args.num_hidden_layers 455 | s_config.num_labels = num_labels 456 | s_config.finetuning_task = args.task_name 457 | s_config.output_hidden_states = True 458 | s_model = model_class.from_pretrained(args.student_model, config=s_config) 459 | 460 | d_criterion = PatientDistillation(t_config, s_config) 461 | 462 | if args.local_rank == 0: 463 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 464 | # Distributed and parallel training 465 | t_model.to(args.device) 466 | s_model.to(args.device) 467 | d_criterion.to(args.device) 468 | 469 | if args.local_rank != -1: 470 | t_model = torch.nn.parallel.DistributedDataParallel(t_model, device_ids=[args.local_rank], 471 | output_device=args.local_rank, 472 | find_unused_parameters=True) 473 | s_model = torch.nn.parallel.DistributedDataParallel(s_model, device_ids=[args.local_rank], 474 | output_device=args.local_rank, 475 | find_unused_parameters=True) 476 | 477 | elif args.n_gpu > 1: 478 | t_model = torch.nn.DataParallel(t_model) 479 | s_model = torch.nn.DataParallel(s_model) 480 | 481 | logger.info("Training/evaluation parameters %s", args) 482 | 483 | # Training 484 | if args.do_train: 485 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 486 | 487 | if args.select == 'last': 488 | order = list(range(t_config.num_hidden_layers - 1)) 489 | order = torch.LongTensor(order[-(s_config.num_hidden_layers - 1):]) 490 | 491 | elif args.select == 'skip': 492 | order = list(range(t_config.num_hidden_layers - 1)) 493 | every_num = t_config.num_hidden_layers // s_config.num_hidden_layers 494 | order = torch.LongTensor(order[(every_num - 1)::every_num]) 495 | else: 496 | print('layer selection must be in [entropy, attn, dist, every]') 497 | order, _ = order[:(s_config.num_hidden_layers - 1)].sort() 498 | 499 | global_step, tr_loss = train(args, train_dataset, t_model, s_model, order, d_criterion, tokenizer) 500 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 501 | 502 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 503 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 504 | # Create output directory if needed 505 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 506 | os.makedirs(args.output_dir) 507 | 508 | logger.info("Saving model checkpoint to %s", args.output_dir) 509 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 510 | # They can then be reloaded using `from_pretrained()` 511 | model_to_save = s_model.module if hasattr(s_model, 512 | 'module') else s_model # Take care of distributed/parallel training 513 | model_to_save.save_pretrained(args.output_dir) 514 | tokenizer.save_pretrained(args.output_dir) 515 | 516 | # Good practice: save your training arguments together with the trained model 517 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 518 | 519 | # Load a trained model and vocabulary that you have fine-tuned 520 | s_model = model_class.from_pretrained(args.output_dir) 521 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 522 | s_model.to(args.device) 523 | 524 | # Evaluation 525 | results = {} 526 | if args.do_eval and args.local_rank in [-1, 0]: 527 | checkpoints = [args.output_dir] 528 | if args.eval_all_checkpoints: 529 | checkpoints = list( 530 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 531 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 532 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 533 | for checkpoint in checkpoints: 534 | epoch = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 535 | s_model = model_class.from_pretrained(checkpoint) 536 | s_model.to(args.device) 537 | result = evaluate(args, s_model, tokenizer, prefix=epoch) 538 | result = dict((k + '_{}'.format(epoch), v) for k, v in result.items()) 539 | results.update(result) 540 | 541 | return results 542 | 543 | 544 | if __name__ == "__main__": 545 | main() 546 | -------------------------------------------------------------------------------- /utils_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import csv 21 | import logging 22 | import os 23 | import sys 24 | from io import open 25 | 26 | from scipy.stats import pearsonr, spearmanr 27 | from sklearn.metrics import matthews_corrcoef, f1_score 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class InputExample(object): 33 | """A single training/test example for simple sequence classification.""" 34 | 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | """Constructs a InputExample. 37 | 38 | Args: 39 | guid: Unique id for the example. 40 | text_a: string. The untokenized text of the first sequence. For single 41 | sequence tasks, only this sequence must be specified. 42 | text_b: (Optional) string. The untokenized text of the second sequence. 43 | Only must be specified for sequence pair tasks. 44 | label: (Optional) string. The label of the example. This should be 45 | specified for train and dev examples, but not for test examples. 46 | """ 47 | self.guid = guid 48 | self.text_a = text_a 49 | self.text_b = text_b 50 | self.label = label 51 | 52 | 53 | class InputFeatures(object): 54 | """A single set of features of data.""" 55 | 56 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 57 | self.input_ids = input_ids 58 | self.input_mask = input_mask 59 | self.segment_ids = segment_ids 60 | self.label_id = label_id 61 | 62 | 63 | class DataProcessor(object): 64 | """Base class for data converters for sequence classification data sets.""" 65 | 66 | def get_train_examples(self, data_dir): 67 | """Gets a collection of `InputExample`s for the train set.""" 68 | raise NotImplementedError() 69 | 70 | def get_dev_examples(self, data_dir): 71 | """Gets a collection of `InputExample`s for the dev set.""" 72 | raise NotImplementedError() 73 | 74 | def get_labels(self): 75 | """Gets the list of labels for this data set.""" 76 | raise NotImplementedError() 77 | 78 | @classmethod 79 | def _read_tsv(cls, input_file, quotechar=None): 80 | """Reads a tab separated value file.""" 81 | with open(input_file, "r", encoding="utf-8-sig") as f: 82 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 83 | lines = [] 84 | for line in reader: 85 | if sys.version_info[0] == 2: 86 | line = list(unicode(cell, 'utf-8') for cell in line) 87 | lines.append(line) 88 | return lines 89 | 90 | 91 | class MrpcProcessor(DataProcessor): 92 | """Processor for the MRPC data set (GLUE version).""" 93 | 94 | def get_train_examples(self, data_dir): 95 | """See base class.""" 96 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 97 | return self._create_examples( 98 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 99 | 100 | def get_dev_examples(self, data_dir): 101 | """See base class.""" 102 | return self._create_examples( 103 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 104 | 105 | def get_labels(self): 106 | """See base class.""" 107 | return ["0", "1"] 108 | 109 | def _create_examples(self, lines, set_type): 110 | """Creates examples for the training and dev sets.""" 111 | examples = [] 112 | for (i, line) in enumerate(lines): 113 | if i == 0: 114 | continue 115 | guid = "%s-%s" % (set_type, i) 116 | text_a = line[3] 117 | text_b = line[4] 118 | label = line[0] 119 | examples.append( 120 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 121 | return examples 122 | 123 | 124 | class MnliProcessor(DataProcessor): 125 | """Processor for the MultiNLI data set (GLUE version).""" 126 | 127 | def get_train_examples(self, data_dir): 128 | """See base class.""" 129 | return self._create_examples( 130 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 131 | 132 | def get_dev_examples(self, data_dir): 133 | """See base class.""" 134 | return self._create_examples( 135 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 136 | "dev_matched") 137 | 138 | def get_labels(self): 139 | """See base class.""" 140 | return ["contradiction", "entailment", "neutral"] 141 | 142 | def _create_examples(self, lines, set_type): 143 | """Creates examples for the training and dev sets.""" 144 | examples = [] 145 | for (i, line) in enumerate(lines): 146 | if i == 0: 147 | continue 148 | guid = "%s-%s" % (set_type, line[0]) 149 | text_a = line[8] 150 | text_b = line[9] 151 | label = line[-1] 152 | examples.append( 153 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 154 | return examples 155 | 156 | 157 | class MnliMismatchedProcessor(MnliProcessor): 158 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 159 | 160 | def get_dev_examples(self, data_dir): 161 | """See base class.""" 162 | return self._create_examples( 163 | self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), 164 | "dev_matched") 165 | 166 | 167 | class ColaProcessor(DataProcessor): 168 | """Processor for the CoLA data set (GLUE version).""" 169 | 170 | def get_train_examples(self, data_dir): 171 | """See base class.""" 172 | return self._create_examples( 173 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 174 | 175 | def get_dev_examples(self, data_dir): 176 | """See base class.""" 177 | return self._create_examples( 178 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 179 | 180 | def get_labels(self): 181 | """See base class.""" 182 | return ["0", "1"] 183 | 184 | def _create_examples(self, lines, set_type): 185 | """Creates examples for the training and dev sets.""" 186 | examples = [] 187 | for (i, line) in enumerate(lines): 188 | guid = "%s-%s" % (set_type, i) 189 | text_a = line[3] 190 | label = line[1] 191 | examples.append( 192 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 193 | return examples 194 | 195 | 196 | class Sst2Processor(DataProcessor): 197 | """Processor for the SST-2 data set (GLUE version).""" 198 | 199 | def get_train_examples(self, data_dir): 200 | """See base class.""" 201 | return self._create_examples( 202 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 203 | 204 | def get_dev_examples(self, data_dir): 205 | """See base class.""" 206 | return self._create_examples( 207 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 208 | 209 | def get_labels(self): 210 | """See base class.""" 211 | return ["0", "1"] 212 | 213 | def _create_examples(self, lines, set_type): 214 | """Creates examples for the training and dev sets.""" 215 | examples = [] 216 | for (i, line) in enumerate(lines): 217 | if i == 0: 218 | continue 219 | guid = "%s-%s" % (set_type, i) 220 | text_a = line[0] 221 | label = line[1] 222 | examples.append( 223 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 224 | return examples 225 | 226 | 227 | class StsbProcessor(DataProcessor): 228 | """Processor for the STS-B data set (GLUE version).""" 229 | 230 | def get_train_examples(self, data_dir): 231 | """See base class.""" 232 | return self._create_examples( 233 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 234 | 235 | def get_dev_examples(self, data_dir): 236 | """See base class.""" 237 | return self._create_examples( 238 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 239 | 240 | def get_labels(self): 241 | """See base class.""" 242 | return [None] 243 | 244 | def _create_examples(self, lines, set_type): 245 | """Creates examples for the training and dev sets.""" 246 | examples = [] 247 | for (i, line) in enumerate(lines): 248 | if i == 0: 249 | continue 250 | guid = "%s-%s" % (set_type, line[0]) 251 | text_a = line[7] 252 | text_b = line[8] 253 | label = line[-1] 254 | examples.append( 255 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 256 | return examples 257 | 258 | 259 | class QqpProcessor(DataProcessor): 260 | """Processor for the QQP data set (GLUE version).""" 261 | 262 | def get_train_examples(self, data_dir): 263 | """See base class.""" 264 | return self._create_examples( 265 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 266 | 267 | def get_dev_examples(self, data_dir): 268 | """See base class.""" 269 | return self._create_examples( 270 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 271 | 272 | def get_labels(self): 273 | """See base class.""" 274 | return ["0", "1"] 275 | 276 | def _create_examples(self, lines, set_type): 277 | """Creates examples for the training and dev sets.""" 278 | examples = [] 279 | for (i, line) in enumerate(lines): 280 | if i == 0: 281 | continue 282 | guid = "%s-%s" % (set_type, line[0]) 283 | try: 284 | text_a = line[3] 285 | text_b = line[4] 286 | label = line[5] 287 | except IndexError: 288 | continue 289 | examples.append( 290 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 291 | return examples 292 | 293 | 294 | class QnliProcessor(DataProcessor): 295 | """Processor for the QNLI data set (GLUE version).""" 296 | 297 | def get_train_examples(self, data_dir): 298 | """See base class.""" 299 | return self._create_examples( 300 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 301 | 302 | def get_dev_examples(self, data_dir): 303 | """See base class.""" 304 | return self._create_examples( 305 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 306 | "dev_matched") 307 | 308 | def get_labels(self): 309 | """See base class.""" 310 | return ["entailment", "not_entailment"] 311 | 312 | def _create_examples(self, lines, set_type): 313 | """Creates examples for the training and dev sets.""" 314 | examples = [] 315 | for (i, line) in enumerate(lines): 316 | if i == 0: 317 | continue 318 | guid = "%s-%s" % (set_type, line[0]) 319 | text_a = line[1] 320 | text_b = line[2] 321 | label = line[-1] 322 | examples.append( 323 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 324 | return examples 325 | 326 | 327 | class RteProcessor(DataProcessor): 328 | """Processor for the RTE data set (GLUE version).""" 329 | 330 | def get_train_examples(self, data_dir): 331 | """See base class.""" 332 | return self._create_examples( 333 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 334 | 335 | def get_dev_examples(self, data_dir): 336 | """See base class.""" 337 | return self._create_examples( 338 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 339 | 340 | def get_labels(self): 341 | """See base class.""" 342 | return ["entailment", "not_entailment"] 343 | 344 | def _create_examples(self, lines, set_type): 345 | """Creates examples for the training and dev sets.""" 346 | examples = [] 347 | for (i, line) in enumerate(lines): 348 | if i == 0: 349 | continue 350 | guid = "%s-%s" % (set_type, line[0]) 351 | text_a = line[1] 352 | text_b = line[2] 353 | label = line[-1] 354 | examples.append( 355 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 356 | return examples 357 | 358 | 359 | class WnliProcessor(DataProcessor): 360 | """Processor for the WNLI data set (GLUE version).""" 361 | 362 | def get_train_examples(self, data_dir): 363 | """See base class.""" 364 | return self._create_examples( 365 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 366 | 367 | def get_dev_examples(self, data_dir): 368 | """See base class.""" 369 | return self._create_examples( 370 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 371 | 372 | def get_labels(self): 373 | """See base class.""" 374 | return ["0", "1"] 375 | 376 | def _create_examples(self, lines, set_type): 377 | """Creates examples for the training and dev sets.""" 378 | examples = [] 379 | for (i, line) in enumerate(lines): 380 | if i == 0: 381 | continue 382 | guid = "%s-%s" % (set_type, line[0]) 383 | text_a = line[1] 384 | text_b = line[2] 385 | label = line[-1] 386 | examples.append( 387 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 388 | return examples 389 | 390 | 391 | def convert_examples_to_features(examples, label_list, max_seq_length, 392 | tokenizer, output_mode, 393 | cls_token_at_end=False, 394 | cls_token='[CLS]', 395 | cls_token_segment_id=1, 396 | sep_token='[SEP]', 397 | sep_token_extra=False, 398 | pad_on_left=False, 399 | pad_token=0, 400 | pad_token_segment_id=0, 401 | sequence_a_segment_id=0, 402 | sequence_b_segment_id=1, 403 | mask_padding_with_zero=True): 404 | """ Loads a data file into a list of `InputBatch`s 405 | `cls_token_at_end` define the location of the CLS token: 406 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 407 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 408 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 409 | """ 410 | 411 | label_map = {label : i for i, label in enumerate(label_list)} 412 | 413 | features = [] 414 | for (ex_index, example) in enumerate(examples): 415 | if ex_index % 10000 == 0: 416 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 417 | 418 | tokens_a = tokenizer.tokenize(example.text_a) 419 | 420 | tokens_b = None 421 | if example.text_b: 422 | tokens_b = tokenizer.tokenize(example.text_b) 423 | # Modifies `tokens_a` and `tokens_b` in place so that the total 424 | # length is less than the specified length. 425 | # Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa. 426 | special_tokens_count = 4 if sep_token_extra else 3 427 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) 428 | else: 429 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 430 | special_tokens_count = 3 if sep_token_extra else 2 431 | if len(tokens_a) > max_seq_length - special_tokens_count: 432 | tokens_a = tokens_a[:(max_seq_length - special_tokens_count)] 433 | 434 | # The convention in BERT is: 435 | # (a) For sequence pairs: 436 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 437 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 438 | # (b) For single sequences: 439 | # tokens: [CLS] the dog is hairy . [SEP] 440 | # type_ids: 0 0 0 0 0 0 0 441 | # 442 | # Where "type_ids" are used to indicate whether this is the first 443 | # sequence or the second sequence. The embedding vectors for `type=0` and 444 | # `type=1` were learned during pre-training and are added to the wordpiece 445 | # embedding vector (and position vector). This is not *strictly* necessary 446 | # since the [SEP] token unambiguously separates the sequences, but it makes 447 | # it easier for the model to learn the concept of sequences. 448 | # 449 | # For classification tasks, the first vector (corresponding to [CLS]) is 450 | # used as as the "sentence vector". Note that this only makes sense because 451 | # the entire model is fine-tuned. 452 | tokens = tokens_a + [sep_token] 453 | if sep_token_extra: 454 | # roberta uses an extra separator b/w pairs of sentences 455 | tokens += [sep_token] 456 | segment_ids = [sequence_a_segment_id] * len(tokens) 457 | 458 | if tokens_b: 459 | tokens += tokens_b + [sep_token] 460 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) 461 | 462 | if cls_token_at_end: 463 | tokens = tokens + [cls_token] 464 | segment_ids = segment_ids + [cls_token_segment_id] 465 | else: 466 | tokens = [cls_token] + tokens 467 | segment_ids = [cls_token_segment_id] + segment_ids 468 | 469 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 470 | 471 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 472 | # tokens are attended to. 473 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 474 | 475 | # Zero-pad up to the sequence length. 476 | padding_length = max_seq_length - len(input_ids) 477 | if pad_on_left: 478 | input_ids = ([pad_token] * padding_length) + input_ids 479 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 480 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 481 | else: 482 | input_ids = input_ids + ([pad_token] * padding_length) 483 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 484 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) 485 | 486 | assert len(input_ids) == max_seq_length 487 | assert len(input_mask) == max_seq_length 488 | assert len(segment_ids) == max_seq_length 489 | 490 | if output_mode == "classification": 491 | label_id = label_map[example.label] 492 | elif output_mode == "regression": 493 | label_id = float(example.label) 494 | else: 495 | raise KeyError(output_mode) 496 | 497 | if ex_index < 5: 498 | logger.info("*** Example ***") 499 | logger.info("guid: %s" % (example.guid)) 500 | logger.info("tokens: %s" % " ".join( 501 | [str(x) for x in tokens])) 502 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 503 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 504 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 505 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 506 | 507 | features.append( 508 | InputFeatures(input_ids=input_ids, 509 | input_mask=input_mask, 510 | segment_ids=segment_ids, 511 | label_id=label_id)) 512 | return features 513 | 514 | 515 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 516 | """Truncates a sequence pair in place to the maximum length.""" 517 | 518 | # This is a simple heuristic which will always truncate the longer sequence 519 | # one token at a time. This makes more sense than truncating an equal percent 520 | # of tokens from each, since if one sequence is very short then each token 521 | # that's truncated likely contains more information than a longer sequence. 522 | while True: 523 | total_length = len(tokens_a) + len(tokens_b) 524 | if total_length <= max_length: 525 | break 526 | if len(tokens_a) > len(tokens_b): 527 | tokens_a.pop() 528 | else: 529 | tokens_b.pop() 530 | 531 | 532 | def simple_accuracy(preds, labels): 533 | return (preds == labels).mean() 534 | 535 | 536 | def acc_and_f1(preds, labels): 537 | acc = simple_accuracy(preds, labels) 538 | f1 = f1_score(y_true=labels, y_pred=preds) 539 | return { 540 | "acc": acc, 541 | "f1": f1, 542 | "acc_and_f1": (acc + f1) / 2, 543 | } 544 | 545 | 546 | def pearson_and_spearman(preds, labels): 547 | pearson_corr = pearsonr(preds, labels)[0] 548 | spearman_corr = spearmanr(preds, labels)[0] 549 | return { 550 | "pearson": pearson_corr, 551 | "spearmanr": spearman_corr, 552 | "corr": (pearson_corr + spearman_corr) / 2, 553 | } 554 | 555 | 556 | def compute_metrics(task_name, preds, labels): 557 | assert len(preds) == len(labels) 558 | if task_name == "cola": 559 | return {"mcc": matthews_corrcoef(labels, preds)} 560 | elif task_name == "sst-2": 561 | return {"acc": simple_accuracy(preds, labels)} 562 | elif task_name == "mrpc": 563 | return acc_and_f1(preds, labels) 564 | elif task_name == "sts-b": 565 | return pearson_and_spearman(preds, labels) 566 | elif task_name == "qqp": 567 | return acc_and_f1(preds, labels) 568 | elif task_name == "mnli": 569 | return {"acc": simple_accuracy(preds, labels)} 570 | elif task_name == "mnli-mm": 571 | return {"acc": simple_accuracy(preds, labels)} 572 | elif task_name == "qnli": 573 | return {"acc": simple_accuracy(preds, labels)} 574 | elif task_name == "rte": 575 | return {"acc": simple_accuracy(preds, labels)} 576 | elif task_name == "wnli": 577 | return {"acc": simple_accuracy(preds, labels)} 578 | else: 579 | raise KeyError(task_name) 580 | 581 | processors = { 582 | "cola": ColaProcessor, 583 | "mnli": MnliProcessor, 584 | "mnli-mm": MnliMismatchedProcessor, 585 | "mrpc": MrpcProcessor, 586 | "sst-2": Sst2Processor, 587 | "sts-b": StsbProcessor, 588 | "qqp": QqpProcessor, 589 | "qnli": QnliProcessor, 590 | "rte": RteProcessor, 591 | "wnli": WnliProcessor, 592 | } 593 | 594 | output_modes = { 595 | "cola": "classification", 596 | "mnli": "classification", 597 | "mnli-mm": "classification", 598 | "mrpc": "classification", 599 | "sst-2": "classification", 600 | "sts-b": "regression", 601 | "qqp": "classification", 602 | "qnli": "classification", 603 | "rte": "classification", 604 | "wnli": "classification", 605 | } 606 | 607 | GLUE_TASKS_NUM_LABELS = { 608 | "cola": 2, 609 | "mnli": 3, 610 | "mrpc": 2, 611 | "sst-2": 2, 612 | "sts-b": 1, 613 | "qqp": 2, 614 | "qnli": 2, 615 | "rte": 2, 616 | "wnli": 2, 617 | } 618 | --------------------------------------------------------------------------------