├── README.md ├── data └── ubuntu_data_example │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── nup_lm_finetuning ├── finetune_on_pregenerated.py ├── make_lm_data.py └── pregenerate_training_data_NUP.py ├── pytorch_pretrained_bert ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── file_utils.cpython-36.pyc │ ├── modeling.cpython-36.pyc │ ├── optimization.cpython-36.pyc │ └── tokenization.cpython-36.pyc ├── file_utils.py ├── modeling.py ├── optimization.py └── tokenization.py ├── run_FE_DAtt_RNN.py ├── run_IE_CoAtt_CNN_DCM.py ├── run_IE_DAtt_CNN.py ├── run_IE_DAtt_RNN.py ├── run_IE_MHAtt_CNN.py └── run_bert.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | Codes for the paper **Deep Context Modeling for Multi-turn Response Selection in Dialogue Systems** 4 | 5 | ### Instruction 6 | 7 | #### Dataset 8 | 9 | Since the datasets are quite large that exceed the Github file size limit, we only upload part of the data as examples. Do not forget to change to the data directory after you download the full data. 10 | 1. Datasets can be download from [Ubuntu dataset](https://www.dropbox.com/s/2fdn26rj6h9bpvl/ubuntu_data.zip?dl=0), [Douban dataset](https://www.dropbox.com/s/90t0qtji9ow20ca/DoubanConversaionCorpus.zip?dl=0), and [ECD dataset](https://drive.google.com/file/d/154J-neBo20ABtSmJDvm7DK0eTuieAuvw/view?usp=sharing). 11 | 2. Unzip the dataset and put data directory into `data/`. 12 | 13 | #### NUP pre-training 14 | 15 | The steps to further pre-training BERT with NUP strategy is introduced as follows. We also provide the language model trained on Ubuntu training set. Our trained nup language model on Ubuntu training set can be accessed here. 16 | 17 | https://www.dropbox.com/s/d1earb9ta6drqoy/ubuntu_nup_bert_base.zip?dl=0 18 | 19 | You can unzip the model and put it into `ubuntu_nup_bert_base` directory then use it during model training. 20 | 21 | 1. Run `make_lm_data.py` to process the original training data format into a single file with one sentence(utterance) per line, and one blank line between documents(dialog context). 22 | 23 | ``` 24 | python nup_lm_finetuning/make_lm_data.py \ 25 | --data_file ../data/ubuntu_data/train.txt \ 26 | --output_file data/ubuntu_data/lm_train.txt 27 | ``` 28 | 29 | 2. Use `pregenerate_training_data_NUP.py` to pre-process the data into training examples following the NUP methodology. 30 | 31 | ``` 32 | python nup_lm_finetuning/pregenerate_training_data_NUP.py \ 33 | --train_corpus ../data/ubuntu_data/lm_train.txt \ 34 | --bert_model bert-base-uncased \ 35 | --do_lower_case \ 36 | --output_dir ../data/ubuntu_data/ubuntu_NUP \ 37 | --epochs_to_generate 1 \ 38 | --max_seq_len 256 39 | ``` 40 | 41 | 3. Train on the pregenerated data using `finetune_on_pregenerated.py`, and pointing it to the folder created by `pregenerate_training_data.py` . 42 | 43 | ``` 44 | python nup_lm_finetuning/finetune_on_pregenerated.py \ 45 | --pregenerated_data ../data/ubuntu_data/ubuntu_NUP \ 46 | --bert_model bert-base-uncased \ 47 | --train_batch_size 12 \ 48 | --reduce_memory \ 49 | --do_lower_case \ 50 | --output_dir ubuntu_finetuned_lm \ 51 | --epochs 1 52 | ``` 53 | 54 | #### Model training 55 | 56 | 1. Train a model 57 | 58 | Change the `--bert_model` parameter to the path of the NUP-pretrained language model if need. Example as ` ubuntu_nup_bert_base` for Ubuntu dataset. 59 | 60 | An example: 61 | 62 | ``` 63 | python run_IE_CoAtt_CNN_DCM.py \ 64 | --data_dir data/ubuntu_data \ 65 | --task_name ubuntu \ 66 | --train_batch_size 64 \ 67 | --eval_batch_size 64 \ 68 | --max_seq_length 384 \ 69 | --max_utterance_num 20 \ 70 | --bert_model bert-base-uncased \ 71 | --cache_flag ubuntu \ 72 | --learning_rate 3e-5 \ 73 | --num_train_epochs 2 \ 74 | --do_train \ # set to do_eval when evaluation on test set 75 | --do_lower_case \ 76 | --output_dir experiments/ubuntu 77 | ``` 78 | 79 | 2. Evaluation 80 | 81 | ``` 82 | python run_IE_CoAtt_CNN_DCM.py \ 83 | --data_dir data/ubuntu_data \ 84 | --task_name ubuntu \ 85 | --train_batch_size 64 \ 86 | --eval_batch_size 64 \ 87 | --max_seq_length 384 \ 88 | --max_utterance_num 20 \ 89 | --bert_model bert-base-uncased \ 90 | --cache_flag ubuntu \ 91 | --learning_rate 3e-5 \ 92 | --num_train_epochs 2 \ 93 | --do_eval \ 94 | --do_lower_case \ 95 | --output_dir experiments/ubuntu 96 | ``` 97 | 98 | ### Requirements 99 | 100 | Python 3.6 + Pytorch 1.0.1 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /nup_lm_finetuning/finetune_on_pregenerated.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | import torch 4 | import logging 5 | import json 6 | import random 7 | import numpy as np 8 | from collections import namedtuple 9 | from tempfile import TemporaryDirectory 10 | 11 | from torch.utils.data import DataLoader, Dataset, RandomSampler 12 | from torch.utils.data.distributed import DistributedSampler 13 | from tqdm import tqdm 14 | 15 | from pytorch_pretrained_bert.modeling import BertForPreTraining 16 | from pytorch_pretrained_bert.tokenization import BertTokenizer 17 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 18 | 19 | InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next") 20 | 21 | log_format = '%(asctime)-10s: %(message)s' 22 | logging.basicConfig(level=logging.INFO, format=log_format) 23 | 24 | 25 | def convert_example_to_features(example, tokenizer, max_seq_length): 26 | tokens = example["tokens"] 27 | segment_ids = example["segment_ids"] 28 | is_random_next = example["is_random_next"] 29 | masked_lm_positions = example["masked_lm_positions"] 30 | masked_lm_labels = example["masked_lm_labels"] 31 | 32 | assert len(tokens) == len(segment_ids) <= max_seq_length # The preprocessed data should be already truncated 33 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 34 | masked_label_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels) 35 | 36 | input_array = np.zeros(max_seq_length, dtype=np.int) 37 | input_array[:len(input_ids)] = input_ids 38 | 39 | mask_array = np.zeros(max_seq_length, dtype=np.bool) 40 | mask_array[:len(input_ids)] = 1 41 | 42 | segment_array = np.zeros(max_seq_length, dtype=np.bool) 43 | segment_array[:len(segment_ids)] = segment_ids 44 | 45 | lm_label_array = np.full(max_seq_length, dtype=np.int, fill_value=-1) 46 | lm_label_array[masked_lm_positions] = masked_label_ids 47 | 48 | features = InputFeatures(input_ids=input_array, 49 | input_mask=mask_array, 50 | segment_ids=segment_array, 51 | lm_label_ids=lm_label_array, 52 | is_next=is_random_next) 53 | return features 54 | 55 | 56 | class PregeneratedDataset(Dataset): 57 | def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False): 58 | self.vocab = tokenizer.vocab 59 | self.tokenizer = tokenizer 60 | self.epoch = epoch 61 | self.data_epoch = epoch % num_data_epochs 62 | data_file = training_path / f"epoch_{self.data_epoch}.json" 63 | metrics_file = training_path / f"epoch_{self.data_epoch}_metrics.json" 64 | assert data_file.is_file() and metrics_file.is_file() 65 | metrics = json.loads(metrics_file.read_text()) 66 | num_samples = metrics['num_training_examples'] 67 | seq_len = metrics['max_seq_len'] 68 | self.temp_dir = None 69 | self.working_dir = None 70 | if reduce_memory: 71 | self.temp_dir = TemporaryDirectory() 72 | self.working_dir = Path(self.temp_dir.name) 73 | input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap', 74 | mode='w+', dtype=np.int32, shape=(num_samples, seq_len)) 75 | input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap', 76 | shape=(num_samples, seq_len), mode='w+', dtype=np.bool) 77 | segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap', 78 | shape=(num_samples, seq_len), mode='w+', dtype=np.bool) 79 | lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap', 80 | shape=(num_samples, seq_len), mode='w+', dtype=np.int32) 81 | lm_label_ids[:] = -1 82 | is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap', 83 | shape=(num_samples,), mode='w+', dtype=np.bool) 84 | else: 85 | input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32) 86 | input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) 87 | segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool) 88 | lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1) 89 | is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool) 90 | logging.info(f"Loading training examples for epoch {epoch}") 91 | with data_file.open() as f: 92 | for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")): 93 | line = line.strip() 94 | example = json.loads(line) 95 | features = convert_example_to_features(example, tokenizer, seq_len) 96 | input_ids[i] = features.input_ids 97 | segment_ids[i] = features.segment_ids 98 | input_masks[i] = features.input_mask 99 | lm_label_ids[i] = features.lm_label_ids 100 | is_nexts[i] = features.is_next 101 | assert i == num_samples - 1 # Assert that the sample count metric was true 102 | logging.info("Loading complete!") 103 | self.num_samples = num_samples 104 | self.seq_len = seq_len 105 | self.input_ids = input_ids 106 | self.input_masks = input_masks 107 | self.segment_ids = segment_ids 108 | self.lm_label_ids = lm_label_ids 109 | self.is_nexts = is_nexts 110 | 111 | def __len__(self): 112 | return self.num_samples 113 | 114 | def __getitem__(self, item): 115 | return (torch.tensor(self.input_ids[item].astype(np.int64)), 116 | torch.tensor(self.input_masks[item].astype(np.int64)), 117 | torch.tensor(self.segment_ids[item].astype(np.int64)), 118 | torch.tensor(self.lm_label_ids[item].astype(np.int64)), 119 | torch.tensor(self.is_nexts[item].astype(np.int64))) 120 | 121 | 122 | def main(): 123 | parser = ArgumentParser() 124 | parser.add_argument('--pregenerated_data', type=Path, required=True) 125 | parser.add_argument('--output_dir', type=Path, required=True) 126 | parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " 127 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 128 | parser.add_argument("--do_lower_case", action="store_true") 129 | parser.add_argument("--reduce_memory", action="store_true", 130 | help="Store training data as on-disc memmaps to massively reduce memory usage") 131 | 132 | parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for") 133 | parser.add_argument("--local_rank", 134 | type=int, 135 | default=-1, 136 | help="local_rank for distributed training on gpus") 137 | parser.add_argument("--no_cuda", 138 | action='store_true', 139 | help="Whether not to use CUDA when available") 140 | parser.add_argument('--gradient_accumulation_steps', 141 | type=int, 142 | default=1, 143 | help="Number of updates steps to accumulate before performing a backward/update pass.") 144 | parser.add_argument("--train_batch_size", 145 | default=32, 146 | type=int, 147 | help="Total batch size for training.") 148 | parser.add_argument('--fp16', 149 | action='store_true', 150 | help="Whether to use 16-bit float precision instead of 32-bit") 151 | parser.add_argument('--loss_scale', 152 | type=float, default=0, 153 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 154 | "0 (default value): dynamic loss scaling.\n" 155 | "Positive power of 2: static loss scaling value.\n") 156 | parser.add_argument("--warmup_proportion", 157 | default=0.1, 158 | type=float, 159 | help="Proportion of training to perform linear learning rate warmup for. " 160 | "E.g., 0.1 = 10%% of training.") 161 | parser.add_argument("--learning_rate", 162 | default=3e-5, 163 | type=float, 164 | help="The initial learning rate for Adam.") 165 | parser.add_argument('--seed', 166 | type=int, 167 | default=42, 168 | help="random seed for initialization") 169 | args = parser.parse_args() 170 | 171 | assert args.pregenerated_data.is_dir(), \ 172 | "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!" 173 | 174 | samples_per_epoch = [] 175 | for i in range(args.epochs): 176 | epoch_file = args.pregenerated_data / f"epoch_{i}.json" 177 | metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json" 178 | if epoch_file.is_file() and metrics_file.is_file(): 179 | metrics = json.loads(metrics_file.read_text()) 180 | samples_per_epoch.append(metrics['num_training_examples']) 181 | else: 182 | if i == 0: 183 | exit("No training data was found!") 184 | print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).") 185 | print("This script will loop over the available data, but training diversity may be negatively impacted.") 186 | num_data_epochs = i 187 | break 188 | else: 189 | num_data_epochs = args.epochs 190 | 191 | if args.local_rank == -1 or args.no_cuda: 192 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 193 | n_gpu = torch.cuda.device_count() 194 | else: 195 | torch.cuda.set_device(args.local_rank) 196 | device = torch.device("cuda", args.local_rank) 197 | n_gpu = 1 198 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 199 | torch.distributed.init_process_group(backend='nccl') 200 | logging.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 201 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 202 | 203 | if args.gradient_accumulation_steps < 1: 204 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 205 | args.gradient_accumulation_steps)) 206 | 207 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 208 | 209 | random.seed(args.seed) 210 | np.random.seed(args.seed) 211 | torch.manual_seed(args.seed) 212 | if n_gpu > 0: 213 | torch.cuda.manual_seed_all(args.seed) 214 | 215 | if args.output_dir.is_dir() and list(args.output_dir.iterdir()): 216 | logging.warning(f"Output directory ({args.output_dir}) already exists and is not empty!") 217 | args.output_dir.mkdir(parents=True, exist_ok=True) 218 | 219 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 220 | 221 | total_train_examples = 0 222 | for i in range(args.epochs): 223 | # The modulo takes into account the fact that we may loop over limited epochs of data 224 | total_train_examples += samples_per_epoch[i % len(samples_per_epoch)] 225 | 226 | num_train_optimization_steps = int( 227 | total_train_examples / args.train_batch_size / args.gradient_accumulation_steps) 228 | if args.local_rank != -1: 229 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 230 | 231 | # Prepare model 232 | model = BertForPreTraining.from_pretrained(args.bert_model) 233 | if args.fp16: 234 | model.half() 235 | model.to(device) 236 | if args.local_rank != -1: 237 | try: 238 | from apex.parallel import DistributedDataParallel as DDP 239 | except ImportError: 240 | raise ImportError( 241 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 242 | model = DDP(model) 243 | elif n_gpu > 1: 244 | model = torch.nn.DataParallel(model) 245 | 246 | # Prepare optimizer 247 | param_optimizer = list(model.named_parameters()) 248 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 249 | optimizer_grouped_parameters = [ 250 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 251 | 'weight_decay': 0.01}, 252 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 253 | ] 254 | 255 | if args.fp16: 256 | try: 257 | from apex.optimizers import FP16_Optimizer 258 | from apex.optimizers import FusedAdam 259 | except ImportError: 260 | raise ImportError( 261 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 262 | 263 | optimizer = FusedAdam(optimizer_grouped_parameters, 264 | lr=args.learning_rate, 265 | bias_correction=False, 266 | max_grad_norm=1.0) 267 | if args.loss_scale == 0: 268 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 269 | else: 270 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 271 | warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, 272 | t_total=num_train_optimization_steps) 273 | else: 274 | optimizer = BertAdam(optimizer_grouped_parameters, 275 | lr=args.learning_rate, 276 | warmup=args.warmup_proportion, 277 | t_total=num_train_optimization_steps) 278 | 279 | global_step = 0 280 | logging.info("***** Running training *****") 281 | logging.info(f" Num examples = {total_train_examples}") 282 | logging.info(" Batch size = %d", args.train_batch_size) 283 | logging.info(" Num steps = %d", num_train_optimization_steps) 284 | model.train() 285 | for epoch in range(args.epochs): 286 | epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, 287 | num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory) 288 | if args.local_rank == -1: 289 | train_sampler = RandomSampler(epoch_dataset) 290 | else: 291 | train_sampler = DistributedSampler(epoch_dataset) 292 | train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 293 | tr_loss = 0 294 | nb_tr_examples, nb_tr_steps = 0, 0 295 | with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar: 296 | for step, batch in enumerate(train_dataloader): 297 | batch = tuple(t.to(device) for t in batch) 298 | input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch 299 | loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) 300 | if n_gpu > 1: 301 | loss = loss.mean() # mean() to average on multi-gpu. 302 | if args.gradient_accumulation_steps > 1: 303 | loss = loss / args.gradient_accumulation_steps 304 | if args.fp16: 305 | optimizer.backward(loss) 306 | else: 307 | loss.backward() 308 | tr_loss += loss.item() 309 | nb_tr_examples += input_ids.size(0) 310 | nb_tr_steps += 1 311 | pbar.update(1) 312 | mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps 313 | pbar.set_postfix_str(f"Loss: {mean_loss:.5f}") 314 | if (step + 1) % args.gradient_accumulation_steps == 0: 315 | if args.fp16: 316 | # modify learning rate with special warm up BERT uses 317 | # if args.fp16 is False, BertAdam is used that handles this automatically 318 | lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step/num_train_optimization_steps, 319 | args.warmup_proportion) 320 | for param_group in optimizer.param_groups: 321 | param_group['lr'] = lr_this_step 322 | optimizer.step() 323 | optimizer.zero_grad() 324 | global_step += 1 325 | 326 | # Save a trained model 327 | logging.info("** ** * Saving fine-tuned model ** ** * ") 328 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 329 | output_model_file = args.output_dir / "pytorch_model.bin" 330 | torch.save(model_to_save.state_dict(), str(output_model_file)) 331 | 332 | 333 | if __name__ == '__main__': 334 | main() 335 | -------------------------------------------------------------------------------- /nup_lm_finetuning/make_lm_data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | 4 | def make_lm_data(input_file, output_file): 5 | """Reads a tab separated value file.""" 6 | with open(input_file, 'r', encoding='utf-8') as rf, open(output_file, 'w', encoding='utf-8') as wf: 7 | # cnt = 0 8 | for line in rf: 9 | # cnt += 1 10 | line = line.strip().split('\t') 11 | if line[0] == '1': 12 | for i in range(1, len(line) - 1): 13 | line[i] = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f]').sub(' ', line[i]) 14 | if len(line[i].strip()) >= 1: 15 | wf.write(line[i] + '\n') 16 | line[-1] = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f]').sub(' ', line[-1]) 17 | if len(line[-1].strip()) >= 1: 18 | wf.write(line[-1]) 19 | 20 | wf.write('\n') 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | 25 | ## Required parameters 26 | parser.add_argument("--data_file", 27 | default=None, 28 | type=str, 29 | help="The input data dir.") 30 | parser.add_argument("--output_file", 31 | default=None, 32 | type=str, 33 | help="The output data dir.") 34 | 35 | args = parser.parse_args() 36 | 37 | make_lm_data(args.data_file, args.output_file) 38 | 39 | if __name__ == "__main__": 40 | main() -------------------------------------------------------------------------------- /nup_lm_finetuning/pregenerate_training_data_NUP.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | from tqdm import tqdm, trange 4 | from tempfile import TemporaryDirectory 5 | import shelve 6 | 7 | from random import random, randrange, randint, shuffle, choice, sample 8 | from pytorch_pretrained_bert.tokenization import BertTokenizer 9 | import numpy as np 10 | import json 11 | 12 | 13 | class DocumentDatabase: 14 | def __init__(self, reduce_memory=False): 15 | if reduce_memory: 16 | self.temp_dir = TemporaryDirectory() 17 | self.working_dir = Path(self.temp_dir.name) 18 | self.document_shelf_filepath = self.working_dir / 'shelf.db' 19 | self.document_shelf = shelve.open(str(self.document_shelf_filepath), 20 | flag='n', protocol=-1) 21 | self.documents = None 22 | else: 23 | self.documents = [] 24 | self.document_shelf = None 25 | self.document_shelf_filepath = None 26 | self.temp_dir = None 27 | self.doc_lengths = [] 28 | self.doc_cumsum = None 29 | self.cumsum_max = None 30 | self.reduce_memory = reduce_memory 31 | 32 | def add_document(self, document): 33 | if not document: 34 | return 35 | if self.reduce_memory: 36 | current_idx = len(self.doc_lengths) 37 | self.document_shelf[str(current_idx)] = document 38 | else: 39 | self.documents.append(document) 40 | self.doc_lengths.append(len(document)) 41 | 42 | def _precalculate_doc_weights(self): 43 | self.doc_cumsum = np.cumsum(self.doc_lengths) 44 | self.cumsum_max = self.doc_cumsum[-1] 45 | 46 | def sample_doc(self, current_idx, sentence_weighted=True): 47 | # Uses the current iteration counter to ensure we don't sample the same doc twice 48 | if sentence_weighted: 49 | # With sentence weighting, we sample docs proportionally to their sentence length 50 | if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): 51 | self._precalculate_doc_weights() 52 | rand_start = self.doc_cumsum[current_idx] 53 | rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] 54 | sentence_index = randrange(rand_start, rand_end) % self.cumsum_max 55 | sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') 56 | else: 57 | # If we don't use sentence weighting, then every doc has an equal chance to be chosen 58 | sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) 59 | assert sampled_doc_index != current_idx 60 | if self.reduce_memory: 61 | return self.document_shelf[str(sampled_doc_index)] 62 | else: 63 | return self.documents[sampled_doc_index] 64 | 65 | def __len__(self): 66 | return len(self.doc_lengths) 67 | 68 | def __getitem__(self, item): 69 | if self.reduce_memory: 70 | return self.document_shelf[str(item)] 71 | else: 72 | return self.documents[item] 73 | 74 | def __enter__(self): 75 | return self 76 | 77 | def __exit__(self, exc_type, exc_val, traceback): 78 | if self.document_shelf is not None: 79 | self.document_shelf.close() 80 | if self.temp_dir is not None: 81 | self.temp_dir.cleanup() 82 | 83 | 84 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): 85 | """Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo.""" 86 | while True: 87 | total_length = len(tokens_a) + len(tokens_b) 88 | if total_length <= max_num_tokens: 89 | break 90 | 91 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 92 | assert len(trunc_tokens) >= 1 93 | 94 | # We want to sometimes truncate from the front and sometimes from the 95 | # back to add more randomness and avoid biases. 96 | if random() < 0.5: 97 | del trunc_tokens[0] 98 | else: 99 | trunc_tokens.pop() 100 | 101 | 102 | def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list): 103 | """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but 104 | with several refactors to clean it up and remove a lot of unnecessary variables.""" 105 | cand_indices = [] 106 | for (i, token) in enumerate(tokens): 107 | if token == "[CLS]" or token == "[SEP]": 108 | continue 109 | cand_indices.append(i) 110 | 111 | num_to_mask = min(max_predictions_per_seq, 112 | max(1, int(round(len(tokens) * masked_lm_prob)))) 113 | shuffle(cand_indices) 114 | mask_indices = sorted(sample(cand_indices, num_to_mask)) 115 | masked_token_labels = [] 116 | for index in mask_indices: 117 | # 80% of the time, replace with [MASK] 118 | if random() < 0.8: 119 | masked_token = "[MASK]" 120 | else: 121 | # 10% of the time, keep original 122 | if random() < 0.5: 123 | masked_token = tokens[index] 124 | # 10% of the time, replace with random word 125 | else: 126 | masked_token = choice(vocab_list) 127 | masked_token_labels.append(tokens[index]) 128 | # Once we've saved the true label for that token, we can overwrite it with the masked version 129 | tokens[index] = masked_token 130 | 131 | return tokens, mask_indices, masked_token_labels 132 | 133 | 134 | def create_instances_from_document( 135 | doc_database, doc_idx, max_seq_length, short_seq_prob, 136 | masked_lm_prob, max_predictions_per_seq, vocab_list): 137 | """This code is mostly a duplicate of the equivalent function from Google BERT's repo. 138 | However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. 139 | Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence 140 | (rather than each document) has an equal chance of being sampled as a false example for the NextSentence task.""" 141 | document = doc_database[doc_idx] 142 | # Account for [CLS], [SEP], [SEP] 143 | max_num_tokens = max_seq_length - 3 144 | 145 | target_seq_length = max_num_tokens 146 | # if random() < short_seq_prob: 147 | # target_seq_length = randint(2, max_num_tokens) 148 | 149 | # for each dialogue context with multiple utterances we let each utterance (except the first turn) 150 | # be positive response of history conversations. And randomly sampled a utterance from the corpus 151 | # as the negative example. 152 | instances = [] 153 | current_chunk = [] 154 | current_length = 0 155 | i = 0 156 | while i < len(document): 157 | segment = document[i] 158 | if len(segment)>=1: 159 | current_chunk.append(segment) 160 | current_length += len(segment) 161 | if i == len(document) - 1 or current_length >= target_seq_length: 162 | if current_chunk: 163 | for m in range(len(current_chunk)): 164 | if m == len(current_chunk) - 1: 165 | continue 166 | tokens_a = [] 167 | tokens_b = [] 168 | for j in range(m + 1): 169 | tokens_a.extend(current_chunk[j]) 170 | # Actual next utterance 171 | tokens_b.extend(current_chunk[m + 1]) 172 | 173 | tokens_a_2 = tokens_a 174 | # Random next utterance 175 | # Sample a random document, with longer docs being sampled more frequently 176 | random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) 177 | random_document = [sent for sent in random_document if len(sent)>=1] 178 | random_start = randrange(0, len(random_document)) 179 | tokens_c = random_document[random_start] 180 | 181 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) 182 | truncate_seq_pair(tokens_a_2, tokens_c, max_num_tokens) 183 | 184 | assert len(tokens_a) >= 1 , print("document:",document) 185 | assert len(tokens_b) >= 1, print("current_chunk:",current_chunk) 186 | assert len(tokens_c) >= 1 , print("random_document",random_document) 187 | 188 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] 189 | tokens_neg = ["[CLS]"] + tokens_a_2 + ["[SEP]"] + tokens_c + ["[SEP]"] 190 | # The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP] 191 | # They are 1 for the B tokens and the final [SEP] 192 | segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] 193 | segment_ids_neg = [0 for _ in range(len(tokens_a_2) + 2)] + [1 for _ in range(len(tokens_c) + 1)] 194 | 195 | tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions( 196 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_list) 197 | tokens_neg, masked_lm_positions_neg, masked_lm_labels_neg = create_masked_lm_predictions( 198 | tokens_neg, masked_lm_prob, max_predictions_per_seq, vocab_list) 199 | 200 | instance = { 201 | "tokens": tokens, 202 | "segment_ids": segment_ids, 203 | "is_random_next": False, 204 | "masked_lm_positions": masked_lm_positions, 205 | "masked_lm_labels": masked_lm_labels} 206 | instances.append(instance) 207 | 208 | instance_neg = { 209 | "tokens": tokens_neg, 210 | "segment_ids": segment_ids_neg, 211 | "is_random_next": True, 212 | "masked_lm_positions": masked_lm_positions_neg, 213 | "masked_lm_labels": masked_lm_labels_neg} 214 | instances.append(instance_neg) 215 | current_chunk = [] 216 | current_length = 0 217 | i += 1 218 | 219 | return instances 220 | 221 | 222 | def main(): 223 | parser = ArgumentParser() 224 | parser.add_argument('--train_corpus', type=Path, required=True) 225 | parser.add_argument("--output_dir", type=Path, required=True) 226 | parser.add_argument("--bert_model", type=str, required=True, 227 | choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", 228 | "bert-base-multilingual", "bert-base-chinese"]) 229 | parser.add_argument("--do_lower_case", action="store_true") 230 | 231 | parser.add_argument("--reduce_memory", action="store_true", 232 | help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") 233 | 234 | parser.add_argument("--epochs_to_generate", type=int, default=3, 235 | help="Number of epochs of data to pregenerate") 236 | parser.add_argument("--max_seq_len", type=int, default=128) 237 | parser.add_argument("--short_seq_prob", type=float, default=0.1, 238 | help="Probability of making a short sentence as a training example") 239 | parser.add_argument("--masked_lm_prob", type=float, default=0.15, 240 | help="Probability of masking each token for the LM task") 241 | parser.add_argument("--max_predictions_per_seq", type=int, default=20, 242 | help="Maximum number of tokens to mask in each sequence") 243 | 244 | args = parser.parse_args() 245 | 246 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 247 | vocab_list = list(tokenizer.vocab.keys()) 248 | with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: 249 | with args.train_corpus.open() as f: 250 | doc = [] 251 | for line in tqdm(f, desc="Loading Dataset", unit=" lines"): 252 | line = line.strip() 253 | if line == "": 254 | docs.add_document(doc) 255 | doc = [] 256 | else: 257 | tokens = tokenizer.tokenize(line) 258 | doc.append(tokens) 259 | #debug 260 | # if len(docs) > 33000: 261 | # break 262 | if doc: 263 | docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added 264 | 265 | if len(docs) <= 1: 266 | exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to " 267 | "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " 268 | "indicate breaks between documents in your input file. If your dataset does not contain multiple " 269 | "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " 270 | "sections or paragraphs.") 271 | 272 | args.output_dir.mkdir(exist_ok=True) 273 | for epoch in trange(args.epochs_to_generate, desc="Epoch"): 274 | epoch_filename = args.output_dir / f"epoch_{epoch}.json" 275 | num_instances = 0 276 | with epoch_filename.open('w') as epoch_file: 277 | for doc_idx in trange(len(docs), desc="Document"): 278 | doc_instances = create_instances_from_document( 279 | docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, 280 | masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, 281 | vocab_list=vocab_list) 282 | doc_instances = [json.dumps(instance) for instance in doc_instances] 283 | for instance in doc_instances: 284 | epoch_file.write(instance + '\n') 285 | num_instances += 1 286 | metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" 287 | with metrics_file.open('w') as metrics_file: 288 | metrics = { 289 | "num_training_examples": num_instances, 290 | "max_seq_len": args.max_seq_len 291 | } 292 | metrics_file.write(json.dumps(metrics)) 293 | print("finish") 294 | 295 | 296 | if __name__ == '__main__': 297 | main() 298 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | 4 | 5 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 6 | BertForMaskedLM, BertForNextSentencePrediction, 7 | BertForSequenceClassification, 8 | load_tf_weights_in_bert) 9 | 10 | 11 | from .optimization import BertAdam 12 | 13 | 14 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 15 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeaLiLu/DeepContextModeling/f964456751e9d893e54ae5c9c6634cf020484ab1/pytorch_pretrained_bert/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeaLiLu/DeepContextModeling/f964456751e9d893e54ae5c9c6634cf020484ab1/pytorch_pretrained_bert/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeaLiLu/DeepContextModeling/f964456751e9d893e54ae5c9c6634cf020484ab1/pytorch_pretrained_bert/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeaLiLu/DeepContextModeling/f964456751e9d893e54ae5c9c6634cf020484ab1/pytorch_pretrained_bert/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeaLiLu/DeepContextModeling/f964456751e9d893e54ae5c9c6634cf020484ab1/pytorch_pretrained_bert/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | 240 | Arguments: 241 | closure (callable, optional): A closure that reevaluates the model 242 | and returns the loss. 243 | """ 244 | loss = None 245 | if closure is not None: 246 | loss = closure() 247 | 248 | for group in self.param_groups: 249 | for p in group['params']: 250 | if p.grad is None: 251 | continue 252 | grad = p.grad.data 253 | if grad.is_sparse: 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 255 | 256 | state = self.state[p] 257 | 258 | # State initialization 259 | if len(state) == 0: 260 | state['step'] = 0 261 | # Exponential moving average of gradient values 262 | state['next_m'] = torch.zeros_like(p.data) 263 | # Exponential moving average of squared gradient values 264 | state['next_v'] = torch.zeros_like(p.data) 265 | 266 | next_m, next_v = state['next_m'], state['next_v'] 267 | beta1, beta2 = group['b1'], group['b2'] 268 | 269 | # Add grad clipping 270 | if group['max_grad_norm'] > 0: 271 | clip_grad_norm_(p, group['max_grad_norm']) 272 | 273 | # Decay the first and second moment running average coefficient 274 | # In-place operations to update the averages at the same time 275 | next_m.mul_(beta1).add_(1 - beta1, grad) 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 277 | update = next_m / (next_v.sqrt() + group['e']) 278 | 279 | # Just adding the square of the weights to the loss function is *not* 280 | # the correct way of using L2 regularization/weight decay with Adam, 281 | # since that will interact with the m and v parameters in strange ways. 282 | # 283 | # Instead we want to decay the weights in a manner that doesn't interact 284 | # with the m/v parameters. This is equivalent to adding the square 285 | # of the weights to the loss with plain (non-momentum) SGD. 286 | if group['weight_decay'] > 0.0: 287 | update += group['weight_decay'] * p.data 288 | 289 | lr_scheduled = group['lr'] 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) 291 | 292 | update_with_lr = lr_scheduled * update 293 | p.data.add_(-update_with_lr) 294 | 295 | state['step'] += 1 296 | 297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 298 | # No bias correction 299 | # bias_correction1 = 1 - beta1 ** state['step'] 300 | # bias_correction2 = 1 - beta2 ** state['step'] 301 | 302 | return loss 303 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | # 'bert-base-uncased': "E:\CS\Bert-pretrain-model/bert-base-uncased-vocab.txt", 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | # 'bert-base-chinese': "E:\CS\Bert-pretrain-model/bert-base-chinese-vocab.txt", 38 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 39 | } 40 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 41 | 'bert-base-uncased': 512, 42 | 'bert-large-uncased': 512, 43 | 'bert-base-cased': 512, 44 | 'bert-large-cased': 512, 45 | 'bert-base-multilingual-uncased': 512, 46 | 'bert-base-multilingual-cased': 512, 47 | 'bert-base-chinese': 512, 48 | } 49 | VOCAB_NAME = 'vocab.txt' 50 | 51 | 52 | def load_vocab(vocab_file): 53 | """Loads a vocabulary file into a dictionary.""" 54 | vocab = collections.OrderedDict() 55 | index = 0 56 | with open(vocab_file, "r", encoding="utf-8") as reader: 57 | while True: 58 | token = reader.readline() 59 | if not token: 60 | break 61 | token = token.strip() 62 | vocab[token] = index 63 | index += 1 64 | return vocab 65 | 66 | 67 | def whitespace_tokenize(text): 68 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 69 | text = text.strip() 70 | if not text: 71 | return [] 72 | tokens = text.split() 73 | return tokens 74 | 75 | 76 | class BertTokenizer(object): 77 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 78 | 79 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 80 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 81 | """Constructs a BertTokenizer. 82 | 83 | Args: 84 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 85 | do_lower_case: Whether to lower case the input 86 | Only has an effect when do_wordpiece_only=False 87 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 88 | max_len: An artificial maximum length to truncate tokenized sequences to; 89 | Effective maximum length is always the minimum of this 90 | value (if specified) and the underlying BERT model's 91 | sequence length. 92 | never_split: List of tokens which will never be split during tokenization. 93 | Only has an effect when do_wordpiece_only=False 94 | """ 95 | if not os.path.isfile(vocab_file): 96 | raise ValueError( 97 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 98 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 99 | self.vocab = load_vocab(vocab_file) 100 | self.ids_to_tokens = collections.OrderedDict( 101 | [(ids, tok) for tok, ids in self.vocab.items()]) 102 | self.do_basic_tokenize = do_basic_tokenize 103 | if do_basic_tokenize: 104 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 105 | never_split=never_split) 106 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 107 | self.max_len = max_len if max_len is not None else int(1e12) 108 | 109 | def tokenize(self, text): 110 | split_tokens = [] 111 | if self.do_basic_tokenize: 112 | for token in self.basic_tokenizer.tokenize(text): 113 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 114 | split_tokens.append(sub_token) 115 | else: 116 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 117 | return split_tokens 118 | 119 | def convert_tokens_to_ids(self, tokens): 120 | """Converts a sequence of tokens into ids using the vocab.""" 121 | ids = [] 122 | for token in tokens: 123 | ids.append(self.vocab[token]) 124 | if len(ids) > self.max_len: 125 | logger.warning( 126 | "Token indices sequence length is longer than the specified maximum " 127 | " sequence length for this BERT model ({} > {}). Running this" 128 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 129 | ) 130 | return ids 131 | 132 | def convert_ids_to_tokens(self, ids): 133 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 134 | tokens = [] 135 | for i in ids: 136 | tokens.append(self.ids_to_tokens[i]) 137 | return tokens 138 | 139 | def save_vocabulary(self, vocab_path): 140 | """Save the tokenizer vocabulary to a directory or file.""" 141 | index = 0 142 | if os.path.isdir(vocab_path): 143 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 144 | with open(vocab_file, "w", encoding="utf-8") as writer: 145 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 146 | if index != token_index: 147 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 148 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 149 | index = token_index 150 | writer.write(token + u'\n') 151 | index += 1 152 | return vocab_file 153 | 154 | @classmethod 155 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 156 | """ 157 | Instantiate a PreTrainedBertModel from a pre-trained model file. 158 | Download and cache the pre-trained model file if needed. 159 | """ 160 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 161 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 162 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 163 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 164 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 165 | "you may want to check this behavior.") 166 | kwargs['do_lower_case'] = False 167 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 168 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 169 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 170 | "but you may want to check this behavior.") 171 | kwargs['do_lower_case'] = True 172 | else: 173 | vocab_file = pretrained_model_name_or_path 174 | if os.path.isdir(vocab_file): 175 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 176 | # redirect to the cache, if necessary 177 | try: 178 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 179 | except EnvironmentError: 180 | logger.error( 181 | "Model name '{}' was not found in model name list ({}). " 182 | "We assumed '{}' was a path or url but couldn't find any file " 183 | "associated to this path or url.".format( 184 | pretrained_model_name_or_path, 185 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 186 | vocab_file)) 187 | return None 188 | if resolved_vocab_file == vocab_file: 189 | logger.info("loading vocabulary file {}".format(vocab_file)) 190 | else: 191 | logger.info("loading vocabulary file {} from cache at {}".format( 192 | vocab_file, resolved_vocab_file)) 193 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 194 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 195 | # than the number of positional embeddings 196 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 197 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 198 | # Instantiate tokenizer. 199 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 200 | return tokenizer 201 | 202 | 203 | class BasicTokenizer(object): 204 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 205 | 206 | def __init__(self, 207 | do_lower_case=True, 208 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 209 | """Constructs a BasicTokenizer. 210 | 211 | Args: 212 | do_lower_case: Whether to lower case the input. 213 | """ 214 | self.do_lower_case = do_lower_case 215 | self.never_split = never_split 216 | 217 | def tokenize(self, text): 218 | """Tokenizes a piece of text.""" 219 | text = self._clean_text(text) 220 | # This was added on November 1st, 2018 for the multilingual and Chinese 221 | # models. This is also applied to the English models now, but it doesn't 222 | # matter since the English models were not trained on any Chinese data 223 | # and generally don't have any Chinese data in them (there are Chinese 224 | # characters in the vocabulary because Wikipedia does have some Chinese 225 | # words in the English Wikipedia.). 226 | text = self._tokenize_chinese_chars(text) 227 | orig_tokens = whitespace_tokenize(text) 228 | split_tokens = [] 229 | for token in orig_tokens: 230 | if self.do_lower_case and token not in self.never_split: 231 | token = token.lower() 232 | token = self._run_strip_accents(token) 233 | split_tokens.extend(self._run_split_on_punc(token)) 234 | 235 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 236 | return output_tokens 237 | 238 | def _run_strip_accents(self, text): 239 | """Strips accents from a piece of text.""" 240 | text = unicodedata.normalize("NFD", text) 241 | output = [] 242 | for char in text: 243 | cat = unicodedata.category(char) 244 | if cat == "Mn": 245 | continue 246 | output.append(char) 247 | return "".join(output) 248 | 249 | def _run_split_on_punc(self, text): 250 | """Splits punctuation on a piece of text.""" 251 | if text in self.never_split: 252 | return [text] 253 | chars = list(text) 254 | i = 0 255 | start_new_word = True 256 | output = [] 257 | while i < len(chars): 258 | char = chars[i] 259 | if _is_punctuation(char): 260 | output.append([char]) 261 | start_new_word = True 262 | else: 263 | if start_new_word: 264 | output.append([]) 265 | start_new_word = False 266 | output[-1].append(char) 267 | i += 1 268 | 269 | return ["".join(x) for x in output] 270 | 271 | def _tokenize_chinese_chars(self, text): 272 | """Adds whitespace around any CJK character.""" 273 | output = [] 274 | for char in text: 275 | cp = ord(char) 276 | if self._is_chinese_char(cp): 277 | output.append(" ") 278 | output.append(char) 279 | output.append(" ") 280 | else: 281 | output.append(char) 282 | return "".join(output) 283 | 284 | def _is_chinese_char(self, cp): 285 | """Checks whether CP is the codepoint of a CJK character.""" 286 | # This defines a "chinese character" as anything in the CJK Unicode block: 287 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 288 | # 289 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 290 | # despite its name. The modern Korean Hangul alphabet is a different block, 291 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 292 | # space-separated words, so they are not treated specially and handled 293 | # like the all of the other languages. 294 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 295 | (cp >= 0x3400 and cp <= 0x4DBF) or # 296 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 297 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 298 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 299 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 300 | (cp >= 0xF900 and cp <= 0xFAFF) or # 301 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 302 | return True 303 | 304 | return False 305 | 306 | def _clean_text(self, text): 307 | """Performs invalid character removal and whitespace cleanup on text.""" 308 | output = [] 309 | for char in text: 310 | cp = ord(char) 311 | if cp == 0 or cp == 0xfffd or _is_control(char): 312 | continue 313 | if _is_whitespace(char): 314 | output.append(" ") 315 | else: 316 | output.append(char) 317 | return "".join(output) 318 | 319 | 320 | class WordpieceTokenizer(object): 321 | """Runs WordPiece tokenization.""" 322 | 323 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 324 | self.vocab = vocab 325 | self.unk_token = unk_token 326 | self.max_input_chars_per_word = max_input_chars_per_word 327 | 328 | def tokenize(self, text): 329 | """Tokenizes a piece of text into its word pieces. 330 | 331 | This uses a greedy longest-match-first algorithm to perform tokenization 332 | using the given vocabulary. 333 | 334 | For example: 335 | input = "unaffable" 336 | output = ["un", "##aff", "##able"] 337 | 338 | Args: 339 | text: A single token or whitespace separated tokens. This should have 340 | already been passed through `BasicTokenizer`. 341 | 342 | Returns: 343 | A list of wordpiece tokens. 344 | """ 345 | 346 | output_tokens = [] 347 | for token in whitespace_tokenize(text): 348 | chars = list(token) 349 | if len(chars) > self.max_input_chars_per_word: 350 | output_tokens.append(self.unk_token) 351 | continue 352 | 353 | is_bad = False 354 | start = 0 355 | sub_tokens = [] 356 | while start < len(chars): 357 | end = len(chars) 358 | cur_substr = None 359 | while start < end: 360 | substr = "".join(chars[start:end]) 361 | if start > 0: 362 | substr = "##" + substr 363 | if substr in self.vocab: 364 | cur_substr = substr 365 | break 366 | end -= 1 367 | if cur_substr is None: 368 | is_bad = True 369 | break 370 | sub_tokens.append(cur_substr) 371 | start = end 372 | 373 | if is_bad: 374 | output_tokens.append(self.unk_token) 375 | else: 376 | output_tokens.extend(sub_tokens) 377 | return output_tokens 378 | 379 | 380 | def _is_whitespace(char): 381 | """Checks whether `chars` is a whitespace character.""" 382 | # \t, \n, and \r are technically contorl characters but we treat them 383 | # as whitespace since they are generally considered as such. 384 | if char == " " or char == "\t" or char == "\n" or char == "\r": 385 | return True 386 | cat = unicodedata.category(char) 387 | if cat == "Zs": 388 | return True 389 | return False 390 | 391 | 392 | def _is_control(char): 393 | """Checks whether `chars` is a control character.""" 394 | # These are technically control characters but we count them as whitespace 395 | # characters. 396 | if char == "\t" or char == "\n" or char == "\r": 397 | return False 398 | cat = unicodedata.category(char) 399 | if cat.startswith("C"): 400 | return True 401 | return False 402 | 403 | 404 | def _is_punctuation(char): 405 | """Checks whether `chars` is a punctuation character.""" 406 | cp = ord(char) 407 | # We treat all non-letter/number ASCII as punctuation. 408 | # Characters such as "^", "$", and "`" are not in the Unicode 409 | # Punctuation class but we treat them as punctuation anyways, for 410 | # consistency. 411 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 412 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 413 | return True 414 | cat = unicodedata.category(char) 415 | if cat.startswith("P"): 416 | return True 417 | return False 418 | -------------------------------------------------------------------------------- /run_FE_DAtt_RNN.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import csv 22 | import logging 23 | import os 24 | import random 25 | import sys 26 | 27 | import numpy as np 28 | import torch 29 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 30 | TensorDataset) 31 | from torch.utils.data.distributed import DistributedSampler 32 | from tqdm import tqdm, trange 33 | 34 | from torch.nn import CrossEntropyLoss, MSELoss 35 | from scipy.stats import pearsonr, spearmanr 36 | from sklearn.metrics import matthews_corrcoef, f1_score 37 | import pickle 38 | 39 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 40 | from pytorch_pretrained_bert.modeling import BertConfig,FE_DAtt_RNN 41 | from pytorch_pretrained_bert.tokenization import BertTokenizer 42 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | class InputExample(object): 48 | """A single training/test example for simple sequence classification.""" 49 | 50 | def __init__(self, guid, text_a, text_b=None, label=None): 51 | """Constructs a InputExample. 52 | 53 | Args: 54 | guid: Unique id for the example. 55 | text_a: string. The untokenized text of the first sequence. For single 56 | sequence tasks, only this sequence must be specified. 57 | text_b: (Optional) string. The untokenized text of the second sequence. 58 | Only must be specified for sequence pair tasks. 59 | label: (Optional) string. The label of the example. This should be 60 | specified for train and dev examples, but not for test examples. 61 | """ 62 | self.guid = guid 63 | self.text_a = text_a 64 | self.text_b = text_b 65 | self.label = label 66 | 67 | 68 | class InputFeatures(object): 69 | """A single set of features of data.""" 70 | 71 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 72 | self.input_ids = input_ids 73 | self.input_mask = input_mask 74 | self.segment_ids = segment_ids 75 | self.label_id = label_id 76 | 77 | import re 78 | class DataProcessor(object): 79 | """Base class for data converters for sequence classification data sets.""" 80 | 81 | def get_train_examples(self, data_dir): 82 | """Gets a collection of `InputExample`s for the train set.""" 83 | raise NotImplementedError() 84 | 85 | def get_dev_examples(self, data_dir): 86 | """Gets a collection of `InputExample`s for the dev set.""" 87 | raise NotImplementedError() 88 | 89 | def get_labels(self): 90 | """Gets the list of labels for this data set.""" 91 | raise NotImplementedError() 92 | 93 | # @classmethod 94 | # def _read_tsv(cls, input_file, quotechar=None): 95 | # """Reads a tab separated value file.""" 96 | # with open(input_file, "r", encoding="utf-8") as f: 97 | # reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 98 | # lines = [] 99 | # for line in reader: 100 | # if sys.version_info[0] == 2: 101 | # line = list(unicode(cell, 'utf-8') for cell in line) 102 | # lines.append(line) 103 | # return lines 104 | 105 | @classmethod 106 | def _read_data(cls, input_file): 107 | """Reads a tab separated value file.""" 108 | with open(input_file, "r", encoding="utf-8") as f: 109 | lines = [] 110 | for i,line in enumerate(f): 111 | # if i >= 1000: 112 | # break 113 | line = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f]').sub(' ', line).strip() 114 | line = line.strip().replace("_", "") 115 | parts = line.strip().split("\t") 116 | lable = parts[0] 117 | message = "" 118 | for i in range(1, len(parts) - 1, 1): 119 | part = parts[i].strip() 120 | if len(part)>0: 121 | message += part 122 | message += " [SEP] " 123 | response = parts[-1] 124 | data = {"y": lable, "m": message, "r": response} 125 | lines.append(data) 126 | return lines 127 | 128 | def _read_douban_data(cls, input_file): 129 | """Reads a tab separated value file.""" 130 | with open(input_file, "r", encoding="utf-8") as f: 131 | lines = [] 132 | label_list = [] 133 | message_list = [] 134 | response_list = [] 135 | label_any_1 = 0 136 | for ids,line in enumerate(f): 137 | # if ids >= 100: 138 | # break 139 | line = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f]').sub(' ', line).strip() 140 | line = line.strip().replace("_", "") 141 | parts = line.strip().split("\t") 142 | lable = parts[0] 143 | message = "" 144 | for i in range(1, len(parts) - 1, 1): 145 | part = parts[i].strip() 146 | if len(part) > 0: 147 | message += part 148 | message += " [SEP] " 149 | response = parts[-1] 150 | if lable == '1': 151 | label_any_1 = 1 152 | label_list.append(lable) 153 | message_list.append(message) 154 | response_list.append(response) 155 | if ids % 10 == 9: 156 | if label_any_1 == 1: 157 | for lable,message,response in zip(label_list,message_list,response_list): 158 | data = {"y": lable, "m": message, "r": response} 159 | lines.append(data) 160 | label_any_1 = 0 161 | label_list = [] 162 | message_list = [] 163 | response_list = [] 164 | return lines 165 | 166 | class UbuntuProcessor(DataProcessor): 167 | """Processor for the Ubuntu data set.""" 168 | 169 | def get_train_examples(self, data_dir): 170 | """See base class.""" 171 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.txt"))) 172 | return self._create_examples( 173 | self._read_data(os.path.join(data_dir, "train.txt")), "train") 174 | 175 | def get_dev_examples(self, data_dir): 176 | """See base class.""" 177 | return self._create_examples( 178 | self._read_data(os.path.join(data_dir, "valid.txt")), "dev") 179 | 180 | def get_test_examples(self, data_dir): 181 | """See base class.""" 182 | return self._create_examples( 183 | self._read_data(os.path.join(data_dir, "test.txt")), "test") 184 | 185 | def get_labels(self): 186 | """See base class.""" 187 | return ["0", "1"] 188 | 189 | def _create_examples(self, lines, set_type): 190 | """Creates examples for the training and dev sets.""" 191 | examples = [] 192 | for (i, line) in enumerate(lines): 193 | guid = "%s-%s" % (set_type, i) 194 | text_a = line["r"] 195 | text_b = line["m"].strip().split("[SEP]") 196 | text_b = [text.strip() for text in text_b if len(text.strip())>0] 197 | label = line["y"] 198 | examples.append( 199 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 200 | return examples 201 | 202 | class DoubanProcessor(DataProcessor): 203 | """Processor for the MRPC data set (GLUE version).""" 204 | 205 | def get_train_examples(self, data_dir): 206 | """See base class.""" 207 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.txt"))) 208 | return self._create_examples( 209 | self._read_douban_data(os.path.join(data_dir, "train.txt")), "train") 210 | 211 | def get_dev_examples(self, data_dir): 212 | """See base class.""" 213 | return self._create_examples( 214 | self._read_douban_data(os.path.join(data_dir, "dev.txt")), "dev") 215 | 216 | def get_test_examples(self, data_dir): 217 | """See base class.""" 218 | return self._create_examples( 219 | self._read_douban_data(os.path.join(data_dir, "test.txt")), "test") 220 | 221 | def get_labels(self): 222 | """See base class.""" 223 | return ["0", "1"] 224 | 225 | def _create_examples(self, lines, set_type): 226 | """Creates examples for the training and dev sets.""" 227 | examples = [] 228 | #for (i, line) in enumerate(lines): 229 | for (i, line) in enumerate(lines): 230 | guid = "%s-%s" % (set_type, i) 231 | text_a = line["r"] 232 | text_b = line["m"].strip().split("[SEP]") 233 | text_b = [text.strip() for text in text_b if len(text.strip()) > 0] 234 | label = line["y"] 235 | examples.append( 236 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 237 | return examples 238 | 239 | def convert_examples_to_features(examples, label_list, max_seq_length,max_utterance_num, 240 | tokenizer): 241 | """Loads a data file into a list of `InputBatch`s.""" 242 | 243 | label_map = {label : i for i, label in enumerate(label_list)} 244 | 245 | features = [] 246 | for (ex_index, example) in enumerate(examples): 247 | if ex_index % 10000 == 0: 248 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 249 | 250 | tokens_a = tokenizer.tokenize(example.text_a) 251 | max_turn_length = max_seq_length 252 | tokens_a = tokens_a[:(max_turn_length-2)] 253 | tokens_a = ["[CLS]"] + tokens_a + ["[SEP]"] 254 | 255 | tokens_b = [tokenizer.tokenize(text)[-(max_turn_length-2):] for text in example.text_b[-(max_utterance_num):]] 256 | tokens_b = [ list for list in tokens_b if len(list)>0] 257 | 258 | turns = [ ["[CLS]"]+ turn + ["[SEP]"] for turn in tokens_b ] 259 | turns = [tokens_a] + turns 260 | 261 | input_ids_turn = [tokenizer.convert_tokens_to_ids(turn) for turn in turns] 262 | input_mask_turn = [[1]*len(turn) for turn in turns] 263 | padding_turn = [[0] * (max_turn_length-len(turn)) for turn in turns] 264 | segment_ids_turn = [ [0]*(max_turn_length) for turn in turns] 265 | for i in range(len(input_ids_turn)): 266 | input_ids_turn[i] += padding_turn[i] 267 | input_mask_turn[i] += padding_turn[i] 268 | pad_turn_num = max_utterance_num +1 - len(input_ids_turn) 269 | for i in range(pad_turn_num): 270 | input_ids_turn.append([0]*(max_turn_length)) 271 | input_mask_turn.append([0]*(max_turn_length)) 272 | segment_ids_turn.append([0]*(max_turn_length)) 273 | 274 | assert len(input_ids_turn) == max_utterance_num +1 275 | assert len(input_mask_turn) == max_utterance_num +1 276 | assert len(segment_ids_turn) == max_utterance_num +1 277 | 278 | label_id = label_map[example.label] 279 | 280 | if ex_index < 5: 281 | logger.info("*** Example ***") 282 | logger.info("guid: %s" % (example.guid)) 283 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 284 | 285 | features.append( 286 | InputFeatures(input_ids=input_ids_turn, 287 | input_mask=input_mask_turn, 288 | segment_ids=segment_ids_turn, 289 | label_id=label_id)) 290 | return features 291 | 292 | 293 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 294 | """Truncates a sequence pair in place to the maximum length.""" 295 | 296 | # This is a simple heuristic which will always truncate the longer sequence 297 | # one token at a time. This makes more sense than truncating an equal percent 298 | # of tokens from each, since if one sequence is very short then each token 299 | # that's truncated likely contains more information than a longer sequence. 300 | while True: 301 | total_length = len(tokens_a) + len(tokens_b) 302 | if total_length <= max_length: 303 | break 304 | if len(tokens_a) > len(tokens_b): 305 | tokens_a.pop() 306 | else: 307 | tokens_b.pop() 308 | 309 | def get_p_at_n_in_m(pred, n, m, ind): 310 | pos_score = pred[ind] 311 | curr = pred[ind:ind + m] 312 | curr = sorted(curr, reverse=True) 313 | 314 | if len(set(curr))==1: 315 | return 0 316 | if curr[n - 1] <= pos_score: 317 | return 1 318 | return 0 319 | 320 | 321 | def evaluate(pred, label): 322 | # assert len(data) % 10 == 0 323 | 324 | p_at_1_in_2 = 0.0 325 | p_at_1_in_10 = 0.0 326 | p_at_2_in_10 = 0.0 327 | p_at_5_in_10 = 0.0 328 | 329 | length = int(len(pred) / 10) 330 | 331 | for i in range(0, length): 332 | ind = i * 10 333 | 334 | # if label[ind] != 1: 335 | # print(i,ind) 336 | # print(label) 337 | # print(label[ind]) 338 | assert label[ind] == 1 339 | 340 | p_at_1_in_2 += get_p_at_n_in_m(pred, 1, 2, ind) 341 | p_at_1_in_10 += get_p_at_n_in_m(pred, 1, 10, ind) 342 | p_at_2_in_10 += get_p_at_n_in_m(pred, 2, 10, ind) 343 | p_at_5_in_10 += get_p_at_n_in_m(pred, 5, 10, ind) 344 | 345 | return (p_at_1_in_2 / length, p_at_1_in_10 / length, p_at_2_in_10 / length, p_at_5_in_10 / length) 346 | 347 | def simple_accuracy(preds, labels): 348 | return (preds == labels).mean() 349 | 350 | def ComputeR10(scores,labels,count = 10): 351 | total = 0 352 | correct1 = 0 353 | correct5 = 0 354 | correct2 = 0 355 | correct10 = 0 356 | #删除全0的例子 test 357 | for i in range(len(labels)): 358 | if labels[i] == 1: 359 | #print(i) 360 | total = total+1 361 | sublist = scores[i:i+count] 362 | #print(np.argmax(sublist)) 363 | if np.argmax(sublist) < 1: 364 | correct1 = correct1 + 1 365 | if np.argmax(sublist) < 2: 366 | correct2 = correct2 + 1 367 | if np.argmax(sublist) < 5: 368 | correct5 = correct5 + 1 369 | if np.argmax(sublist) < 10: 370 | correct10 = correct10 + 1 371 | # if max(sublist) == scores[i]: 372 | # correct = correct + 1 373 | print(correct1, correct5, correct10, total) 374 | return (float(correct1)/ total, float(correct2)/ total, float(correct5)/ total, float(correct10)/ total) 375 | 376 | def ComputeR2_1(scores,labels,count = 2): 377 | total = 0 378 | correct = 0 379 | for i in range(len(labels)): 380 | if labels[i] == 1: 381 | total = total+1 382 | sublist = scores[i:i+count] 383 | if max(sublist) == scores[i]: 384 | correct = correct + 1 385 | return (float(correct)/ total) 386 | 387 | def mean_average_precision(sort_data): 388 | # to do 389 | count_1 = 0 390 | sum_precision = 0 391 | for index in range(len(sort_data)): 392 | if sort_data[index][1] == 1: 393 | count_1 += 1 394 | sum_precision += 1.0 * count_1 / (index + 1) 395 | return sum_precision / count_1 396 | 397 | 398 | def mean_reciprocal_rank(sort_data): 399 | sort_lable = [s_d[1] for s_d in sort_data] 400 | assert 1 in sort_lable 401 | return 1.0 / (1 + sort_lable.index(1)) 402 | 403 | 404 | def precision_at_position_1(sort_data): 405 | if sort_data[0][1] == 1: 406 | return 1 407 | else: 408 | return 0 409 | 410 | 411 | def recall_at_position_k_in_10(sort_data, k): 412 | sort_lable = [s_d[1] for s_d in sort_data] 413 | select_lable = sort_lable[:k] 414 | return 1.0 * select_lable.count(1) / sort_lable.count(1) 415 | 416 | def evaluation_one_session(data): 417 | sort_data = sorted(data, key=lambda x: x[0], reverse=True) 418 | m_a_p = mean_average_precision(sort_data) 419 | m_r_r = mean_reciprocal_rank(sort_data) 420 | p_1 = precision_at_position_1(sort_data) 421 | r_1 = recall_at_position_k_in_10(sort_data, 1) 422 | r_2 = recall_at_position_k_in_10(sort_data, 2) 423 | r_5 = recall_at_position_k_in_10(sort_data, 5) 424 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5 425 | 426 | def evaluate_douban(pred, label): 427 | sum_m_a_p = 0 428 | sum_m_r_r = 0 429 | sum_p_1 = 0 430 | sum_r_1 = 0 431 | sum_r_2 = 0 432 | sum_r_5 = 0 433 | 434 | total_num = 0 435 | data = [] 436 | #print(label) 437 | for i in range(0, len(label)): 438 | if i % 10 == 0: 439 | data = [] 440 | data.append((float(pred[i]), int(label[i]))) 441 | if i % 10 == 9: 442 | total_num += 1 443 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = evaluation_one_session(data) 444 | sum_m_a_p += m_a_p 445 | sum_m_r_r += m_r_r 446 | sum_p_1 += p_1 447 | sum_r_1 += r_1 448 | sum_r_2 += r_2 449 | sum_r_5 += r_5 450 | 451 | return (1.0 * sum_m_a_p / total_num, 1.0 * sum_m_r_r / total_num, 1.0 * sum_p_1 / total_num, 452 | 1.0 * sum_r_1 / total_num, 1.0 * sum_r_2 / total_num, 1.0 * sum_r_5 / total_num) 453 | 454 | 455 | def compute_metrics(task_name, preds, labels): 456 | assert len(preds) == len(labels) 457 | preds_logits = preds[:, 1] # 预测为1的概率 458 | 459 | if task_name == "ubuntu" or task_name == "ecd": 460 | return {"recall@2 recall@10(1,2,5)": evaluate(preds_logits, labels)} 461 | elif task_name == "douban": 462 | return {"MAP MRR P@1 recall@10(1,2,5)": evaluate_douban(preds_logits, labels)} 463 | else: 464 | raise KeyError(task_name) 465 | 466 | 467 | def main(): 468 | parser = argparse.ArgumentParser() 469 | 470 | ## Required parameters 471 | parser.add_argument("--data_dir", 472 | default="C:/Users/xzhzhang/Desktop/project/multiturn/data/ubuntu_data", 473 | type=str, 474 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 475 | parser.add_argument("--bert_model", default="C:/Users/xzhzhang/Desktop/project/google-tuned-bert/LARGE-BERT-UNCASED", type=str, 476 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 477 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 478 | "bert-base-multilingual-cased, bert-base-chinese.") 479 | parser.add_argument("--task_name", 480 | default="ubuntu", 481 | type=str, 482 | help="The name of the task to train.") 483 | parser.add_argument("--output_dir", 484 | default="output_ubuntu", 485 | type=str, 486 | help="The output directory where the model predictions and checkpoints will be written.") 487 | 488 | ## Other parameters 489 | parser.add_argument("--cache_dir", 490 | default="", 491 | type=str, 492 | help="Where do you want to store the pre-trained models downloaded from s3") 493 | parser.add_argument("--max_seq_length", 494 | default=128, 495 | type=int, 496 | help="The maximum total input sequence length after WordPiece tokenization. \n" 497 | "Sequences longer than this will be truncated, and sequences shorter \n" 498 | "than this will be padded.") 499 | parser.add_argument("--max_utterance_num", 500 | default=10, 501 | type=int, 502 | help="The maximum total utterance number.") 503 | parser.add_argument("--cache_flag", 504 | default="separate_encode", 505 | type=str, 506 | help="The output directory where the model predictions and checkpoints will be written.") 507 | parser.add_argument("--do_train", 508 | action='store_true', 509 | help="Whether to run training.") 510 | parser.add_argument("--do_eval", 511 | action='store_true', 512 | help="Whether to run eval on the dev set.") 513 | parser.add_argument("--do_lower_case", 514 | action='store_true', 515 | help="Set this flag if you are using an uncased model.") 516 | parser.add_argument("--train_batch_size", 517 | default=32, 518 | type=int, 519 | help="Total batch size for training.") 520 | parser.add_argument("--eval_batch_size", 521 | default=8, 522 | type=int, 523 | help="Total batch size for eval.") 524 | parser.add_argument("--learning_rate", 525 | default=5e-5, 526 | type=float, 527 | help="The initial learning rate for Adam.") 528 | parser.add_argument("--num_train_epochs", 529 | default=3.0, 530 | type=float, 531 | help="Total number of training epochs to perform.") 532 | parser.add_argument("--warmup_proportion", 533 | default=0.1, 534 | type=float, 535 | help="Proportion of training to perform linear learning rate warmup for. " 536 | "E.g., 0.1 = 10%% of training.") 537 | parser.add_argument("--no_cuda", 538 | action='store_true', 539 | help="Whether not to use CUDA when available") 540 | parser.add_argument("--local_rank", 541 | type=int, 542 | default=-1, 543 | help="local_rank for distributed training on gpus") 544 | parser.add_argument('--seed', 545 | type=int, 546 | default=42, 547 | help="random seed for initialization") 548 | parser.add_argument('--gradient_accumulation_steps', 549 | type=int, 550 | default=1, 551 | help="Number of updates steps to accumulate before performing a backward/update pass.") 552 | parser.add_argument('--fp16', 553 | action='store_true', 554 | help="Whether to use 16-bit float precision instead of 32-bit") 555 | parser.add_argument('--loss_scale', 556 | type=float, default=0, 557 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 558 | "0 (default value): dynamic loss scaling.\n" 559 | "Positive power of 2: static loss scaling value.\n") 560 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 561 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 562 | args = parser.parse_args() 563 | 564 | if args.server_ip and args.server_port: 565 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 566 | import ptvsd 567 | print("Waiting for debugger attach") 568 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 569 | ptvsd.wait_for_attach() 570 | 571 | processors = { 572 | "ubuntu": UbuntuProcessor, 573 | "douban": DoubanProcessor, 574 | "ecd": UbuntuProcessor, 575 | } 576 | 577 | if args.local_rank == -1 or args.no_cuda: 578 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 579 | n_gpu = torch.cuda.device_count() 580 | else: 581 | torch.cuda.set_device(args.local_rank) 582 | device = torch.device("cuda", args.local_rank) 583 | n_gpu = 1 584 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 585 | torch.distributed.init_process_group(backend='nccl') 586 | 587 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 588 | datefmt = '%m/%d/%Y %H:%M:%S', 589 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 590 | 591 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 592 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 593 | 594 | if args.gradient_accumulation_steps < 1: 595 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 596 | args.gradient_accumulation_steps)) 597 | 598 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 599 | 600 | random.seed(args.seed) 601 | np.random.seed(args.seed) 602 | torch.manual_seed(args.seed) 603 | if n_gpu > 0: 604 | torch.cuda.manual_seed_all(args.seed) 605 | 606 | if not args.do_train and not args.do_eval: 607 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 608 | 609 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 610 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 611 | if not os.path.exists(args.output_dir): 612 | os.makedirs(args.output_dir) 613 | 614 | task_name = args.task_name.lower() 615 | 616 | if task_name not in processors: 617 | raise ValueError("Task not found: %s" % (task_name)) 618 | 619 | processor = processors[task_name]() 620 | 621 | label_list = processor.get_labels() 622 | num_labels = len(label_list) 623 | 624 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 625 | 626 | train_examples = None 627 | num_train_optimization_steps = None 628 | if args.do_train: 629 | train_examples = processor.get_train_examples(args.data_dir) 630 | num_train_optimization_steps = int( 631 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 632 | if args.local_rank != -1: 633 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 634 | 635 | # Prepare model 636 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 637 | 'distributed_{}'.format(args.local_rank)) 638 | 639 | model = FE_DAtt_RNN.from_pretrained(args.bert_model, 640 | cache_dir=cache_dir, 641 | num_labels=num_labels) 642 | if args.fp16: 643 | model.half() 644 | model.to(device) 645 | if args.local_rank != -1: 646 | try: 647 | from apex.parallel import DistributedDataParallel as DDP 648 | except ImportError: 649 | raise ImportError( 650 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 651 | 652 | model = DDP(model) 653 | elif n_gpu > 1: 654 | model = torch.nn.DataParallel(model) 655 | 656 | # Prepare optimizer 657 | if args.do_train: 658 | param_optimizer = list(model.named_parameters()) 659 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 660 | optimizer_grouped_parameters = [ 661 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 662 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 663 | ] 664 | if args.fp16: 665 | try: 666 | from apex.optimizers import FP16_Optimizer 667 | from apex.optimizers import FusedAdam 668 | except ImportError: 669 | raise ImportError( 670 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 671 | 672 | optimizer = FusedAdam(optimizer_grouped_parameters, 673 | lr=args.learning_rate, 674 | bias_correction=False, 675 | max_grad_norm=1.0) 676 | if args.loss_scale == 0: 677 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 678 | else: 679 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 680 | warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, 681 | t_total=num_train_optimization_steps) 682 | 683 | else: 684 | optimizer = BertAdam(optimizer_grouped_parameters, 685 | lr=args.learning_rate, 686 | warmup=args.warmup_proportion, 687 | t_total=num_train_optimization_steps) 688 | 689 | global_step = 0 690 | nb_tr_steps = 0 691 | tr_loss = 0 692 | if args.do_train: 693 | cached_train_features_file = args.data_dir + '_{0}_{1}_{2}_{3}_{4}_{5}'.format( 694 | list(filter(None, args.bert_model.split('/'))).pop(), "train", str(args.task_name), 695 | str(args.max_seq_length), 696 | str(args.max_utterance_num), str(args.cache_flag)) 697 | train_features = None 698 | try: 699 | with open(cached_train_features_file, "rb") as reader: 700 | train_features = pickle.load(reader) 701 | except: 702 | train_features = convert_examples_to_features( 703 | train_examples, label_list, args.max_seq_length, args.max_utterance_num,tokenizer) 704 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 705 | logger.info(" Saving train features into cached file %s", cached_train_features_file) 706 | with open(cached_train_features_file, "wb") as writer: 707 | pickle.dump(train_features, writer) 708 | 709 | logger.info("***** Running training *****") 710 | logger.info(" Num examples = %d", len(train_examples)) 711 | logger.info(" Batch size = %d", args.train_batch_size) 712 | logger.info(" Num steps = %d", num_train_optimization_steps) 713 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 714 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 715 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 716 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 717 | 718 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 719 | if args.local_rank == -1: 720 | train_sampler = RandomSampler(train_data) 721 | else: 722 | train_sampler = DistributedSampler(train_data) 723 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 724 | 725 | eval_examples = processor.get_dev_examples(args.data_dir) 726 | cached_train_features_file = args.data_dir + '_{0}_{1}_{2}_{3}_{4}_{5}'.format( 727 | list(filter(None, args.bert_model.split('/'))).pop(), "valid", str(args.task_name), 728 | str(args.max_seq_length), 729 | str(args.max_utterance_num), str(args.cache_flag)) 730 | eval_features = None 731 | try: 732 | with open(cached_train_features_file, "rb") as reader: 733 | eval_features = pickle.load(reader) 734 | except: 735 | eval_features = convert_examples_to_features( 736 | eval_examples, label_list, args.max_seq_length,args.max_utterance_num, tokenizer) 737 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 738 | logger.info(" Saving eval features into cached file %s", cached_train_features_file) 739 | with open(cached_train_features_file, "wb") as writer: 740 | pickle.dump(eval_features, writer) 741 | 742 | logger.info("***** Running evaluation *****") 743 | logger.info(" Num examples = %d", len(eval_examples)) 744 | logger.info(" Batch size = %d", args.eval_batch_size) 745 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 746 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 747 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 748 | 749 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 750 | 751 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 752 | # Run prediction for full data 753 | eval_sampler = SequentialSampler(eval_data) 754 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 755 | 756 | model.train() 757 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 758 | 759 | tr_loss = 0 760 | nb_tr_examples, nb_tr_steps = 0, 0 761 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 762 | batch = tuple(t.to(device) for t in batch) 763 | input_ids, input_mask, segment_ids, label_ids= batch 764 | 765 | # define a new function to compute loss values for both output_modes 766 | logits = model(input_ids, segment_ids, input_mask, labels=None) 767 | 768 | loss_fct = CrossEntropyLoss() 769 | loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 770 | 771 | if n_gpu > 1: 772 | loss = loss.mean() # mean() to average on multi-gpu. 773 | if args.gradient_accumulation_steps > 1: 774 | loss = loss / args.gradient_accumulation_steps 775 | 776 | if args.fp16: 777 | optimizer.backward(loss) 778 | else: 779 | loss.backward() 780 | 781 | tr_loss += loss.item() 782 | nb_tr_examples += input_ids.size(0) 783 | nb_tr_steps += 1 784 | if (step + 1) % args.gradient_accumulation_steps == 0: 785 | if args.fp16: 786 | # modify learning rate with special warm up BERT uses 787 | # if args.fp16 is False, BertAdam is used that handles this automatically 788 | lr_this_step = args.learning_rate * warmup_linear.get_lr( 789 | global_step / num_train_optimization_steps, 790 | args.warmup_proportion) 791 | for param_group in optimizer.param_groups: 792 | param_group['lr'] = lr_this_step 793 | optimizer.step() 794 | optimizer.zero_grad() 795 | global_step += 1 796 | 797 | 798 | # Save a trained model, configuration and tokenizer 799 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 800 | 801 | # If we save using the predefined names, we can load using `from_pretrained` 802 | output_model_file = os.path.join(args.output_dir, str(epoch) + "_" + WEIGHTS_NAME) 803 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 804 | 805 | torch.save(model_to_save.state_dict(), output_model_file) 806 | model_to_save.config.to_json_file(output_config_file) 807 | tokenizer.save_vocabulary(args.output_dir) 808 | 809 | # Load a trained model and vocabulary that you have fine-tuned 810 | model_state_dict = torch.load(output_model_file) 811 | eval_model = FE_DAtt_RNN.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) 812 | # tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 813 | eval_model.to(device) 814 | eval_model.eval() 815 | nb_eval_steps = 0 816 | preds = [] 817 | 818 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 819 | input_ids = input_ids.to(device) 820 | input_mask = input_mask.to(device) 821 | segment_ids = segment_ids.to(device) 822 | label_ids = label_ids.to(device) 823 | 824 | with torch.no_grad(): 825 | logits = eval_model(input_ids, segment_ids, input_mask, labels=None) 826 | 827 | 828 | nb_eval_steps += 1 829 | if len(preds) == 0: 830 | preds.append(logits.detach().cpu().numpy()) 831 | else: 832 | preds[0] = np.append( 833 | preds[0], logits.detach().cpu().numpy(), axis=0) 834 | 835 | preds = preds[0] 836 | # print(preds) 837 | result = compute_metrics(task_name, preds, all_label_ids.numpy()) 838 | 839 | result['global_step'] = global_step 840 | 841 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 842 | with open(output_eval_file, "a") as writer: 843 | logger.info("***** Eval results *****") 844 | for key in sorted(result.keys()): 845 | logger.info(" %s = %s", key, str(result[key])) 846 | writer.write("%s = %s\n" % (key, str(result[key]))) 847 | # writer.write(preds.__str__()) 848 | # with open(output_eval_file, "a") as writer: 849 | # writer.write(preds.__str__()) 850 | 851 | else: 852 | #output_model_file = 'experiments/separateInput/1_pytorch_model.bin' 853 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 854 | model_state_dict = torch.load(output_model_file) 855 | model = FE_DAtt_RNN.from_pretrained(args.bert_model, state_dict=model_state_dict, 856 | num_labels=num_labels) 857 | 858 | #model.to(device) 859 | 860 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 861 | eval_examples = processor.get_dev_examples(args.data_dir) 862 | eval_features = convert_examples_to_features( 863 | eval_examples, label_list, args.max_seq_length, args.max_utterance_num, tokenizer) 864 | logger.info("***** Running evaluation *****") 865 | logger.info(" Num examples = %d", len(eval_examples)) 866 | logger.info(" Batch size = %d", args.eval_batch_size) 867 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 868 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 869 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 870 | 871 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 872 | 873 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 874 | # Run prediction for full data 875 | eval_sampler = SequentialSampler(eval_data) 876 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 877 | 878 | model.eval() 879 | eval_loss = 0 880 | nb_eval_steps = 0 881 | preds = [] 882 | 883 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 884 | input_ids = input_ids.to(device) 885 | input_mask = input_mask.to(device) 886 | segment_ids = segment_ids.to(device) 887 | label_ids = label_ids.to(device) 888 | 889 | with torch.no_grad(): 890 | logits = model(input_ids, segment_ids, input_mask, labels=None) 891 | 892 | # create eval loss and other metric required by the task 893 | loss_fct = CrossEntropyLoss() 894 | tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 895 | 896 | eval_loss += tmp_eval_loss.mean().item() 897 | nb_eval_steps += 1 898 | if len(preds) == 0: 899 | preds.append(logits.detach().cpu().numpy()) 900 | else: 901 | preds[0] = np.append( 902 | preds[0], logits.detach().cpu().numpy(), axis=0) 903 | 904 | eval_loss = eval_loss / nb_eval_steps 905 | preds = preds[0] 906 | #print(preds) 907 | result = compute_metrics(task_name, preds, all_label_ids.numpy()) 908 | loss = tr_loss/nb_tr_steps if args.do_train else None 909 | 910 | result['eval_loss'] = eval_loss 911 | result['global_step'] = global_step 912 | result['loss'] = loss 913 | 914 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 915 | with open(output_eval_file, "a") as writer: 916 | logger.info("***** Eval results *****") 917 | for key in sorted(result.keys()): 918 | logger.info(" %s = %s", key, str(result[key])) 919 | writer.write("%s = %s\n" % (key, str(result[key]))) 920 | 921 | if __name__ == "__main__": 922 | main() 923 | -------------------------------------------------------------------------------- /run_IE_MHAtt_CNN.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import csv 22 | import logging 23 | import os 24 | import random 25 | import sys 26 | import pickle 27 | import numpy as np 28 | import torch 29 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 30 | TensorDataset) 31 | from torch.utils.data.distributed import DistributedSampler 32 | from tqdm import tqdm, trange 33 | 34 | from torch.nn import CrossEntropyLoss, MSELoss 35 | from scipy.stats import pearsonr, spearmanr 36 | from sklearn.metrics import matthews_corrcoef, f1_score 37 | 38 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 39 | from pytorch_pretrained_bert.modeling import BertConfig, IE_MHAtt_CNN 40 | from pytorch_pretrained_bert.tokenization import BertTokenizer 41 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 42 | import re 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | class InputExample(object): 48 | """A single training/test example for simple sequence classification.""" 49 | 50 | def __init__(self, guid, text_a, text_b=None, label=None): 51 | """Constructs a InputExample. 52 | 53 | Args: 54 | guid: Unique id for the example. 55 | text_a: string. The untokenized text of the first sequence. For single 56 | sequence tasks, only this sequence must be specified. 57 | text_b: (Optional) string. The untokenized text of the second sequence. 58 | Only must be specified for sequence pair tasks. 59 | label: (Optional) string. The label of the example. This should be 60 | specified for train and dev examples, but not for test examples. 61 | """ 62 | self.guid = guid 63 | self.text_a = text_a 64 | self.text_b = text_b 65 | self.label = label 66 | 67 | 68 | class InputFeatures(object): 69 | """A single set of features of data.""" 70 | 71 | def __init__(self, input_ids, input_mask, segment_ids, response_len, sep_pos,context_len, label_id): 72 | self.input_ids = input_ids 73 | self.input_mask = input_mask 74 | self.segment_ids = segment_ids 75 | self.response_len = response_len 76 | self.sep_pos = sep_pos 77 | self.context_len = context_len 78 | self.label_id = label_id 79 | 80 | 81 | class DataProcessor(object): 82 | """Base class for data converters for sequence classification data sets.""" 83 | 84 | def get_train_examples(self, data_dir): 85 | """Gets a collection of `InputExample`s for the train set.""" 86 | raise NotImplementedError() 87 | 88 | def get_dev_examples(self, data_dir): 89 | """Gets a collection of `InputExample`s for the dev set.""" 90 | raise NotImplementedError() 91 | 92 | def get_labels(self): 93 | """Gets the list of labels for this data set.""" 94 | raise NotImplementedError() 95 | 96 | # @classmethod 97 | # def _read_tsv(cls, input_file, quotechar=None): 98 | # """Reads a tab separated value file.""" 99 | # with open(input_file, "r", encoding="utf-8") as f: 100 | # reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 101 | # lines = [] 102 | # for line in reader: 103 | # if sys.version_info[0] == 2: 104 | # line = list(unicode(cell, 'utf-8') for cell in line) 105 | # lines.append(line) 106 | # return lines 107 | 108 | @classmethod 109 | def _read_data(cls, input_file): 110 | """Reads a tab separated value file.""" 111 | with open(input_file, "r", encoding="utf-8") as f: 112 | lines = [] 113 | for i,line in enumerate(f): 114 | # if i >100: 115 | # break 116 | line = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f]').sub(' ', line).strip() 117 | line = line.strip().replace("_", "") 118 | parts = line.strip().split("\t") 119 | lable = parts[0] 120 | message = "" 121 | for i in range(1, len(parts) - 1, 1): 122 | part = parts[i].strip() 123 | if len(part) > 0: 124 | if i != len(parts) - 2: 125 | message += part 126 | message += "[SEP]" 127 | else: 128 | message += part 129 | response = parts[-1] 130 | data = {"y": lable, "m": message, "r": response} 131 | lines.append(data) 132 | return lines 133 | 134 | def _read_douban_data(cls, input_file): 135 | """Reads a tab separated value file.""" 136 | with open(input_file, "r", encoding="utf-8") as f: 137 | lines = [] 138 | label_list = [] 139 | message_list = [] 140 | response_list = [] 141 | label_any_1 = 0 142 | for ids,line in enumerate(f): 143 | # if ids >= 100: 144 | # break 145 | line = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f]').sub(' ', line).strip() 146 | line = line.strip().replace("_", "") 147 | parts = line.strip().split("\t") 148 | lable = parts[0] 149 | message = "" 150 | for i in range(1, len(parts) - 1, 1): 151 | part = parts[i].strip() 152 | if len(part) > 0: 153 | message += part 154 | message += " [SEP] " 155 | response = parts[-1] 156 | if lable == '1': 157 | label_any_1 = 1 158 | label_list.append(lable) 159 | message_list.append(message) 160 | response_list.append(response) 161 | if ids % 10 == 9: 162 | if label_any_1 == 1: 163 | for lable,message,response in zip(label_list,message_list,response_list): 164 | data = {"y": lable, "m": message, "r": response} 165 | lines.append(data) 166 | label_any_1 = 0 167 | label_list = [] 168 | message_list = [] 169 | response_list = [] 170 | return lines 171 | 172 | class UbuntuProcessor(DataProcessor): 173 | """Processor for the MRPC data set (GLUE version).""" 174 | 175 | def get_train_examples(self, data_dir): 176 | """See base class.""" 177 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.txt"))) 178 | return self._create_examples( 179 | self._read_data(os.path.join(data_dir, "train.txt")), "train") 180 | 181 | def get_dev_examples(self, data_dir): 182 | """See base class.""" 183 | return self._create_examples( 184 | self._read_data(os.path.join(data_dir, "valid.txt")), "dev") ## dev.txt for ECD 185 | 186 | def get_test_examples(self, data_dir): 187 | """See base class.""" 188 | return self._create_examples( 189 | self._read_data(os.path.join(data_dir, "test.txt")), "test") 190 | 191 | def get_labels(self): 192 | """See base class.""" 193 | return ["0", "1"] 194 | 195 | def _create_examples(self, lines, set_type): 196 | """Creates examples for the training and dev sets.""" 197 | examples = [] 198 | for (i, line) in enumerate(lines): 199 | guid = "%s-%s" % (set_type, i) 200 | text_a = line["r"] 201 | text_b = line["m"].strip().split("[SEP]") 202 | label = line["y"] 203 | examples.append( 204 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 205 | return examples 206 | 207 | class DoubanProcessor(DataProcessor): 208 | """Processor for the MRPC data set (GLUE version).""" 209 | 210 | def get_train_examples(self, data_dir): 211 | """See base class.""" 212 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.txt"))) 213 | return self._create_examples( 214 | self._read_douban_data(os.path.join(data_dir, "train.txt")), "train") 215 | 216 | def get_dev_examples(self, data_dir): 217 | """See base class.""" 218 | return self._create_examples( 219 | self._read_douban_data(os.path.join(data_dir, "dev.txt")), "dev") 220 | 221 | def get_test_examples(self, data_dir): 222 | """See base class.""" 223 | return self._create_examples( 224 | self._read_douban_data(os.path.join(data_dir, "test.txt")), "test") 225 | 226 | def get_labels(self): 227 | """See base class.""" 228 | return ["0", "1"] 229 | 230 | def _create_examples(self, lines, set_type): 231 | """Creates examples for the training and dev sets.""" 232 | examples = [] 233 | #for (i, line) in enumerate(lines): 234 | for (i, line) in enumerate(lines): 235 | guid = "%s-%s" % (set_type, i) 236 | text_a = line["r"] 237 | text_b = line["m"] 238 | label = line["y"] 239 | examples.append( 240 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 241 | return examples 242 | 243 | def convert_examples_to_features(examples, label_list, max_seq_length, max_utterance_num, 244 | tokenizer): 245 | """Loads a data file into a list of `InputBatch`s.""" 246 | 247 | label_map = {label : i for i, label in enumerate(label_list)} 248 | 249 | features = [] 250 | for (ex_index, example) in enumerate(examples): 251 | if ex_index % 10000 == 0: 252 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 253 | 254 | tokens_a = tokenizer.tokenize(example.text_a) 255 | 256 | tokens_b = [] 257 | 258 | for idx, text in enumerate(example.text_b): 259 | if len(text.strip())>0: 260 | tokens_b.extend(tokenizer.tokenize(text) + ["[SEP]"]) 261 | #print(len(tokens_b)) 262 | num_sep = len(example.text_b) - 1 263 | #print(num_sep) 264 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2) 265 | 266 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 267 | segment_ids = [0] * len(tokens) 268 | response_len = len(tokens) 269 | 270 | context_len = [] 271 | context_len.append(response_len) 272 | sep_pos = [] 273 | tokens_b_raw = " ".join(tokens_b) 274 | tokens_b = [] 275 | current_pos = response_len - 1 276 | sep_pos.append(current_pos + 1) 277 | for toks in tokens_b_raw.split("[SEP]")[-max_utterance_num - 1:-1]: 278 | context_len.append(len(toks.split()) + 1) 279 | tokens_b.extend(toks.split()) 280 | tokens_b.extend(["[SEP]"]) 281 | current_pos += context_len[-1] 282 | sep_pos.append(current_pos + 1) 283 | tokens += tokens_b 284 | 285 | segment_ids += [1] * (len(tokens_b)) 286 | 287 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 288 | 289 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 290 | # tokens are attended to. 291 | input_mask = [1] * len(input_ids) 292 | 293 | # Zero-pad up to the sequence length. 294 | padding = [0] * (max_seq_length - len(input_ids)) 295 | input_ids += padding 296 | input_mask += padding 297 | segment_ids += padding 298 | 299 | context_len += [0] * (max_utterance_num + 1 - len(context_len)) 300 | sep_pos += [0] * (max_utterance_num + 1 - len(sep_pos)) 301 | # if len(input_ids) != max_seq_length: 302 | # print(len(input_ids)) 303 | # print(max_seq_length) 304 | assert len(sep_pos) == max_utterance_num + 1 305 | assert len(input_ids) == max_seq_length 306 | assert len(input_mask) == max_seq_length 307 | assert len(segment_ids) == max_seq_length 308 | assert len(context_len) == max_utterance_num + 1 309 | 310 | label_id = label_map[example.label] 311 | 312 | if ex_index < 5: 313 | logger.info("*** Example ***") 314 | logger.info("guid: %s" % (example.guid)) 315 | logger.info("tokens: %s" % " ".join( 316 | [str(x) for x in tokens])) 317 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 318 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 319 | logger.info( 320 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 321 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 322 | 323 | features.append( 324 | InputFeatures(input_ids=input_ids, 325 | input_mask=input_mask, 326 | segment_ids=segment_ids, 327 | response_len=response_len, 328 | sep_pos=sep_pos, 329 | context_len = context_len, 330 | label_id=label_id)) 331 | return features 332 | 333 | 334 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 335 | """Truncates a sequence pair in place to the maximum length.""" 336 | 337 | # This is a simple heuristic which will always truncate the longer sequence 338 | # one token at a time. This makes more sense than truncating an equal percent 339 | # of tokens from each, since if one sequence is very short then each token 340 | # that's truncated likely contains more information than a longer sequence. 341 | while True: 342 | total_length = len(tokens_a) + len(tokens_b) 343 | if total_length <= max_length: 344 | break 345 | if len(tokens_a) > len(tokens_b): 346 | tokens_a.pop() 347 | else: 348 | tokens_b.pop(0) 349 | 350 | def get_p_at_n_in_m(pred, n, m, ind): 351 | pos_score = pred[ind]; 352 | curr = pred[ind:ind + m]; 353 | curr = sorted(curr, reverse=True) 354 | 355 | if curr[n - 1] <= pos_score: 356 | return 1; 357 | return 0; 358 | 359 | 360 | def evaluate(pred, label): 361 | # assert len(data) % 10 == 0 362 | 363 | p_at_1_in_2 = 0.0 364 | p_at_1_in_10 = 0.0 365 | p_at_2_in_10 = 0.0 366 | p_at_5_in_10 = 0.0 367 | 368 | length = int(len(pred) / 10) 369 | 370 | for i in range(0, length): 371 | ind = i * 10 372 | 373 | # if label[ind] != 1: 374 | # print(i,ind) 375 | # print(label) 376 | # print(label[ind]) 377 | assert label[ind] == 1 378 | 379 | p_at_1_in_2 += get_p_at_n_in_m(pred, 1, 2, ind) 380 | p_at_1_in_10 += get_p_at_n_in_m(pred, 1, 10, ind) 381 | p_at_2_in_10 += get_p_at_n_in_m(pred, 2, 10, ind) 382 | p_at_5_in_10 += get_p_at_n_in_m(pred, 5, 10, ind) 383 | 384 | return (p_at_1_in_2 / length, p_at_1_in_10 / length, p_at_2_in_10 / length, p_at_5_in_10 / length) 385 | 386 | def simple_accuracy(preds, labels): 387 | return (preds == labels).mean() 388 | 389 | def ComputeR10(scores,labels,count = 10): 390 | total = 0 391 | correct1 = 0 392 | correct5 = 0 393 | correct2 = 0 394 | correct10 = 0 395 | #删除全0的例子 test 396 | for i in range(len(labels)): 397 | if labels[i] == 1: 398 | #print(i) 399 | total = total+1 400 | sublist = scores[i:i+count] 401 | #print(np.argmax(sublist)) 402 | if np.argmax(sublist) < 1: 403 | correct1 = correct1 + 1 404 | if np.argmax(sublist) < 2: 405 | correct2 = correct2 + 1 406 | if np.argmax(sublist) < 5: 407 | correct5 = correct5 + 1 408 | if np.argmax(sublist) < 10: 409 | correct10 = correct10 + 1 410 | # if max(sublist) == scores[i]: 411 | # correct = correct + 1 412 | print(correct1, correct5, correct10, total) 413 | return (float(correct1)/ total, float(correct2)/ total, float(correct5)/ total, float(correct10)/ total) 414 | 415 | def ComputeR2_1(scores,labels,count = 2): 416 | total = 0 417 | correct = 0 418 | for i in range(len(labels)): 419 | if labels[i] == 1: 420 | total = total+1 421 | sublist = scores[i:i+count] 422 | if max(sublist) == scores[i]: 423 | correct = correct + 1 424 | return (float(correct)/ total) 425 | 426 | def mean_average_precision(sort_data): 427 | # to do 428 | count_1 = 0 429 | sum_precision = 0 430 | for index in range(len(sort_data)): 431 | if sort_data[index][1] == 1: 432 | count_1 += 1 433 | sum_precision += 1.0 * count_1 / (index + 1) 434 | return sum_precision / count_1 435 | 436 | 437 | def mean_reciprocal_rank(sort_data): 438 | sort_lable = [s_d[1] for s_d in sort_data] 439 | assert 1 in sort_lable 440 | return 1.0 / (1 + sort_lable.index(1)) 441 | 442 | 443 | def precision_at_position_1(sort_data): 444 | if sort_data[0][1] == 1: 445 | return 1 446 | else: 447 | return 0 448 | 449 | 450 | def recall_at_position_k_in_10(sort_data, k): 451 | sort_lable = [s_d[1] for s_d in sort_data] 452 | select_lable = sort_lable[:k] 453 | return 1.0 * select_lable.count(1) / sort_lable.count(1) 454 | 455 | def evaluation_one_session(data): 456 | sort_data = sorted(data, key=lambda x: x[0], reverse=True) 457 | m_a_p = mean_average_precision(sort_data) 458 | m_r_r = mean_reciprocal_rank(sort_data) 459 | p_1 = precision_at_position_1(sort_data) 460 | r_1 = recall_at_position_k_in_10(sort_data, 1) 461 | r_2 = recall_at_position_k_in_10(sort_data, 2) 462 | r_5 = recall_at_position_k_in_10(sort_data, 5) 463 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5 464 | 465 | def evaluate_douban(pred, label): 466 | sum_m_a_p = 0 467 | sum_m_r_r = 0 468 | sum_p_1 = 0 469 | sum_r_1 = 0 470 | sum_r_2 = 0 471 | sum_r_5 = 0 472 | 473 | total_num = 0 474 | data = [] 475 | #print(label) 476 | for i in range(0, len(label)): 477 | if i % 10 == 0: 478 | data = [] 479 | data.append((float(pred[i]), int(label[i]))) 480 | if i % 10 == 9: 481 | total_num += 1 482 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = evaluation_one_session(data) 483 | sum_m_a_p += m_a_p 484 | sum_m_r_r += m_r_r 485 | sum_p_1 += p_1 486 | sum_r_1 += r_1 487 | sum_r_2 += r_2 488 | sum_r_5 += r_5 489 | # print('total num: %s' %total_num) 490 | # print('MAP: %s' %(1.0*sum_m_a_p/total_num)) 491 | # print('MRR: %s' %(1.0*sum_m_r_r/total_num)) 492 | # print('P@1: %s' %(1.0*sum_p_1/total_num)) 493 | return (1.0 * sum_m_a_p / total_num, 1.0 * sum_m_r_r / total_num, 1.0 * sum_p_1 / total_num, 494 | 1.0 * sum_r_1 / total_num, 1.0 * sum_r_2 / total_num, 1.0 * sum_r_5 / total_num) 495 | 496 | def compute_metrics(task_name, preds, labels): 497 | assert len(preds) == len(labels) 498 | preds_logits = preds[:, 1] # 预测为1的概率 499 | 500 | if task_name == "ubuntu" or task_name == "ecd": 501 | return {"recall@2 recall@10(1,2,5)": evaluate(preds_logits, labels)} 502 | elif task_name == "douban": 503 | return {"MAP MRR P@1 recall@10(1,2,5)": evaluate_douban(preds_logits, labels)} 504 | else: 505 | raise KeyError(task_name) 506 | 507 | 508 | def main(): 509 | parser = argparse.ArgumentParser() 510 | 511 | ## Required parameters 512 | parser.add_argument("--data_dir", 513 | default="E:/lab/Bert-Multi-turn-Response-Selection/data/ubuntu_data", 514 | type=str, 515 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 516 | parser.add_argument("--bert_model", default="bert-base-uncased", type=str, 517 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 518 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 519 | "bert-base-multilingual-cased, bert-base-chinese.") 520 | parser.add_argument("--task_name", 521 | default="ubuntu", 522 | type=str, 523 | help="The name of the task to train.") 524 | parser.add_argument("--output_dir", 525 | default="output_ubuntu22", 526 | type=str, 527 | help="The output directory where the model predictions and checkpoints will be written.") 528 | parser.add_argument("--max_utterance_num", 529 | default=10, 530 | type=int, 531 | help="The maximum total utterance number.") 532 | parser.add_argument("--cache_flag", 533 | default="sequence_match", 534 | type=str, 535 | help="The data features cache will be written.") 536 | 537 | ## Other parameters 538 | parser.add_argument("--cache_dir", 539 | default="", 540 | type=str, 541 | help="Where do you want to store the pre-trained models downloaded from s3") 542 | parser.add_argument("--max_seq_length", 543 | default=100, 544 | type=int, 545 | help="The maximum total input sequence length after WordPiece tokenization. \n" 546 | "Sequences longer than this will be truncated, and sequences shorter \n" 547 | "than this will be padded.") 548 | parser.add_argument("--do_train", 549 | action='store_true', 550 | help="Whether to run training.") 551 | parser.add_argument("--do_eval", 552 | action='store_true', 553 | help="Whether to run eval on the dev set.") 554 | parser.add_argument("--do_lower_case", 555 | action='store_true', 556 | help="Set this flag if you are using an uncased model.") 557 | parser.add_argument("--train_batch_size", 558 | default=4, 559 | type=int, 560 | help="Total batch size for training.") 561 | parser.add_argument("--eval_batch_size", 562 | default=4, 563 | type=int, 564 | help="Total batch size for eval.") 565 | parser.add_argument("--learning_rate", 566 | default=5e-5, 567 | type=float, 568 | help="The initial learning rate for Adam.") 569 | parser.add_argument("--num_train_epochs", 570 | default=3.0, 571 | type=float, 572 | help="Total number of training epochs to perform.") 573 | parser.add_argument("--warmup_proportion", 574 | default=0.1, 575 | type=float, 576 | help="Proportion of training to perform linear learning rate warmup for. " 577 | "E.g., 0.1 = 10%% of training.") 578 | parser.add_argument("--no_cuda", 579 | action='store_true', 580 | help="Whether not to use CUDA when available") 581 | parser.add_argument("--local_rank", 582 | type=int, 583 | default=-1, 584 | help="local_rank for distributed training on gpus") 585 | parser.add_argument('--seed', 586 | type=int, 587 | default=42, 588 | help="random seed for initialization") 589 | parser.add_argument('--gradient_accumulation_steps', 590 | type=int, 591 | default=1, 592 | help="Number of updates steps to accumulate before performing a backward/update pass.") 593 | parser.add_argument('--fp16', 594 | action='store_true', 595 | help="Whether to use 16-bit float precision instead of 32-bit") 596 | parser.add_argument('--loss_scale', 597 | type=float, default=0, 598 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 599 | "0 (default value): dynamic loss scaling.\n" 600 | "Positive power of 2: static loss scaling value.\n") 601 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 602 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 603 | args = parser.parse_args() 604 | 605 | if args.server_ip and args.server_port: 606 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 607 | import ptvsd 608 | print("Waiting for debugger attach") 609 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 610 | ptvsd.wait_for_attach() 611 | 612 | processors = { 613 | "ubuntu": UbuntuProcessor, 614 | "douban": DoubanProcessor, 615 | "ecd": UbuntuProcessor, 616 | } 617 | 618 | if args.local_rank == -1 or args.no_cuda: 619 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 620 | n_gpu = torch.cuda.device_count() 621 | else: 622 | torch.cuda.set_device(args.local_rank) 623 | device = torch.device("cuda", args.local_rank) 624 | n_gpu = 1 625 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 626 | torch.distributed.init_process_group(backend='nccl') 627 | 628 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 629 | datefmt = '%m/%d/%Y %H:%M:%S', 630 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 631 | 632 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 633 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 634 | 635 | if args.gradient_accumulation_steps < 1: 636 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 637 | args.gradient_accumulation_steps)) 638 | 639 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 640 | 641 | random.seed(args.seed) 642 | np.random.seed(args.seed) 643 | torch.manual_seed(args.seed) 644 | if n_gpu > 0: 645 | torch.cuda.manual_seed_all(args.seed) 646 | 647 | if not args.do_train and not args.do_eval: 648 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 649 | 650 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 651 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 652 | if not os.path.exists(args.output_dir): 653 | os.makedirs(args.output_dir) 654 | 655 | task_name = args.task_name.lower() 656 | 657 | if task_name not in processors: 658 | raise ValueError("Task not found: %s" % (task_name)) 659 | 660 | processor = processors[task_name]() 661 | 662 | label_list = processor.get_labels() 663 | num_labels = len(label_list) 664 | 665 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 666 | 667 | train_examples = None 668 | num_train_optimization_steps = None 669 | if args.do_train: 670 | train_examples = processor.get_train_examples(args.data_dir) 671 | num_train_optimization_steps = int( 672 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 673 | if args.local_rank != -1: 674 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 675 | 676 | # Prepare model 677 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 678 | 'distributed_{}'.format(args.local_rank)) 679 | model = IE_MHAtt_CNN.from_pretrained(args.bert_model, 680 | cache_dir=cache_dir, 681 | num_labels=num_labels, 682 | max_turn_num = args.max_utterance_num 683 | ) 684 | 685 | if args.fp16: 686 | model.half() 687 | model.to(device) 688 | if args.local_rank != -1: 689 | try: 690 | from apex.parallel import DistributedDataParallel as DDP 691 | except ImportError: 692 | raise ImportError( 693 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 694 | 695 | model = DDP(model) 696 | elif n_gpu > 1: 697 | model = torch.nn.DataParallel(model) 698 | 699 | # Prepare optimizer 700 | if args.do_train: 701 | param_optimizer = list(model.named_parameters()) 702 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 703 | optimizer_grouped_parameters = [ 704 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 705 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 706 | ] 707 | if args.fp16: 708 | try: 709 | from apex.optimizers import FP16_Optimizer 710 | from apex.optimizers import FusedAdam 711 | except ImportError: 712 | raise ImportError( 713 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 714 | 715 | optimizer = FusedAdam(optimizer_grouped_parameters, 716 | lr=args.learning_rate, 717 | bias_correction=False, 718 | max_grad_norm=1.0) 719 | if args.loss_scale == 0: 720 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 721 | else: 722 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 723 | warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, 724 | t_total=num_train_optimization_steps) 725 | 726 | else: 727 | optimizer = BertAdam(optimizer_grouped_parameters, 728 | lr=args.learning_rate, 729 | warmup=args.warmup_proportion, 730 | t_total=num_train_optimization_steps) 731 | 732 | global_step = 0 733 | nb_tr_steps = 0 734 | tr_loss = 0 735 | if args.do_train: 736 | cached_train_features_file = args.data_dir + '_{0}_{1}_{2}_{3}_{4}_{5}'.format( 737 | list(filter(None, args.bert_model.split('/'))).pop(), "train",str(args.task_name), str(args.max_seq_length), 738 | str(args.max_utterance_num), str(args.cache_flag)) 739 | train_features = None 740 | try: 741 | with open(cached_train_features_file, "rb") as reader: 742 | train_features = pickle.load(reader) 743 | except: 744 | train_features = convert_examples_to_features( 745 | train_examples, label_list, args.max_seq_length, args.max_utterance_num, tokenizer) 746 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 747 | logger.info(" Saving train features into cached file %s", cached_train_features_file) 748 | with open(cached_train_features_file, "wb") as writer: 749 | pickle.dump(train_features, writer) 750 | 751 | logger.info("***** Running training *****") 752 | logger.info(" Num examples = %d", len(train_examples)) 753 | logger.info(" Batch size = %d", args.train_batch_size) 754 | logger.info(" Num steps = %d", num_train_optimization_steps) 755 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 756 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 757 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 758 | all_response_len = torch.tensor([f.response_len for f in train_features], dtype=torch.long) 759 | all_sep_pos = torch.tensor([f.sep_pos for f in train_features], dtype=torch.long) 760 | all_context_len = torch.tensor([f.context_len for f in train_features], dtype=torch.long) 761 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 762 | 763 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_response_len, all_sep_pos,all_context_len, all_label_ids) 764 | if args.local_rank == -1: 765 | train_sampler = RandomSampler(train_data) 766 | else: 767 | train_sampler = DistributedSampler(train_data) 768 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 769 | 770 | 771 | eval_examples = processor.get_dev_examples(args.data_dir) 772 | cached_train_features_file = args.data_dir + '_{0}_{1}_{2}_{3}_{4}_{5}'.format( 773 | list(filter(None, args.bert_model.split('/'))).pop(), "valid",str(args.task_name), str(args.max_seq_length), 774 | str(args.max_utterance_num), str(args.cache_flag)) 775 | eval_features = None 776 | try: 777 | with open(cached_train_features_file, "rb") as reader: 778 | eval_features = pickle.load(reader) 779 | except: 780 | eval_features = convert_examples_to_features( 781 | eval_examples, label_list, args.max_seq_length, args.max_utterance_num, tokenizer) 782 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 783 | logger.info(" Saving eval features into cached file %s", cached_train_features_file) 784 | with open(cached_train_features_file, "wb") as writer: 785 | pickle.dump(eval_features, writer) 786 | 787 | logger.info("***** Running evaluation *****") 788 | logger.info(" Num examples = %d", len(eval_examples)) 789 | logger.info(" Batch size = %d", args.eval_batch_size) 790 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 791 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 792 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 793 | all_response_len = torch.tensor([f.response_len for f in eval_features], dtype=torch.long) 794 | all_sep_pos = torch.tensor([f.sep_pos for f in eval_features], dtype=torch.long) 795 | all_context_len = torch.tensor([f.context_len for f in eval_features], dtype=torch.long) 796 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 797 | 798 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_response_len, all_sep_pos, all_context_len, all_label_ids) 799 | # Run prediction for full data 800 | eval_sampler = SequentialSampler(eval_data) 801 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 802 | 803 | model.train() 804 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 805 | 806 | tr_loss = 0 807 | nb_tr_examples, nb_tr_steps = 0, 0 808 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 809 | batch = tuple(t.to(device) for t in batch) 810 | input_ids, input_mask, segment_ids, response_len, sep_pos, context_len, label_ids = batch 811 | 812 | # define a new function to compute loss values for both output_modes 813 | logits = model(input_ids, segment_ids, input_mask, response_len, sep_pos, context_len, labels=None) 814 | 815 | loss_fct = CrossEntropyLoss() 816 | loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 817 | 818 | if n_gpu > 1: 819 | loss = loss.mean() # mean() to average on multi-gpu. 820 | if args.gradient_accumulation_steps > 1: 821 | loss = loss / args.gradient_accumulation_steps 822 | 823 | if args.fp16: 824 | optimizer.backward(loss) 825 | else: 826 | loss.backward() 827 | 828 | tr_loss += loss.item() 829 | nb_tr_examples += input_ids.size(0) 830 | nb_tr_steps += 1 831 | if (step + 1) % args.gradient_accumulation_steps == 0: 832 | if args.fp16: 833 | # modify learning rate with special warm up BERT uses 834 | # if args.fp16 is False, BertAdam is used that handles this automatically 835 | lr_this_step = args.learning_rate * warmup_linear.get_lr( 836 | global_step / num_train_optimization_steps, 837 | args.warmup_proportion) 838 | for param_group in optimizer.param_groups: 839 | param_group['lr'] = lr_this_step 840 | optimizer.step() 841 | optimizer.zero_grad() 842 | global_step += 1 843 | 844 | 845 | # Save a trained model, configuration and tokenizer 846 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 847 | 848 | # If we save using the predefined names, we can load using `from_pretrained` 849 | output_model_file = os.path.join(args.output_dir, str(epoch) + "_" + WEIGHTS_NAME) 850 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 851 | 852 | torch.save(model_to_save.state_dict(), output_model_file) 853 | model_to_save.config.to_json_file(output_config_file) 854 | tokenizer.save_vocabulary(args.output_dir) 855 | 856 | # Load a trained model and vocabulary that you have fine-tuned 857 | model_state_dict = torch.load(output_model_file) 858 | eval_model = IE_MHAtt_CNN.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels,max_turn_num = args.max_utterance_num) 859 | # tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 860 | eval_model.to(device) 861 | eval_model.eval() 862 | eval_loss = 0 863 | nb_eval_steps = 0 864 | preds = [] 865 | 866 | for input_ids, input_mask, segment_ids, response_len, sep_pos, context_len,label_ids in tqdm(eval_dataloader, desc="Evaluating"): 867 | input_ids = input_ids.to(device) 868 | input_mask = input_mask.to(device) 869 | segment_ids = segment_ids.to(device) 870 | response_len = response_len.to(device) 871 | sep_pos = sep_pos.to(device) 872 | context_len = context_len.to(device) 873 | label_ids = label_ids.to(device) 874 | 875 | with torch.no_grad(): 876 | logits = eval_model(input_ids, segment_ids, input_mask, response_len, sep_pos,context_len, labels=None) 877 | 878 | # create eval loss and other metric required by the task 879 | loss_fct = CrossEntropyLoss() 880 | tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 881 | 882 | 883 | eval_loss += tmp_eval_loss.mean().item() 884 | nb_eval_steps += 1 885 | if len(preds) == 0: 886 | preds.append(logits.detach().cpu().numpy()) 887 | else: 888 | preds[0] = np.append( 889 | preds[0], logits.detach().cpu().numpy(), axis=0) 890 | 891 | eval_loss = eval_loss / nb_eval_steps 892 | preds = preds[0] 893 | # print(preds) 894 | result = compute_metrics(task_name, preds, all_label_ids.numpy()) 895 | loss = tr_loss / nb_tr_steps if args.do_train else None 896 | 897 | result['eval_loss'] = eval_loss 898 | result['global_step'] = global_step 899 | result['loss'] = loss 900 | 901 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 902 | with open(output_eval_file, "a") as writer: 903 | logger.info("***** Eval results *****") 904 | for key in sorted(result.keys()): 905 | logger.info(" %s = %s", key, str(result[key])) 906 | writer.write("%s = %s\n" % (key, str(result[key]))) 907 | 908 | else: 909 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 910 | model_state_dict = torch.load(output_model_file) 911 | model = IE_MHAtt_CNN.from_pretrained(args.bert_model, state_dict=model_state_dict, 912 | num_labels=num_labels,max_turn_num = args.max_utterance_num) 913 | model.to(device) 914 | 915 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 916 | eval_examples = processor.get_test_examples(args.data_dir) 917 | eval_features = convert_examples_to_features( 918 | eval_examples, label_list, args.max_seq_length, args.max_utterance_num, tokenizer) 919 | logger.info("***** Running evaluation *****") 920 | logger.info(" Num examples = %d", len(eval_examples)) 921 | logger.info(" Batch size = %d", args.eval_batch_size) 922 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 923 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 924 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 925 | all_response_len = torch.tensor([f.response_len for f in eval_features], dtype=torch.long) 926 | all_sep_pos = torch.tensor([f.sep_pos for f in eval_features], dtype=torch.long) 927 | all_context_len = torch.tensor([f.context_len for f in eval_features], dtype=torch.long) 928 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 929 | 930 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_response_len, all_sep_pos, 931 | all_context_len, all_label_ids) 932 | # Run prediction for full data 933 | eval_sampler = SequentialSampler(eval_data) 934 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 935 | 936 | model.eval() 937 | eval_loss = 0 938 | nb_eval_steps = 0 939 | preds = [] 940 | 941 | for input_ids, input_mask, segment_ids, response_len, sep_pos, context_len, label_ids in tqdm(eval_dataloader, 942 | desc="Evaluating"): 943 | input_ids = input_ids.to(device) 944 | input_mask = input_mask.to(device) 945 | segment_ids = segment_ids.to(device) 946 | response_len = response_len.to(device) 947 | sep_pos = sep_pos.to(device) 948 | context_len = context_len.to(device) 949 | label_ids = label_ids.to(device) 950 | 951 | with torch.no_grad(): 952 | logits = model(input_ids, segment_ids, input_mask, response_len, sep_pos,context_len, labels=None) 953 | 954 | # create eval loss and other metric required by the task 955 | 956 | loss_fct = CrossEntropyLoss() 957 | tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 958 | 959 | eval_loss += tmp_eval_loss.mean().item() 960 | nb_eval_steps += 1 961 | if len(preds) == 0: 962 | preds.append(logits.detach().cpu().numpy()) 963 | else: 964 | preds[0] = np.append( 965 | preds[0], logits.detach().cpu().numpy(), axis=0) 966 | 967 | eval_loss = eval_loss / nb_eval_steps 968 | preds = preds[0] 969 | # print(preds) 970 | result = compute_metrics(task_name, preds, all_label_ids.numpy()) 971 | loss = tr_loss / nb_tr_steps if args.do_train else None 972 | 973 | result['eval_loss'] = eval_loss 974 | result['global_step'] = global_step 975 | result['loss'] = loss 976 | 977 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 978 | with open(output_eval_file, "a") as writer: 979 | logger.info("***** Eval results *****") 980 | for key in sorted(result.keys()): 981 | logger.info(" %s = %s", key, str(result[key])) 982 | writer.write("%s = %s\n" % (key, str(result[key]))) 983 | 984 | if __name__ == "__main__": 985 | main() 986 | -------------------------------------------------------------------------------- /run_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import csv 22 | import logging 23 | import os 24 | import random 25 | import sys 26 | import pickle 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 31 | TensorDataset) 32 | from torch.utils.data.distributed import DistributedSampler 33 | from tqdm import tqdm, trange 34 | 35 | from torch.nn import CrossEntropyLoss, MSELoss 36 | from scipy.stats import pearsonr, spearmanr 37 | from sklearn.metrics import matthews_corrcoef, f1_score 38 | 39 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 40 | from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig 41 | from pytorch_pretrained_bert.tokenization import BertTokenizer 42 | from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule 43 | import re 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | class InputExample(object): 49 | """A single training/test example for simple sequence classification.""" 50 | 51 | def __init__(self, guid, text_a, text_b=None, label=None): 52 | """Constructs a InputExample. 53 | 54 | Args: 55 | guid: Unique id for the example. 56 | text_a: string. The untokenized text of the first sequence. For single 57 | sequence tasks, only this sequence must be specified. 58 | text_b: (Optional) string. The untokenized text of the second sequence. 59 | Only must be specified for sequence pair tasks. 60 | label: (Optional) string. The label of the example. This should be 61 | specified for train and dev examples, but not for test examples. 62 | """ 63 | self.guid = guid 64 | self.text_a = text_a 65 | self.text_b = text_b 66 | self.label = label 67 | 68 | 69 | class InputFeatures(object): 70 | """A single set of features of data.""" 71 | 72 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 73 | self.input_ids = input_ids 74 | self.input_mask = input_mask 75 | self.segment_ids = segment_ids 76 | self.label_id = label_id 77 | 78 | 79 | class DataProcessor(object): 80 | """Base class for data converters for sequence classification data sets.""" 81 | 82 | def get_train_examples(self, data_dir): 83 | """Gets a collection of `InputExample`s for the train set.""" 84 | raise NotImplementedError() 85 | 86 | def get_dev_examples(self, data_dir): 87 | """Gets a collection of `InputExample`s for the dev set.""" 88 | raise NotImplementedError() 89 | 90 | def get_labels(self): 91 | """Gets the list of labels for this data set.""" 92 | raise NotImplementedError() 93 | 94 | @classmethod 95 | def _read_data(cls, input_file): 96 | """Reads a tab separated value file.""" 97 | with open(input_file, "r", encoding="utf-8") as f: 98 | lines = [] 99 | for line in f: 100 | line = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f]').sub(' ', line).strip() 101 | line = line.strip().replace("_", "") 102 | parts = line.strip().split("\t") 103 | lable = parts[0] 104 | message = "" 105 | for i in range(1, len(parts) - 1, 1): 106 | part = parts[i].strip() 107 | if len(part) > 0: 108 | message += part 109 | message += " [SEP] " # 可以试试其他分隔符,例如[unused100] 110 | response = parts[-1] 111 | data = {"y": lable, "m": message, "r": response} 112 | lines.append(data) 113 | return lines 114 | 115 | def _read_douban_data(cls, input_file): 116 | """Reads a tab separated value file.""" 117 | with open(input_file, "r", encoding="utf-8") as f: 118 | lines = [] 119 | label_list = [] 120 | message_list = [] 121 | response_list = [] 122 | label_any_1 = 0 123 | for ids,line in enumerate(f): 124 | # if ids >= 100: 125 | # break 126 | line = re.compile('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f]').sub(' ', line).strip() 127 | line = line.strip().replace("_", "") 128 | parts = line.strip().split("\t") 129 | lable = parts[0] 130 | message = "" 131 | for i in range(1, len(parts) - 1, 1): 132 | part = parts[i].strip() 133 | if len(part) > 0: 134 | message += part 135 | message += " [SEP] " 136 | response = parts[-1] 137 | if lable == '1': 138 | label_any_1 = 1 139 | label_list.append(lable) 140 | message_list.append(message) 141 | response_list.append(response) 142 | if ids % 10 == 9: 143 | if label_any_1 == 1: 144 | for lable,message,response in zip(label_list,message_list,response_list): 145 | data = {"y": lable, "m": message, "r": response} 146 | lines.append(data) 147 | label_any_1 = 0 148 | label_list = [] 149 | message_list = [] 150 | response_list = [] 151 | return lines 152 | 153 | 154 | class UbuntuProcessor(DataProcessor): 155 | """Processor for the MRPC data set (GLUE version).""" 156 | 157 | def get_train_examples(self, data_dir): 158 | """See base class.""" 159 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.txt"))) 160 | return self._create_examples( 161 | self._read_data(os.path.join(data_dir, "train.txt")), "train") 162 | 163 | def get_dev_examples(self, data_dir): 164 | """See base class.""" 165 | return self._create_examples( 166 | self._read_data(os.path.join(data_dir, "valid.txt")), "dev") ## dev.txt for ECD 167 | 168 | def get_test_examples(self, data_dir): 169 | """See base class.""" 170 | return self._create_examples( 171 | self._read_data(os.path.join(data_dir, "test.txt")), "test") 172 | 173 | def get_labels(self): 174 | """See base class.""" 175 | return ["0", "1"] 176 | 177 | def _create_examples(self, lines, set_type): 178 | """Creates examples for the training and dev sets.""" 179 | examples = [] 180 | #for (i, line) in enumerate(lines): 181 | for (i, line) in enumerate(lines): 182 | guid = "%s-%s" % (set_type, i) 183 | text_a = line["r"] 184 | text_b = line["m"] 185 | label = line["y"] 186 | examples.append( 187 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 188 | return examples 189 | 190 | class DoubanProcessor(DataProcessor): 191 | """Processor for the MRPC data set (GLUE version).""" 192 | 193 | def get_train_examples(self, data_dir): 194 | """See base class.""" 195 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.txt"))) 196 | return self._create_examples( 197 | self._read_douban_data(os.path.join(data_dir, "train.txt")), "train") 198 | 199 | def get_dev_examples(self, data_dir): 200 | """See base class.""" 201 | return self._create_examples( 202 | self._read_douban_data(os.path.join(data_dir, "dev.txt")), "dev") 203 | 204 | def get_test_examples(self, data_dir): 205 | """See base class.""" 206 | return self._create_examples( 207 | self._read_douban_data(os.path.join(data_dir, "test.txt")), "test") 208 | 209 | def get_labels(self): 210 | """See base class.""" 211 | return ["0", "1"] 212 | 213 | def _create_examples(self, lines, set_type): 214 | """Creates examples for the training and dev sets.""" 215 | examples = [] 216 | #for (i, line) in enumerate(lines): 217 | for (i, line) in enumerate(lines): 218 | guid = "%s-%s" % (set_type, i) 219 | text_a = line["r"] 220 | text_b = line["m"] 221 | label = line["y"] 222 | examples.append( 223 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 224 | return examples 225 | 226 | 227 | def convert_examples_to_features(examples, label_list, max_seq_length, 228 | tokenizer): 229 | """Loads a data file into a list of `InputBatch`s.""" 230 | 231 | label_map = {label : i for i, label in enumerate(label_list)} 232 | 233 | features = [] 234 | for (ex_index, example) in enumerate(examples): 235 | if ex_index % 10000 == 0: 236 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 237 | 238 | tokens_a = tokenizer.tokenize(example.text_a) 239 | 240 | tokens_b = None 241 | if example.text_b: 242 | tokens_b = tokenizer.tokenize(example.text_b) 243 | # Modifies `tokens_a` and `tokens_b` in place so that the total 244 | # length is less than the specified length. 245 | # Account for [CLS], [SEP], [SEP] with "- 3" 246 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 247 | else: 248 | # Account for [CLS] and [SEP] with "- 2" 249 | if len(tokens_a) > max_seq_length - 2: 250 | tokens_a = tokens_a[:(max_seq_length - 2)] 251 | 252 | # The convention in BERT is: 253 | # (a) For sequence pairs: 254 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 255 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 256 | # (b) For single sequences: 257 | # tokens: [CLS] the dog is hairy . [SEP] 258 | # type_ids: 0 0 0 0 0 0 0 259 | # 260 | # Where "type_ids" are used to indicate whether this is the first 261 | # sequence or the second sequence. The embedding vectors for `type=0` and 262 | # `type=1` were learned during pre-training and are added to the wordpiece 263 | # embedding vector (and position vector). This is not *strictly* necessary 264 | # since the [SEP] token unambiguously separates the sequences, but it makes 265 | # it easier for the model to learn the concept of sequences. 266 | # 267 | # For classification tasks, the first vector (corresponding to [CLS]) is 268 | # used as as the "sentence vector". Note that this only makes sense because 269 | # the entire model is fine-tuned. 270 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 271 | segment_ids = [0] * len(tokens) 272 | 273 | if tokens_b: 274 | tokens += tokens_b + ["[SEP]"] 275 | segment_ids += [1] * (len(tokens_b) + 1) 276 | 277 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 278 | 279 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 280 | # tokens are attended to. 281 | input_mask = [1] * len(input_ids) 282 | 283 | # Zero-pad up to the sequence length. 284 | padding = [0] * (max_seq_length - len(input_ids)) 285 | input_ids += padding 286 | input_mask += padding 287 | segment_ids += padding 288 | 289 | assert len(input_ids) == max_seq_length 290 | assert len(input_mask) == max_seq_length 291 | assert len(segment_ids) == max_seq_length 292 | 293 | label_id = label_map[example.label] 294 | 295 | if ex_index < 5: 296 | logger.info("*** Example ***") 297 | logger.info("guid: %s" % (example.guid)) 298 | logger.info("tokens: %s" % " ".join( 299 | [str(x) for x in tokens])) 300 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 301 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 302 | logger.info( 303 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 304 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 305 | 306 | features.append( 307 | InputFeatures(input_ids=input_ids, 308 | input_mask=input_mask, 309 | segment_ids=segment_ids, 310 | label_id=label_id)) 311 | return features 312 | 313 | 314 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 315 | """Truncates a sequence pair in place to the maximum length.""" 316 | 317 | # This is a simple heuristic which will always truncate the longer sequence 318 | # one token at a time. This makes more sense than truncating an equal percent 319 | # of tokens from each, since if one sequence is very short then each token 320 | # that's truncated likely contains more information than a longer sequence. 321 | while True: 322 | total_length = len(tokens_a) + len(tokens_b) 323 | if total_length <= max_length: 324 | break 325 | if len(tokens_a) > len(tokens_b): 326 | tokens_a.pop() 327 | else: 328 | tokens_b.pop() 329 | 330 | 331 | def mean_average_precision(sort_data): 332 | # to do 333 | count_1 = 0 334 | sum_precision = 0 335 | for index in range(len(sort_data)): 336 | if sort_data[index][1] == 1: 337 | count_1 += 1 338 | sum_precision += 1.0 * count_1 / (index + 1) 339 | return sum_precision / count_1 340 | 341 | 342 | def mean_reciprocal_rank(sort_data): 343 | sort_lable = [s_d[1] for s_d in sort_data] 344 | assert 1 in sort_lable 345 | return 1.0 / (1 + sort_lable.index(1)) 346 | 347 | 348 | def precision_at_position_1(sort_data): 349 | if sort_data[0][1] == 1: 350 | return 1 351 | else: 352 | return 0 353 | 354 | 355 | def recall_at_position_k_in_10(sort_data, k): 356 | sort_lable = [s_d[1] for s_d in sort_data] 357 | select_lable = sort_lable[:k] 358 | return 1.0 * select_lable.count(1) / sort_lable.count(1) 359 | 360 | 361 | def evaluation_one_session(data): 362 | sort_data = sorted(data, key=lambda x: x[0], reverse=True) 363 | m_a_p = mean_average_precision(sort_data) 364 | m_r_r = mean_reciprocal_rank(sort_data) 365 | p_1 = precision_at_position_1(sort_data) 366 | r_1 = recall_at_position_k_in_10(sort_data, 1) 367 | r_2 = recall_at_position_k_in_10(sort_data, 2) 368 | r_5 = recall_at_position_k_in_10(sort_data, 5) 369 | return m_a_p, m_r_r, p_1, r_1, r_2, r_5 370 | 371 | def get_p_at_n_in_m(pred, n, m, ind): 372 | pos_score = pred[ind] 373 | curr = pred[ind:ind + m] 374 | curr = sorted(curr, reverse=True) 375 | 376 | if curr[n - 1] <= pos_score: 377 | return 1 378 | return 0 379 | 380 | 381 | def evaluate(pred, label): 382 | # assert len(data) % 10 == 0 383 | 384 | p_at_1_in_2 = 0.0 385 | p_at_1_in_10 = 0.0 386 | p_at_2_in_10 = 0.0 387 | p_at_5_in_10 = 0.0 388 | 389 | length = int(len(pred) / 10) 390 | 391 | for i in range(0, length): 392 | ind = i * 10 393 | 394 | # if label[ind] != 1: 395 | # print(i,ind) 396 | # print(label) 397 | # print(label[ind]) 398 | assert label[ind] == 1 399 | 400 | p_at_1_in_2 += get_p_at_n_in_m(pred, 1, 2, ind) 401 | p_at_1_in_10 += get_p_at_n_in_m(pred, 1, 10, ind) 402 | p_at_2_in_10 += get_p_at_n_in_m(pred, 2, 10, ind) 403 | p_at_5_in_10 += get_p_at_n_in_m(pred, 5, 10, ind) 404 | 405 | return (p_at_1_in_2 / length, p_at_1_in_10 / length, p_at_2_in_10 / length, p_at_5_in_10 / length) 406 | 407 | def evaluate_douban(pred, label): 408 | sum_m_a_p = 0 409 | sum_m_r_r = 0 410 | sum_p_1 = 0 411 | sum_r_1 = 0 412 | sum_r_2 = 0 413 | sum_r_5 = 0 414 | 415 | total_num = 0 416 | data = [] 417 | #print(label) 418 | for i in range(0, len(label)): 419 | if i % 10 == 0: 420 | data = [] 421 | data.append((float(pred[i]), int(label[i]))) 422 | if i % 10 == 9: 423 | total_num += 1 424 | m_a_p, m_r_r, p_1, r_1, r_2, r_5 = evaluation_one_session(data) 425 | sum_m_a_p += m_a_p 426 | sum_m_r_r += m_r_r 427 | sum_p_1 += p_1 428 | sum_r_1 += r_1 429 | sum_r_2 += r_2 430 | sum_r_5 += r_5 431 | # print('total num: %s' %total_num) 432 | # print('MAP: %s' %(1.0*sum_m_a_p/total_num)) 433 | # print('MRR: %s' %(1.0*sum_m_r_r/total_num)) 434 | # print('P@1: %s' %(1.0*sum_p_1/total_num)) 435 | return (1.0 * sum_m_a_p / total_num, 1.0 * sum_m_r_r / total_num, 1.0 * sum_p_1 / total_num, 436 | 1.0 * sum_r_1 / total_num, 1.0 * sum_r_2 / total_num, 1.0 * sum_r_5 / total_num) 437 | 438 | def simple_accuracy(preds, labels): 439 | return (preds == labels).mean() 440 | 441 | def ComputeR10(scores,labels,count = 10): 442 | total = 0 443 | correct1 = 0 444 | correct5 = 0 445 | correct2 = 0 446 | correct10 = 0 447 | #删除全0的例子 test 448 | for i in range(len(labels)): 449 | if labels[i] == 1: 450 | #print(i) 451 | total = total+1 452 | sublist = scores[i:i+count] 453 | #print(np.argmax(sublist)) 454 | if np.argmax(sublist) < 1: 455 | correct1 = correct1 + 1 456 | if np.argmax(sublist) < 2: 457 | correct2 = correct2 + 1 458 | if np.argmax(sublist) < 5: 459 | correct5 = correct5 + 1 460 | if np.argmax(sublist) < 10: 461 | correct10 = correct10 + 1 462 | # if max(sublist) == scores[i]: 463 | # correct = correct + 1 464 | print(correct1, correct5, correct10, total) 465 | return (float(correct1)/ total, float(correct2)/ total, float(correct5)/ total, float(correct10)/ total) 466 | 467 | def ComputeR2_1(scores,labels,count = 2): 468 | total = 0 469 | correct = 0 470 | for i in range(len(labels)): 471 | if labels[i] == 1: 472 | total = total+1 473 | sublist = scores[i:i+count] 474 | if max(sublist) == scores[i]: 475 | correct = correct + 1 476 | return (float(correct)/ total) 477 | 478 | def compute_metrics(task_name, preds, labels): 479 | assert len(preds) == len(labels) 480 | preds_logits = preds[:, 1] # 预测为1的概率 481 | 482 | if task_name == "ubuntu" or task_name == "ecd": 483 | return {"recall@2 recall@10(1,2,5)": evaluate(preds_logits, labels)} 484 | elif task_name =="douban": 485 | return {"MAP MRR P@1 recall@10(1,2,5)": evaluate_douban(preds_logits, labels)} 486 | else: 487 | raise KeyError(task_name) 488 | 489 | 490 | def main(): 491 | parser = argparse.ArgumentParser() 492 | 493 | ## Required parameters 494 | parser.add_argument("--data_dir", 495 | default="data/ubuntu_data", 496 | type=str, 497 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 498 | parser.add_argument("--bert_model", default="C:/Users/xzhzhang/Desktop/project/google-tuned-bert/BASE-BERT-UNCASED", type=str, 499 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 500 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 501 | "bert-base-multilingual-cased, bert-base-chinese.") 502 | parser.add_argument("--task_name", 503 | default="ubuntu", 504 | type=str, 505 | help="The name of the task to train.") 506 | parser.add_argument("--output_dir", 507 | default="output_ubuntu", 508 | type=str, 509 | help="The output directory where the model predictions and checkpoints will be written.") 510 | 511 | ## Other parameters 512 | parser.add_argument("--cache_dir", 513 | default="", 514 | type=str, 515 | help="Where do you want to store the pre-trained models downloaded from s3") 516 | parser.add_argument("--max_seq_length", 517 | default=128, 518 | type=int, 519 | help="The maximum total input sequence length after WordPiece tokenization. \n" 520 | "Sequences longer than this will be truncated, and sequences shorter \n" 521 | "than this will be padded.") 522 | parser.add_argument("--cache_flag", 523 | default="bert", 524 | type=str, 525 | help="The data features cache will be written.") 526 | parser.add_argument("--do_train", 527 | action='store_true', 528 | help="Whether to run training.") 529 | parser.add_argument("--do_eval", 530 | action='store_true', 531 | help="Whether to run eval on the dev set.") 532 | parser.add_argument("--do_lower_case", 533 | action='store_true', 534 | help="Set this flag if you are using an uncased model.") 535 | parser.add_argument("--train_batch_size", 536 | default=32, 537 | type=int, 538 | help="Total batch size for training.") 539 | parser.add_argument("--eval_batch_size", 540 | default=8, 541 | type=int, 542 | help="Total batch size for eval.") 543 | parser.add_argument("--learning_rate", 544 | default=5e-5, 545 | type=float, 546 | help="The initial learning rate for Adam.") 547 | parser.add_argument("--num_train_epochs", 548 | default=3.0, 549 | type=float, 550 | help="Total number of training epochs to perform.") 551 | parser.add_argument("--warmup_proportion", 552 | default=0.1, 553 | type=float, 554 | help="Proportion of training to perform linear learning rate warmup for. " 555 | "E.g., 0.1 = 10%% of training.") 556 | parser.add_argument("--no_cuda", 557 | action='store_true', 558 | help="Whether not to use CUDA when available") 559 | parser.add_argument("--local_rank", 560 | type=int, 561 | default=-1, 562 | help="local_rank for distributed training on gpus") 563 | parser.add_argument('--seed', 564 | type=int, 565 | default=42, 566 | help="random seed for initialization") 567 | parser.add_argument('--gradient_accumulation_steps', 568 | type=int, 569 | default=1, 570 | help="Number of updates steps to accumulate before performing a backward/update pass.") 571 | parser.add_argument('--fp16', 572 | action='store_true', 573 | help="Whether to use 16-bit float precision instead of 32-bit") 574 | parser.add_argument('--loss_scale', 575 | type=float, default=0, 576 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 577 | "0 (default value): dynamic loss scaling.\n" 578 | "Positive power of 2: static loss scaling value.\n") 579 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 580 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 581 | args = parser.parse_args() 582 | 583 | if args.server_ip and args.server_port: 584 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 585 | import ptvsd 586 | print("Waiting for debugger attach") 587 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 588 | ptvsd.wait_for_attach() 589 | 590 | processors = { 591 | "ubuntu": UbuntuProcessor, 592 | "douban": DoubanProcessor, 593 | "ecd": UbuntuProcessor, 594 | } 595 | 596 | if args.local_rank == -1 or args.no_cuda: 597 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 598 | n_gpu = torch.cuda.device_count() 599 | else: 600 | torch.cuda.set_device(args.local_rank) 601 | device = torch.device("cuda", args.local_rank) 602 | n_gpu = 1 603 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 604 | torch.distributed.init_process_group(backend='nccl') 605 | 606 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 607 | datefmt = '%m/%d/%Y %H:%M:%S', 608 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 609 | 610 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 611 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 612 | 613 | if args.gradient_accumulation_steps < 1: 614 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 615 | args.gradient_accumulation_steps)) 616 | 617 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 618 | 619 | random.seed(args.seed) 620 | np.random.seed(args.seed) 621 | torch.manual_seed(args.seed) 622 | if n_gpu > 0: 623 | torch.cuda.manual_seed_all(args.seed) 624 | 625 | if not args.do_train and not args.do_eval: 626 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 627 | 628 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 629 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 630 | if not os.path.exists(args.output_dir): 631 | os.makedirs(args.output_dir) 632 | 633 | task_name = args.task_name.lower() 634 | 635 | if task_name not in processors: 636 | raise ValueError("Task not found: %s" % (task_name)) 637 | 638 | processor = processors[task_name]() 639 | 640 | label_list = processor.get_labels() 641 | num_labels = len(label_list) 642 | 643 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 644 | 645 | train_examples = None 646 | num_train_optimization_steps = None 647 | if args.do_train: 648 | train_examples = processor.get_train_examples(args.data_dir) 649 | num_train_optimization_steps = int( 650 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 651 | if args.local_rank != -1: 652 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 653 | 654 | # Prepare model 655 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 656 | 'distributed_{}'.format(args.local_rank)) 657 | model = BertForSequenceClassification.from_pretrained(args.bert_model, 658 | cache_dir=cache_dir, 659 | num_labels=num_labels) 660 | if args.fp16: 661 | model.half() 662 | model.to(device) 663 | if args.local_rank != -1: 664 | try: 665 | from apex.parallel import DistributedDataParallel as DDP 666 | except ImportError: 667 | raise ImportError( 668 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 669 | 670 | model = DDP(model) 671 | elif n_gpu > 1: 672 | model = torch.nn.DataParallel(model) 673 | 674 | # Prepare optimizer 675 | if args.do_train: 676 | param_optimizer = list(model.named_parameters()) 677 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 678 | optimizer_grouped_parameters = [ 679 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 680 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 681 | ] 682 | if args.fp16: 683 | try: 684 | from apex.optimizers import FP16_Optimizer 685 | from apex.optimizers import FusedAdam 686 | except ImportError: 687 | raise ImportError( 688 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 689 | 690 | optimizer = FusedAdam(optimizer_grouped_parameters, 691 | lr=args.learning_rate, 692 | bias_correction=False, 693 | max_grad_norm=1.0) 694 | if args.loss_scale == 0: 695 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 696 | else: 697 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 698 | warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, 699 | t_total=num_train_optimization_steps) 700 | 701 | else: 702 | optimizer = BertAdam(optimizer_grouped_parameters, 703 | lr=args.learning_rate, 704 | warmup=args.warmup_proportion, 705 | t_total=num_train_optimization_steps) 706 | 707 | global_step = 0 708 | nb_tr_steps = 0 709 | tr_loss = 0 710 | if args.do_train: 711 | cached_train_features_file = args.data_dir + '_{0}_{1}_{2}_{3}_{4}'.format( 712 | list(filter(None, args.bert_model.split('/'))).pop(), "train", str(args.task_name), 713 | str(args.max_seq_length), str(args.cache_flag)) 714 | train_features = None 715 | try: 716 | with open(cached_train_features_file, "rb") as reader: 717 | train_features = pickle.load(reader) 718 | except: 719 | train_features = convert_examples_to_features( 720 | train_examples, label_list, args.max_seq_length, tokenizer) 721 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 722 | logger.info(" Saving train features into cached file %s", cached_train_features_file) 723 | with open(cached_train_features_file, "wb") as writer: 724 | pickle.dump(train_features, writer) 725 | logger.info("***** Running training *****") 726 | logger.info(" Num examples = %d", len(train_examples)) 727 | logger.info(" Batch size = %d", args.train_batch_size) 728 | logger.info(" Num steps = %d", num_train_optimization_steps) 729 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 730 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 731 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 732 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 733 | 734 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 735 | if args.local_rank == -1: 736 | train_sampler = RandomSampler(train_data) 737 | else: 738 | train_sampler = DistributedSampler(train_data) 739 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 740 | 741 | 742 | eval_examples = processor.get_dev_examples(args.data_dir) 743 | cached_train_features_file = args.data_dir + '_{0}_{1}_{2}_{3}_{4}'.format( 744 | list(filter(None, args.bert_model.split('/'))).pop(), "valid", str(args.task_name), 745 | str(args.max_seq_length),str(args.cache_flag)) 746 | eval_features = None 747 | try: 748 | with open(cached_train_features_file, "rb") as reader: 749 | eval_features = pickle.load(reader) 750 | except: 751 | eval_features = convert_examples_to_features( 752 | eval_examples, label_list, args.max_seq_length, tokenizer) 753 | if args.local_rank == -1 or torch.distributed.get_rank() == 0: 754 | logger.info(" Saving eval features into cached file %s", cached_train_features_file) 755 | with open(cached_train_features_file, "wb") as writer: 756 | pickle.dump(eval_features, writer) 757 | logger.info("***** Running evaluation *****") 758 | logger.info(" Num examples = %d", len(eval_examples)) 759 | logger.info(" Batch size = %d", args.eval_batch_size) 760 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 761 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 762 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 763 | 764 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 765 | 766 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 767 | # Run prediction for full data 768 | eval_sampler = SequentialSampler(eval_data) 769 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 770 | 771 | model.train() 772 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 773 | 774 | tr_loss = 0 775 | nb_tr_examples, nb_tr_steps = 0, 0 776 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 777 | batch = tuple(t.to(device) for t in batch) 778 | input_ids, input_mask, segment_ids, label_ids = batch 779 | 780 | # define a new function to compute loss values for both output_modes 781 | logits = model(input_ids, segment_ids, input_mask, labels=None) 782 | 783 | loss_fct = CrossEntropyLoss() 784 | loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 785 | 786 | 787 | if n_gpu > 1: 788 | loss = loss.mean() # mean() to average on multi-gpu. 789 | if args.gradient_accumulation_steps > 1: 790 | loss = loss / args.gradient_accumulation_steps 791 | 792 | if args.fp16: 793 | optimizer.backward(loss) 794 | else: 795 | loss.backward() 796 | 797 | tr_loss += loss.item() 798 | nb_tr_examples += input_ids.size(0) 799 | nb_tr_steps += 1 800 | if (step + 1) % args.gradient_accumulation_steps == 0: 801 | if args.fp16: 802 | # modify learning rate with special warm up BERT uses 803 | # if args.fp16 is False, BertAdam is used that handles this automatically 804 | lr_this_step = args.learning_rate * warmup_linear.get_lr( 805 | global_step / num_train_optimization_steps, 806 | args.warmup_proportion) 807 | for param_group in optimizer.param_groups: 808 | param_group['lr'] = lr_this_step 809 | optimizer.step() 810 | optimizer.zero_grad() 811 | global_step += 1 812 | 813 | 814 | # Save a trained model, configuration and tokenizer 815 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 816 | 817 | # If we save using the predefined names, we can load using `from_pretrained` 818 | output_model_file = os.path.join(args.output_dir, str(epoch) + "_" + WEIGHTS_NAME) 819 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 820 | 821 | torch.save(model_to_save.state_dict(), output_model_file) 822 | model_to_save.config.to_json_file(output_config_file) 823 | tokenizer.save_vocabulary(args.output_dir) 824 | 825 | # Load a trained model and vocabulary that you have fine-tuned 826 | model_state_dict = torch.load(output_model_file) 827 | eval_model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) 828 | # tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 829 | eval_model.to(device) 830 | eval_model.eval() 831 | eval_loss = 0 832 | nb_eval_steps = 0 833 | preds = [] 834 | 835 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 836 | input_ids = input_ids.to(device) 837 | input_mask = input_mask.to(device) 838 | segment_ids = segment_ids.to(device) 839 | label_ids = label_ids.to(device) 840 | 841 | with torch.no_grad(): 842 | logits = eval_model(input_ids, segment_ids, input_mask, labels=None) 843 | 844 | # create eval loss and other metric required by the task 845 | 846 | loss_fct = CrossEntropyLoss() 847 | tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 848 | 849 | eval_loss += tmp_eval_loss.mean().item() 850 | nb_eval_steps += 1 851 | if len(preds) == 0: 852 | preds.append(logits.detach().cpu().numpy()) 853 | else: 854 | preds[0] = np.append( 855 | preds[0], logits.detach().cpu().numpy(), axis=0) 856 | 857 | eval_loss = eval_loss / nb_eval_steps 858 | preds = preds[0] 859 | # print(preds) 860 | result = compute_metrics(task_name, preds, all_label_ids.numpy()) 861 | loss = tr_loss / nb_tr_steps if args.do_train else None 862 | 863 | result['eval_loss'] = eval_loss 864 | result['global_step'] = global_step 865 | result['loss'] = loss 866 | 867 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 868 | with open(output_eval_file, "a") as writer: 869 | logger.info("***** Eval results *****") 870 | for key in sorted(result.keys()): 871 | logger.info(" %s = %s", key, str(result[key])) 872 | writer.write("%s = %s\n" % (key, str(result[key]))) 873 | 874 | else: 875 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 876 | model_state_dict = torch.load(output_model_file) 877 | model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, 878 | num_labels=num_labels) 879 | 880 | model.to(device) 881 | 882 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 883 | eval_examples = processor.get_test_examples(args.data_dir) 884 | eval_features = convert_examples_to_features( 885 | eval_examples, label_list, args.max_seq_length, tokenizer) 886 | logger.info("***** Running evaluation *****") 887 | logger.info(" Num examples = %d", len(eval_examples)) 888 | logger.info(" Batch size = %d", args.eval_batch_size) 889 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 890 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 891 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 892 | 893 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 894 | 895 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 896 | # Run prediction for full data 897 | eval_sampler = SequentialSampler(eval_data) 898 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 899 | 900 | model.eval() 901 | eval_loss = 0 902 | nb_eval_steps = 0 903 | preds = [] 904 | 905 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 906 | input_ids = input_ids.to(device) 907 | input_mask = input_mask.to(device) 908 | segment_ids = segment_ids.to(device) 909 | label_ids = label_ids.to(device) 910 | 911 | with torch.no_grad(): 912 | logits = model(input_ids, segment_ids, input_mask, labels=None) 913 | 914 | # create eval loss and other metric required by the task 915 | 916 | loss_fct = CrossEntropyLoss() 917 | tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) 918 | 919 | eval_loss += tmp_eval_loss.mean().item() 920 | nb_eval_steps += 1 921 | if len(preds) == 0: 922 | preds.append(logits.detach().cpu().numpy()) 923 | else: 924 | preds[0] = np.append( 925 | preds[0], logits.detach().cpu().numpy(), axis=0) 926 | 927 | eval_loss = eval_loss / nb_eval_steps 928 | preds = preds[0] 929 | #print(preds) 930 | result = compute_metrics(task_name, preds, all_label_ids.numpy()) 931 | loss = tr_loss/nb_tr_steps if args.do_train else None 932 | 933 | result['eval_loss'] = eval_loss 934 | result['global_step'] = global_step 935 | result['loss'] = loss 936 | 937 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 938 | with open(output_eval_file, "a") as writer: 939 | logger.info("***** Eval results *****") 940 | for key in sorted(result.keys()): 941 | logger.info(" %s = %s", key, str(result[key])) 942 | writer.write("%s = %s\n" % (key, str(result[key]))) 943 | 944 | if __name__ == "__main__": 945 | main() 946 | --------------------------------------------------------------------------------