├── requirements.txt ├── QA_model ├── utils.py ├── model_QuAC.py ├── model_CoQA.py ├── detail_model.py └── layers.py ├── download.sh ├── README.md ├── predict_CoQA.py ├── predict_QuAC.py ├── CoQA_eval.py ├── preprocess_QuAC.py ├── train_CoQA.py ├── preprocess_CoQA.py ├── train_QuAC.py └── general_utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | msgpack-python 4 | spacy 5 | allennlp 6 | torch 7 | -------------------------------------------------------------------------------- /QA_model/utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value.""" 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = self.sum / self.count 17 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Download QuAC 4 | mkdir -p QuAC_data 5 | wget https://s3.amazonaws.com/my89public/quac/train.json -O QuAC_data/train.json 6 | wget https://s3.amazonaws.com/my89public/quac/val.json -O QuAC_data/dev.json 7 | 8 | # Download CoQA 9 | mkdir -p CoQA 10 | wget https://worksheets.codalab.org/rest/bundles/0xe3674fd34560425786f97541ec91aeb8/contents/blob/ -O CoQA/train.json 11 | wget https://worksheets.codalab.org/rest/bundles/0xe254829ab81946198433c4da847fb485/contents/blob/ -O CoQA/dev.json 12 | 13 | # Download GloVe 14 | mkdir -p glove 15 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O glove/glove.840B.300d.zip 16 | unzip glove/glove.840B.300d.zip -d glove 17 | 18 | # Download CoVe 19 | wget https://s3.amazonaws.com/research.metamind.io/cove/wmtlstm-b142a7f2.pth -O glove/MT-LSTM.pth 20 | 21 | # Download SpaCy English language models 22 | python -m spacy download en 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlowQA 2 | 3 | This is our first attempt to make state-of-the-art single-turn QA models conversational. 4 | Feel free to build on top of our code to build an even stronger conversational QA model. 5 | 6 | For more details, please see: [FlowQA: Grasping Flow in History for Conversational Machine Comprehension](https://arxiv.org/abs/1810.06683) 7 | 8 | #### Step 1: 9 | perform the following: 10 | ```shell 11 | pip install -r requirements.txt 12 | ``` 13 | to install all dependent python packages. 14 | 15 | #### Step 2: 16 | download necessary files using: 17 | ```shell 18 | ./download.sh 19 | ``` 20 | 21 | #### Step 3: 22 | preprocess the data files using: 23 | ```shell 24 | python preprocess_QuAC.py 25 | python preprocess_CoQA.py 26 | ``` 27 | 28 | #### Step 4: 29 | run the training code using: 30 | ```shell 31 | python train_QuAC.py 32 | python train_CoQA.py 33 | ``` 34 | For naming the output model, you can do 35 | ```shell 36 | python train_OOOO.py --name XXX 37 | ``` 38 | Remove any answer marking by: 39 | ```shell 40 | python train_OOOO.py --explicit_dialog_ctx 0 41 | ``` 42 | `OOOO` is the name of the dataset (QuAC or CoQA). 43 | 44 | #### Step 5: 45 | Do prediction with answer thresholding using 46 | ```shell 47 | python predict_OOOO.py -m models_XXX/best_model.pt --show SS 48 | ``` 49 | `XXX` is the name you used during train.py. 50 | `SS` is the number of dialog examples to be shown. 51 | `OOOO` is the name of the dataset (QuAC or CoQA). 52 | -------------------------------------------------------------------------------- /predict_CoQA.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | import sys 5 | import random 6 | import string 7 | import logging 8 | import argparse 9 | from os.path import basename 10 | from shutil import copyfile 11 | from datetime import datetime 12 | from collections import Counter 13 | import torch 14 | import msgpack 15 | import pickle 16 | import pandas as pd 17 | import numpy as np 18 | from QA_model.model_CoQA import QAModel 19 | from CoQA_eval import CoQAEvaluator 20 | from general_utils import score, BatchGen_CoQA 21 | 22 | parser = argparse.ArgumentParser( 23 | description='Predict using a Dialog QA model.' 24 | ) 25 | parser.add_argument('--dev_dir', default='CoQA/') 26 | 27 | parser.add_argument('-o', '--output_dir', default='pred_out/') 28 | parser.add_argument('--number', type=int, default=-1, help='id of the current prediction') 29 | parser.add_argument('-m', '--model', default='', 30 | help='testing model pathname, e.g. "models/checkpoint_epoch_11.pt"') 31 | 32 | parser.add_argument('-bs', '--batch_size', default=1) 33 | 34 | parser.add_argument('--show', type=int, default=3) 35 | parser.add_argument('--seed', type=int, default=1023, 36 | help='random seed for data shuffling, dropout, etc.') 37 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(), 38 | help='whether to use GPU acceleration.') 39 | 40 | args = parser.parse_args() 41 | if args.model == '': 42 | print("model file is not provided") 43 | sys.exit(-1) 44 | if args.model[-3:] != '.pt': 45 | print("does not recognize the model file") 46 | sys.exit(-1) 47 | 48 | # create prediction output dir 49 | os.makedirs(args.output_dir, exist_ok=True) 50 | # count the number of prediction files 51 | if args.number == -1: 52 | args.number = len(os.listdir(args.output_dir))+1 53 | args.output = args.output_dir + 'pred' + str(args.number) + '.pckl' 54 | 55 | random.seed(args.seed) 56 | np.random.seed(args.seed) 57 | torch.manual_seed(args.seed) 58 | if args.cuda: 59 | torch.cuda.manual_seed_all(args.seed) 60 | 61 | log = logging.getLogger(__name__) 62 | log.setLevel(logging.DEBUG) 63 | ch = logging.StreamHandler(sys.stdout) 64 | ch.setLevel(logging.INFO) 65 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 66 | ch.setFormatter(formatter) 67 | log.addHandler(ch) 68 | 69 | def main(): 70 | log.info('[program starts.]') 71 | checkpoint = torch.load(args.model) 72 | opt = checkpoint['config'] 73 | opt['task_name'] = 'CoQA' 74 | opt['cuda'] = args.cuda 75 | opt['seed'] = args.seed 76 | if opt.get('do_hierarchical_query') is None: 77 | opt['do_hierarchical_query'] = False 78 | state_dict = checkpoint['state_dict'] 79 | log.info('[model loaded.]') 80 | 81 | test, test_embedding = load_dev_data(opt) 82 | model = QAModel(opt, state_dict = state_dict) 83 | CoQAEval = CoQAEvaluator("CoQA/dev.json") 84 | log.info('[Data loaded.]') 85 | 86 | model.setup_eval_embed(test_embedding) 87 | 88 | if args.cuda: 89 | model.cuda() 90 | 91 | batches = BatchGen_CoQA(test, batch_size=args.batch_size, evaluation=True, gpu=args.cuda, dialog_ctx=opt['explicit_dialog_ctx'], precompute_elmo=16 // args.batch_size) 92 | sample_idx = random.sample(range(len(batches)), args.show) 93 | 94 | with open("CoQA/dev.json", "r", encoding="utf8") as f: 95 | dev_data = json.load(f) 96 | 97 | list_of_ids = [] 98 | for article in dev_data['data']: 99 | id = article["id"] 100 | for Qs in article["questions"]: 101 | tid = Qs["turn_id"] 102 | list_of_ids.append((id, tid)) 103 | 104 | predictions = [] 105 | for i, batch in enumerate(batches): 106 | prediction = model.predict(batch) 107 | predictions.extend(prediction) 108 | 109 | if not (i in sample_idx): 110 | continue 111 | 112 | print("Story: ", batch[-4][0]) 113 | for j in range(len(batch[-2][0])): 114 | print("Q: ", batch[-2][0][j]) 115 | print("A: ", prediction[j]) 116 | print("Gold A: ", batch[-1][0][j]) 117 | print("---") 118 | print("") 119 | 120 | assert(len(list_of_ids) == len(predictions)) 121 | official_predictions = [] 122 | for ids, pred in zip(list_of_ids, predictions): 123 | official_predictions.append({ 124 | "id": ids[0], 125 | "turn_id": ids[1], 126 | "answer": pred}) 127 | with open("model_prediction.json", "w", encoding="utf8") as f: 128 | json.dump(official_predictions, f) 129 | 130 | f1 = CoQAEval.compute_turn_score_seq(predictions) 131 | log.warning("Test F1: {:.3f}".format(f1 * 100.0)) 132 | 133 | def load_dev_data(opt): # can be extended to true test set 134 | with open(os.path.join(args.dev_dir, 'dev_meta.msgpack'), 'rb') as f: 135 | meta = msgpack.load(f, encoding='utf8') 136 | embedding = torch.Tensor(meta['embedding']) 137 | assert opt['embedding_dim'] == embedding.size(1) 138 | 139 | with open(os.path.join(args.dev_dir, 'dev_data.msgpack'), 'rb') as f: 140 | data = msgpack.load(f, encoding='utf8') 141 | 142 | assert opt['num_features'] == len(data['context_features'][0][0]) + opt['explicit_dialog_ctx'] * 3 143 | 144 | dev = {'context': list(zip( 145 | data['context_ids'], 146 | data['context_tags'], 147 | data['context_ents'], 148 | data['context'], 149 | data['context_span'], 150 | data['1st_question'], 151 | data['context_tokenized'])), 152 | 'qa': list(zip( 153 | data['question_CID'], 154 | data['question_ids'], 155 | data['context_features'], 156 | data['answer_start'], 157 | data['answer_end'], 158 | data['rationale_start'], 159 | data['rationale_end'], 160 | data['answer_choice'], 161 | data['question'], 162 | data['answer'], 163 | data['question_tokenized'])) 164 | } 165 | 166 | return dev, embedding 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /predict_QuAC.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import random 5 | import string 6 | import logging 7 | import argparse 8 | from os.path import basename 9 | from shutil import copyfile 10 | from datetime import datetime 11 | from collections import Counter 12 | import torch 13 | import msgpack 14 | import pickle 15 | import pandas as pd 16 | import numpy as np 17 | from QA_model.model_QuAC import QAModel 18 | from general_utils import score, BatchGen_QuAC, find_best_score_and_thresh 19 | 20 | parser = argparse.ArgumentParser( 21 | description='Predict using a Dialog QA model.' 22 | ) 23 | parser.add_argument('--dev_dir', default='QuAC_data/') 24 | 25 | parser.add_argument('-o', '--output_dir', default='pred_out/') 26 | parser.add_argument('--number', type=int, default=-1, help='id of the current prediction') 27 | parser.add_argument('-m', '--model', default='', 28 | help='testing model pathname, e.g. "models/checkpoint_epoch_11.pt"') 29 | 30 | parser.add_argument('-bs', '--batch_size', type=int, default=4) 31 | parser.add_argument('--no_ans', type=float, default=0) 32 | parser.add_argument('--min_f1', type=float, default=0.4) 33 | 34 | parser.add_argument('--show', type=int, default=3) 35 | parser.add_argument('--seed', type=int, default=1023, 36 | help='random seed for data shuffling, dropout, etc.') 37 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(), 38 | help='whether to use GPU acceleration.') 39 | 40 | args = parser.parse_args() 41 | if args.model == '': 42 | print("model file is not provided") 43 | sys.exit(-1) 44 | if args.model[-3:] != '.pt': 45 | print("does not recognize the model file") 46 | sys.exit(-1) 47 | 48 | # create prediction output dir 49 | os.makedirs(args.output_dir, exist_ok=True) 50 | # count the number of prediction files 51 | if args.number == -1: 52 | args.number = len(os.listdir(args.output_dir))+1 53 | args.output = args.output_dir + 'pred' + str(args.number) + '.pckl' 54 | 55 | random.seed(args.seed) 56 | np.random.seed(args.seed) 57 | torch.manual_seed(args.seed) 58 | if args.cuda: 59 | torch.cuda.manual_seed_all(args.seed) 60 | 61 | log = logging.getLogger(__name__) 62 | log.setLevel(logging.DEBUG) 63 | ch = logging.StreamHandler(sys.stdout) 64 | ch.setLevel(logging.INFO) 65 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 66 | ch.setFormatter(formatter) 67 | log.addHandler(ch) 68 | 69 | def main(): 70 | log.info('[program starts.]') 71 | checkpoint = torch.load(args.model) 72 | opt = checkpoint['config'] 73 | opt['task_name'] = 'QuAC' 74 | opt['cuda'] = args.cuda 75 | opt['seed'] = args.seed 76 | if opt.get('disperse_flow') is None: 77 | opt['disperse_flow'] = False 78 | if opt.get('rationale_lambda') is None: 79 | opt['rationale_lambda'] = 0.0 80 | if opt.get('no_dialog_flow') is None: 81 | opt['no_dialog_flow'] = False 82 | if opt.get('do_hierarchical_query') is None: 83 | opt['do_hierarchical_query'] = False 84 | state_dict = checkpoint['state_dict'] 85 | log.info('[model loaded.]') 86 | 87 | test, test_embedding, test_answer = load_dev_data(opt) 88 | model = QAModel(opt, state_dict = state_dict) 89 | log.info('[Data loaded.]') 90 | 91 | model.setup_eval_embed(test_embedding) 92 | 93 | if args.cuda: 94 | model.cuda() 95 | 96 | batches = BatchGen_QuAC(test, batch_size=args.batch_size, evaluation=True, gpu=args.cuda, dialog_ctx=opt['explicit_dialog_ctx'], use_dialog_act=opt['use_dialog_act'], precompute_elmo=opt['elmo_batch_size'] // args.batch_size) 97 | sample_idx = random.sample(range(len(batches)), args.show) 98 | 99 | predictions = [] 100 | no_ans_scores = [] 101 | for i, batch in enumerate(batches): 102 | prediction, noans = model.predict(batch, No_Ans_Threshold=args.no_ans) 103 | predictions.extend(prediction) 104 | no_ans_scores.extend(noans) 105 | 106 | if not (i in sample_idx): 107 | continue 108 | 109 | print("Context: ", batch[-4][0]) 110 | for j in range(len(batch[-2][0])): 111 | print("Q: ", batch[-2][0][j]) 112 | print("A: ", prediction[0][j]) 113 | print(" True A: ", batch[-1][0][j], "| Follow up" if batch[-6][0][j].item() // 10 else "| Don't follow up") 114 | print(" Val. A: ", test_answer[args.batch_size * i][j]) 115 | print("") 116 | 117 | 118 | pred_out = {'predictions': predictions, 'no_ans_scores': no_ans_scores} 119 | with open(args.output, 'wb') as f: 120 | pickle.dump(pred_out, f) 121 | 122 | f1, h_f1, HEQ_Q, HEQ_D = score(predictions, test_answer, min_F1=args.min_f1) 123 | log.warning("Test F1: {:.2f}, HEQ_Q: {:.2f}, HEQ_D: {:.2f}".format(f1, HEQ_Q, HEQ_D)) 124 | 125 | def load_dev_data(opt): # can be extended to true test set 126 | with open(os.path.join(args.dev_dir, 'dev_meta.msgpack'), 'rb') as f: 127 | meta = msgpack.load(f, encoding='utf8') 128 | embedding = torch.Tensor(meta['embedding']) 129 | assert opt['embedding_dim'] == embedding.size(1) 130 | 131 | with open(os.path.join(args.dev_dir, 'dev_data.msgpack'), 'rb') as f: 132 | data = msgpack.load(f, encoding='utf8') 133 | 134 | assert opt['num_features'] == len(data['context_features'][0][0]) + opt['explicit_dialog_ctx'] * (opt['use_dialog_act']*3 + 2) 135 | 136 | dev = {'context': list(zip( 137 | data['context_ids'], 138 | data['context_tags'], 139 | data['context_ents'], 140 | data['context'], 141 | data['context_span'], 142 | data['1st_question'], 143 | data['context_tokenized'])), 144 | 'qa': list(zip( 145 | data['question_CID'], 146 | data['question_ids'], 147 | data['context_features'], 148 | data['answer_start'], 149 | data['answer_end'], 150 | data['answer_choice'], 151 | data['question'], 152 | data['answer'], 153 | data['question_tokenized'])) 154 | } 155 | 156 | dev_answer = [] 157 | for i, CID in enumerate(data['question_CID']): 158 | if len(dev_answer) <= CID: 159 | dev_answer.append([]) 160 | dev_answer[CID].append(data['all_answer'][i]) 161 | 162 | return dev, embedding, dev_answer 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /QA_model/model_QuAC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import logging 7 | 8 | from torch.nn import Parameter 9 | from torch.autograd import Variable 10 | from .utils import AverageMeter 11 | from .detail_model import FlowQA 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class QAModel(object): 17 | """ 18 | High level model that handles intializing the underlying network 19 | architecture, saving, updating examples, and predicting examples. 20 | """ 21 | 22 | def __init__(self, opt, embedding=None, state_dict=None): 23 | # Book-keeping. 24 | self.opt = opt 25 | self.updates = state_dict['updates'] if state_dict else 0 26 | self.eval_embed_transfer = True 27 | self.train_loss = AverageMeter() 28 | 29 | # Building network. 30 | self.network = FlowQA(opt, embedding) 31 | if state_dict: 32 | new_state = set(self.network.state_dict().keys()) 33 | for k in list(state_dict['network'].keys()): 34 | if k not in new_state: 35 | del state_dict['network'][k] 36 | self.network.load_state_dict(state_dict['network']) 37 | 38 | # Building optimizer. 39 | parameters = [p for p in self.network.parameters() if p.requires_grad] 40 | if opt['optimizer'] == 'sgd': 41 | self.optimizer = optim.SGD(parameters, opt['learning_rate'], 42 | momentum=opt['momentum'], 43 | weight_decay=opt['weight_decay']) 44 | elif opt['optimizer'] == 'adamax': 45 | self.optimizer = optim.Adamax(parameters, 46 | weight_decay=opt['weight_decay']) 47 | elif opt['optimizer'] == 'adadelta': 48 | self.optimizer = optim.Adadelta(parameters, rho=0.95, weight_decay=opt['weight_decay']) 49 | else: 50 | raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) 51 | if state_dict: 52 | self.optimizer.load_state_dict(state_dict['optimizer']) 53 | 54 | if opt['fix_embeddings']: 55 | wvec_size = 0 56 | else: 57 | wvec_size = (opt['vocab_size'] - opt['tune_partial']) * opt['embedding_dim'] 58 | self.total_param = sum([p.nelement() for p in parameters]) - wvec_size 59 | 60 | def update(self, batch): 61 | # Train mode 62 | self.network.train() 63 | torch.set_grad_enabled(True) 64 | 65 | # Transfer to GPU 66 | if self.opt['cuda']: 67 | inputs = [e.cuda(non_blocking=True) for e in batch[:9]] 68 | overall_mask = batch[9].cuda(non_blocking=True) 69 | 70 | answer_s = batch[10].cuda(non_blocking=True) 71 | answer_e = batch[11].cuda(non_blocking=True) 72 | answer_c = batch[12].cuda(non_blocking=True) 73 | else: 74 | inputs = [e for e in batch[:9]] 75 | overall_mask = batch[9] 76 | 77 | answer_s = batch[10] 78 | answer_e = batch[11] 79 | answer_c = batch[12] 80 | 81 | # Run forward 82 | # output: [batch_size, question_num, context_len], [batch_size, question_num] 83 | score_s, score_e, score_no_answ = self.network(*inputs) 84 | 85 | # Compute loss and accuracies 86 | loss = self.opt['elmo_lambda'] * (self.network.elmo.scalar_mix_0.scalar_parameters[0] ** 2 87 | + self.network.elmo.scalar_mix_0.scalar_parameters[1] ** 2 88 | + self.network.elmo.scalar_mix_0.scalar_parameters[2] ** 2) # ELMo L2 regularization 89 | all_no_answ = (answer_c == 0) 90 | answer_s.masked_fill_(all_no_answ, -100) # ignore_index is -100 in F.cross_entropy 91 | answer_e.masked_fill_(all_no_answ, -100) 92 | 93 | for i in range(overall_mask.size(0)): 94 | q_num = sum(overall_mask[i]) # the true question number for this sampled context 95 | 96 | target_s = answer_s[i, :q_num] # Size: q_num 97 | target_e = answer_e[i, :q_num] 98 | target_c = answer_c[i, :q_num] 99 | target_no_answ = all_no_answ[i, :q_num] 100 | 101 | # single_loss is averaged across q_num 102 | if self.opt['question_normalize']: 103 | single_loss = F.binary_cross_entropy_with_logits(score_no_answ[i, :q_num], target_no_answ.float()) * q_num.item() / 8.0 104 | single_loss = single_loss + F.cross_entropy(score_s[i, :q_num], target_s) * (q_num - sum(target_no_answ)).item() / 7.0 105 | single_loss = single_loss + F.cross_entropy(score_e[i, :q_num], target_e) * (q_num - sum(target_no_answ)).item() / 7.0 106 | else: 107 | single_loss = F.binary_cross_entropy_with_logits(score_no_answ[i, :q_num], target_no_answ.float()) \ 108 | + F.cross_entropy(score_s[i, :q_num], target_s) + F.cross_entropy(score_e[i, :q_num], target_e) 109 | 110 | loss = loss + (single_loss / overall_mask.size(0)) 111 | self.train_loss.update(loss.item(), overall_mask.size(0)) 112 | 113 | # Clear gradients and run backward 114 | self.optimizer.zero_grad() 115 | loss.backward() 116 | 117 | # Clip gradients 118 | torch.nn.utils.clip_grad_norm_(self.network.parameters(), 119 | self.opt['grad_clipping']) 120 | 121 | # Update parameters 122 | self.optimizer.step() 123 | self.updates += 1 124 | 125 | # Reset any partially fixed parameters (e.g. rare words) 126 | self.reset_embeddings() 127 | self.eval_embed_transfer = True 128 | 129 | def predict(self, batch, No_Ans_Threshold=None): 130 | # Eval mode 131 | self.network.eval() 132 | torch.set_grad_enabled(False) 133 | 134 | # Transfer trained embedding to evaluation embedding 135 | if self.eval_embed_transfer: 136 | self.update_eval_embed() 137 | self.eval_embed_transfer = False 138 | 139 | # Transfer to GPU 140 | if self.opt['cuda']: 141 | inputs = [e.cuda(non_blocking=True) for e in batch[:9]] 142 | else: 143 | inputs = [e for e in batch[:9]] 144 | 145 | # Run forward 146 | # output: [batch_size, question_num, context_len], [batch_size, question_num] 147 | score_s, score_e, score_no_answ = self.network(*inputs) 148 | score_s = F.softmax(score_s, dim=2) 149 | score_e = F.softmax(score_e, dim=2) 150 | 151 | # Transfer to CPU/normal tensors for numpy ops 152 | score_s = score_s.data.cpu() 153 | score_e = score_e.data.cpu() 154 | score_no_answ = score_no_answ.data.cpu() 155 | 156 | # Get argmax text spans 157 | text = batch[13] 158 | spans = batch[14] 159 | overall_mask = batch[9] 160 | 161 | predictions, no_ans_scores = [], [] 162 | max_len = self.opt['max_len'] or score_s.size(2) 163 | 164 | for i in range(overall_mask.size(0)): 165 | dialog_pred, dialog_noans = [], [] 166 | 167 | for j in range(overall_mask.size(1)): 168 | if overall_mask[i, j] == 0: # this dialog has ended 169 | break 170 | 171 | dialog_noans.append(score_no_answ[i, j].item()) 172 | if No_Ans_Threshold is not None and score_no_answ[i, j] > No_Ans_Threshold: 173 | dialog_pred.append("CANNOTANSWER") 174 | else: 175 | scores = torch.ger(score_s[i, j], score_e[i, j]) 176 | scores.triu_().tril_(max_len - 1) 177 | scores = scores.numpy() 178 | s_idx, e_idx = np.unravel_index(np.argmax(scores), scores.shape) 179 | 180 | s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1] 181 | dialog_pred.append(text[i][s_offset:e_offset]) 182 | 183 | predictions.append(dialog_pred) 184 | no_ans_scores.append(dialog_noans) 185 | 186 | return predictions, no_ans_scores # list of (list of strings), list of (list of floats) 187 | 188 | # allow the evaluation embedding be larger than training embedding 189 | # this is helpful if we have pretrained word embeddings 190 | def setup_eval_embed(self, eval_embed, padding_idx = 0): 191 | # eval_embed should be a supermatrix of training embedding 192 | self.network.eval_embed = nn.Embedding(eval_embed.size(0), 193 | eval_embed.size(1), 194 | padding_idx = padding_idx) 195 | self.network.eval_embed.weight.data = eval_embed 196 | for p in self.network.eval_embed.parameters(): 197 | p.requires_grad = False 198 | self.eval_embed_transfer = True 199 | 200 | if hasattr(self.network, 'CoVe'): 201 | self.network.CoVe.setup_eval_embed(eval_embed) 202 | 203 | def update_eval_embed(self): 204 | # update evaluation embedding to trained embedding 205 | if self.opt['tune_partial'] > 0: 206 | offset = self.opt['tune_partial'] 207 | self.network.eval_embed.weight.data[0:offset] \ 208 | = self.network.embedding.weight.data[0:offset] 209 | else: 210 | offset = 10 211 | self.network.eval_embed.weight.data[0:offset] \ 212 | = self.network.embedding.weight.data[0:offset] 213 | 214 | def reset_embeddings(self): 215 | # Reset fixed embeddings to original value 216 | if self.opt['tune_partial'] > 0: 217 | offset = self.opt['tune_partial'] 218 | if offset < self.network.embedding.weight.data.size(0): 219 | self.network.embedding.weight.data[offset:] \ 220 | = self.network.fixed_embedding 221 | 222 | def get_pretrain(self, state_dict): 223 | own_state = self.network.state_dict() 224 | for name, param in state_dict.items(): 225 | if name not in own_state: 226 | continue 227 | if isinstance(param, Parameter): 228 | param = param.data 229 | try: 230 | own_state[name].copy_(param) 231 | except: 232 | print("Skip", name) 233 | continue 234 | 235 | def save(self, filename, epoch): 236 | params = { 237 | 'state_dict': { 238 | 'network': self.network.state_dict(), 239 | 'optimizer': self.optimizer.state_dict(), 240 | 'updates': self.updates # how many updates 241 | }, 242 | 'config': self.opt, 243 | 'epoch': epoch 244 | } 245 | try: 246 | torch.save(params, filename) 247 | logger.info('model saved to {}'.format(filename)) 248 | except BaseException: 249 | logger.warn('[ WARN: Saving failed... continuing anyway. ]') 250 | 251 | def save_for_predict(self, filename, epoch): 252 | network_state = dict([(k, v) for k, v in self.network.state_dict().items() if k[0:4] != 'CoVe']) 253 | if 'eval_embed.weight' in network_state: 254 | del network_state['eval_embed.weight'] 255 | if 'fixed_embedding' in network_state: 256 | del network_state['fixed_embedding'] 257 | params = { 258 | 'state_dict': {'network': network_state}, 259 | 'config': self.opt, 260 | } 261 | try: 262 | torch.save(params, filename) 263 | logger.info('model saved to {}'.format(filename)) 264 | except BaseException: 265 | logger.warn('[ WARN: Saving failed... continuing anyway. ]') 266 | 267 | def cuda(self): 268 | self.network.cuda() 269 | -------------------------------------------------------------------------------- /CoQA_eval.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for CoQA. 2 | 3 | The code is based partially on SQuAD 2.0 evaluation script. 4 | """ 5 | import argparse 6 | import json 7 | import re 8 | import string 9 | import sys 10 | 11 | from collections import Counter, OrderedDict 12 | 13 | OPTS = None 14 | 15 | out_domain = ["reddit", "science"] 16 | in_domain = ["mctest", "gutenberg", "race", "cnn", "wikipedia"] 17 | domain_mappings = {"mctest":"children_stories", "gutenberg":"literature", "race":"mid-high_school", "cnn":"news", "wikipedia":"wikipedia", "science":"science", "reddit":"reddit"} 18 | 19 | 20 | class CoQAEvaluator(): 21 | 22 | def __init__(self, gold_file): 23 | self.gold_data, self.gold_list, self.id_to_source = CoQAEvaluator.gold_answers_to_dict(gold_file) 24 | 25 | @staticmethod 26 | def gold_answers_to_dict(gold_file): 27 | dataset = json.load(open(gold_file)) 28 | gold_dict = {} 29 | gold_list = [] 30 | id_to_source = {} 31 | for story in dataset['data']: 32 | source = story['source'] 33 | story_id = story['id'] 34 | id_to_source[story_id] = source 35 | questions = story['questions'] 36 | multiple_answers = [story['answers']] 37 | multiple_answers += story['additional_answers'].values() 38 | for i, qa in enumerate(questions): 39 | qid = qa['turn_id'] 40 | if i + 1 != qid: 41 | sys.stderr.write("Turn id should match index {}: {}\n".format(i + 1, qa)) 42 | gold_answers = [] 43 | for answers in multiple_answers: 44 | answer = answers[i] 45 | if qid != answer['turn_id']: 46 | sys.stderr.write("Question turn id does match answer: {} {}\n".format(qa, answer)) 47 | gold_answers.append(answer['input_text']) 48 | key = (story_id, qid) 49 | if key in gold_dict: 50 | sys.stderr.write("Gold file has duplicate stories: {}".format(source)) 51 | gold_dict[key] = gold_answers 52 | gold_list.append(gold_answers) 53 | return gold_dict, gold_list, id_to_source 54 | 55 | @staticmethod 56 | def preds_to_dict(pred_file): 57 | preds = json.load(open(pred_file)) 58 | pred_dict = {} 59 | for pred in preds: 60 | pred_dict[(pred['id'], pred['turn_id'])] = pred['answer'] 61 | return pred_dict 62 | 63 | @staticmethod 64 | def normalize_answer(s): 65 | """Lower text and remove punctuation, storys and extra whitespace.""" 66 | 67 | def remove_articles(text): 68 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 69 | return re.sub(regex, ' ', text) 70 | 71 | def white_space_fix(text): 72 | return ' '.join(text.split()) 73 | 74 | def remove_punc(text): 75 | exclude = set(string.punctuation) 76 | return ''.join(ch for ch in text if ch not in exclude) 77 | 78 | def lower(text): 79 | return text.lower() 80 | 81 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 82 | 83 | @staticmethod 84 | def get_tokens(s): 85 | if not s: return [] 86 | return CoQAEvaluator.normalize_answer(s).split() 87 | 88 | @staticmethod 89 | def compute_exact(a_gold, a_pred): 90 | return int(CoQAEvaluator.normalize_answer(a_gold) == CoQAEvaluator.normalize_answer(a_pred)) 91 | 92 | @staticmethod 93 | def compute_f1(a_gold, a_pred): 94 | gold_toks = CoQAEvaluator.get_tokens(a_gold) 95 | pred_toks = CoQAEvaluator.get_tokens(a_pred) 96 | common = Counter(gold_toks) & Counter(pred_toks) 97 | num_same = sum(common.values()) 98 | if len(gold_toks) == 0 or len(pred_toks) == 0: 99 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 100 | return int(gold_toks == pred_toks) 101 | if num_same == 0: 102 | return 0 103 | precision = 1.0 * num_same / len(pred_toks) 104 | recall = 1.0 * num_same / len(gold_toks) 105 | f1 = (2 * precision * recall) / (precision + recall) 106 | return f1 107 | 108 | @staticmethod 109 | def _compute_turn_score(a_gold_list, a_pred): 110 | f1_sum = 0.0 111 | em_sum = 0.0 112 | if len(a_gold_list) > 1: 113 | for i in range(len(a_gold_list)): 114 | # exclude the current answer 115 | gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:] 116 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in gold_answers) 117 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in gold_answers) 118 | else: 119 | em_sum += max(CoQAEvaluator.compute_exact(a, a_pred) for a in a_gold_list) 120 | f1_sum += max(CoQAEvaluator.compute_f1(a, a_pred) for a in a_gold_list) 121 | 122 | return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))} 123 | 124 | def compute_turn_score_seq(self, preds): 125 | ''' Added by Hsin-Yuan Huang for sequential evaluation. ''' 126 | assert(len(self.gold_list) == len(preds)) 127 | 128 | score = 0 129 | for i in range(len(preds)): 130 | score += CoQAEvaluator._compute_turn_score(self.gold_list[i], preds[i])['f1'] 131 | return score / len(preds) 132 | 133 | def compute_turn_score(self, story_id, turn_id, a_pred): 134 | ''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. ''' 135 | key = (story_id, turn_id) 136 | a_gold_list = self.gold_data[key] 137 | return CoQAEvaluator._compute_turn_score(a_gold_list, a_pred) 138 | 139 | def get_raw_scores(self, pred_data): 140 | ''''Returns a dict with score with each turn prediction''' 141 | exact_scores = {} 142 | f1_scores = {} 143 | for story_id, turn_id in self.gold_data: 144 | key = (story_id, turn_id) 145 | if key not in pred_data: 146 | sys.stderr.write('Missing prediction for {} and turn_id: {}\n'.format(story_id, turn_id)) 147 | continue 148 | a_pred = pred_data[key] 149 | scores = self.compute_turn_score(story_id, turn_id, a_pred) 150 | # Take max over all gold answers 151 | exact_scores[key] = scores['em'] 152 | f1_scores[key] = scores['f1'] 153 | return exact_scores, f1_scores 154 | 155 | def get_raw_scores_human(self): 156 | ''''Returns a dict with score for each turn''' 157 | exact_scores = {} 158 | f1_scores = {} 159 | for story_id, turn_id in self.gold_data: 160 | key = (story_id, turn_id) 161 | f1_sum = 0.0 162 | em_sum = 0.0 163 | if len(self.gold_data[key]) > 1: 164 | for i in range(len(self.gold_data[key])): 165 | # exclude the current answer 166 | gold_answers = self.gold_data[key][0:i] + self.gold_data[key][i + 1:] 167 | em_sum += max(CoQAEvaluator.compute_exact(a, self.gold_data[key][i]) for a in gold_answers) 168 | f1_sum += max(CoQAEvaluator.compute_f1(a, self.gold_data[key][i]) for a in gold_answers) 169 | else: 170 | exit("Gold answers should be multiple: {}={}".format(key, self.gold_data[key])) 171 | exact_scores[key] = em_sum / len(self.gold_data[key]) 172 | f1_scores[key] = f1_sum / len(self.gold_data[key]) 173 | return exact_scores, f1_scores 174 | 175 | def human_performance(self): 176 | exact_scores, f1_scores = self.get_raw_scores_human() 177 | return self.get_domain_scores(exact_scores, f1_scores) 178 | 179 | def model_performance(self, pred_data): 180 | exact_scores, f1_scores = self.get_raw_scores(pred_data) 181 | return self.get_domain_scores(exact_scores, f1_scores) 182 | 183 | def get_domain_scores(self, exact_scores, f1_scores): 184 | sources = {} 185 | for source in in_domain + out_domain: 186 | sources[source] = Counter() 187 | 188 | for story_id, turn_id in self.gold_data: 189 | key = (story_id, turn_id) 190 | source = self.id_to_source[story_id] 191 | sources[source]['em_total'] += exact_scores.get(key, 0) 192 | sources[source]['f1_total'] += f1_scores.get(key, 0) 193 | sources[source]['turn_count'] += 1 194 | 195 | scores = OrderedDict() 196 | in_domain_em_total = 0.0 197 | in_domain_f1_total = 0.0 198 | in_domain_turn_count = 0 199 | 200 | out_domain_em_total = 0.0 201 | out_domain_f1_total = 0.0 202 | out_domain_turn_count = 0 203 | 204 | for source in in_domain + out_domain: 205 | domain = domain_mappings[source] 206 | scores[domain] = {} 207 | scores[domain]['em'] = round(sources[source]['em_total'] / max(1, sources[source]['turn_count']) * 100, 1) 208 | scores[domain]['f1'] = round(sources[source]['f1_total'] / max(1, sources[source]['turn_count']) * 100, 1) 209 | scores[domain]['turns'] = sources[source]['turn_count'] 210 | if source in in_domain: 211 | in_domain_em_total += sources[source]['em_total'] 212 | in_domain_f1_total += sources[source]['f1_total'] 213 | in_domain_turn_count += sources[source]['turn_count'] 214 | elif source in out_domain: 215 | out_domain_em_total += sources[source]['em_total'] 216 | out_domain_f1_total += sources[source]['f1_total'] 217 | out_domain_turn_count += sources[source]['turn_count'] 218 | 219 | scores["in_domain"] = {'em': round(in_domain_em_total / max(1, in_domain_turn_count) * 100, 1), 220 | 'f1': round(in_domain_f1_total / max(1, in_domain_turn_count) * 100, 1), 221 | 'turns': in_domain_turn_count} 222 | scores["out_domain"] = {'em': round(out_domain_em_total / max(1, out_domain_turn_count) * 100, 1), 223 | 'f1': round(out_domain_f1_total / max(1, out_domain_turn_count) * 100, 1), 224 | 'turns': out_domain_turn_count} 225 | 226 | em_total = in_domain_em_total + out_domain_em_total 227 | f1_total = in_domain_f1_total + out_domain_f1_total 228 | turn_count = in_domain_turn_count + out_domain_turn_count 229 | scores["overall"] = {'em': round(em_total / max(1, turn_count) * 100, 1), 230 | 'f1': round(f1_total / max(1, turn_count) * 100, 1), 231 | 'turns': turn_count} 232 | 233 | return scores 234 | 235 | def parse_args(): 236 | parser = argparse.ArgumentParser('Official evaluation script for CoQA.') 237 | parser.add_argument('--data-file', dest="data_file", help='Input data JSON file.') 238 | parser.add_argument('--pred-file', dest="pred_file", help='Model predictions.') 239 | parser.add_argument('--out-file', '-o', metavar='eval.json', 240 | help='Write accuracy metrics to file (default is stdout).') 241 | parser.add_argument('--verbose', '-v', action='store_true') 242 | parser.add_argument('--human', dest="human", action='store_true') 243 | if len(sys.argv) == 1: 244 | parser.print_help() 245 | sys.exit(1) 246 | return parser.parse_args() 247 | 248 | def main(): 249 | evaluator = CoQAEvaluator(OPTS.data_file) 250 | 251 | if OPTS.human: 252 | print(json.dumps(evaluator.human_performance(), indent=2)) 253 | 254 | if OPTS.pred_file: 255 | with open(OPTS.pred_file) as f: 256 | pred_data = CoQAEvaluator.preds_to_dict(OPTS.pred_file) 257 | print(json.dumps(evaluator.model_performance(pred_data), indent=2)) 258 | 259 | if __name__ == '__main__': 260 | OPTS = parse_args() 261 | main() 262 | -------------------------------------------------------------------------------- /QA_model/model_CoQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import logging 7 | 8 | from torch.nn import Parameter 9 | from torch.autograd import Variable 10 | from .utils import AverageMeter 11 | from .detail_model import FlowQA 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class QAModel(object): 17 | """ 18 | High level model that handles intializing the underlying network 19 | architecture, saving, updating examples, and predicting examples. 20 | """ 21 | 22 | def __init__(self, opt, embedding=None, state_dict=None): 23 | # Book-keeping. 24 | self.opt = opt 25 | self.updates = state_dict['updates'] if state_dict else 0 26 | self.eval_embed_transfer = True 27 | self.train_loss = AverageMeter() 28 | 29 | # Building network. 30 | self.network = FlowQA(opt, embedding) 31 | if state_dict: 32 | new_state = set(self.network.state_dict().keys()) 33 | for k in list(state_dict['network'].keys()): 34 | if k not in new_state: 35 | del state_dict['network'][k] 36 | self.network.load_state_dict(state_dict['network']) 37 | 38 | # Building optimizer. 39 | parameters = [p for p in self.network.parameters() if p.requires_grad] 40 | if opt['optimizer'] == 'sgd': 41 | self.optimizer = optim.SGD(parameters, opt['learning_rate'], 42 | momentum=opt['momentum'], 43 | weight_decay=opt['weight_decay']) 44 | elif opt['optimizer'] == 'adamax': 45 | self.optimizer = optim.Adamax(parameters, 46 | weight_decay=opt['weight_decay']) 47 | elif opt['optimizer'] == 'adadelta': 48 | self.optimizer = optim.Adadelta(parameters, rho=0.95, weight_decay=opt['weight_decay']) 49 | else: 50 | raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) 51 | if state_dict: 52 | self.optimizer.load_state_dict(state_dict['optimizer']) 53 | 54 | if opt['fix_embeddings']: 55 | wvec_size = 0 56 | else: 57 | wvec_size = (opt['vocab_size'] - opt['tune_partial']) * opt['embedding_dim'] 58 | self.total_param = sum([p.nelement() for p in parameters]) - wvec_size 59 | 60 | def update(self, batch): 61 | # Train mode 62 | self.network.train() 63 | torch.set_grad_enabled(True) 64 | 65 | # Transfer to GPU 66 | if self.opt['cuda']: 67 | inputs = [e.cuda(non_blocking=True) for e in batch[:9]] 68 | overall_mask = batch[9].cuda(non_blocking=True) 69 | 70 | answer_s = batch[10].cuda(non_blocking=True) 71 | answer_e = batch[11].cuda(non_blocking=True) 72 | answer_c = batch[12].cuda(non_blocking=True) 73 | rationale_s = batch[13].cuda(non_blocking=True) 74 | rationale_e = batch[14].cuda(non_blocking=True) 75 | else: 76 | inputs = [e for e in batch[:9]] 77 | overall_mask = batch[9] 78 | 79 | answer_s = batch[10] 80 | answer_e = batch[11] 81 | answer_c = batch[12] 82 | rationale_s = batch[13] 83 | rationale_e = batch[14] 84 | 85 | # Run forward 86 | # output: [batch_size, question_num, context_len], [batch_size, question_num] 87 | score_s, score_e, score_c = self.network(*inputs) 88 | 89 | # Compute loss and accuracies 90 | loss = self.opt['elmo_lambda'] * (self.network.elmo.scalar_mix_0.scalar_parameters[0] ** 2 91 | + self.network.elmo.scalar_mix_0.scalar_parameters[1] ** 2 92 | + self.network.elmo.scalar_mix_0.scalar_parameters[2] ** 2) # ELMo L2 regularization 93 | all_no_span = (answer_c != 3) 94 | answer_s.masked_fill_(all_no_span, -100) # ignore_index is -100 in F.cross_entropy 95 | answer_e.masked_fill_(all_no_span, -100) 96 | rationale_s.masked_fill_(all_no_span, -100) # ignore_index is -100 in F.cross_entropy 97 | rationale_e.masked_fill_(all_no_span, -100) 98 | 99 | for i in range(overall_mask.size(0)): 100 | q_num = sum(overall_mask[i]) # the true question number for this sampled context 101 | 102 | target_s = answer_s[i, :q_num] # Size: q_num 103 | target_e = answer_e[i, :q_num] 104 | target_c = answer_c[i, :q_num] 105 | target_s_r = rationale_s[i, :q_num] 106 | target_e_r = rationale_e[i, :q_num] 107 | target_no_span = all_no_span[i, :q_num] 108 | 109 | # single_loss is averaged across q_num 110 | single_loss = (F.cross_entropy(score_c[i, :q_num], target_c) * q_num.item() / 15.0 111 | + F.cross_entropy(score_s[i, :q_num], target_s) * (q_num - sum(target_no_span)).item() / 12.0 112 | + F.cross_entropy(score_e[i, :q_num], target_e) * (q_num - sum(target_no_span)).item() / 12.0) 113 | #+ self.opt['rationale_lambda'] * F.cross_entropy(score_s_r[i, :q_num], target_s_r) * (q_num - sum(target_no_span)).item() / 12.0 114 | #+ self.opt['rationale_lambda'] * F.cross_entropy(score_e_r[i, :q_num], target_e_r) * (q_num - sum(target_no_span)).item() / 12.0) 115 | 116 | loss = loss + (single_loss / overall_mask.size(0)) 117 | self.train_loss.update(loss.item(), overall_mask.size(0)) 118 | 119 | # Clear gradients and run backward 120 | self.optimizer.zero_grad() 121 | loss.backward() 122 | 123 | # Clip gradients 124 | torch.nn.utils.clip_grad_norm_(self.network.parameters(), 125 | self.opt['grad_clipping']) 126 | 127 | # Update parameters 128 | self.optimizer.step() 129 | self.updates += 1 130 | 131 | # Reset any partially fixed parameters (e.g. rare words) 132 | self.reset_embeddings() 133 | self.eval_embed_transfer = True 134 | 135 | def predict(self, batch): 136 | # Eval mode 137 | self.network.eval() 138 | torch.set_grad_enabled(False) 139 | 140 | # Transfer trained embedding to evaluation embedding 141 | if self.eval_embed_transfer: 142 | self.update_eval_embed() 143 | self.eval_embed_transfer = False 144 | 145 | # Transfer to GPU 146 | if self.opt['cuda']: 147 | inputs = [e.cuda(non_blocking=True) for e in batch[:9]] 148 | else: 149 | inputs = [e for e in batch[:9]] 150 | 151 | # Run forward 152 | # output: [batch_size, question_num, context_len], [batch_size, question_num] 153 | score_s, score_e, score_c = self.network(*inputs) 154 | score_s = F.softmax(score_s, dim=2) 155 | score_e = F.softmax(score_e, dim=2) 156 | 157 | # Transfer to CPU/normal tensors for numpy ops 158 | score_s = score_s.data.cpu() 159 | score_e = score_e.data.cpu() 160 | score_c = score_c.data.cpu() 161 | 162 | # Get argmax text spans 163 | text = batch[-4] 164 | spans = batch[-3] 165 | overall_mask = batch[9] 166 | 167 | predictions = [] 168 | max_len = self.opt['max_len'] or score_s.size(2) 169 | 170 | for i in range(overall_mask.size(0)): 171 | for j in range(overall_mask.size(1)): 172 | if overall_mask[i, j] == 0: # this dialog has ended 173 | break 174 | 175 | ans_type = np.argmax(score_c[i, j]) 176 | 177 | if ans_type == 0: 178 | predictions.append("unknown") 179 | elif ans_type == 1: 180 | predictions.append("Yes") 181 | elif ans_type == 2: 182 | predictions.append("No") 183 | else: 184 | scores = torch.ger(score_s[i, j], score_e[i, j]) 185 | scores.triu_().tril_(max_len - 1) 186 | scores = scores.numpy() 187 | s_idx, e_idx = np.unravel_index(np.argmax(scores), scores.shape) 188 | 189 | s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1] 190 | predictions.append(text[i][s_offset:e_offset]) 191 | 192 | return predictions # list of (list of strings) 193 | 194 | # allow the evaluation embedding be larger than training embedding 195 | # this is helpful if we have pretrained word embeddings 196 | def setup_eval_embed(self, eval_embed, padding_idx = 0): 197 | # eval_embed should be a supermatrix of training embedding 198 | self.network.eval_embed = nn.Embedding(eval_embed.size(0), 199 | eval_embed.size(1), 200 | padding_idx = padding_idx) 201 | self.network.eval_embed.weight.data = eval_embed 202 | for p in self.network.eval_embed.parameters(): 203 | p.requires_grad = False 204 | self.eval_embed_transfer = True 205 | 206 | if hasattr(self.network, 'CoVe'): 207 | self.network.CoVe.setup_eval_embed(eval_embed) 208 | 209 | def update_eval_embed(self): 210 | # update evaluation embedding to trained embedding 211 | if self.opt['tune_partial'] > 0: 212 | offset = self.opt['tune_partial'] 213 | self.network.eval_embed.weight.data[0:offset] \ 214 | = self.network.embedding.weight.data[0:offset] 215 | else: 216 | offset = 10 217 | self.network.eval_embed.weight.data[0:offset] \ 218 | = self.network.embedding.weight.data[0:offset] 219 | 220 | def reset_embeddings(self): 221 | # Reset fixed embeddings to original value 222 | if self.opt['tune_partial'] > 0: 223 | offset = self.opt['tune_partial'] 224 | if offset < self.network.embedding.weight.data.size(0): 225 | self.network.embedding.weight.data[offset:] \ 226 | = self.network.fixed_embedding 227 | 228 | def get_pretrain(self, state_dict): 229 | own_state = self.network.state_dict() 230 | for name, param in state_dict.items(): 231 | if name not in own_state: 232 | continue 233 | if isinstance(param, Parameter): 234 | param = param.data 235 | try: 236 | own_state[name].copy_(param) 237 | except: 238 | print("Skip", name) 239 | continue 240 | 241 | def save(self, filename, epoch): 242 | params = { 243 | 'state_dict': { 244 | 'network': self.network.state_dict(), 245 | 'optimizer': self.optimizer.state_dict(), 246 | 'updates': self.updates # how many updates 247 | }, 248 | 'config': self.opt, 249 | 'epoch': epoch 250 | } 251 | try: 252 | torch.save(params, filename) 253 | logger.info('model saved to {}'.format(filename)) 254 | except BaseException: 255 | logger.warn('[ WARN: Saving failed... continuing anyway. ]') 256 | 257 | def save_for_predict(self, filename, epoch): 258 | network_state = dict([(k, v) for k, v in self.network.state_dict().items() if k[0:4] != 'CoVe']) 259 | if 'eval_embed.weight' in network_state: 260 | del network_state['eval_embed.weight'] 261 | if 'fixed_embedding' in network_state: 262 | del network_state['fixed_embedding'] 263 | params = { 264 | 'state_dict': {'network': network_state}, 265 | 'config': self.opt, 266 | } 267 | try: 268 | torch.save(params, filename) 269 | logger.info('model saved to {}'.format(filename)) 270 | except BaseException: 271 | logger.warn('[ WARN: Saving failed... continuing anyway. ]') 272 | 273 | def cuda(self): 274 | self.network.cuda() 275 | -------------------------------------------------------------------------------- /preprocess_QuAC.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import spacy 4 | import msgpack 5 | import unicodedata 6 | import numpy as np 7 | import pandas as pd 8 | import argparse 9 | import collections 10 | import multiprocessing 11 | import logging 12 | import random 13 | from allennlp.modules.elmo import batch_to_ids 14 | from general_utils import flatten_json, normalize_text, build_embedding, load_glove_vocab, pre_proc, get_context_span, find_answer_span, feature_gen, token2id 15 | 16 | parser = argparse.ArgumentParser( 17 | description='Preprocessing train + dev files, about 20 minutes to run on Servers.' 18 | ) 19 | parser.add_argument('--wv_file', default='glove/glove.840B.300d.txt', 20 | help='path to word vector file.') 21 | parser.add_argument('--wv_dim', type=int, default=300, 22 | help='word vector dimension.') 23 | parser.add_argument('--sort_all', action='store_true', 24 | help='sort the vocabulary by frequencies of all words.' 25 | 'Otherwise consider question words first.') 26 | parser.add_argument('--threads', type=int, default=multiprocessing.cpu_count(), 27 | help='number of threads for preprocessing.') 28 | parser.add_argument('--no_match', action='store_true', 29 | help='do not extract the three exact matching features.') 30 | parser.add_argument('--seed', type=int, default=1023, 31 | help='random seed for data shuffling, embedding init, etc.') 32 | 33 | 34 | args = parser.parse_args() 35 | trn_file = 'QuAC_data/train.json' 36 | dev_file = 'QuAC_data/dev.json' 37 | wv_file = args.wv_file 38 | wv_dim = args.wv_dim 39 | nlp = spacy.load('en', disable=['parser']) 40 | 41 | random.seed(args.seed) 42 | np.random.seed(args.seed) 43 | 44 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG, 45 | datefmt='%m/%d/%Y %I:%M:%S') 46 | log = logging.getLogger(__name__) 47 | 48 | log.info('start data preparing... (using {} threads)'.format(args.threads)) 49 | 50 | glove_vocab = load_glove_vocab(wv_file, wv_dim) # return a "set" of vocabulary 51 | log.info('glove loaded.') 52 | 53 | #=============================================================== 54 | #=================== Work on training data ===================== 55 | #=============================================================== 56 | 57 | def proc_train(ith, article): 58 | rows = [] 59 | 60 | for paragraph in article['paragraphs']: 61 | context = paragraph['context'] 62 | for qa in paragraph['qas']: 63 | question = qa['question'] 64 | answers = qa['orig_answer'] 65 | 66 | answer = answers['text'] 67 | answer_start = answers['answer_start'] 68 | answer_end = answers['answer_start'] + len(answers['text']) 69 | answer_choice = 0 if answer == 'CANNOTANSWER' else\ 70 | 1 if qa['yesno'] == 'y' else\ 71 | 2 if qa['yesno'] == 'n' else\ 72 | 3 # Not a yes/no question 73 | if answer_choice != 0: 74 | """ 75 | 0: Do not ask a follow up question! 76 | 1: Definitely ask a follow up question! 77 | 2: Not too important, but you can ask a follow up. 78 | """ 79 | answer_choice += 10 * (0 if qa['followup'] == "n" else\ 80 | 1 if qa['followup'] == "y" else\ 81 | 2) 82 | else: 83 | answer_start, answer_end = -1, -1 84 | rows.append((ith, question, answer, answer_start, answer_end, answer_choice)) 85 | return rows, context 86 | 87 | train, train_context = flatten_json(trn_file, proc_train) 88 | train = pd.DataFrame(train, columns=['context_idx', 'question', 'answer', 89 | 'answer_start', 'answer_end', 'answer_choice']) 90 | log.info('train json data flattened.') 91 | 92 | print(train) 93 | 94 | trC_iter = (pre_proc(c) for c in train_context) 95 | trQ_iter = (pre_proc(q) for q in train.question) 96 | trC_docs = [doc for doc in nlp.pipe(trC_iter, batch_size=64, n_threads=args.threads)] 97 | trQ_docs = [doc for doc in nlp.pipe(trQ_iter, batch_size=64, n_threads=args.threads)] 98 | 99 | # tokens 100 | trC_tokens = [[normalize_text(w.text) for w in doc] for doc in trC_docs] 101 | trQ_tokens = [[normalize_text(w.text) for w in doc] for doc in trQ_docs] 102 | trC_unnorm_tokens = [[w.text for w in doc] for doc in trC_docs] 103 | log.info('All tokens for training are obtained.') 104 | 105 | train_context_span = [get_context_span(a, b) for a, b in zip(train_context, trC_unnorm_tokens)] 106 | 107 | ans_st_token_ls, ans_end_token_ls = [], [] 108 | for ans_st, ans_end, idx in zip(train.answer_start, train.answer_end, train.context_idx): 109 | ans_st_token, ans_end_token = find_answer_span(train_context_span[idx], ans_st, ans_end) 110 | ans_st_token_ls.append(ans_st_token) 111 | ans_end_token_ls.append(ans_end_token) 112 | 113 | train['answer_start_token'], train['answer_end_token'] = ans_st_token_ls, ans_end_token_ls 114 | initial_len = len(train) 115 | train.dropna(inplace=True) # modify self DataFrame 116 | log.info('drop {0}/{1} inconsistent samples.'.format(initial_len - len(train), initial_len)) 117 | log.info('answer span for training is generated.') 118 | 119 | # features 120 | trC_tags, trC_ents, trC_features = feature_gen(trC_docs, train.context_idx, trQ_docs, args.no_match) 121 | log.info('features for training is generated: {}, {}, {}'.format(len(trC_tags), len(trC_ents), len(trC_features))) 122 | 123 | def build_train_vocab(questions, contexts): # vocabulary will also be sorted accordingly 124 | if args.sort_all: 125 | counter = collections.Counter(w for doc in questions + contexts for w in doc) 126 | vocab = sorted([t for t in counter if t in glove_vocab], key=counter.get, reverse=True) 127 | else: 128 | counter_c = collections.Counter(w for doc in contexts for w in doc) 129 | counter_q = collections.Counter(w for doc in questions for w in doc) 130 | counter = counter_c + counter_q 131 | vocab = sorted([t for t in counter_q if t in glove_vocab], key=counter_q.get, reverse=True) 132 | vocab += sorted([t for t in counter_c.keys() - counter_q.keys() if t in glove_vocab], 133 | key=counter.get, reverse=True) 134 | total = sum(counter.values()) 135 | matched = sum(counter[t] for t in vocab) 136 | log.info('vocab {1}/{0} OOV {2}/{3} ({4:.4f}%)'.format( 137 | len(counter), len(vocab), (total - matched), total, (total - matched) / total * 100)) 138 | vocab.insert(0, "") 139 | vocab.insert(1, "") 140 | vocab.insert(2, "") 141 | vocab.insert(3, "") 142 | return vocab 143 | 144 | # vocab 145 | tr_vocab = build_train_vocab(trQ_tokens, trC_tokens) 146 | trC_ids = token2id(trC_tokens, tr_vocab, unk_id=1) 147 | trQ_ids = token2id(trQ_tokens, tr_vocab, unk_id=1) 148 | trQ_tokens = [[""] + doc + [""] for doc in trQ_tokens] 149 | trQ_ids = [[2] + qsent + [3] for qsent in trQ_ids] 150 | print(trQ_ids[:10]) 151 | # tags 152 | vocab_tag = [''] + list(nlp.tagger.labels) 153 | trC_tag_ids = token2id(trC_tags, vocab_tag) 154 | # entities 155 | vocab_ent = list(set([ent for sent in trC_ents for ent in sent])) 156 | trC_ent_ids = token2id(trC_ents, vocab_ent, unk_id=0) 157 | 158 | log.info('Found {} POS tags.'.format(len(vocab_tag))) 159 | log.info('Found {} entity tags: {}'.format(len(vocab_ent), vocab_ent)) 160 | log.info('vocabulary for training is built.') 161 | 162 | tr_embedding = build_embedding(wv_file, tr_vocab, wv_dim) 163 | log.info('got embedding matrix for training.') 164 | 165 | # don't store row name in csv 166 | #train.to_csv('QuAC_data/train.csv', index=False, encoding='utf8') 167 | 168 | meta = { 169 | 'vocab': tr_vocab, 170 | 'embedding': tr_embedding.tolist() 171 | } 172 | with open('QuAC_data/train_meta.msgpack', 'wb') as f: 173 | msgpack.dump(meta, f) 174 | 175 | prev_CID, first_question = -1, [] 176 | for i, CID in enumerate(train.context_idx): 177 | if not (CID == prev_CID): 178 | first_question.append(i) 179 | prev_CID = CID 180 | 181 | result = { 182 | 'question_ids': trQ_ids, 183 | 'context_ids': trC_ids, 184 | 'context_features': trC_features, # exact match, tf 185 | 'context_tags': trC_tag_ids, # POS tagging 186 | 'context_ents': trC_ent_ids, # Entity recognition 187 | 'context': train_context, 188 | 'context_span': train_context_span, 189 | '1st_question': first_question, 190 | 'question_CID': train.context_idx.tolist(), 191 | 'question': train.question.tolist(), 192 | 'answer': train.answer.tolist(), 193 | 'answer_start': train.answer_start_token.tolist(), 194 | 'answer_end': train.answer_end_token.tolist(), 195 | 'answer_choice': train.answer_choice.tolist(), 196 | 'context_tokenized': trC_tokens, 197 | 'question_tokenized': trQ_tokens 198 | } 199 | with open('QuAC_data/train_data.msgpack', 'wb') as f: 200 | msgpack.dump(result, f) 201 | 202 | log.info('saved training to disk.') 203 | 204 | #========================================================== 205 | #=================== Work on dev data ===================== 206 | #========================================================== 207 | 208 | def proc_dev(ith, article): 209 | rows = [] 210 | 211 | for paragraph in article['paragraphs']: 212 | context = paragraph['context'] 213 | for qa in paragraph['qas']: 214 | question = qa['question'] 215 | answers = qa['orig_answer'] 216 | 217 | answer = answers['text'] 218 | answer_start = answers['answer_start'] 219 | answer_end = answers['answer_start'] + len(answers['text']) 220 | answer_choice = 0 if answer == 'CANNOTANSWER' else\ 221 | 1 if qa['yesno'] == 'y' else\ 222 | 2 if qa['yesno'] == 'n' else\ 223 | 3 # Not a yes/no question 224 | if answer_choice != 0: 225 | """ 226 | 0: Do not ask a follow up question! 227 | 1: Definitely ask a follow up question! 228 | 2: Not too important, but you can ask a follow up. 229 | """ 230 | answer_choice += 10 * (0 if qa['followup'] == "n" else\ 231 | 1 if qa['followup'] == "y" else\ 232 | 2) 233 | else: 234 | answer_start, answer_end = -1, -1 235 | 236 | ans_ls = [] 237 | for ans in qa['answers']: 238 | ans_ls.append(ans['text']) 239 | 240 | rows.append((ith, question, answer, answer_start, answer_end, answer_choice, ans_ls)) 241 | return rows, context 242 | 243 | dev, dev_context = flatten_json(dev_file, proc_dev) 244 | dev = pd.DataFrame(dev, columns=['context_idx', 'question', 'answer', 245 | 'answer_start', 'answer_end', 'answer_choice', 'all_answer']) 246 | log.info('dev json data flattened.') 247 | 248 | print(dev) 249 | 250 | devC_iter = (pre_proc(c) for c in dev_context) 251 | devQ_iter = (pre_proc(q) for q in dev.question) 252 | devC_docs = [doc for doc in nlp.pipe( 253 | devC_iter, batch_size=64, n_threads=args.threads)] 254 | devQ_docs = [doc for doc in nlp.pipe( 255 | devQ_iter, batch_size=64, n_threads=args.threads)] 256 | 257 | # tokens 258 | devC_tokens = [[normalize_text(w.text) for w in doc] for doc in devC_docs] 259 | devQ_tokens = [[normalize_text(w.text) for w in doc] for doc in devQ_docs] 260 | devC_unnorm_tokens = [[w.text for w in doc] for doc in devC_docs] 261 | log.info('All tokens for dev are obtained.') 262 | 263 | dev_context_span = [get_context_span(a, b) for a, b in zip(dev_context, devC_unnorm_tokens)] 264 | log.info('context span for dev is generated.') 265 | 266 | ans_st_token_ls, ans_end_token_ls = [], [] 267 | for ans_st, ans_end, idx in zip(dev.answer_start, dev.answer_end, dev.context_idx): 268 | ans_st_token, ans_end_token = find_answer_span(dev_context_span[idx], ans_st, ans_end) 269 | ans_st_token_ls.append(ans_st_token) 270 | ans_end_token_ls.append(ans_end_token) 271 | 272 | dev['answer_start_token'], dev['answer_end_token'] = ans_st_token_ls, ans_end_token_ls 273 | initial_len = len(dev) 274 | dev.dropna(inplace=True) # modify self DataFrame 275 | log.info('drop {0}/{1} inconsistent samples.'.format(initial_len - len(dev), initial_len)) 276 | log.info('answer span for dev is generated.') 277 | 278 | # features 279 | devC_tags, devC_ents, devC_features = feature_gen(devC_docs, dev.context_idx, devQ_docs, args.no_match) 280 | log.info('features for dev is generated: {}, {}, {}'.format(len(devC_tags), len(devC_ents), len(devC_features))) 281 | 282 | def build_dev_vocab(questions, contexts): # most vocabulary comes from tr_vocab 283 | existing_vocab = set(tr_vocab) 284 | new_vocab = list(set([w for doc in questions + contexts for w in doc if w not in existing_vocab and w in glove_vocab])) 285 | vocab = tr_vocab + new_vocab 286 | log.info('train vocab {0}, total vocab {1}'.format(len(tr_vocab), len(vocab))) 287 | return vocab 288 | 289 | # vocab 290 | dev_vocab = build_dev_vocab(devQ_tokens, devC_tokens) # tr_vocab is a subset of dev_vocab 291 | devC_ids = token2id(devC_tokens, dev_vocab, unk_id=1) 292 | devQ_ids = token2id(devQ_tokens, dev_vocab, unk_id=1) 293 | devQ_tokens = [[""] + doc + [""] for doc in devQ_tokens] 294 | devQ_ids = [[2] + qsent + [3] for qsent in devQ_ids] 295 | print(devQ_ids[:10]) 296 | # tags 297 | devC_tag_ids = token2id(devC_tags, vocab_tag) # vocab_tag same as training 298 | # entities 299 | devC_ent_ids = token2id(devC_ents, vocab_ent, unk_id=0) # vocab_ent same as training 300 | log.info('vocabulary for dev is built.') 301 | 302 | dev_embedding = build_embedding(wv_file, dev_vocab, wv_dim) 303 | # tr_embedding is a submatrix of dev_embedding 304 | log.info('got embedding matrix for dev.') 305 | 306 | # don't store row name in csv 307 | #dev.to_csv('QuAC_data/dev.csv', index=False, encoding='utf8') 308 | 309 | meta = { 310 | 'vocab': dev_vocab, 311 | 'embedding': dev_embedding.tolist() 312 | } 313 | with open('QuAC_data/dev_meta.msgpack', 'wb') as f: 314 | msgpack.dump(meta, f) 315 | 316 | prev_CID, first_question = -1, [] 317 | for i, CID in enumerate(dev.context_idx): 318 | if not (CID == prev_CID): 319 | first_question.append(i) 320 | prev_CID = CID 321 | 322 | result = { 323 | 'question_ids': devQ_ids, 324 | 'context_ids': devC_ids, 325 | 'context_features': devC_features, # exact match, tf 326 | 'context_tags': devC_tag_ids, # POS tagging 327 | 'context_ents': devC_ent_ids, # Entity recognition 328 | 'context': dev_context, 329 | 'context_span': dev_context_span, 330 | '1st_question': first_question, 331 | 'question_CID': dev.context_idx.tolist(), 332 | 'question': dev.question.tolist(), 333 | 'answer': dev.answer.tolist(), 334 | 'answer_start': dev.answer_start_token.tolist(), 335 | 'answer_end': dev.answer_end_token.tolist(), 336 | 'answer_choice': dev.answer_choice.tolist(), 337 | 'all_answer': dev.all_answer.tolist(), 338 | 'context_tokenized': devC_tokens, 339 | 'question_tokenized': devQ_tokens 340 | } 341 | with open('QuAC_data/dev_data.msgpack', 'wb') as f: 342 | msgpack.dump(result, f) 343 | 344 | log.info('saved dev to disk.') 345 | -------------------------------------------------------------------------------- /train_CoQA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import random 5 | import string 6 | import logging 7 | import argparse 8 | from shutil import copyfile 9 | from datetime import datetime 10 | from collections import Counter 11 | import torch 12 | import msgpack 13 | import pandas as pd 14 | import numpy as np 15 | from QA_model.model_CoQA import QAModel 16 | from CoQA_eval import CoQAEvaluator 17 | from general_utils import find_best_score_and_thresh, BatchGen_CoQA 18 | 19 | parser = argparse.ArgumentParser( 20 | description='Train a Dialog QA model.' 21 | ) 22 | 23 | # system 24 | parser.add_argument('--task_name', default='CoQA') 25 | parser.add_argument('--name', default='', help='additional name of the current run') 26 | parser.add_argument('--log_file', default='output.log', 27 | help='path for log file.') 28 | parser.add_argument('--log_per_updates', type=int, default=20, 29 | help='log model loss per x updates (mini-batches).') 30 | 31 | parser.add_argument('--train_dir', default='CoQA/') 32 | parser.add_argument('--dev_dir', default='CoQA/') 33 | parser.add_argument('--answer_type_num', type=int, default=4) 34 | 35 | parser.add_argument('--model_dir', default='models', 36 | help='path to store saved models.') 37 | parser.add_argument('--eval_per_epoch', type=int, default=1, 38 | help='perform evaluation per x epoches.') 39 | parser.add_argument('--MTLSTM_path', default='glove/MT-LSTM.pth') 40 | parser.add_argument('--save_all', dest='save_best_only', action='store_false', help='save all models.') 41 | parser.add_argument('--do_not_save', action='store_true', help='don\'t save any model') 42 | parser.add_argument('--save_for_predict', action='store_true') 43 | parser.add_argument('--seed', type=int, default=1023, 44 | help='random seed for data shuffling, dropout, etc.') 45 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(), 46 | help='whether to use GPU acceleration.') 47 | # training 48 | parser.add_argument('-e', '--epoches', type=int, default=30) 49 | parser.add_argument('-bs', '--batch_size', type=int, default=1) 50 | parser.add_argument('-ebs', '--elmo_batch_size', type=int, default=12) 51 | parser.add_argument('-rs', '--resume', default='', 52 | help='previous model pathname. ' 53 | 'e.g. "models/checkpoint_epoch_11.pt"') 54 | parser.add_argument('-ro', '--resume_options', action='store_true', 55 | help='use previous model options, ignore the cli and defaults.') 56 | parser.add_argument('-rlr', '--reduce_lr', type=float, default=0., 57 | help='reduce initial (resumed) learning rate by this factor.') 58 | parser.add_argument('-op', '--optimizer', default='adamax', 59 | help='supported optimizer: adamax, sgd, adadelta, adam') 60 | parser.add_argument('-gc', '--grad_clipping', type=float, default=10) 61 | parser.add_argument('-wd', '--weight_decay', type=float, default=0) 62 | parser.add_argument('-lr', '--learning_rate', type=float, default=0.1, 63 | help='only applied to SGD.') 64 | parser.add_argument('-mm', '--momentum', type=float, default=0, 65 | help='only applied to SGD.') 66 | parser.add_argument('-tp', '--tune_partial', type=int, default=1000, 67 | help='finetune top-x embeddings (including , ).') 68 | parser.add_argument('--fix_embeddings', action='store_true', 69 | help='if true, `tune_partial` will be ignored.') 70 | parser.add_argument('--elmo_lambda', type=float, default=0.0) 71 | parser.add_argument('--rationale_lambda', type=float, default=0.0) 72 | parser.add_argument('--no_question_normalize', dest='question_normalize', action='store_false') # when set, do dialog normalize 73 | parser.add_argument('--pretrain', default='') 74 | 75 | # model 76 | parser.add_argument('--explicit_dialog_ctx', type=int, default=1) 77 | parser.add_argument('--no_dialog_flow', action='store_true') 78 | parser.add_argument('--no_hierarchical_query', dest='do_hierarchical_query', action='store_false') 79 | parser.add_argument('--no_prealign', dest='do_prealign', action='store_false') 80 | 81 | parser.add_argument('--final_output_att_hidden', type=int, default=250) 82 | parser.add_argument('--question_merge', default='linear_self_attn') 83 | parser.add_argument('--no_ptr_update', dest='do_ptr_update', action='store_false') 84 | parser.add_argument('--no_ptr_net_indep_attn', dest='ptr_net_indep_attn', action='store_false') 85 | parser.add_argument('--ptr_net_attn_type', default='Bilinear', help="Attention for answer span output: Bilinear, MLP or Default") 86 | 87 | parser.add_argument('--do_residual_rnn', dest='do_residual_rnn', action='store_true') 88 | parser.add_argument('--do_residual_everything', dest='do_residual_everything', action='store_true') 89 | parser.add_argument('--do_residual', dest='do_residual', action='store_true') 90 | parser.add_argument('--rnn_layers', type=int, default=1, help="Default number of RNN layers") 91 | parser.add_argument('--rnn_type', default='lstm', 92 | help='supported types: rnn, gru, lstm') 93 | parser.add_argument('--concat_rnn', dest='concat_rnn', action='store_true') 94 | 95 | parser.add_argument('--deep_inter_att_do_similar', type=int, default=0) 96 | parser.add_argument('--deep_att_hidden_size_per_abstr', type=int, default=250) 97 | 98 | parser.add_argument('--hidden_size', type=int, default=125) 99 | parser.add_argument('--self_attention_opt', type=int, default=1) # 0: no self attention 100 | 101 | parser.add_argument('--no_elmo', dest='use_elmo', action='store_false') 102 | parser.add_argument('--no_em', action='store_true') 103 | 104 | parser.add_argument('--no_wemb', dest='use_wemb', action='store_false') # word embedding 105 | parser.add_argument('--CoVe_opt', type=int, default=1) # contexualized embedding option 106 | parser.add_argument('--no_pos', dest='use_pos', action='store_false') # pos tagging 107 | parser.add_argument('--pos_size', type=int, default=51, help='how many kinds of POS tags.') 108 | parser.add_argument('--pos_dim', type=int, default=12, help='the embedding dimension for POS tags.') 109 | parser.add_argument('--no_ner', dest='use_ner', action='store_false') # named entity 110 | parser.add_argument('--ner_size', type=int, default=19, help='how many kinds of named entity tags.') 111 | parser.add_argument('--ner_dim', type=int, default=8, help='the embedding dimension for named entity tags.') 112 | 113 | parser.add_argument('--prealign_hidden', type=int, default=300) 114 | parser.add_argument('--prealign_option', type=int, default=2, help='0: No prealign, 1, 2, ...: Different options') 115 | 116 | parser.add_argument('--no_seq_dropout', dest='do_seq_dropout', action='store_false') 117 | parser.add_argument('--my_dropout_p', type=float, default=0.4) 118 | parser.add_argument('--dropout_emb', type=float, default=0.4) 119 | 120 | parser.add_argument('--max_len', type=int, default=15) 121 | 122 | args = parser.parse_args() 123 | 124 | if args.name != '': 125 | args.model_dir = args.model_dir + '_' + args.name 126 | args.log_file = os.path.dirname(args.log_file) + 'output_' + args.name + '.log' 127 | 128 | # set model dir 129 | model_dir = args.model_dir 130 | os.makedirs(model_dir, exist_ok=True) 131 | model_dir = os.path.abspath(model_dir) 132 | 133 | # set random seed 134 | random.seed(args.seed) 135 | np.random.seed(args.seed) 136 | torch.manual_seed(args.seed) 137 | if args.cuda: 138 | torch.cuda.manual_seed_all(args.seed) 139 | 140 | # setup logger 141 | log = logging.getLogger(__name__) 142 | log.setLevel(logging.DEBUG) 143 | fh = logging.FileHandler(args.log_file) 144 | fh.setLevel(logging.DEBUG) 145 | ch = logging.StreamHandler(sys.stdout) 146 | ch.setLevel(logging.INFO) 147 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 148 | fh.setFormatter(formatter) 149 | ch.setFormatter(formatter) 150 | log.addHandler(fh) 151 | log.addHandler(ch) 152 | 153 | def main(): 154 | log.info('[program starts.]') 155 | opt = vars(args) # changing opt will change args 156 | train, train_embedding, opt = load_train_data(opt) 157 | dev, dev_embedding = load_dev_data(opt) 158 | opt['num_features'] += args.explicit_dialog_ctx * 3 # dialog_act + previous answer 159 | if opt['use_elmo'] == False: 160 | opt['elmo_batch_size'] = 0 161 | CoQAEval = CoQAEvaluator("CoQA/dev.json") 162 | log.info('[Data loaded.]') 163 | 164 | if args.resume: 165 | log.info('[loading previous model...]') 166 | checkpoint = torch.load(args.resume) 167 | if args.resume_options: 168 | opt = checkpoint['config'] 169 | state_dict = checkpoint['state_dict'] 170 | model = QAModel(opt, train_embedding, state_dict) 171 | epoch_0 = checkpoint['epoch'] + 1 172 | for i in range(checkpoint['epoch']): 173 | random.shuffle(list(range(len(train)))) # synchronize random seed 174 | if args.reduce_lr: 175 | lr_decay(model.optimizer, lr_decay=args.reduce_lr) 176 | else: 177 | model = QAModel(opt, train_embedding) 178 | epoch_0 = 1 179 | 180 | if args.pretrain: 181 | pretrain_model = torch.load(args.pretrain) 182 | state_dict = pretrain_model['state_dict']['network'] 183 | 184 | model.get_pretrain(state_dict) 185 | 186 | model.setup_eval_embed(dev_embedding) 187 | log.info("[dev] Total number of params: {}".format(model.total_param)) 188 | 189 | if args.cuda: 190 | model.cuda() 191 | 192 | if args.resume: 193 | batches = BatchGen_CoQA(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx) 194 | predictions = [] 195 | for batch in batches: 196 | phrases, noans = model.predict(batch) 197 | predictions.extend(phrases) 198 | f1 = CoQAEval.compute_turn_score_seq(predictions) 199 | log.info("[dev F1: {:.3f}]".format(f1)) 200 | best_val_score = f1 201 | else: 202 | best_val_score = 0.0 203 | 204 | for epoch in range(epoch_0, epoch_0 + args.epoches): 205 | log.warning('Epoch {}'.format(epoch)) 206 | 207 | # train 208 | batches = BatchGen_CoQA(train, batch_size=args.batch_size, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx, precompute_elmo=args.elmo_batch_size // args.batch_size) 209 | start = datetime.now() 210 | for i, batch in enumerate(batches): 211 | model.update(batch) 212 | if i % args.log_per_updates == 0: 213 | log.info('updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format( 214 | model.updates, model.train_loss.avg, 215 | str((datetime.now() - start) / (i + 1) * (len(batches) - i - 1)).split('.')[0])) 216 | 217 | # eval 218 | if epoch % args.eval_per_epoch == 0: 219 | batches = BatchGen_CoQA(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx, precompute_elmo=args.elmo_batch_size // args.batch_size) 220 | predictions = [] 221 | for batch in batches: 222 | phrases = model.predict(batch) 223 | predictions.extend(phrases) 224 | f1 = CoQAEval.compute_turn_score_seq(predictions) 225 | 226 | # save 227 | if args.save_best_only: 228 | if f1 > best_val_score: 229 | best_val_score = f1 230 | model_file = os.path.join(model_dir, 'best_model.pt') 231 | model.save(model_file, epoch) 232 | log.info('[new best model saved.]') 233 | else: 234 | model_file = os.path.join(model_dir, 'checkpoint_epoch_{}.pt'.format(epoch)) 235 | model.save(model_file, epoch) 236 | if f1 > best_val_score: 237 | best_val_score = f1 238 | copyfile(os.path.join(model_dir, model_file), 239 | os.path.join(model_dir, 'best_model.pt')) 240 | log.info('[new best model saved.]') 241 | 242 | log.warning("Epoch {} - dev F1: {:.3f} (Best F1: {:.3f})".format(epoch, f1 * 100.0, best_val_score * 100.0)) 243 | 244 | def lr_decay(optimizer, lr_decay): 245 | for param_group in optimizer.param_groups: 246 | param_group['lr'] *= lr_decay 247 | log.info('[learning rate reduced by {}]'.format(lr_decay)) 248 | return optimizer 249 | 250 | def load_train_data(opt): 251 | with open(os.path.join(args.train_dir, 'train_meta.msgpack'), 'rb') as f: 252 | meta = msgpack.load(f, encoding='utf8') 253 | embedding = torch.Tensor(meta['embedding']) 254 | opt['vocab_size'] = embedding.size(0) 255 | opt['embedding_dim'] = embedding.size(1) 256 | 257 | with open(os.path.join(args.train_dir, 'train_data.msgpack'), 'rb') as f: 258 | data = msgpack.load(f, encoding='utf8') 259 | #data_orig = pd.read_csv(os.path.join(args.train_dir, 'train.csv')) 260 | 261 | opt['num_features'] = len(data['context_features'][0][0]) 262 | 263 | train = {'context': list(zip( 264 | data['context_ids'], 265 | data['context_tags'], 266 | data['context_ents'], 267 | data['context'], 268 | data['context_span'], 269 | data['1st_question'], 270 | data['context_tokenized'])), 271 | 'qa': list(zip( 272 | data['question_CID'], 273 | data['question_ids'], 274 | data['context_features'], 275 | data['answer_start'], 276 | data['answer_end'], 277 | data['rationale_start'], 278 | data['rationale_end'], 279 | data['answer_choice'], 280 | data['question'], 281 | data['answer'], 282 | data['question_tokenized'])) 283 | } 284 | return train, embedding, opt 285 | 286 | def load_dev_data(opt): # can be extended to true test set 287 | with open(os.path.join(args.dev_dir, 'dev_meta.msgpack'), 'rb') as f: 288 | meta = msgpack.load(f, encoding='utf8') 289 | embedding = torch.Tensor(meta['embedding']) 290 | assert opt['embedding_dim'] == embedding.size(1) 291 | 292 | with open(os.path.join(args.dev_dir, 'dev_data.msgpack'), 'rb') as f: 293 | data = msgpack.load(f, encoding='utf8') 294 | #data_orig = pd.read_csv(os.path.join(args.dev_dir, 'dev.csv')) 295 | 296 | assert opt['num_features'] == len(data['context_features'][0][0]) 297 | 298 | dev = {'context': list(zip( 299 | data['context_ids'], 300 | data['context_tags'], 301 | data['context_ents'], 302 | data['context'], 303 | data['context_span'], 304 | data['1st_question'], 305 | data['context_tokenized'])), 306 | 'qa': list(zip( 307 | data['question_CID'], 308 | data['question_ids'], 309 | data['context_features'], 310 | data['answer_start'], 311 | data['answer_end'], 312 | data['rationale_start'], 313 | data['rationale_end'], 314 | data['answer_choice'], 315 | data['question'], 316 | data['answer'], 317 | data['question_tokenized'])) 318 | } 319 | 320 | return dev, embedding 321 | 322 | if __name__ == '__main__': 323 | main() 324 | -------------------------------------------------------------------------------- /preprocess_CoQA.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import spacy 4 | import msgpack 5 | import unicodedata 6 | import numpy as np 7 | import pandas as pd 8 | import argparse 9 | import collections 10 | import multiprocessing 11 | import logging 12 | import random 13 | from allennlp.modules.elmo import batch_to_ids 14 | from general_utils import flatten_json, free_text_to_span, normalize_text, build_embedding, load_glove_vocab, pre_proc, get_context_span, find_answer_span, feature_gen, token2id 15 | 16 | parser = argparse.ArgumentParser( 17 | description='Preprocessing train + dev files, about 15 minutes to run on Servers.' 18 | ) 19 | parser.add_argument('--wv_file', default='glove/glove.840B.300d.txt', 20 | help='path to word vector file.') 21 | parser.add_argument('--wv_dim', type=int, default=300, 22 | help='word vector dimension.') 23 | parser.add_argument('--sort_all', action='store_true', 24 | help='sort the vocabulary by frequencies of all words.' 25 | 'Otherwise consider question words first.') 26 | parser.add_argument('--threads', type=int, default=multiprocessing.cpu_count(), 27 | help='number of threads for preprocessing.') 28 | parser.add_argument('--no_match', action='store_true', 29 | help='do not extract the three exact matching features.') 30 | parser.add_argument('--seed', type=int, default=1023, 31 | help='random seed for data shuffling, embedding init, etc.') 32 | 33 | args = parser.parse_args() 34 | trn_file = 'CoQA/train.json' 35 | dev_file = 'CoQA/dev.json' 36 | wv_file = args.wv_file 37 | wv_dim = args.wv_dim 38 | nlp = spacy.load('en', disable=['parser']) 39 | 40 | random.seed(args.seed) 41 | np.random.seed(args.seed) 42 | 43 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG, 44 | datefmt='%m/%d/%Y %I:%M:%S') 45 | log = logging.getLogger(__name__) 46 | 47 | log.info('start data preparing... (using {} threads)'.format(args.threads)) 48 | 49 | glove_vocab = load_glove_vocab(wv_file, wv_dim) # return a "set" of vocabulary 50 | log.info('glove loaded.') 51 | 52 | #=============================================================== 53 | #=================== Work on training data ===================== 54 | #=============================================================== 55 | 56 | def proc_train(ith, article): 57 | rows = [] 58 | context = article['story'] 59 | 60 | for j, (question, answers) in enumerate(zip(article['questions'], article['answers'])): 61 | gold_answer = answers['input_text'] 62 | span_answer = answers['span_text'] 63 | 64 | answer, char_i, char_j = free_text_to_span(gold_answer, span_answer) 65 | answer_choice = 0 if answer == '__NA__' else\ 66 | 1 if answer == '__YES__' else\ 67 | 2 if answer == '__NO__' else\ 68 | 3 # Not a yes/no question 69 | 70 | if answer_choice == 3: 71 | answer_start = answers['span_start'] + char_i 72 | answer_end = answers['span_start'] + char_j 73 | else: 74 | answer_start, answer_end = -1, -1 75 | 76 | rationale = answers['span_text'] 77 | rationale_start = answers['span_start'] 78 | rationale_end = answers['span_end'] 79 | 80 | q_text = question['input_text'] 81 | if j > 0: 82 | q_text = article['answers'][j-1]['input_text'] + " // " + q_text 83 | 84 | rows.append((ith, q_text, answer, answer_start, answer_end, rationale, rationale_start, rationale_end, answer_choice)) 85 | return rows, context 86 | 87 | train, train_context = flatten_json(trn_file, proc_train) 88 | train = pd.DataFrame(train, columns=['context_idx', 'question', 'answer', 'answer_start', 'answer_end', 'rationale', 'rationale_start', 'rationale_end', 'answer_choice']) 89 | log.info('train json data flattened.') 90 | 91 | print(train) 92 | 93 | trC_iter = (pre_proc(c) for c in train_context) 94 | trQ_iter = (pre_proc(q) for q in train.question) 95 | trC_docs = [doc for doc in nlp.pipe(trC_iter, batch_size=64, n_threads=args.threads)] 96 | trQ_docs = [doc for doc in nlp.pipe(trQ_iter, batch_size=64, n_threads=args.threads)] 97 | 98 | # tokens 99 | trC_tokens = [[normalize_text(w.text) for w in doc] for doc in trC_docs] 100 | trQ_tokens = [[normalize_text(w.text) for w in doc] for doc in trQ_docs] 101 | trC_unnorm_tokens = [[w.text for w in doc] for doc in trC_docs] 102 | log.info('All tokens for training are obtained.') 103 | 104 | train_context_span = [get_context_span(a, b) for a, b in zip(train_context, trC_unnorm_tokens)] 105 | 106 | ans_st_token_ls, ans_end_token_ls = [], [] 107 | for ans_st, ans_end, idx in zip(train.answer_start, train.answer_end, train.context_idx): 108 | ans_st_token, ans_end_token = find_answer_span(train_context_span[idx], ans_st, ans_end) 109 | ans_st_token_ls.append(ans_st_token) 110 | ans_end_token_ls.append(ans_end_token) 111 | 112 | ration_st_token_ls, ration_end_token_ls = [], [] 113 | for ration_st, ration_end, idx in zip(train.rationale_start, train.rationale_end, train.context_idx): 114 | ration_st_token, ration_end_token = find_answer_span(train_context_span[idx], ration_st, ration_end) 115 | ration_st_token_ls.append(ration_st_token) 116 | ration_end_token_ls.append(ration_end_token) 117 | 118 | train['answer_start_token'], train['answer_end_token'] = ans_st_token_ls, ans_end_token_ls 119 | train['rationale_start_token'], train['rationale_end_token'] = ration_st_token_ls, ration_end_token_ls 120 | 121 | initial_len = len(train) 122 | train.dropna(inplace=True) # modify self DataFrame 123 | log.info('drop {0}/{1} inconsistent samples.'.format(initial_len - len(train), initial_len)) 124 | log.info('answer span for training is generated.') 125 | 126 | # features 127 | trC_tags, trC_ents, trC_features = feature_gen(trC_docs, train.context_idx, trQ_docs, args.no_match) 128 | log.info('features for training is generated: {}, {}, {}'.format(len(trC_tags), len(trC_ents), len(trC_features))) 129 | 130 | def build_train_vocab(questions, contexts): # vocabulary will also be sorted accordingly 131 | if args.sort_all: 132 | counter = collections.Counter(w for doc in questions + contexts for w in doc) 133 | vocab = sorted([t for t in counter if t in glove_vocab], key=counter.get, reverse=True) 134 | else: 135 | counter_c = collections.Counter(w for doc in contexts for w in doc) 136 | counter_q = collections.Counter(w for doc in questions for w in doc) 137 | counter = counter_c + counter_q 138 | vocab = sorted([t for t in counter_q if t in glove_vocab], key=counter_q.get, reverse=True) 139 | vocab += sorted([t for t in counter_c.keys() - counter_q.keys() if t in glove_vocab], 140 | key=counter.get, reverse=True) 141 | total = sum(counter.values()) 142 | matched = sum(counter[t] for t in vocab) 143 | log.info('vocab {1}/{0} OOV {2}/{3} ({4:.4f}%)'.format( 144 | len(counter), len(vocab), (total - matched), total, (total - matched) / total * 100)) 145 | vocab.insert(0, "") 146 | vocab.insert(1, "") 147 | vocab.insert(2, "") 148 | vocab.insert(3, "") 149 | return vocab 150 | 151 | # vocab 152 | tr_vocab = build_train_vocab(trQ_tokens, trC_tokens) 153 | trC_ids = token2id(trC_tokens, tr_vocab, unk_id=1) 154 | trQ_ids = token2id(trQ_tokens, tr_vocab, unk_id=1) 155 | trQ_tokens = [[""] + doc + [""] for doc in trQ_tokens] 156 | trQ_ids = [[2] + qsent + [3] for qsent in trQ_ids] 157 | print(trQ_ids[:10]) 158 | # tags 159 | vocab_tag = [''] + list(nlp.tagger.labels) 160 | trC_tag_ids = token2id(trC_tags, vocab_tag) 161 | # entities 162 | vocab_ent = list(set([ent for sent in trC_ents for ent in sent])) 163 | trC_ent_ids = token2id(trC_ents, vocab_ent, unk_id=0) 164 | 165 | log.info('Found {} POS tags.'.format(len(vocab_tag))) 166 | log.info('Found {} entity tags: {}'.format(len(vocab_ent), vocab_ent)) 167 | log.info('vocabulary for training is built.') 168 | 169 | tr_embedding = build_embedding(wv_file, tr_vocab, wv_dim) 170 | log.info('got embedding matrix for training.') 171 | 172 | meta = { 173 | 'vocab': tr_vocab, 174 | 'embedding': tr_embedding.tolist() 175 | } 176 | with open('CoQA/train_meta.msgpack', 'wb') as f: 177 | msgpack.dump(meta, f) 178 | 179 | prev_CID, first_question = -1, [] 180 | for i, CID in enumerate(train.context_idx): 181 | if not (CID == prev_CID): 182 | first_question.append(i) 183 | prev_CID = CID 184 | 185 | result = { 186 | 'question_ids': trQ_ids, 187 | 'context_ids': trC_ids, 188 | 'context_features': trC_features, # exact match, tf 189 | 'context_tags': trC_tag_ids, # POS tagging 190 | 'context_ents': trC_ent_ids, # Entity recognition 191 | 'context': train_context, 192 | 'context_span': train_context_span, 193 | '1st_question': first_question, 194 | 'question_CID': train.context_idx.tolist(), 195 | 'question': train.question.tolist(), 196 | 'answer': train.answer.tolist(), 197 | 'answer_start': train.answer_start_token.tolist(), 198 | 'answer_end': train.answer_end_token.tolist(), 199 | 'rationale_start': train.rationale_start_token.tolist(), 200 | 'rationale_end': train.rationale_end_token.tolist(), 201 | 'answer_choice': train.answer_choice.tolist(), 202 | 'context_tokenized': trC_tokens, 203 | 'question_tokenized': trQ_tokens 204 | } 205 | with open('CoQA/train_data.msgpack', 'wb') as f: 206 | msgpack.dump(result, f) 207 | 208 | log.info('saved training to disk.') 209 | 210 | #========================================================== 211 | #=================== Work on dev data ===================== 212 | #========================================================== 213 | 214 | def proc_dev(ith, article): 215 | rows = [] 216 | context = article['story'] 217 | 218 | for j, (question, answers) in enumerate(zip(article['questions'], article['answers'])): 219 | gold_answer = answers['input_text'] 220 | span_answer = answers['span_text'] 221 | 222 | answer, char_i, char_j = free_text_to_span(gold_answer, span_answer) 223 | answer_choice = 0 if answer == '__NA__' else\ 224 | 1 if answer == '__YES__' else\ 225 | 2 if answer == '__NO__' else\ 226 | 3 # Not a yes/no question 227 | 228 | if answer_choice == 3: 229 | answer_start = answers['span_start'] + char_i 230 | answer_end = answers['span_start'] + char_j 231 | else: 232 | answer_start, answer_end = -1, -1 233 | 234 | rationale = answers['span_text'] 235 | rationale_start = answers['span_start'] 236 | rationale_end = answers['span_end'] 237 | 238 | q_text = question['input_text'] 239 | if j > 0: 240 | q_text = article['answers'][j-1]['input_text'] + " // " + q_text 241 | 242 | rows.append((ith, q_text, answer, answer_start, answer_end, rationale, rationale_start, rationale_end, answer_choice)) 243 | return rows, context 244 | 245 | dev, dev_context = flatten_json(dev_file, proc_dev) 246 | dev = pd.DataFrame(dev, columns=['context_idx', 'question', 'answer', 'answer_start', 'answer_end', 'rationale', 'rationale_start', 'rationale_end', 'answer_choice']) 247 | log.info('dev json data flattened.') 248 | 249 | print(dev) 250 | 251 | devC_iter = (pre_proc(c) for c in dev_context) 252 | devQ_iter = (pre_proc(q) for q in dev.question) 253 | devC_docs = [doc for doc in nlp.pipe( 254 | devC_iter, batch_size=64, n_threads=args.threads)] 255 | devQ_docs = [doc for doc in nlp.pipe( 256 | devQ_iter, batch_size=64, n_threads=args.threads)] 257 | 258 | # tokens 259 | devC_tokens = [[normalize_text(w.text) for w in doc] for doc in devC_docs] 260 | devQ_tokens = [[normalize_text(w.text) for w in doc] for doc in devQ_docs] 261 | devC_unnorm_tokens = [[w.text for w in doc] for doc in devC_docs] 262 | log.info('All tokens for dev are obtained.') 263 | 264 | dev_context_span = [get_context_span(a, b) for a, b in zip(dev_context, devC_unnorm_tokens)] 265 | log.info('context span for dev is generated.') 266 | 267 | ans_st_token_ls, ans_end_token_ls = [], [] 268 | for ans_st, ans_end, idx in zip(dev.answer_start, dev.answer_end, dev.context_idx): 269 | ans_st_token, ans_end_token = find_answer_span(dev_context_span[idx], ans_st, ans_end) 270 | ans_st_token_ls.append(ans_st_token) 271 | ans_end_token_ls.append(ans_end_token) 272 | 273 | ration_st_token_ls, ration_end_token_ls = [], [] 274 | for ration_st, ration_end, idx in zip(dev.rationale_start, dev.rationale_end, dev.context_idx): 275 | ration_st_token, ration_end_token = find_answer_span(dev_context_span[idx], ration_st, ration_end) 276 | ration_st_token_ls.append(ration_st_token) 277 | ration_end_token_ls.append(ration_end_token) 278 | 279 | dev['answer_start_token'], dev['answer_end_token'] = ans_st_token_ls, ans_end_token_ls 280 | dev['rationale_start_token'], dev['rationale_end_token'] = ration_st_token_ls, ration_end_token_ls 281 | 282 | initial_len = len(dev) 283 | dev.dropna(inplace=True) # modify self DataFrame 284 | log.info('drop {0}/{1} inconsistent samples.'.format(initial_len - len(dev), initial_len)) 285 | log.info('answer span for dev is generated.') 286 | 287 | # features 288 | devC_tags, devC_ents, devC_features = feature_gen(devC_docs, dev.context_idx, devQ_docs, args.no_match) 289 | log.info('features for dev is generated: {}, {}, {}'.format(len(devC_tags), len(devC_ents), len(devC_features))) 290 | 291 | def build_dev_vocab(questions, contexts): # most vocabulary comes from tr_vocab 292 | existing_vocab = set(tr_vocab) 293 | new_vocab = list(set([w for doc in questions + contexts for w in doc if w not in existing_vocab and w in glove_vocab])) 294 | vocab = tr_vocab + new_vocab 295 | log.info('train vocab {0}, total vocab {1}'.format(len(tr_vocab), len(vocab))) 296 | return vocab 297 | 298 | # vocab 299 | dev_vocab = build_dev_vocab(devQ_tokens, devC_tokens) # tr_vocab is a subset of dev_vocab 300 | devC_ids = token2id(devC_tokens, dev_vocab, unk_id=1) 301 | devQ_ids = token2id(devQ_tokens, dev_vocab, unk_id=1) 302 | devQ_tokens = [[""] + doc + [""] for doc in devQ_tokens] 303 | devQ_ids = [[2] + qsent + [3] for qsent in devQ_ids] 304 | print(devQ_ids[:10]) 305 | # tags 306 | devC_tag_ids = token2id(devC_tags, vocab_tag) # vocab_tag same as training 307 | # entities 308 | devC_ent_ids = token2id(devC_ents, vocab_ent, unk_id=0) # vocab_ent same as training 309 | log.info('vocabulary for dev is built.') 310 | 311 | dev_embedding = build_embedding(wv_file, dev_vocab, wv_dim) 312 | # tr_embedding is a submatrix of dev_embedding 313 | log.info('got embedding matrix for dev.') 314 | 315 | # don't store row name in csv 316 | #dev.to_csv('QuAC_data/dev.csv', index=False, encoding='utf8') 317 | 318 | meta = { 319 | 'vocab': dev_vocab, 320 | 'embedding': dev_embedding.tolist() 321 | } 322 | with open('CoQA/dev_meta.msgpack', 'wb') as f: 323 | msgpack.dump(meta, f) 324 | 325 | prev_CID, first_question = -1, [] 326 | for i, CID in enumerate(dev.context_idx): 327 | if not (CID == prev_CID): 328 | first_question.append(i) 329 | prev_CID = CID 330 | 331 | result = { 332 | 'question_ids': devQ_ids, 333 | 'context_ids': devC_ids, 334 | 'context_features': devC_features, # exact match, tf 335 | 'context_tags': devC_tag_ids, # POS tagging 336 | 'context_ents': devC_ent_ids, # Entity recognition 337 | 'context': dev_context, 338 | 'context_span': dev_context_span, 339 | '1st_question': first_question, 340 | 'question_CID': dev.context_idx.tolist(), 341 | 'question': dev.question.tolist(), 342 | 'answer': dev.answer.tolist(), 343 | 'answer_start': dev.answer_start_token.tolist(), 344 | 'answer_end': dev.answer_end_token.tolist(), 345 | 'rationale_start': dev.rationale_start_token.tolist(), 346 | 'rationale_end': dev.rationale_end_token.tolist(), 347 | 'answer_choice': dev.answer_choice.tolist(), 348 | 'context_tokenized': devC_tokens, 349 | 'question_tokenized': devQ_tokens 350 | } 351 | with open('CoQA/dev_data.msgpack', 'wb') as f: 352 | msgpack.dump(result, f) 353 | 354 | log.info('saved dev to disk.') 355 | -------------------------------------------------------------------------------- /train_QuAC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import random 5 | import string 6 | import logging 7 | import argparse 8 | from shutil import copyfile 9 | from datetime import datetime 10 | from collections import Counter 11 | import torch 12 | import msgpack 13 | import pandas as pd 14 | import numpy as np 15 | from QA_model.model_QuAC import QAModel 16 | from general_utils import find_best_score_and_thresh, BatchGen_QuAC 17 | 18 | parser = argparse.ArgumentParser( 19 | description='Train a Dialog QA model.' 20 | ) 21 | 22 | # system 23 | parser.add_argument('--task_name', default='QuAC') 24 | parser.add_argument('--name', default='', help='additional name of the current run') 25 | parser.add_argument('--log_file', default='output.log', 26 | help='path for log file.') 27 | parser.add_argument('--log_per_updates', type=int, default=20, 28 | help='log model loss per x updates (mini-batches).') 29 | 30 | parser.add_argument('--train_dir', default='QuAC_data/') 31 | parser.add_argument('--dev_dir', default='QuAC_data/') 32 | parser.add_argument('--answer_type_num', type=int, default=1) 33 | 34 | parser.add_argument('--model_dir', default='models', 35 | help='path to store saved models.') 36 | parser.add_argument('--eval_per_epoch', type=int, default=1, 37 | help='perform evaluation per x epoches.') 38 | parser.add_argument('--MTLSTM_path', default='glove/MT-LSTM.pth') 39 | parser.add_argument('--save_all', dest='save_best_only', action='store_false', help='save all models.') 40 | parser.add_argument('--do_not_save', action='store_true', help='don\'t save any model') 41 | parser.add_argument('--save_for_predict', action='store_true') 42 | parser.add_argument('--seed', type=int, default=1023, 43 | help='random seed for data shuffling, dropout, etc.') 44 | parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(), 45 | help='whether to use GPU acceleration.') 46 | # training 47 | parser.add_argument('-e', '--epoches', type=int, default=30) 48 | parser.add_argument('-bs', '--batch_size', type=int, default=3) 49 | parser.add_argument('-ebs', '--elmo_batch_size', type=int, default=12) 50 | parser.add_argument('-rs', '--resume', default='', 51 | help='previous model pathname. ' 52 | 'e.g. "models/checkpoint_epoch_11.pt"') 53 | parser.add_argument('-ro', '--resume_options', action='store_true', 54 | help='use previous model options, ignore the cli and defaults.') 55 | parser.add_argument('-rlr', '--reduce_lr', type=float, default=0., 56 | help='reduce initial (resumed) learning rate by this factor.') 57 | parser.add_argument('-op', '--optimizer', default='adamax', 58 | help='supported optimizer: adamax, sgd, adadelta, adam') 59 | parser.add_argument('-gc', '--grad_clipping', type=float, default=10) 60 | parser.add_argument('-wd', '--weight_decay', type=float, default=0) 61 | parser.add_argument('-lr', '--learning_rate', type=float, default=0.1, 62 | help='only applied to SGD.') 63 | parser.add_argument('-mm', '--momentum', type=float, default=0, 64 | help='only applied to SGD.') 65 | parser.add_argument('-tp', '--tune_partial', type=int, default=1000, 66 | help='finetune top-x embeddings (including , ).') 67 | parser.add_argument('--fix_embeddings', action='store_true', 68 | help='if true, `tune_partial` will be ignored.') 69 | parser.add_argument('--elmo_lambda', type=float, default=0.0) 70 | parser.add_argument('--no_question_normalize', dest='question_normalize', action='store_false') # when set, do dialog normalize 71 | parser.add_argument('--pretrain', default='') 72 | 73 | # model 74 | parser.add_argument('--explicit_dialog_ctx', type=int, default=2) 75 | parser.add_argument('--use_dialog_act', action='store_true') 76 | parser.add_argument('--no_dialog_flow', action='store_true') 77 | parser.add_argument('--no_hierarchical_query', dest='do_hierarchical_query', action='store_false') 78 | parser.add_argument('--no_prealign', dest='do_prealign', action='store_false') 79 | 80 | parser.add_argument('--final_output_att_hidden', type=int, default=250) 81 | parser.add_argument('--question_merge', default='linear_self_attn') 82 | parser.add_argument('--no_ptr_update', dest='do_ptr_update', action='store_false') 83 | parser.add_argument('--no_ptr_net_indep_attn', dest='ptr_net_indep_attn', action='store_false') 84 | parser.add_argument('--ptr_net_attn_type', default='Bilinear', help="Attention for answer span output: Bilinear, MLP or Default") 85 | 86 | parser.add_argument('--do_residual_rnn', dest='do_residual_rnn', action='store_true') 87 | parser.add_argument('--do_residual_everything', dest='do_residual_everything', action='store_true') 88 | parser.add_argument('--do_residual', dest='do_residual', action='store_true') 89 | parser.add_argument('--rnn_layers', type=int, default=1, help="Default number of RNN layers") 90 | parser.add_argument('--rnn_type', default='lstm', 91 | help='supported types: rnn, gru, lstm') 92 | parser.add_argument('--concat_rnn', dest='concat_rnn', action='store_true') 93 | 94 | parser.add_argument('--hidden_size', type=int, default=125) 95 | parser.add_argument('--self_attention_opt', type=int, default=1) # 0: no self attention 96 | 97 | parser.add_argument('--deep_inter_att_do_similar', type=int, default=0) 98 | parser.add_argument('--deep_att_hidden_size_per_abstr', type=int, default=250) 99 | 100 | parser.add_argument('--no_elmo', dest='use_elmo', action='store_false') 101 | parser.add_argument('--no_em', action='store_true') 102 | 103 | parser.add_argument('--no_wemb', dest='use_wemb', action='store_false') # word embedding 104 | parser.add_argument('--CoVe_opt', type=int, default=1) # contexualized embedding option 105 | parser.add_argument('--no_pos', dest='use_pos', action='store_false') # pos tagging 106 | parser.add_argument('--pos_size', type=int, default=51, help='how many kinds of POS tags.') 107 | parser.add_argument('--pos_dim', type=int, default=12, help='the embedding dimension for POS tags.') 108 | parser.add_argument('--no_ner', dest='use_ner', action='store_false') # named entity 109 | parser.add_argument('--ner_size', type=int, default=19, help='how many kinds of named entity tags.') 110 | parser.add_argument('--ner_dim', type=int, default=8, help='the embedding dimension for named entity tags.') 111 | 112 | parser.add_argument('--prealign_hidden', type=int, default=300) 113 | parser.add_argument('--prealign_option', type=int, default=2, help='0: No prealign, 1, 2, ...: Different options') 114 | 115 | parser.add_argument('--no_seq_dropout', dest='do_seq_dropout', action='store_false') 116 | parser.add_argument('--my_dropout_p', type=float, default=0.4) 117 | parser.add_argument('--dropout_emb', type=float, default=0.4) 118 | 119 | parser.add_argument('--max_len', type=int, default=35) 120 | 121 | args = parser.parse_args() 122 | 123 | if args.name != '': 124 | args.model_dir = args.model_dir + '_' + args.name 125 | args.log_file = os.path.dirname(args.log_file) + 'output_' + args.name + '.log' 126 | 127 | # set model dir 128 | model_dir = args.model_dir 129 | os.makedirs(model_dir, exist_ok=True) 130 | model_dir = os.path.abspath(model_dir) 131 | 132 | # set random seed 133 | random.seed(args.seed) 134 | np.random.seed(args.seed) 135 | torch.manual_seed(args.seed) 136 | if args.cuda: 137 | torch.cuda.manual_seed_all(args.seed) 138 | 139 | # setup logger 140 | log = logging.getLogger(__name__) 141 | log.setLevel(logging.DEBUG) 142 | fh = logging.FileHandler(args.log_file) 143 | fh.setLevel(logging.DEBUG) 144 | ch = logging.StreamHandler(sys.stdout) 145 | ch.setLevel(logging.INFO) 146 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 147 | fh.setFormatter(formatter) 148 | ch.setFormatter(formatter) 149 | log.addHandler(fh) 150 | log.addHandler(ch) 151 | 152 | def main(): 153 | log.info('[program starts.]') 154 | opt = vars(args) # changing opt will change args 155 | train, train_embedding, opt = load_train_data(opt) 156 | dev, dev_embedding, dev_answer = load_dev_data(opt) 157 | opt['num_features'] += args.explicit_dialog_ctx * (args.use_dialog_act*3 + 2) # dialog_act + previous answer 158 | if opt['use_elmo'] == False: 159 | opt['elmo_batch_size'] = 0 160 | log.info('[Data loaded.]') 161 | 162 | if args.resume: 163 | log.info('[loading previous model...]') 164 | checkpoint = torch.load(args.resume) 165 | if args.resume_options: 166 | opt = checkpoint['config'] 167 | state_dict = checkpoint['state_dict'] 168 | model = QAModel(opt, train_embedding, state_dict) 169 | epoch_0 = checkpoint['epoch'] + 1 170 | for i in range(checkpoint['epoch']): 171 | random.shuffle(list(range(len(train)))) # synchronize random seed 172 | if args.reduce_lr: 173 | lr_decay(model.optimizer, lr_decay=args.reduce_lr) 174 | else: 175 | model = QAModel(opt, train_embedding) 176 | epoch_0 = 1 177 | 178 | if args.pretrain: 179 | pretrain_model = torch.load(args.pretrain) 180 | state_dict = pretrain_model['state_dict']['network'] 181 | 182 | model.get_pretrain(state_dict) 183 | 184 | model.setup_eval_embed(dev_embedding) 185 | log.info("[dev] Total number of params: {}".format(model.total_param)) 186 | 187 | if args.cuda: 188 | model.cuda() 189 | 190 | if args.resume: 191 | batches = BatchGen_QuAC(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx, use_dialog_act=args.use_dialog_act) 192 | predictions, no_ans_scores = [], [] 193 | for batch in batches: 194 | phrases, noans = model.predict(batch) 195 | predictions.extend(phrases) 196 | no_ans_scores.extend(noans) 197 | f1, na, thresh = find_best_score_and_thresh(predictions, dev_answer, no_ans_scores) 198 | log.info("[dev F1: {} NA: {} TH: {}]".format(f1, na, thresh)) 199 | best_val_score, best_na, best_thresh = f1, na, thresh 200 | else: 201 | best_val_score, best_na, best_thresh = 0.0, 0.0, 0.0 202 | 203 | for epoch in range(epoch_0, epoch_0 + args.epoches): 204 | log.warning('Epoch {}'.format(epoch)) 205 | # train 206 | batches = BatchGen_QuAC(train, batch_size=args.batch_size, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx, use_dialog_act=args.use_dialog_act, precompute_elmo=args.elmo_batch_size // args.batch_size) 207 | start = datetime.now() 208 | for i, batch in enumerate(batches): 209 | model.update(batch) 210 | if i % args.log_per_updates == 0: 211 | log.info('updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format( 212 | model.updates, model.train_loss.avg, 213 | str((datetime.now() - start) / (i + 1) * (len(batches) - i - 1)).split('.')[0])) 214 | 215 | # eval 216 | if epoch % args.eval_per_epoch == 0: 217 | batches = BatchGen_QuAC(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx, use_dialog_act=args.use_dialog_act, precompute_elmo=args.elmo_batch_size // args.batch_size) 218 | predictions, no_ans_scores = [], [] 219 | for batch in batches: 220 | phrases, noans = model.predict(batch) 221 | predictions.extend(phrases) 222 | no_ans_scores.extend(noans) 223 | f1, na, thresh = find_best_score_and_thresh(predictions, dev_answer, no_ans_scores) 224 | 225 | # save 226 | if args.save_best_only: 227 | if f1 > best_val_score: 228 | best_val_score, best_na, best_thresh = f1, na, thresh 229 | model_file = os.path.join(model_dir, 'best_model.pt') 230 | model.save(model_file, epoch) 231 | log.info('[new best model saved.]') 232 | else: 233 | model_file = os.path.join(model_dir, 'checkpoint_epoch_{}.pt'.format(epoch)) 234 | model.save(model_file, epoch) 235 | if f1 > best_val_score: 236 | best_val_score, best_na, best_thresh = f1, na, thresh 237 | copyfile(os.path.join(model_dir, model_file), 238 | os.path.join(model_dir, 'best_model.pt')) 239 | log.info('[new best model saved.]') 240 | 241 | log.warning("Epoch {} - dev F1: {:.3f} NA: {:.3f} TH: {:.3f} (best F1: {:.3f} NA: {:.3f} TH: {:.3f})".format(epoch, f1, na, thresh, best_val_score, best_na, best_thresh)) 242 | 243 | def lr_decay(optimizer, lr_decay): 244 | for param_group in optimizer.param_groups: 245 | param_group['lr'] *= lr_decay 246 | log.info('[learning rate reduced by {}]'.format(lr_decay)) 247 | return optimizer 248 | 249 | def load_train_data(opt): 250 | with open(os.path.join(args.train_dir, 'train_meta.msgpack'), 'rb') as f: 251 | meta = msgpack.load(f, encoding='utf8') 252 | embedding = torch.Tensor(meta['embedding']) 253 | opt['vocab_size'] = embedding.size(0) 254 | opt['embedding_dim'] = embedding.size(1) 255 | 256 | with open(os.path.join(args.train_dir, 'train_data.msgpack'), 'rb') as f: 257 | data = msgpack.load(f, encoding='utf8') 258 | #data_orig = pd.read_csv(os.path.join(args.train_dir, 'train.csv')) 259 | 260 | opt['num_features'] = len(data['context_features'][0][0]) 261 | 262 | train = {'context': list(zip( 263 | data['context_ids'], 264 | data['context_tags'], 265 | data['context_ents'], 266 | data['context'], 267 | data['context_span'], 268 | data['1st_question'], 269 | data['context_tokenized'])), 270 | 'qa': list(zip( 271 | data['question_CID'], 272 | data['question_ids'], 273 | data['context_features'], 274 | data['answer_start'], 275 | data['answer_end'], 276 | data['answer_choice'], 277 | data['question'], 278 | data['answer'], 279 | data['question_tokenized'])) 280 | } 281 | return train, embedding, opt 282 | 283 | def load_dev_data(opt): # can be extended to true test set 284 | with open(os.path.join(args.dev_dir, 'dev_meta.msgpack'), 'rb') as f: 285 | meta = msgpack.load(f, encoding='utf8') 286 | embedding = torch.Tensor(meta['embedding']) 287 | assert opt['embedding_dim'] == embedding.size(1) 288 | 289 | with open(os.path.join(args.dev_dir, 'dev_data.msgpack'), 'rb') as f: 290 | data = msgpack.load(f, encoding='utf8') 291 | #data_orig = pd.read_csv(os.path.join(args.dev_dir, 'dev.csv')) 292 | 293 | assert opt['num_features'] == len(data['context_features'][0][0]) 294 | 295 | dev = {'context': list(zip( 296 | data['context_ids'], 297 | data['context_tags'], 298 | data['context_ents'], 299 | data['context'], 300 | data['context_span'], 301 | data['1st_question'], 302 | data['context_tokenized'])), 303 | 'qa': list(zip( 304 | data['question_CID'], 305 | data['question_ids'], 306 | data['context_features'], 307 | data['answer_start'], 308 | data['answer_end'], 309 | data['answer_choice'], 310 | data['question'], 311 | data['answer'], 312 | data['question_tokenized'])) 313 | } 314 | 315 | dev_answer = [] 316 | for i, CID in enumerate(data['question_CID']): 317 | if len(dev_answer) <= CID: 318 | dev_answer.append([]) 319 | dev_answer[CID].append(data['all_answer'][i]) 320 | 321 | return dev, embedding, dev_answer 322 | 323 | if __name__ == '__main__': 324 | main() 325 | -------------------------------------------------------------------------------- /QA_model/detail_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from allennlp.modules.elmo import Elmo 6 | from allennlp.nn.util import remove_sentence_boundaries 7 | from . import layers 8 | 9 | class FlowQA(nn.Module): 10 | """Network for the FlowQA Module.""" 11 | def __init__(self, opt, embedding=None, padding_idx=0): 12 | super(FlowQA, self).__init__() 13 | 14 | # Input size to RNN: word emb + char emb + question emb + manual features 15 | doc_input_size = 0 16 | que_input_size = 0 17 | 18 | layers.set_my_dropout_prob(opt['my_dropout_p']) 19 | layers.set_seq_dropout(opt['do_seq_dropout']) 20 | 21 | if opt['use_wemb']: 22 | # Word embeddings 23 | self.embedding = nn.Embedding(opt['vocab_size'], 24 | opt['embedding_dim'], 25 | padding_idx=padding_idx) 26 | if embedding is not None: 27 | self.embedding.weight.data = embedding 28 | if opt['fix_embeddings'] or opt['tune_partial'] == 0: 29 | opt['fix_embeddings'] = True 30 | opt['tune_partial'] = 0 31 | for p in self.embedding.parameters(): 32 | p.requires_grad = False 33 | else: 34 | assert opt['tune_partial'] < embedding.size(0) 35 | fixed_embedding = embedding[opt['tune_partial']:] 36 | # a persistent buffer for the nn.Module 37 | self.register_buffer('fixed_embedding', fixed_embedding) 38 | self.fixed_embedding = fixed_embedding 39 | embedding_dim = opt['embedding_dim'] 40 | doc_input_size += embedding_dim 41 | que_input_size += embedding_dim 42 | else: 43 | opt['fix_embeddings'] = True 44 | opt['tune_partial'] = 0 45 | 46 | if opt['CoVe_opt'] > 0: 47 | self.CoVe = layers.MTLSTM(opt, embedding) 48 | CoVe_size = self.CoVe.output_size 49 | doc_input_size += CoVe_size 50 | que_input_size += CoVe_size 51 | 52 | if opt['use_elmo']: 53 | options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" 54 | weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" 55 | self.elmo = Elmo(options_file, weight_file, 1, dropout=0) 56 | doc_input_size += 1024 57 | que_input_size += 1024 58 | if opt['use_pos']: 59 | self.pos_embedding = nn.Embedding(opt['pos_size'], opt['pos_dim']) 60 | doc_input_size += opt['pos_dim'] 61 | if opt['use_ner']: 62 | self.ner_embedding = nn.Embedding(opt['ner_size'], opt['ner_dim']) 63 | doc_input_size += opt['ner_dim'] 64 | 65 | if opt['do_prealign']: 66 | self.pre_align = layers.GetAttentionHiddens(embedding_dim, opt['prealign_hidden'], similarity_attention=True) 67 | doc_input_size += embedding_dim 68 | if opt['no_em']: 69 | doc_input_size += opt['num_features'] - 3 70 | else: 71 | doc_input_size += opt['num_features'] 72 | 73 | # Setup the vector size for [doc, question] 74 | # they will be modified in the following code 75 | doc_hidden_size, que_hidden_size = doc_input_size, que_input_size 76 | print('Initially, the vector_sizes [doc, query] are', doc_hidden_size, que_hidden_size) 77 | 78 | flow_size = opt['hidden_size'] 79 | 80 | # RNN document encoder 81 | self.doc_rnn1 = layers.StackedBRNN(doc_hidden_size, opt['hidden_size'], num_layers=1) 82 | self.dialog_flow1 = layers.StackedBRNN(opt['hidden_size'] * 2, opt['hidden_size'], num_layers=1, rnn_type=nn.GRU, bidir=False) 83 | self.doc_rnn2 = layers.StackedBRNN(opt['hidden_size'] * 2 + flow_size + CoVe_size, opt['hidden_size'], num_layers=1) 84 | self.dialog_flow2 = layers.StackedBRNN(opt['hidden_size'] * 2, opt['hidden_size'], num_layers=1, rnn_type=nn.GRU, bidir=False) 85 | doc_hidden_size = opt['hidden_size'] * 2 86 | 87 | # RNN question encoder 88 | self.question_rnn, que_hidden_size = layers.RNN_from_opt(que_hidden_size, opt['hidden_size'], opt, 89 | num_layers=2, concat_rnn=opt['concat_rnn'], add_feat=CoVe_size) 90 | 91 | # Output sizes of rnn encoders 92 | print('After Input LSTM, the vector_sizes [doc, query] are [', doc_hidden_size, que_hidden_size, '] * 2') 93 | 94 | # Deep inter-attention 95 | self.deep_attn = layers.DeepAttention(opt, abstr_list_cnt=2, deep_att_hidden_size_per_abstr=opt['deep_att_hidden_size_per_abstr'], do_similarity=opt['deep_inter_att_do_similar'], word_hidden_size=embedding_dim+CoVe_size, no_rnn=True) 96 | 97 | self.deep_attn_rnn, doc_hidden_size = layers.RNN_from_opt(self.deep_attn.att_final_size + flow_size, opt['hidden_size'], opt, num_layers=1) 98 | self.dialog_flow3 = layers.StackedBRNN(doc_hidden_size, opt['hidden_size'], num_layers=1, rnn_type=nn.GRU, bidir=False) 99 | 100 | # Question understanding and compression 101 | self.high_lvl_qrnn, que_hidden_size = layers.RNN_from_opt(que_hidden_size * 2, opt['hidden_size'], opt, num_layers = 1, concat_rnn = True) 102 | 103 | # Self attention on context 104 | att_size = doc_hidden_size + 2 * opt['hidden_size'] * 2 105 | 106 | if opt['self_attention_opt'] > 0: 107 | self.highlvl_self_att = layers.GetAttentionHiddens(att_size, opt['deep_att_hidden_size_per_abstr']) 108 | self.high_lvl_crnn, doc_hidden_size = layers.RNN_from_opt(doc_hidden_size * 2 + flow_size, opt['hidden_size'], opt, num_layers = 1, concat_rnn = False) 109 | print('Self deep-attention {} rays in {}-dim space'.format(opt['deep_att_hidden_size_per_abstr'], att_size)) 110 | elif opt['self_attention_opt'] == 0: 111 | self.high_lvl_crnn, doc_hidden_size = layers.RNN_from_opt(doc_hidden_size + flow_size, opt['hidden_size'], opt, num_layers = 1, concat_rnn = False) 112 | 113 | print('Before answer span finding, hidden size are', doc_hidden_size, que_hidden_size) 114 | 115 | # Question merging 116 | self.self_attn = layers.LinearSelfAttn(que_hidden_size) 117 | if opt['do_hierarchical_query']: 118 | self.hier_query_rnn = layers.StackedBRNN(que_hidden_size, opt['hidden_size'], num_layers=1, rnn_type=nn.GRU, bidir=False) 119 | que_hidden_size = opt['hidden_size'] 120 | 121 | # Attention for span start/end 122 | self.get_answer = layers.GetSpanStartEnd(doc_hidden_size, que_hidden_size, opt, 123 | opt['ptr_net_indep_attn'], opt["ptr_net_attn_type"], opt['do_ptr_update']) 124 | 125 | self.ans_type_prediction = layers.BilinearLayer(doc_hidden_size * 2, que_hidden_size, opt['answer_type_num']) 126 | 127 | # Store config 128 | self.opt = opt 129 | 130 | def forward(self, x1, x1_c, x1_f, x1_pos, x1_ner, x1_mask, x2_full, x2_c, x2_full_mask): 131 | """Inputs: 132 | x1 = document word indices [batch * len_d] 133 | x1_c = document char indices [batch * len_d * len_w] or [1] 134 | x1_f = document word features indices [batch * q_num * len_d * nfeat] 135 | x1_pos = document POS tags [batch * len_d] 136 | x1_ner = document entity tags [batch * len_d] 137 | x1_mask = document padding mask [batch * len_d] 138 | x2_full = question word indices [batch * q_num * len_q] 139 | x2_c = question char indices [(batch * q_num) * len_q * len_w] 140 | x2_full_mask = question padding mask [batch * q_num * len_q] 141 | """ 142 | 143 | # precomputing ELMo is only for context (to speedup computation) 144 | if self.opt['use_elmo'] and self.opt['elmo_batch_size'] > self.opt['batch_size']: # precomputing ELMo is used 145 | if x1_c.dim() != 1: # precomputation is needed 146 | precomputed_bilm_output = self.elmo._elmo_lstm(x1_c) 147 | self.precomputed_layer_activations = [t.detach().cpu() for t in precomputed_bilm_output['activations']] 148 | self.precomputed_mask_with_bos_eos = precomputed_bilm_output['mask'].detach().cpu() 149 | self.precomputed_cnt = 0 150 | 151 | # get precomputed ELMo 152 | layer_activations = [t[x1.size(0) * self.precomputed_cnt: x1.size(0) * (self.precomputed_cnt + 1), :, :] for t in self.precomputed_layer_activations] 153 | mask_with_bos_eos = self.precomputed_mask_with_bos_eos[x1.size(0) * self.precomputed_cnt: x1.size(0) * (self.precomputed_cnt + 1), :] 154 | if x1.is_cuda: 155 | layer_activations = [t.cuda() for t in layer_activations] 156 | mask_with_bos_eos = mask_with_bos_eos.cuda() 157 | 158 | representations = [] 159 | for i in range(len(self.elmo._scalar_mixes)): 160 | scalar_mix = getattr(self.elmo, 'scalar_mix_{}'.format(i)) 161 | representation_with_bos_eos = scalar_mix(layer_activations, mask_with_bos_eos) 162 | representation_without_bos_eos, mask_without_bos_eos = remove_sentence_boundaries( 163 | representation_with_bos_eos, mask_with_bos_eos 164 | ) 165 | representations.append(self.elmo._dropout(representation_without_bos_eos)) 166 | 167 | x1_elmo = representations[0][:, :x1.size(1), :] 168 | self.precomputed_cnt += 1 169 | 170 | precomputed_elmo = True 171 | else: 172 | precomputed_elmo = False 173 | 174 | """ 175 | x1_full = document word indices [batch * q_num * len_d] 176 | x1_full_mask = document padding mask [batch * q_num * len_d] 177 | """ 178 | x1_full = x1.unsqueeze(1).expand(x2_full.size(0), x2_full.size(1), x1.size(1)).contiguous() 179 | x1_full_mask = x1_mask.unsqueeze(1).expand(x2_full.size(0), x2_full.size(1), x1.size(1)).contiguous() 180 | 181 | drnn_input_list, qrnn_input_list = [], [] 182 | 183 | x2 = x2_full.view(-1, x2_full.size(-1)) 184 | x2_mask = x2_full_mask.view(-1, x2_full.size(-1)) 185 | 186 | if self.opt['use_wemb']: 187 | # Word embedding for both document and question 188 | emb = self.embedding if self.training else self.eval_embed 189 | x1_emb = emb(x1) 190 | x2_emb = emb(x2) 191 | # Dropout on embeddings 192 | if self.opt['dropout_emb'] > 0: 193 | x1_emb = layers.dropout(x1_emb, p=self.opt['dropout_emb'], training=self.training) 194 | x2_emb = layers.dropout(x2_emb, p=self.opt['dropout_emb'], training=self.training) 195 | 196 | drnn_input_list.append(x1_emb) 197 | qrnn_input_list.append(x2_emb) 198 | 199 | if self.opt['CoVe_opt'] > 0: 200 | x1_cove_mid, x1_cove_high = self.CoVe(x1, x1_mask) 201 | x2_cove_mid, x2_cove_high = self.CoVe(x2, x2_mask) 202 | # Dropout on contexualized embeddings 203 | if self.opt['dropout_emb'] > 0: 204 | x1_cove_mid = layers.dropout(x1_cove_mid, p=self.opt['dropout_emb'], training=self.training) 205 | x1_cove_high = layers.dropout(x1_cove_high, p=self.opt['dropout_emb'], training=self.training) 206 | x2_cove_mid = layers.dropout(x2_cove_mid, p=self.opt['dropout_emb'], training=self.training) 207 | x2_cove_high = layers.dropout(x2_cove_high, p=self.opt['dropout_emb'], training=self.training) 208 | 209 | drnn_input_list.append(x1_cove_mid) 210 | qrnn_input_list.append(x2_cove_mid) 211 | 212 | if self.opt['use_elmo']: 213 | if not precomputed_elmo: 214 | x1_elmo = self.elmo(x1_c)['elmo_representations'][0]#torch.zeros(x1_emb.size(0), x1_emb.size(1), 1024, dtype=x1_emb.dtype, layout=x1_emb.layout, device=x1_emb.device) 215 | x2_elmo = self.elmo(x2_c)['elmo_representations'][0]#torch.zeros(x2_emb.size(0), x2_emb.size(1), 1024, dtype=x2_emb.dtype, layout=x2_emb.layout, device=x2_emb.device) 216 | # Dropout on contexualized embeddings 217 | if self.opt['dropout_emb'] > 0: 218 | x1_elmo = layers.dropout(x1_elmo, p=self.opt['dropout_emb'], training=self.training) 219 | x2_elmo = layers.dropout(x2_elmo, p=self.opt['dropout_emb'], training=self.training) 220 | 221 | drnn_input_list.append(x1_elmo) 222 | qrnn_input_list.append(x2_elmo) 223 | 224 | if self.opt['use_pos']: 225 | x1_pos_emb = self.pos_embedding(x1_pos) 226 | drnn_input_list.append(x1_pos_emb) 227 | 228 | if self.opt['use_ner']: 229 | x1_ner_emb = self.ner_embedding(x1_ner) 230 | drnn_input_list.append(x1_ner_emb) 231 | 232 | x1_input = torch.cat(drnn_input_list, dim=2) 233 | x2_input = torch.cat(qrnn_input_list, dim=2) 234 | 235 | def expansion_for_doc(z): 236 | return z.unsqueeze(1).expand(z.size(0), x2_full.size(1), z.size(1), z.size(2)).contiguous().view(-1, z.size(1), z.size(2)) 237 | 238 | x1_emb_expand = expansion_for_doc(x1_emb) 239 | x1_cove_high_expand = expansion_for_doc(x1_cove_high) 240 | #x1_elmo_expand = expansion_for_doc(x1_elmo) 241 | if self.opt['no_em']: 242 | x1_f = x1_f[:, :, :, 3:] 243 | 244 | x1_input = torch.cat([expansion_for_doc(x1_input), x1_f.view(-1, x1_f.size(-2), x1_f.size(-1))], dim=2) 245 | x1_mask = x1_full_mask.view(-1, x1_full_mask.size(-1)) 246 | 247 | if self.opt['do_prealign']: 248 | x1_atten = self.pre_align(x1_emb_expand, x2_emb, x2_mask) 249 | x1_input = torch.cat([x1_input, x1_atten], dim=2) 250 | 251 | # === Start processing the dialog === 252 | # cur_h: [batch_size * max_qa_pair, context_length, hidden_state] 253 | # flow : fn (rnn) 254 | # x1_full: [batch_size, max_qa_pair, context_length] 255 | def flow_operation(cur_h, flow): 256 | flow_in = cur_h.transpose(0, 1).view(x1_full.size(2), x1_full.size(0), x1_full.size(1), -1) 257 | flow_in = flow_in.transpose(0, 2).contiguous().view(x1_full.size(1), x1_full.size(0) * x1_full.size(2), -1).transpose(0, 1) 258 | # [bsz * context_length, max_qa_pair, hidden_state] 259 | flow_out = flow(flow_in) 260 | # [bsz * context_length, max_qa_pair, flow_hidden_state_dim (hidden_state/2)] 261 | if self.opt['no_dialog_flow']: 262 | flow_out = flow_out * 0 263 | 264 | flow_out = flow_out.transpose(0, 1).view(x1_full.size(1), x1_full.size(0), x1_full.size(2), -1).transpose(0, 2).contiguous() 265 | flow_out = flow_out.view(x1_full.size(2), x1_full.size(0) * x1_full.size(1), -1).transpose(0, 1) 266 | # [bsz * max_qa_pair, context_length, flow_hidden_state_dim] 267 | return flow_out 268 | 269 | # Encode document with RNN 270 | doc_abstr_ls = [] 271 | 272 | doc_hiddens = self.doc_rnn1(x1_input, x1_mask) 273 | doc_hiddens_flow = flow_operation(doc_hiddens, self.dialog_flow1) 274 | 275 | doc_abstr_ls.append(doc_hiddens) 276 | 277 | doc_hiddens = self.doc_rnn2(torch.cat((doc_hiddens, doc_hiddens_flow, x1_cove_high_expand), dim=2), x1_mask) 278 | doc_hiddens_flow = flow_operation(doc_hiddens, self.dialog_flow2) 279 | doc_abstr_ls.append(doc_hiddens) 280 | 281 | #with open('flow_bef_att.pkl', 'wb') as output: 282 | # pickle.dump(doc_hiddens_flow, output, pickle.HIGHEST_PROTOCOL) 283 | #while(1): 284 | # pass 285 | 286 | # Encode question with RNN 287 | _, que_abstr_ls = self.question_rnn(x2_input, x2_mask, return_list=True, additional_x=x2_cove_high) 288 | 289 | # Final question layer 290 | question_hiddens = self.high_lvl_qrnn(torch.cat(que_abstr_ls, 2), x2_mask) 291 | que_abstr_ls += [question_hiddens] 292 | 293 | # Main Attention Fusion Layer 294 | doc_info = self.deep_attn([torch.cat([x1_emb_expand, x1_cove_high_expand], 2)], doc_abstr_ls, 295 | [torch.cat([x2_emb, x2_cove_high], 2)], que_abstr_ls, x1_mask, x2_mask) 296 | 297 | doc_hiddens = self.deep_attn_rnn(torch.cat((doc_info, doc_hiddens_flow), dim=2), x1_mask) 298 | doc_hiddens_flow = flow_operation(doc_hiddens, self.dialog_flow3) 299 | 300 | doc_abstr_ls += [doc_hiddens] 301 | 302 | # Self Attention Fusion Layer 303 | x1_att = torch.cat(doc_abstr_ls, 2) 304 | 305 | if self.opt['self_attention_opt'] > 0: 306 | highlvl_self_attn_hiddens = self.highlvl_self_att(x1_att, x1_att, x1_mask, x3=doc_hiddens, drop_diagonal=True) 307 | doc_hiddens = self.high_lvl_crnn(torch.cat([doc_hiddens, highlvl_self_attn_hiddens, doc_hiddens_flow], dim=2), x1_mask) 308 | elif self.opt['self_attention_opt'] == 0: 309 | doc_hiddens = self.high_lvl_crnn(torch.cat([doc_hiddens, doc_hiddens_flow], dim=2), x1_mask) 310 | 311 | doc_abstr_ls += [doc_hiddens] 312 | 313 | # Merge the question hidden vectors 314 | q_merge_weights = self.self_attn(question_hiddens, x2_mask) 315 | question_avg_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) 316 | if self.opt['do_hierarchical_query']: 317 | question_avg_hidden = self.hier_query_rnn(question_avg_hidden.view(x1_full.size(0), x1_full.size(1), -1)) 318 | question_avg_hidden = question_avg_hidden.contiguous().view(-1, question_avg_hidden.size(-1)) 319 | 320 | # Get Start, End span 321 | start_scores, end_scores = self.get_answer(doc_hiddens, question_avg_hidden, x1_mask) 322 | all_start_scores = start_scores.view_as(x1_full) # batch x q_num x len_d 323 | all_end_scores = end_scores.view_as(x1_full) # batch x q_num x len_d 324 | 325 | # Get whether there is an answer 326 | doc_avg_hidden = torch.cat((torch.max(doc_hiddens, dim=1)[0], torch.mean(doc_hiddens, dim=1)), dim=1) 327 | class_scores = self.ans_type_prediction(doc_avg_hidden, question_avg_hidden) 328 | all_class_scores = class_scores.view(x1_full.size(0), x1_full.size(1), -1) # batch x q_num x class_num 329 | all_class_scores = all_class_scores.squeeze(-1) # when class_num = 1 330 | 331 | return all_start_scores, all_end_scores, all_class_scores 332 | -------------------------------------------------------------------------------- /QA_model/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import msgpack 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.nn.parameter import Parameter 9 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 10 | from torch.nn.utils.rnn import pack_padded_sequence as pack 11 | 12 | # ------------------------------------------------------------------------------ 13 | # Neural Modules 14 | # ------------------------------------------------------------------------------ 15 | 16 | def set_seq_dropout(option): # option = True or False 17 | global do_seq_dropout 18 | do_seq_dropout = option 19 | 20 | def set_my_dropout_prob(p): # p between 0 to 1 21 | global my_dropout_p 22 | my_dropout_p = p 23 | 24 | def seq_dropout(x, p=0, training=False): 25 | """ 26 | x: batch * len * input_size 27 | """ 28 | if training == False or p == 0: 29 | return x 30 | dropout_mask = 1.0 / (1-p) * torch.bernoulli((1-p) * (x.new_zeros(x.size(0), x.size(2)) + 1)) 31 | return dropout_mask.unsqueeze(1).expand_as(x) * x 32 | 33 | def dropout(x, p=0, training=False): 34 | """ 35 | x: (batch * len * input_size) or (any other shape) 36 | """ 37 | if do_seq_dropout and len(x.size()) == 3: # if x is (batch * len * input_size) 38 | return seq_dropout(x, p=p, training=training) 39 | else: 40 | return F.dropout(x, p=p, training=training) 41 | 42 | class StackedBRNN(nn.Module): 43 | def __init__(self, input_size, hidden_size, num_layers, rnn_type=nn.LSTM, concat_layers=False, do_residual=False, add_feat=0, dialog_flow=False, bidir=True): 44 | super(StackedBRNN, self).__init__() 45 | self.num_layers = num_layers 46 | self.concat_layers = concat_layers 47 | self.do_residual = do_residual 48 | self.dialog_flow = dialog_flow 49 | self.hidden_size = hidden_size 50 | 51 | self.rnns = nn.ModuleList() 52 | for i in range(num_layers): 53 | input_size = input_size if i == 0 else (2 * hidden_size + add_feat if i == 1 else 2 * hidden_size) 54 | if self.dialog_flow == True: 55 | input_size += 2 * hidden_size 56 | self.rnns.append(rnn_type(input_size, hidden_size,num_layers=1,bidirectional=bidir)) 57 | 58 | def forward(self, x, x_mask=None, return_list=False, additional_x=None, previous_hiddens=None): 59 | # return_list: return a list for layers of hidden vectors 60 | # Transpose batch and sequence dims 61 | x = x.transpose(0, 1) 62 | if additional_x is not None: 63 | additional_x = additional_x.transpose(0, 1) 64 | 65 | # Encode all layers 66 | hiddens = [x] 67 | for i in range(self.num_layers): 68 | rnn_input = hiddens[-1] 69 | if i == 1 and additional_x is not None: 70 | rnn_input = torch.cat((rnn_input, additional_x), 2) 71 | # Apply dropout to input 72 | if my_dropout_p > 0: 73 | rnn_input = dropout(rnn_input, p=my_dropout_p, training=self.training) 74 | if self.dialog_flow == True: 75 | if previous_hiddens is not None: 76 | dialog_memory = previous_hiddens[i-1].transpose(0, 1) 77 | else: 78 | dialog_memory = rnn_input.new_zeros((rnn_input.size(0), rnn_input.size(1), self.hidden_size * 2)) 79 | rnn_input = torch.cat((rnn_input, dialog_memory), 2) 80 | # Forward 81 | rnn_output = self.rnns[i](rnn_input)[0] 82 | if self.do_residual and i > 0: 83 | rnn_output = rnn_output + hiddens[-1] 84 | hiddens.append(rnn_output) 85 | 86 | # Transpose back 87 | hiddens = [h.transpose(0, 1) for h in hiddens] 88 | 89 | # Concat hidden layers 90 | if self.concat_layers: 91 | output = torch.cat(hiddens[1:], 2) 92 | else: 93 | output = hiddens[-1] 94 | 95 | if return_list: 96 | return output, hiddens[1:] 97 | else: 98 | return output 99 | 100 | def RNN_from_opt(input_size_, hidden_size_, opt, num_layers=-1, concat_rnn=None, add_feat=0, dialog_flow=False): 101 | RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} 102 | new_rnn = StackedBRNN( 103 | input_size=input_size_, 104 | hidden_size=hidden_size_, 105 | num_layers=num_layers if num_layers > 0 else opt['rnn_layers'], 106 | rnn_type=RNN_TYPES[opt['rnn_type']], 107 | concat_layers=concat_rnn if concat_rnn is not None else opt['concat_rnn'], 108 | do_residual=opt['do_residual_rnn'] or opt['do_residual_everything'], 109 | add_feat=add_feat, 110 | dialog_flow=dialog_flow 111 | ) 112 | output_size = 2 * hidden_size_ 113 | if (concat_rnn if concat_rnn is not None else opt['concat_rnn']): 114 | output_size *= num_layers if num_layers > 0 else opt['rnn_layers'] 115 | return new_rnn, output_size 116 | 117 | class MemoryLasagna_Time(nn.Module): 118 | def __init__(self, input_size, hidden_size, rnn_type='lstm'): 119 | super(MemoryLasagna_Time, self).__init__() 120 | RNN_TYPES = {'lstm': nn.LSTMCell, 'gru': nn.GRUCell} 121 | 122 | self.rnn = RNN_TYPES[rnn_type](input_size, hidden_size) 123 | self.rnn_type = rnn_type 124 | self.input_size = input_size 125 | self.hidden_size = hidden_size 126 | 127 | def forward(self, x, memory): 128 | if self.training: 129 | x = x * self.dropout_mask 130 | 131 | memory = self.rnn(x.contiguous().view(-1, x.size(-1)), memory) 132 | if self.rnn_type == 'lstm': 133 | h = memory[0].view(x.size(0), x.size(1), -1) 134 | else: 135 | h = memory.view(x.size(0), x.size(1), -1) 136 | return h, memory 137 | 138 | def get_init(self, sample_tensor): 139 | global my_dropout_p 140 | self.dropout_mask = 1.0 / (1-my_dropout_p) * torch.bernoulli((1-my_dropout_p) * (sample_tensor.new_zeros(sample_tensor.size(0), sample_tensor.size(1), self.input_size) + 1)) 141 | 142 | h = sample_tensor.new_zeros(sample_tensor.size(0), sample_tensor.size(1), self.hidden_size).float() 143 | memory = sample_tensor.new_zeros(sample_tensor.size(0) * sample_tensor.size(1), self.hidden_size).float() 144 | if self.rnn_type == 'lstm': 145 | memory = (memory, memory) 146 | return h, memory 147 | 148 | class MTLSTM(nn.Module): 149 | def __init__(self, opt, embedding=None, padding_idx=0): 150 | """Initialize an MTLSTM 151 | 152 | Arguments: 153 | embedding (Float Tensor): If not None, initialize embedding matrix with specified embedding vectors 154 | """ 155 | super(MTLSTM, self).__init__() 156 | 157 | self.embedding = nn.Embedding(opt['vocab_size'], opt['embedding_dim'], padding_idx=padding_idx) 158 | if embedding is not None: 159 | self.embedding.weight.data = embedding 160 | 161 | state_dict = torch.load(opt['MTLSTM_path']) 162 | self.rnn1 = nn.LSTM(300, 300, num_layers=1, bidirectional=True) 163 | self.rnn2 = nn.LSTM(600, 300, num_layers=1, bidirectional=True) 164 | 165 | state_dict1 = dict([(name, param.data) if isinstance(param, Parameter) else (name, param) 166 | for name, param in state_dict.items() if '0' in name]) 167 | state_dict2 = dict([(name.replace('1', '0'), param.data) if isinstance(param, Parameter) else (name.replace('1', '0'), param) 168 | for name, param in state_dict.items() if '1' in name]) 169 | self.rnn1.load_state_dict(state_dict1) 170 | self.rnn2.load_state_dict(state_dict2) 171 | 172 | for p in self.embedding.parameters(): 173 | p.requires_grad = False 174 | for p in self.rnn1.parameters(): 175 | p.requires_grad = False 176 | for p in self.rnn2.parameters(): 177 | p.requires_grad = False 178 | 179 | self.output_size = 600 180 | 181 | def setup_eval_embed(self, eval_embed, padding_idx=0): 182 | """Allow evaluation vocabulary size to be greater than training vocabulary size 183 | 184 | Arguments: 185 | eval_embed (Float Tensor): Initialize eval_embed to be the specified embedding vectors 186 | """ 187 | self.eval_embed = nn.Embedding(eval_embed.size(0), eval_embed.size(1), padding_idx = padding_idx) 188 | self.eval_embed.weight.data = eval_embed 189 | 190 | for p in self.eval_embed.parameters(): 191 | p.requires_grad = False 192 | 193 | def forward(self, x_idx, x_mask): 194 | """A pretrained MT-LSTM (McCann et. al. 2017). 195 | This LSTM was trained with 300d 840B GloVe on the WMT 2017 machine translation dataset. 196 | 197 | Arguments: 198 | x_idx (Long Tensor): a Long Tensor of size (batch * len). 199 | x_mask (Byte Tensor): a Byte Tensor of mask for the input tensor (batch * len). 200 | """ 201 | emb = self.embedding if self.training else self.eval_embed 202 | x_hiddens = emb(x_idx) 203 | 204 | lengths = x_mask.data.eq(0).long().sum(dim=1) 205 | lens, indices = torch.sort(lengths, 0, True) 206 | 207 | output1, _ = self.rnn1(pack(x_hiddens[indices], lens.tolist(), batch_first=True)) 208 | output2, _ = self.rnn2(output1) 209 | 210 | output1 = unpack(output1, batch_first=True)[0] 211 | output2 = unpack(output2, batch_first=True)[0] 212 | 213 | _, _indices = torch.sort(indices, 0) 214 | output1 = output1[_indices] 215 | output2 = output2[_indices] 216 | 217 | return output1, output2 218 | 219 | # Attention layers 220 | class AttentionScore(nn.Module): 221 | """ 222 | sij = Relu(Wx1)DRelu(Wx2) 223 | """ 224 | def __init__(self, input_size, attention_hidden_size, similarity_score = False): 225 | super(AttentionScore, self).__init__() 226 | self.linear = nn.Linear(input_size, attention_hidden_size, bias=False) 227 | 228 | if similarity_score: 229 | self.linear_final = Parameter(torch.ones(1, 1, 1) / (attention_hidden_size ** 0.5), requires_grad = False) 230 | else: 231 | self.linear_final = Parameter(torch.ones(1, 1, attention_hidden_size), requires_grad = True) 232 | 233 | def forward(self, x1, x2): 234 | """ 235 | x1: batch * len1 * input_size 236 | x2: batch * len2 * input_size 237 | scores: batch * len1 * len2 238 | """ 239 | x1 = dropout(x1, p=my_dropout_p, training=self.training) 240 | x2 = dropout(x2, p=my_dropout_p, training=self.training) 241 | 242 | x1_rep = self.linear(x1.contiguous().view(-1, x1.size(-1))).view(x1.size(0), x1.size(1), -1) 243 | x2_rep = self.linear(x2.contiguous().view(-1, x2.size(-1))).view(x2.size(0), x2.size(1), -1) 244 | 245 | x1_rep = F.relu(x1_rep) 246 | x2_rep = F.relu(x2_rep) 247 | final_v = self.linear_final.expand_as(x2_rep) 248 | 249 | x2_rep_v = final_v * x2_rep 250 | scores = x1_rep.bmm(x2_rep_v.transpose(1, 2)) 251 | return scores 252 | 253 | 254 | class GetAttentionHiddens(nn.Module): 255 | def __init__(self, input_size, attention_hidden_size, similarity_attention = False): 256 | super(GetAttentionHiddens, self).__init__() 257 | self.scoring = AttentionScore(input_size, attention_hidden_size, similarity_score=similarity_attention) 258 | 259 | def forward(self, x1, x2, x2_mask, x3=None, scores=None, return_scores=False, drop_diagonal=False): 260 | """ 261 | Using x1, x2 to calculate attention score, but x1 will take back info from x3. 262 | If x3 is not specified, x1 will attend on x2. 263 | 264 | x1: batch * len1 * x1_input_size 265 | x2: batch * len2 * x2_input_size 266 | x2_mask: batch * len2 267 | 268 | x3: batch * len2 * x3_input_size (or None) 269 | """ 270 | if x3 is None: 271 | x3 = x2 272 | 273 | if scores is None: 274 | scores = self.scoring(x1, x2) 275 | 276 | # Mask padding 277 | x2_mask = x2_mask.unsqueeze(1).expand_as(scores) 278 | scores.data.masked_fill_(x2_mask.data, -float('inf')) 279 | if drop_diagonal: 280 | assert(scores.size(1) == scores.size(2)) 281 | diag_mask = torch.diag(scores.data.new(scores.size(1)).zero_() + 1).byte().unsqueeze(0).expand_as(scores) 282 | scores.data.masked_fill_(diag_mask, -float('inf')) 283 | 284 | # Normalize with softmax 285 | alpha = F.softmax(scores, dim=2) 286 | 287 | # Take weighted average 288 | matched_seq = alpha.bmm(x3) 289 | if return_scores: 290 | return matched_seq, scores 291 | else: 292 | return matched_seq 293 | 294 | class DeepAttention(nn.Module): 295 | def __init__(self, opt, abstr_list_cnt, deep_att_hidden_size_per_abstr, do_similarity=False, word_hidden_size=None, do_self_attn=False, dialog_flow=False, no_rnn=False): 296 | super(DeepAttention, self).__init__() 297 | 298 | self.no_rnn = no_rnn 299 | 300 | word_hidden_size = opt['embedding_dim'] if word_hidden_size is None else word_hidden_size 301 | abstr_hidden_size = opt['hidden_size'] * 2 302 | 303 | att_size = abstr_hidden_size * abstr_list_cnt + word_hidden_size 304 | 305 | self.int_attn_list = nn.ModuleList() 306 | for i in range(abstr_list_cnt+1): 307 | self.int_attn_list.append(GetAttentionHiddens(att_size, deep_att_hidden_size_per_abstr, similarity_attention=do_similarity)) 308 | 309 | rnn_input_size = abstr_hidden_size * abstr_list_cnt * 2 + (opt['hidden_size'] * 2) 310 | 311 | self.att_final_size = rnn_input_size 312 | if not self.no_rnn: 313 | self.rnn, self.output_size = RNN_from_opt(rnn_input_size, opt['hidden_size'], opt, num_layers=1, dialog_flow=dialog_flow) 314 | #print('Deep attention x {}: Each with {} rays in {}-dim space'.format(abstr_list_cnt, deep_att_hidden_size_per_abstr, att_size)) 315 | #print('Deep attention RNN input {} -> output {}'.format(self.rnn_input_size, self.output_size)) 316 | 317 | self.opt = opt 318 | self.do_self_attn = do_self_attn 319 | 320 | def forward(self, x1_word, x1_abstr, x2_word, x2_abstr, x1_mask, x2_mask, return_bef_rnn=False, previous_hiddens=None): 321 | """ 322 | x1_word, x2_word, x1_abstr, x2_abstr are list of 3D tensors. 323 | 3D tensor: batch_size * length * hidden_size 324 | """ 325 | # the last tensor of x2_abstr is an addtional tensor 326 | x1_att = torch.cat(x1_word + x1_abstr, 2) 327 | x2_att = torch.cat(x2_word + x2_abstr[:-1], 2) 328 | x1 = torch.cat(x1_abstr, 2) 329 | 330 | x2_list = x2_abstr 331 | for i in range(len(x2_list)): 332 | attn_hiddens = self.int_attn_list[i](x1_att, x2_att, x2_mask, x3=x2_list[i], drop_diagonal=self.do_self_attn) 333 | x1 = torch.cat((x1, attn_hiddens), 2) 334 | 335 | if not self.no_rnn: 336 | x1_hiddens = self.rnn(x1, x1_mask, previous_hiddens=previous_hiddens) 337 | if return_bef_rnn: 338 | return x1_hiddens, x1 339 | else: 340 | return x1_hiddens 341 | else: 342 | return x1 343 | 344 | # For summarizing a set of vectors into a single vector 345 | class LinearSelfAttn(nn.Module): 346 | """Self attention over a sequence: 347 | * o_i = softmax(Wx_i) for x_i in X. 348 | """ 349 | def __init__(self, input_size): 350 | super(LinearSelfAttn, self).__init__() 351 | self.linear = nn.Linear(input_size, 1) 352 | 353 | def forward(self, x, x_mask): 354 | """ 355 | x = batch * len * hdim 356 | x_mask = batch * len 357 | """ 358 | x = dropout(x, p=my_dropout_p, training=self.training) 359 | 360 | x_flat = x.contiguous().view(-1, x.size(-1)) 361 | scores = self.linear(x_flat).view(x.size(0), x.size(1)) 362 | scores.data.masked_fill_(x_mask.data, -float('inf')) 363 | alpha = F.softmax(scores, dim=1) 364 | return alpha 365 | 366 | # For attending the span in document from the query 367 | class BilinearSeqAttn(nn.Module): 368 | """A bilinear attention layer over a sequence X w.r.t y: 369 | * o_i = x_i'Wy for x_i in X. 370 | """ 371 | def __init__(self, x_size, y_size, opt, identity=False): 372 | super(BilinearSeqAttn, self).__init__() 373 | if not identity: 374 | self.linear = nn.Linear(y_size, x_size) 375 | else: 376 | self.linear = None 377 | 378 | def forward(self, x, y, x_mask): 379 | """ 380 | x = batch * len * h1 381 | y = batch * h2 382 | x_mask = batch * len 383 | """ 384 | x = dropout(x, p=my_dropout_p, training=self.training) 385 | y = dropout(y, p=my_dropout_p, training=self.training) 386 | 387 | Wy = self.linear(y) if self.linear is not None else y 388 | xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) 389 | xWy.data.masked_fill_(x_mask.data, -float('inf')) 390 | return xWy 391 | 392 | class GetSpanStartEnd(nn.Module): 393 | # supports MLP attention and GRU for pointer network updating 394 | def __init__(self, x_size, h_size, opt, do_indep_attn=True, attn_type="Bilinear", do_ptr_update=True): 395 | super(GetSpanStartEnd, self).__init__() 396 | 397 | self.attn = BilinearSeqAttn(x_size, h_size, opt) 398 | self.attn2 = BilinearSeqAttn(x_size, h_size, opt) if do_indep_attn else None 399 | 400 | self.rnn = nn.GRUCell(x_size, h_size) if do_ptr_update else None 401 | 402 | def forward(self, x, h0, x_mask): 403 | """ 404 | x = batch * len * x_size 405 | h0 = batch * h_size 406 | x_mask = batch * len 407 | """ 408 | st_scores = self.attn(x, h0, x_mask) 409 | # st_scores = batch * len 410 | 411 | if self.rnn is not None: 412 | ptr_net_in = torch.bmm(F.softmax(st_scores, dim=1).unsqueeze(1), x).squeeze(1) 413 | ptr_net_in = dropout(ptr_net_in, p=my_dropout_p, training=self.training) 414 | h0 = dropout(h0, p=my_dropout_p, training=self.training) 415 | h1 = self.rnn(ptr_net_in, h0) 416 | # h1 same size as h0 417 | else: 418 | h1 = h0 419 | 420 | end_scores = self.attn(x, h1, x_mask) if self.attn2 is None else\ 421 | self.attn2(x, h1, x_mask) 422 | # end_scores = batch * len 423 | return st_scores, end_scores 424 | 425 | class BilinearLayer(nn.Module): 426 | def __init__(self, x_size, y_size, class_num): 427 | super(BilinearLayer, self).__init__() 428 | self.linear = nn.Linear(y_size, x_size * class_num) 429 | self.class_num = class_num 430 | 431 | def forward(self, x, y): 432 | """ 433 | x = batch * h1 434 | y = batch * h2 435 | """ 436 | x = dropout(x, p=my_dropout_p, training=self.training) 437 | y = dropout(y, p=my_dropout_p, training=self.training) 438 | 439 | Wy = self.linear(y) 440 | Wy = Wy.view(Wy.size(0), self.class_num, x.size(1)) 441 | xWy = torch.sum(x.unsqueeze(1).expand_as(Wy) * Wy, dim=2) 442 | return xWy.squeeze(-1) # size = batch * class_num 443 | 444 | # ------------------------------------------------------------------------------ 445 | # Functional 446 | # ------------------------------------------------------------------------------ 447 | 448 | # by default in PyTorch, +-*/ are all element-wise 449 | def uniform_weights(x, x_mask): # used in lego_reader.py 450 | """Return uniform weights over non-masked input.""" 451 | alpha = Variable(torch.ones(x.size(0), x.size(1))) 452 | if x.data.is_cuda: 453 | alpha = alpha.cuda() 454 | alpha = alpha * x_mask.eq(0).float() 455 | alpha = alpha / alpha.sum(1).expand(alpha.size()) 456 | return alpha 457 | 458 | # bmm: batch matrix multiplication 459 | # unsqueeze: add singleton dimension 460 | # squeeze: remove singleton dimension 461 | def weighted_avg(x, weights): # used in lego_reader.py 462 | """ x = batch * len * d 463 | weights = batch * len 464 | """ 465 | return weights.unsqueeze(1).bmm(x).squeeze(1) 466 | -------------------------------------------------------------------------------- /general_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import random 5 | import string 6 | import logging 7 | import argparse 8 | import unicodedata 9 | from shutil import copyfile 10 | from datetime import datetime 11 | from collections import Counter 12 | import torch 13 | import msgpack 14 | import json 15 | import numpy as np 16 | import pandas as pd 17 | from allennlp.modules.elmo import batch_to_ids 18 | 19 | #=========================================================================== 20 | #================= All for preprocessing SQuAD data set ==================== 21 | #=========================================================================== 22 | 23 | def len_preserved_normalize_answer(s): 24 | """Lower text and remove punctuation, articles and extra whitespace.""" 25 | 26 | def len_preserved_space(matchobj): 27 | return ' ' * len(matchobj.group(0)) 28 | 29 | def remove_articles(text): 30 | return re.sub(r'\b(a|an|the)\b', len_preserved_space, text) 31 | 32 | def remove_punc(text): 33 | exclude = set(string.punctuation) 34 | return ''.join(ch if ch not in exclude else " " for ch in text) 35 | 36 | def lower(text): 37 | return text.lower() 38 | 39 | return remove_articles(remove_punc(lower(s))) 40 | 41 | def split_with_span(s): 42 | if s.split() == []: 43 | return [], [] 44 | else: 45 | return zip(*[(m.group(0), (m.start(), m.end()-1)) for m in re.finditer(r'\S+', s)]) 46 | 47 | def free_text_to_span(free_text, full_text): 48 | if free_text == "unknown": 49 | return "__NA__", -1, -1 50 | if normalize_answer(free_text) == "yes": 51 | return "__YES__", -1, -1 52 | if normalize_answer(free_text) == "no": 53 | return "__NO__", -1, -1 54 | 55 | free_ls = len_preserved_normalize_answer(free_text).split() 56 | full_ls, full_span = split_with_span(len_preserved_normalize_answer(full_text)) 57 | if full_ls == []: 58 | return full_text, 0, len(full_text) 59 | 60 | max_f1, best_index = 0.0, (0, len(full_ls)-1) 61 | free_cnt = Counter(free_ls) 62 | for i in range(len(full_ls)): 63 | full_cnt = Counter() 64 | for j in range(len(full_ls)): 65 | if i+j >= len(full_ls): break 66 | full_cnt[full_ls[i+j]] += 1 67 | 68 | common = free_cnt & full_cnt 69 | num_same = sum(common.values()) 70 | if num_same == 0: continue 71 | 72 | precision = 1.0 * num_same / (j + 1) 73 | recall = 1.0 * num_same / len(free_ls) 74 | f1 = (2 * precision * recall) / (precision + recall) 75 | 76 | if max_f1 < f1: 77 | max_f1 = f1 78 | best_index = (i, j) 79 | 80 | assert(best_index is not None) 81 | (best_i, best_j) = best_index 82 | char_i, char_j = full_span[best_i][0], full_span[best_i+best_j][1]+1 83 | 84 | return full_text[char_i:char_j], char_i, char_j 85 | 86 | def flatten_json(file, proc_func): 87 | with open(file, encoding="utf8") as f: 88 | data = json.load(f)['data'] 89 | rows, contexts = [], [] 90 | for i in range(len(data)): 91 | partial_rows, context = proc_func(i, data[i]) 92 | rows.extend(partial_rows) 93 | contexts.append(context) 94 | return rows, contexts 95 | 96 | def normalize_text(text): 97 | return unicodedata.normalize('NFD', text) 98 | 99 | def load_glove_vocab(file, wv_dim): 100 | vocab = set() 101 | with open(file, encoding="utf8") as f: 102 | for line in f: 103 | elems = line.split() 104 | token = normalize_text(''.join(elems[0:-wv_dim])) 105 | vocab.add(token) 106 | return vocab 107 | 108 | def space_extend(matchobj): 109 | return ' ' + matchobj.group(0) + ' ' 110 | 111 | def pre_proc(text): 112 | # make hyphens, spaces clean 113 | text = re.sub(u'-|\u2010|\u2011|\u2012|\u2013|\u2014|\u2015|%|\[|\]|:|\(|\)|/', space_extend, text) 114 | text = text.strip(' \n') 115 | text = re.sub('\s+', ' ', text) 116 | return text 117 | 118 | def feature_gen(C_docs, Q_CID, Q_docs, no_match): 119 | C_tags = [[w.tag_ for w in doc] for doc in C_docs] 120 | C_ents = [[w.ent_type_ for w in doc] for doc in C_docs] 121 | C_features = [] 122 | 123 | for question, context_id in zip(Q_docs, Q_CID): 124 | context = C_docs[context_id] 125 | 126 | counter_ = Counter(w.text.lower() for w in context) 127 | total = sum(counter_.values()) 128 | term_freq = [counter_[w.text.lower()] / total for w in context] 129 | 130 | if no_match: 131 | C_features.append(list(zip(term_freq))) 132 | else: 133 | question_word = {w.text for w in question} 134 | question_lower = {w.text.lower() for w in question} 135 | question_lemma = {w.lemma_ if w.lemma_ != '-PRON-' else w.text.lower() for w in question} 136 | match_origin = [w.text in question_word for w in context] 137 | match_lower = [w.text.lower() in question_lower for w in context] 138 | match_lemma = [(w.lemma_ if w.lemma_ != '-PRON-' else w.text.lower()) in question_lemma for w in context] 139 | C_features.append(list(zip(match_origin, match_lower, match_lemma, term_freq))) 140 | 141 | return C_tags, C_ents, C_features 142 | 143 | def get_context_span(context, context_token): 144 | p_str = 0 145 | p_token = 0 146 | t_span = [] 147 | while p_str < len(context): 148 | if re.match('\s', context[p_str]): 149 | p_str += 1 150 | continue 151 | 152 | token = context_token[p_token] 153 | token_len = len(token) 154 | if context[p_str:p_str + token_len] != token: 155 | log.info("Something wrong with get_context_span()") 156 | return [] 157 | t_span.append((p_str, p_str + token_len)) 158 | 159 | p_str += token_len 160 | p_token += 1 161 | return t_span 162 | 163 | def find_answer_span(context_span, answer_start, answer_end): 164 | if answer_start == -1 and answer_end == -1: 165 | return (-1, -1) 166 | 167 | t_start, t_end = 0, 0 168 | for token_id, (s, t) in enumerate(context_span): 169 | if s <= answer_start: 170 | t_start = token_id 171 | if t <= answer_end: 172 | t_end = token_id 173 | 174 | if t_start == -1 or t_end == -1: 175 | print(context_span, answer_start, answer_end) 176 | return (None, None) 177 | else: 178 | return (t_start, t_end) 179 | 180 | def build_embedding(embed_file, targ_vocab, wv_dim): 181 | vocab_size = len(targ_vocab) 182 | emb = np.random.uniform(-1, 1, (vocab_size, wv_dim)) 183 | emb[0] = 0 # should be all 0 (using broadcast) 184 | 185 | w2id = {w: i for i, w in enumerate(targ_vocab)} 186 | with open(embed_file, encoding="utf8") as f: 187 | for line in f: 188 | elems = line.split() 189 | token = normalize_text(''.join(elems[0:-wv_dim])) 190 | if token in w2id: 191 | emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]] 192 | return emb 193 | 194 | def token2id(docs, vocab, unk_id=None): 195 | w2id = {w: i for i, w in enumerate(vocab)} 196 | ids = [[w2id[w] if w in w2id else unk_id for w in doc] for doc in docs] 197 | return ids 198 | 199 | #=========================================================================== 200 | #================ For batch generation (train & predict) =================== 201 | #=========================================================================== 202 | 203 | class BatchGen_CoQA: 204 | def __init__(self, data, batch_size, gpu, dialog_ctx=0, evaluation=False, context_maxlen=100000, precompute_elmo=0): 205 | ''' 206 | input: 207 | data - see train.py 208 | batch_size - int 209 | ''' 210 | self.dialog_ctx = dialog_ctx 211 | self.batch_size = batch_size 212 | self.context_maxlen = context_maxlen 213 | self.precompute_elmo = precompute_elmo 214 | 215 | self.eval = evaluation 216 | self.gpu = gpu 217 | 218 | self.context_num = len(data['context']) 219 | self.question_num = len(data['qa']) 220 | self.data = data 221 | 222 | def __len__(self): 223 | return (self.context_num + self.batch_size - 1) // self.batch_size 224 | 225 | def __iter__(self): 226 | # Random permutation for the context 227 | idx_perm = range(0, self.context_num) 228 | if not self.eval: 229 | idx_perm = np.random.permutation(idx_perm) 230 | 231 | batch_size = self.batch_size 232 | for batch_i in range((self.context_num + self.batch_size - 1) // self.batch_size): 233 | 234 | batch_idx = idx_perm[self.batch_size * batch_i: self.batch_size * (batch_i+1)] 235 | 236 | context_batch = [self.data['context'][i] for i in batch_idx] 237 | batch_size = len(context_batch) 238 | 239 | context_batch = list(zip(*context_batch)) 240 | 241 | # Process Context Tokens 242 | context_len = max(len(x) for x in context_batch[0]) 243 | if not self.eval: 244 | context_len = min(context_len, self.context_maxlen) 245 | context_id = torch.LongTensor(batch_size, context_len).fill_(0) 246 | for i, doc in enumerate(context_batch[0]): 247 | select_len = min(len(doc), context_len) 248 | context_id[i, :select_len] = torch.LongTensor(doc[:select_len]) 249 | 250 | # Process Context POS Tags 251 | context_tag = torch.LongTensor(batch_size, context_len).fill_(0) 252 | for i, doc in enumerate(context_batch[1]): 253 | select_len = min(len(doc), context_len) 254 | context_tag[i, :select_len] = torch.LongTensor(doc[:select_len]) 255 | 256 | # Process Context Named Entity 257 | context_ent = torch.LongTensor(batch_size, context_len).fill_(0) 258 | for i, doc in enumerate(context_batch[2]): 259 | select_len = min(len(doc), context_len) 260 | context_ent[i, :select_len] = torch.LongTensor(doc[:select_len]) 261 | 262 | if self.precompute_elmo > 0: 263 | if batch_i % self.precompute_elmo == 0: 264 | precompute_idx = idx_perm[self.batch_size * batch_i: self.batch_size * (batch_i+self.precompute_elmo)] 265 | elmo_tokens = [self.data['context'][i][6] for i in precompute_idx] 266 | context_cid = batch_to_ids(elmo_tokens) 267 | else: 268 | context_cid = torch.LongTensor(1).fill_(0) 269 | else: 270 | context_cid = batch_to_ids(context_batch[6]) 271 | 272 | # Process Questions (number = batch * Qseq) 273 | qa_data = self.data['qa'] 274 | 275 | question_num, question_len = 0, 0 276 | question_batch = [] 277 | for first_QID in context_batch[5]: 278 | i, question_seq = 0, [] 279 | while True: 280 | if first_QID + i >= len(qa_data) or qa_data[first_QID + i][0] != qa_data[first_QID][0]: # their corresponding context ID is different 281 | break 282 | question_seq.append(first_QID + i) 283 | question_len = max(question_len, len(qa_data[first_QID + i][1])) 284 | i += 1 285 | question_batch.append(question_seq) 286 | question_num = max(question_num, i) 287 | 288 | question_id = torch.LongTensor(batch_size, question_num, question_len).fill_(0) 289 | question_tokens = [] 290 | for i, q_seq in enumerate(question_batch): 291 | for j, id in enumerate(q_seq): 292 | doc = qa_data[id][1] 293 | question_id[i, j, :len(doc)] = torch.LongTensor(doc) 294 | question_tokens.append(qa_data[id][10]) 295 | 296 | for j in range(len(q_seq), question_num): 297 | question_id[i, j, :2] = torch.LongTensor([2, 3]) 298 | question_tokens.append(["", ""]) 299 | 300 | question_cid = batch_to_ids(question_tokens) 301 | 302 | # Process Context-Question Features 303 | feature_len = len(qa_data[0][2][0]) 304 | context_feature = torch.Tensor(batch_size, question_num, context_len, feature_len + (self.dialog_ctx * 3)).fill_(0) 305 | for i, q_seq in enumerate(question_batch): 306 | for j, id in enumerate(q_seq): 307 | doc = qa_data[id][2] 308 | select_len = min(len(doc), context_len) 309 | context_feature[i, j, :select_len, :feature_len] = torch.Tensor(doc[:select_len]) 310 | 311 | for prv_ctx in range(0, self.dialog_ctx): 312 | if j > prv_ctx: 313 | prv_id = id - prv_ctx - 1 314 | prv_ans_st, prv_ans_end, prv_rat_st, prv_rat_end, prv_ans_choice = qa_data[prv_id][3], qa_data[prv_id][4], qa_data[prv_id][5], qa_data[prv_id][6], qa_data[prv_id][7] 315 | 316 | if prv_ans_choice == 3: 317 | # There is an answer 318 | for k in range(prv_ans_st, prv_ans_end + 1): 319 | if k >= context_len: 320 | break 321 | context_feature[i, j, k, feature_len + prv_ctx * 3 + 1] = 1 322 | else: 323 | context_feature[i, j, :select_len, feature_len + prv_ctx * 3 + 2] = 1 324 | 325 | # Process Answer (w/ raw question, answer text) 326 | answer_s = torch.LongTensor(batch_size, question_num).fill_(0) 327 | answer_e = torch.LongTensor(batch_size, question_num).fill_(0) 328 | rationale_s = torch.LongTensor(batch_size, question_num).fill_(0) 329 | rationale_e = torch.LongTensor(batch_size, question_num).fill_(0) 330 | answer_c = torch.LongTensor(batch_size, question_num).fill_(0) 331 | overall_mask = torch.ByteTensor(batch_size, question_num).fill_(0) 332 | question, answer = [], [] 333 | for i, q_seq in enumerate(question_batch): 334 | question_pack, answer_pack = [], [] 335 | for j, id in enumerate(q_seq): 336 | answer_s[i, j], answer_e[i, j], rationale_s[i, j], rationale_e[i, j], answer_c[i, j] = qa_data[id][3], qa_data[id][4], qa_data[id][5], qa_data[id][6], qa_data[id][7] 337 | overall_mask[i, j] = 1 338 | question_pack.append(qa_data[id][8]) 339 | answer_pack.append(qa_data[id][9]) 340 | question.append(question_pack) 341 | answer.append(answer_pack) 342 | 343 | # Process Masks 344 | context_mask = torch.eq(context_id, 0) 345 | question_mask = torch.eq(question_id, 0) 346 | 347 | text = list(context_batch[3]) # raw text 348 | span = list(context_batch[4]) # character span for each words 349 | 350 | if self.gpu: # page locked memory for async data transfer 351 | context_id = context_id.pin_memory() 352 | context_feature = context_feature.pin_memory() 353 | context_tag = context_tag.pin_memory() 354 | context_ent = context_ent.pin_memory() 355 | context_mask = context_mask.pin_memory() 356 | question_id = question_id.pin_memory() 357 | question_mask = question_mask.pin_memory() 358 | answer_s = answer_s.pin_memory() 359 | answer_e = answer_e.pin_memory() 360 | rationale_s = rationale_s.pin_memory() 361 | rationale_e = rationale_e.pin_memory() 362 | answer_c = answer_c.pin_memory() 363 | overall_mask = overall_mask.pin_memory() 364 | context_cid = context_cid.pin_memory() 365 | question_cid = question_cid.pin_memory() 366 | 367 | yield (context_id, context_cid, context_feature, context_tag, context_ent, context_mask, 368 | question_id, question_cid, question_mask, overall_mask, 369 | answer_s, answer_e, answer_c, rationale_s, rationale_e, 370 | text, span, question, answer) 371 | 372 | class BatchGen_QuAC: 373 | def __init__(self, data, batch_size, gpu, dialog_ctx=0, use_dialog_act=False, evaluation=False, context_maxlen=100000, precompute_elmo=0): 374 | ''' 375 | input: 376 | data - see train.py 377 | batch_size - int 378 | ''' 379 | self.dialog_ctx = dialog_ctx 380 | self.use_dialog_act = use_dialog_act 381 | self.batch_size = batch_size 382 | self.context_maxlen = context_maxlen 383 | self.precompute_elmo = precompute_elmo 384 | 385 | self.eval = evaluation 386 | self.gpu = gpu 387 | 388 | self.context_num = len(data['context']) 389 | self.question_num = len(data['qa']) 390 | self.data = data 391 | 392 | def __len__(self): 393 | return (self.context_num + self.batch_size - 1) // self.batch_size 394 | 395 | def __iter__(self): 396 | # Random permutation for the context 397 | idx_perm = range(0, self.context_num) 398 | if not self.eval: 399 | idx_perm = np.random.permutation(idx_perm) 400 | 401 | batch_size = self.batch_size 402 | for batch_i in range((self.context_num + self.batch_size - 1) // self.batch_size): 403 | 404 | batch_idx = idx_perm[self.batch_size * batch_i: self.batch_size * (batch_i+1)] 405 | 406 | context_batch = [self.data['context'][i] for i in batch_idx] 407 | batch_size = len(context_batch) 408 | 409 | context_batch = list(zip(*context_batch)) 410 | 411 | # Process Context Tokens 412 | context_len = max(len(x) for x in context_batch[0]) 413 | if not self.eval: 414 | context_len = min(context_len, self.context_maxlen) 415 | context_id = torch.LongTensor(batch_size, context_len).fill_(0) 416 | for i, doc in enumerate(context_batch[0]): 417 | select_len = min(len(doc), context_len) 418 | context_id[i, :select_len] = torch.LongTensor(doc[:select_len]) 419 | 420 | # Process Context POS Tags 421 | context_tag = torch.LongTensor(batch_size, context_len).fill_(0) 422 | for i, doc in enumerate(context_batch[1]): 423 | select_len = min(len(doc), context_len) 424 | context_tag[i, :select_len] = torch.LongTensor(doc[:select_len]) 425 | 426 | # Process Context Named Entity 427 | context_ent = torch.LongTensor(batch_size, context_len).fill_(0) 428 | for i, doc in enumerate(context_batch[2]): 429 | select_len = min(len(doc), context_len) 430 | context_ent[i, :select_len] = torch.LongTensor(doc[:select_len]) 431 | 432 | if self.precompute_elmo > 0: 433 | if batch_i % self.precompute_elmo == 0: 434 | precompute_idx = idx_perm[self.batch_size * batch_i: self.batch_size * (batch_i+self.precompute_elmo)] 435 | elmo_tokens = [self.data['context'][i][6] for i in precompute_idx] 436 | context_cid = batch_to_ids(elmo_tokens) 437 | else: 438 | context_cid = torch.LongTensor(1).fill_(0) 439 | else: 440 | context_cid = batch_to_ids(context_batch[6]) 441 | 442 | # Process Questions (number = batch * Qseq) 443 | qa_data = self.data['qa'] 444 | 445 | question_num, question_len = 0, 0 446 | question_batch = [] 447 | for first_QID in context_batch[5]: 448 | i, question_seq = 0, [] 449 | while True: 450 | if first_QID + i >= len(qa_data) or qa_data[first_QID + i][0] != qa_data[first_QID][0]: # their corresponding context ID is different 451 | break 452 | question_seq.append(first_QID + i) 453 | question_len = max(question_len, len(qa_data[first_QID + i][1])) 454 | i += 1 455 | question_batch.append(question_seq) 456 | question_num = max(question_num, i) 457 | 458 | question_id = torch.LongTensor(batch_size, question_num, question_len).fill_(0) 459 | question_tokens = [] 460 | for i, q_seq in enumerate(question_batch): 461 | for j, id in enumerate(q_seq): 462 | doc = qa_data[id][1] 463 | question_id[i, j, :len(doc)] = torch.LongTensor(doc) 464 | question_tokens.append(qa_data[id][8]) 465 | 466 | for j in range(len(q_seq), question_num): 467 | question_id[i, j, :2] = torch.LongTensor([2, 3]) 468 | question_tokens.append(["", ""]) 469 | 470 | question_cid = batch_to_ids(question_tokens) 471 | 472 | # Process Context-Question Features 473 | feature_len = len(qa_data[0][2][0]) 474 | context_feature = torch.Tensor(batch_size, question_num, context_len, feature_len + (self.dialog_ctx * (self.use_dialog_act*3+2))).fill_(0) 475 | for i, q_seq in enumerate(question_batch): 476 | for j, id in enumerate(q_seq): 477 | doc = qa_data[id][2] 478 | select_len = min(len(doc), context_len) 479 | context_feature[i, j, :select_len, :feature_len] = torch.Tensor(doc[:select_len]) 480 | 481 | for prv_ctx in range(0, self.dialog_ctx): 482 | if j > prv_ctx: 483 | prv_id = id - prv_ctx - 1 484 | prv_ans_st, prv_ans_end, prv_ans_choice = qa_data[prv_id][3], qa_data[prv_id][4], qa_data[prv_id][5] 485 | 486 | # dialog act: don't follow-up, follow-up, maybe follow-up (prv_ans_choice // 10) 487 | if self.use_dialog_act: 488 | context_feature[i, j, :select_len, feature_len + prv_ctx * (self.use_dialog_act*3+2) + 2 + (prv_ans_choice // 10)] = 1 489 | 490 | if prv_ans_choice == 0: # indicating that the previous reply is NO ANSWER 491 | context_feature[i, j, :select_len, feature_len + prv_ctx * (self.use_dialog_act*3+2) + 1] = 1 492 | continue 493 | 494 | # There is an answer 495 | for k in range(prv_ans_st, prv_ans_end + 1): 496 | if k >= context_len: 497 | break 498 | context_feature[i, j, k, feature_len + prv_ctx * (self.use_dialog_act*3+2)] = 1 499 | 500 | # Process Answer (w/ raw question, answer text) 501 | answer_s = torch.LongTensor(batch_size, question_num).fill_(0) 502 | answer_e = torch.LongTensor(batch_size, question_num).fill_(0) 503 | answer_c = torch.LongTensor(batch_size, question_num).fill_(0) 504 | overall_mask = torch.ByteTensor(batch_size, question_num).fill_(0) 505 | question, answer = [], [] 506 | for i, q_seq in enumerate(question_batch): 507 | question_pack, answer_pack = [], [] 508 | for j, id in enumerate(q_seq): 509 | answer_s[i, j], answer_e[i, j], answer_c[i, j] = qa_data[id][3], qa_data[id][4], qa_data[id][5] 510 | overall_mask[i, j] = 1 511 | question_pack.append(qa_data[id][6]) 512 | answer_pack.append(qa_data[id][7]) 513 | question.append(question_pack) 514 | answer.append(answer_pack) 515 | 516 | # Process Masks 517 | context_mask = torch.eq(context_id, 0) 518 | question_mask = torch.eq(question_id, 0) 519 | 520 | text = list(context_batch[3]) # raw text 521 | span = list(context_batch[4]) # character span for each words 522 | 523 | if self.gpu: # page locked memory for async data transfer 524 | context_id = context_id.pin_memory() 525 | context_feature = context_feature.pin_memory() 526 | context_tag = context_tag.pin_memory() 527 | context_ent = context_ent.pin_memory() 528 | context_mask = context_mask.pin_memory() 529 | question_id = question_id.pin_memory() 530 | question_mask = question_mask.pin_memory() 531 | answer_s = answer_s.pin_memory() 532 | answer_e = answer_e.pin_memory() 533 | answer_c = answer_c.pin_memory() 534 | overall_mask = overall_mask.pin_memory() 535 | context_cid = context_cid.pin_memory() 536 | question_cid = question_cid.pin_memory() 537 | 538 | yield (context_id, context_cid, context_feature, context_tag, context_ent, context_mask, 539 | question_id, question_cid, question_mask, overall_mask, 540 | answer_s, answer_e, answer_c, 541 | text, span, question, answer) 542 | 543 | #=========================================================================== 544 | #========================== For QuAC evaluation ============================ 545 | #=========================================================================== 546 | 547 | def normalize_answer(s): 548 | """Lower text and remove punctuation, articles and extra whitespace.""" 549 | def remove_articles(text): 550 | return re.sub(r'\b(a|an|the)\b', ' ', text) 551 | 552 | def white_space_fix(text): 553 | return ' '.join(text.split()) 554 | 555 | def remove_punc(text): 556 | exclude = set(string.punctuation) 557 | return ''.join(ch for ch in text if ch not in exclude) 558 | 559 | def lower(text): 560 | return text.lower() 561 | 562 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 563 | 564 | def f1_score(prediction, ground_truth): 565 | prediction_tokens = normalize_answer(prediction).split() 566 | ground_truth_tokens = normalize_answer(ground_truth).split() 567 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 568 | num_same = sum(common.values()) 569 | if num_same == 0: 570 | return 0 571 | precision = 1.0 * num_same / len(prediction_tokens) 572 | recall = 1.0 * num_same / len(ground_truth_tokens) 573 | f1 = (2 * precision * recall) / (precision + recall) 574 | return f1 575 | 576 | def single_score(prediction, ground_truth): 577 | if prediction == "CANNOTANSWER" and ground_truth == "CANNOTANSWER": 578 | return 1.0 579 | elif prediction == "CANNOTANSWER" or ground_truth == "CANNOTANSWER": 580 | return 0.0 581 | else: 582 | return f1_score(prediction, ground_truth) 583 | 584 | def handle_cannot(refs): 585 | num_cannot = 0 586 | num_spans = 0 587 | for ref in refs: 588 | if ref == 'CANNOTANSWER': num_cannot += 1 589 | else: num_spans += 1 590 | 591 | if num_cannot >= num_spans: 592 | refs = ['CANNOTANSWER'] 593 | else: 594 | refs = [x for x in refs if x != 'CANNOTANSWER'] 595 | return refs 596 | 597 | def leave_one_out(refs): 598 | if len(refs) == 1: 599 | return 1.0 600 | 601 | t_f1 = 0.0 602 | for i in range(len(refs)): 603 | m_f1 = 0 604 | new_refs = refs[:i] + refs[i+1:] 605 | 606 | for j in range(len(new_refs)): 607 | f1_ij = single_score(refs[i], new_refs[j]) 608 | 609 | if f1_ij > m_f1: 610 | m_f1 = f1_ij 611 | t_f1 += m_f1 612 | 613 | return t_f1 / len(refs) 614 | 615 | def leave_one_out_max(prediction, ground_truths): 616 | scores_for_ground_truths = [] 617 | for ground_truth in ground_truths: 618 | scores_for_ground_truths.append(single_score(prediction, ground_truth)) 619 | 620 | if len(scores_for_ground_truths) == 1: 621 | return scores_for_ground_truths[0] 622 | else: 623 | # leave out one ref every time 624 | t_f1 = [] 625 | for i in range(len(scores_for_ground_truths)): 626 | t_f1.append(max(scores_for_ground_truths[:i] + scores_for_ground_truths[i+1:])) 627 | return 1.0 * sum(t_f1) / len(t_f1) 628 | 629 | def find_best_score_and_thresh(pred, truth, no_ans_score, min_F1=0.4): 630 | pred = [p for dialog_p in pred for p in dialog_p] 631 | truth = [t for dialog_t in truth for t in dialog_t] 632 | no_ans_score = [n for dialog_n in no_ans_score for n in dialog_n] 633 | 634 | clean_pred, clean_truth, clean_noans = [], [], [] 635 | 636 | all_f1 = [] 637 | for p, t, n in zip(pred, truth, no_ans_score): 638 | clean_t = handle_cannot(t) 639 | human_F1 = leave_one_out(clean_t) 640 | if human_F1 < min_F1: continue 641 | 642 | clean_pred.append(p) 643 | clean_truth.append(clean_t) 644 | clean_noans.append(n) 645 | all_f1.append(leave_one_out_max(p, clean_t)) 646 | 647 | cur_f1, best_f1 = sum(all_f1), sum(all_f1) 648 | best_thresh = max(clean_noans) + 1 649 | 650 | cur_noans, best_noans, noans_cnt = 0, 0, 0 651 | sort_idx = sorted(range(len(clean_noans)), key=lambda k: clean_noans[k], reverse=True) 652 | for i in sort_idx: 653 | if clean_truth[i] == ['CANNOTANSWER']: 654 | cur_f1 += 1 655 | cur_noans += 1 656 | noans_cnt += 1 657 | else: 658 | cur_f1 -= all_f1[i] 659 | cur_noans -= 1 660 | 661 | if cur_f1 > best_f1: 662 | best_f1 = cur_f1 663 | best_noans = cur_noans 664 | best_thresh = clean_noans[i] - 1e-7 665 | 666 | return 100.0 * best_f1 / len(clean_pred), 100.0 * (len(clean_pred) - noans_cnt + best_noans) / len(clean_pred), best_thresh 667 | 668 | def score(model_results, human_results, min_F1=0.4): 669 | Q_at_least_human, total_Qs = 0.0, 0.0 670 | D_at_least_human, total_Ds = 0.0, 0.0 671 | total_machine_f1, total_human_f1 = 0.0, 0.0 672 | 673 | assert len(human_results) == len(model_results) 674 | for human_diag_ans, model_diag_ans in zip(human_results, model_results): 675 | good_dialog = 1.0 676 | 677 | assert len(human_diag_ans) == len(model_diag_ans) 678 | for human_ans, model_ans in zip(human_diag_ans, model_diag_ans): 679 | # model_ans is (text, choice) 680 | # human_ans is a list of (text, choice) 681 | 682 | # human_ans[0] is the original dialog answer 683 | clean_human_ans = handle_cannot(human_ans) 684 | human_F1 = leave_one_out(clean_human_ans) 685 | 686 | if human_F1 < min_F1: continue 687 | 688 | machine_f1 = leave_one_out_max(model_ans, clean_human_ans) 689 | total_machine_f1 += machine_f1 690 | total_human_f1 += human_F1 691 | 692 | if machine_f1 >= human_F1: 693 | Q_at_least_human += 1.0 694 | else: 695 | good_dialog = 0.0 696 | total_Qs += 1.0 697 | 698 | D_at_least_human += good_dialog 699 | total_Ds += 1.0 700 | 701 | return 100.0 * total_machine_f1 / total_Qs, 100.0 * total_human_f1 / total_Qs, 100.0 * Q_at_least_human / total_Qs, 100.0 * D_at_least_human / total_Ds 702 | --------------------------------------------------------------------------------