├── .dockerignore ├── .gitignore ├── README.md ├── docker ├── Dockerfile └── wrap.sh ├── download.sh ├── editor_model ├── base.py └── double.py ├── evaluator.py ├── inference.py ├── list_exp.py ├── metric.py ├── model ├── __init__.py ├── base.py ├── editor.py ├── entail.py ├── retrieve.py └── span.py ├── preprocess_editor.py ├── preprocess_sharc.py ├── requirements.txt ├── train_editor.py └── train_sharc.py /.dockerignore: -------------------------------------------------------------------------------- 1 | cache/ 2 | editor_save 3 | editor_save/**/* 4 | save 5 | save/**/* 6 | pred 7 | pred/**/* 8 | .git 9 | .git/**/* 10 | sharc 11 | sharc/**/* 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | *.txt 3 | *.json* 4 | *.pt 5 | cache/ 6 | *save/ 7 | embeddings/ 8 | stanfordnlp_resources/ 9 | node_modules/ 10 | sharc/ 11 | trained/ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E3: Entailment-driven Extracting and Editing for Conversational Machine Reading 2 | 3 | This repository contains the source code for the paper [E3: Entailment-driven Extracting and Editing for Conversational Machine Reading](https://arxiv.org/abs/1906.05373). 4 | This work was published at ACL 2019. 5 | If you find the paper or this repository helpful in your work, please use the following citation: 6 | 7 | ``` 8 | @inproceedings{ zhong2019e3, 9 | title={ E3: Entailment-driven Extracting and Editing for Conversational Machine Reading }, 10 | author={ Zhong, Victor and Zettlemoyer, Luke }, 11 | booktitle={ ACL }, 12 | year={ 2019 } 13 | } 14 | ``` 15 | 16 | The output results from this codebase have minor differences from those reported in the paper due to library versions. 17 | The most consistent way to replicate the experiments is via the Docker instructions. 18 | Once ran, inference on the dev set should produce something like the following: 19 | 20 | ``` 21 | {'bleu_1': 0.6714, 22 | 'bleu_2': 0.6059, 23 | 'bleu_3': 0.5646, 24 | 'bleu_4': 0.5367, 25 | 'combined': 0.39372312, 26 | 'macro_accuracy': 0.7336, 27 | 'micro_accuracy': 0.6802} 28 | ``` 29 | 30 | In any event, the model binaries used for our submission are included in the `/opt/save` directory of the docker image `vzhong/e3`. 31 | For correspondence, please contact [Victor Zhong](mailto://victor@victorzhong.com). 32 | 33 | 34 | ## Non-Docker instructions 35 | 36 | If you have docker, scroll down to the (much shorter) docker instructions. 37 | 38 | 39 | ### Setup 40 | 41 | First we will install the dependencies required. 42 | 43 | ```bash 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | Next we'll download the pretrained BERT parameters and vocabulary, word embeddings, and Stanford NLP. 48 | This is a big download ~10GB. 49 | 50 | ```bash 51 | # StanfordNLP, BERT, and ShARC data 52 | ./download.sh 53 | 54 | # Spacy data for evaluator 55 | python -m spacy download en_core_web_md 56 | 57 | # word embeddings 58 | python -c "import embeddings as e; e.GloveEmbedding()" 59 | python -c "import embeddings as e; e.KazumaCharEmbedding()" 60 | ``` 61 | 62 | 63 | ### Training 64 | 65 | The E3 model is trained in two parts due to data imbalance (there are many more turn examples than full dialogue trees). 66 | The first part consists of everything except for the editor. 67 | The second part trains the editor alone, because it relies on unique dialogue trees, of which there are few compared to the total number of turn examples. 68 | We start by preprocessing the data. 69 | This command will print out some statistics from preprocessing the train/dev sets. 70 | 71 | ```bash 72 | ./preprocess_sharc.py 73 | ``` 74 | 75 | Now, we will train the model without the editor. 76 | With a Titan-X, this takes roughly 2 hours to complete. 77 | For more options, check out `python train_sharc.py --help` 78 | 79 | ```bash 80 | CUDA_VISIBLE_DEVICES=0 python train_sharc.py 81 | ``` 82 | 83 | Now, we will train the editor. 84 | Again, with a Titan-X, this takes roughly 20 minutes to complete. 85 | For more options, check out `python train_editor.py --help` 86 | 87 | ```bash 88 | ./preprocess_editor_sharc.py 89 | CUDA_VISIBLE_DEVICES=0 python train_editor.py 90 | ``` 91 | 92 | To evaluate the models, run `inference.py`. 93 | For more options, check out `python inference.py --help` 94 | 95 | ```bash 96 | CUDA_VISIBLE_DEVICES=0 python inference.py --retrieval save/default-entail/best.pt --editor editor_save/default-double/best.pt --verify 97 | ``` 98 | 99 | If you want to tune the models, you can also use `list_exp.py` to visualize the experiment results. 100 | The model ablations from our paper are found in the `model` directory. 101 | Namely, `base` is the BERTQA model (referred to in the paper as `E3-{edit,entail,extract}`), `retrieve` is the `E3-{edit,entail}` model, and `entail` is the `E3-{edit}` model. 102 | You can choose amongst these models using the `--model` flag in `train_sharc.py`. 103 | 104 | 105 | ## Docker instructions 106 | 107 | If you have `docker` (and `nvidia-docker`), then there is no need to install dependencies. 108 | You still need to clone this repo and run `download.sh`. 109 | For convenience, I've made a wrapper script that pass through your username and mounts the current directory. 110 | From inside the directory, to preprocess and train the model: 111 | 112 | ```bash 113 | docker/wrap.sh python preprocess_sharc.py 114 | NV_GPU=0 docker/wrap.sh python train_sharc.py 115 | docker/wrap.sh python preprocess_editor.py 116 | NV_GPU=0 docker/wrap.sh python train_editor.py 117 | ``` 118 | 119 | To evaluate the model and dump predictions in an output folder: 120 | 121 | ```bash 122 | NV_GPU=0 docker/wrap.sh python inference.py --retrieval save/default-entail/best.pt --editor editor_save/default-double/best.pt --verify 123 | ``` 124 | 125 | To reproduce our submission results with the included model binaries: 126 | 127 | ```bash 128 | NV_GPU=0 docker/wrap.sh python inference.py --retrieval /opt/save/retrieval.pt --editor /opt/save/editor.pt --verify 129 | ``` 130 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM anibali/pytorch:cuda-8.0 2 | 3 | RUN sudo apt-get update --quiet 4 | RUN sudo apt-get install gcc -yq 5 | 6 | RUN whoami 7 | ENV HOME=/home/user 8 | 9 | RUN sudo mkdir -p /opt/code 10 | RUN sudo chown -R user /opt/code 11 | WORKDIR /opt/code 12 | 13 | COPY requirements.txt . 14 | RUN pip install -r requirements.txt 15 | 16 | # download data files 17 | RUN mkdir cache 18 | RUN python -m spacy download en_core_web_md 19 | RUN python -c "import embeddings as e; e.GloveEmbedding()" 20 | RUN python -c "import embeddings as e; e.KazumaCharEmbedding()" 21 | 22 | RUN sudo apt-get install wget -yq 23 | RUN wget --quiet https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz -O cache/bert-base-uncased.tar.gz 24 | RUN wget --quiet https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt -O cache/bert-base-uncased-vocab.txt 25 | 26 | # add some models 27 | RUN sudo mkdir -p /opt/save 28 | RUN sudo chown -R user /opt/save 29 | COPY trained/* /opt/save/ 30 | -------------------------------------------------------------------------------- /docker/wrap.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | nvidia-docker run --rm \ 4 | -v $PWD:/opt/code \ 5 | -u $(id -u $USER):$(id -g $USER) \ 6 | vzhong/e3 \ 7 | $@ 8 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | mkdir -p cache 3 | 4 | # BERT parameters 5 | wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz -O cache/bert-base-uncased.tar.gz 6 | wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt -O cache/bert-base-uncased-vocab.txt 7 | 8 | # Stanford NLP data 9 | echo 'Y' | python -c "import stanfordnlp; stanfordnlp.download('en', resource_dir='cache')" 10 | 11 | # NOTE: the following lines are no longer necessary because the data files have been included in the repo - vzhong 12 | wget https://sharc-data.github.io/data/sharc1-official.zip 13 | unzip sharc1-official.zip 14 | rm sharc1-official.zip 15 | mv sharc1-official sharc 16 | 17 | -------------------------------------------------------------------------------- /editor_model/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import importlib 4 | from torch.nn import functional as F 5 | from torch.nn.utils.rnn import pad_sequence 6 | from model.base import Module as Base 7 | from model.editor import Decoder 8 | from metric import compute_f1 9 | from preprocess_sharc import detokenize 10 | 11 | 12 | class Module(Base): 13 | 14 | def __init__(self, args, vocab=None): 15 | super().__init__(args) 16 | self.denc = self.args.bert_hidden_size 17 | vocab = vocab or torch.load(os.path.join(args.data, 'vocab.pt')) 18 | self.emb = vocab['emb'] 19 | self.vocab = vocab['vocab'] 20 | self.decoder = Decoder(self.denc, self.emb, dropout=self.args.dropout) 21 | 22 | @classmethod 23 | def load_module(cls, name): 24 | return importlib.import_module('editor_model.{}'.format(name)).Module 25 | 26 | def create_input_tensors(self, batch): 27 | feat = { 28 | k: torch.stack([e[k] for e in batch], dim=0).to(self.device) 29 | for k in ['inp_ids', 'type_ids', 'inp_mask'] 30 | } 31 | feat['inp_mask'] = feat['inp_mask'].float() 32 | feat['out_vids'] = pad_sequence([e['out_vids'] for e in batch], batch_first=True, padding_value=-1).to(self.device) if self.training else None 33 | return feat 34 | 35 | def forward(self, batch): 36 | out = self.create_input_tensors(batch) 37 | out['bert_enc'], _ = bert_enc, _ = self.bert(out['inp_ids'], out['type_ids'], out['inp_mask'], output_all_encoded_layers=False) 38 | out['dec'] = self.decoder.forward(bert_enc, out['inp_mask'], out['out_vids'], max_decode_len=30) 39 | return out 40 | 41 | def extract_preds(self, out, batch): 42 | preds = [] 43 | for pred, ex in zip(out['dec'].max(2)[1].tolist(), batch): 44 | pred = self.vocab.index2word(pred) 45 | if 'eos' in pred: 46 | pred = pred[:pred.index('eos')] 47 | preds.append({ 48 | 'utterance_id': ex['utterance_id'], 49 | 'answer': ' '.join(pred), 50 | }) 51 | return preds 52 | 53 | def compute_loss(self, out, batch): 54 | return {'dec': F.cross_entropy(out['dec'].view(-1, len(self.vocab)), out['out_vids'].view(-1), ignore_index=-1)} 55 | 56 | def compute_metrics(self, preds, batch): 57 | f1s = [compute_f1(p['answer'], detokenize(e['question'])) for p, e in zip(preds, batch)] 58 | return {'f1': sum(f1s) / len(f1s)} 59 | -------------------------------------------------------------------------------- /editor_model/double.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from torch.nn.utils.rnn import pad_sequence 3 | from editor_model.base import Module as Base 4 | from model.editor import Decoder 5 | from preprocess_sharc import detokenize 6 | 7 | 8 | class Module(Base): 9 | 10 | def __init__(self, args, vocab=None): 11 | super().__init__(args, vocab=vocab) 12 | self.decoder_after = Decoder(self.denc, self.emb, dropout=self.args.dropout) 13 | 14 | def create_input_tensors(self, batch): 15 | feat = super().create_input_tensors(batch) 16 | feat['before_vids'] = pad_sequence([e['before_vids'] for e in batch], batch_first=True, padding_value=-1).to(self.device) if self.training else None 17 | feat['after_vids'] = pad_sequence([e['after_vids'] for e in batch], batch_first=True, padding_value=-1).to(self.device) if self.training else None 18 | return feat 19 | 20 | def forward(self, batch): 21 | out = self.create_input_tensors(batch) 22 | out['bert_enc'], _ = bert_enc, _ = self.bert(out['inp_ids'], out['type_ids'], out['inp_mask'], output_all_encoded_layers=False) 23 | out['before'] = self.decoder.forward(bert_enc, out['inp_mask'], out['before_vids'], max_decode_len=10) 24 | out['after'] = self.decoder_after.forward(bert_enc, out['inp_mask'], out['after_vids'], max_decode_len=10) 25 | return out 26 | 27 | def extract_preds(self, out, batch): 28 | preds = [] 29 | for before, after, ex in zip( 30 | out['before'].max(2)[1].tolist(), 31 | out['after'].max(2)[1].tolist(), 32 | batch): 33 | before = self.vocab.index2word(before) 34 | if 'eos' in before: 35 | before = before[:before.index('eos')] 36 | after = self.vocab.index2word(after) 37 | if 'eos' in after: 38 | after = after[:after.index('eos')] 39 | s, e = ex['span'] 40 | middle = detokenize(ex['inp'][s:e+1]) 41 | preds.append({ 42 | 'utterance_id': ex['utterance_id'], 43 | 'answer': '{} {} {}'.format(' '.join(before), middle, ' '.join(after)), 44 | }) 45 | return preds 46 | 47 | def compute_loss(self, out, batch): 48 | return { 49 | 'before': F.cross_entropy(out['before'].view(-1, len(self.vocab)), out['before_vids'].view(-1), ignore_index=-1), 50 | 'after': F.cross_entropy(out['after'].view(-1, len(self.vocab)), out['after_vids'].view(-1), ignore_index=-1), 51 | } 52 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import collections 5 | import math 6 | import numpy as np 7 | from sklearn.metrics import accuracy_score, confusion_matrix 8 | import spacy 9 | 10 | 11 | nlp = spacy.load('en_core_web_md') 12 | 13 | 14 | class ClassificationEvaluator: 15 | def __init__(self, labels=None): 16 | self.labels = labels 17 | 18 | def evaluate(self, y_true, y_pred): 19 | if not self.labels: 20 | self.labels = list(set(y_true)) 21 | 22 | # micro_accuracy = sum([y_t == y_p for y_t, y_p in zip(y_true, y_pred)]) / len(y_true) 23 | micro_accuracy = accuracy_score(y_true, y_pred) 24 | results = {} 25 | results["micro_accuracy"] = float("{0:.4f}".format(micro_accuracy)) #int(100 * micro_accuracy) / 100 26 | 27 | conf_mat = confusion_matrix(y_true, y_pred, labels=self.labels) 28 | conf_mat_norm = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis] 29 | macro_accuracy = np.mean([conf_mat_norm[i][i] for i in range(conf_mat_norm.shape[0])]) 30 | results["macro_accuracy"] = float("{0:.4f}".format(macro_accuracy)) #int(100 * macro_accuracy) / 100 31 | return results 32 | 33 | 34 | class MoreEvaluator: 35 | def __init__(self, max_bleu_order=4, bleu_smoothing=True): 36 | self.max_bleu_order = max_bleu_order 37 | self.bleu_smoothing = bleu_smoothing 38 | 39 | def evaluate(self, y_true, y_pred): 40 | results = {} 41 | bleu_scores = [compute_bleu([[y.split()] for y in y_true], [y.split() for y in y_pred], 42 | max_order=bleu_order, smooth=self.bleu_smoothing)[0] 43 | for bleu_order in range(1, self.max_bleu_order + 1)] 44 | 45 | for bleu_order, bleu_score in enumerate(bleu_scores): 46 | results["bleu_" + str(bleu_order + 1)] = float("{0:.4f}".format(bleu_score)) 47 | return results 48 | 49 | 50 | class CombinedEvaluator: 51 | def __init__(self, labels=['yes', 'no', 'more', 'irrelevant'], accuracy_targets=['yes', 'no', 'irrelevant']): 52 | self.labels = labels 53 | self.accuracy_targets = accuracy_targets 54 | self.classification_evaluator = ClassificationEvaluator(labels=labels) 55 | self.more_evaluator = MoreEvaluator() 56 | 57 | def replace_follow_up_with_more(self, y_list): 58 | return [y.lower() if y.lower() in self.accuracy_targets else 'more' for y in y_list] 59 | 60 | def extract_follow_ups(self, y_true, y_pred): 61 | extracted = [(y_t, y_p) for (y_t, y_p) in zip(y_true, y_pred) if 62 | y_t.lower() not in self.labels and y_p.lower() not in self.labels] 63 | if extracted: 64 | return zip(*extracted) 65 | else: 66 | return [], [] 67 | 68 | def evaluate(self, y_true, y_pred): 69 | 70 | # Classification 71 | classification_y_true = self.replace_follow_up_with_more(y_true) 72 | classification_y_pred = self.replace_follow_up_with_more(y_pred) 73 | results = self.classification_evaluator.evaluate(classification_y_true, classification_y_pred) 74 | 75 | # Follow Up Generation 76 | num_true_follow_ups = len([y_t for y_t in y_true if y_t.lower() not in self.labels]) 77 | num_pred_follow_ups = len([y_p for y_p in y_pred if y_p.lower() not in self.labels]) 78 | # print(f'{num_true_follow_ups} follow-ups in ground truth. {num_pred_follow_ups} follow-ups predicted | {len(generation_y_true)} follow-up questions used for BLEU evaluation.') 79 | generation_y_true, generation_y_pred = self.extract_follow_ups(y_true, y_pred) 80 | if generation_y_true and generation_y_pred: 81 | results.update(self.more_evaluator.evaluate(generation_y_true, generation_y_pred)) 82 | else: 83 | results.update({'bleu_{}'.format(i): 0. for i in range(1, 5)}) 84 | return results 85 | 86 | 87 | def prepro(text): 88 | doc = nlp(text, disable=['parser', 'tagger', 'ner']) 89 | result = "" 90 | for token in doc: 91 | orth = token.text 92 | if orth == "": 93 | result += " " 94 | elif orth == " ": 95 | result += " " 96 | else: 97 | result += orth.lower() + " " 98 | return result.strip().replace('\n', '') 99 | 100 | 101 | def _get_ngrams(segment, max_order): 102 | """Extracts all n-grams upto a given maximum order from an input segment. 103 | 104 | Args: 105 | segment: text segment from which n-grams will be extracted. 106 | max_order: maximum length in tokens of the n-grams returned by this 107 | methods. 108 | 109 | Returns: 110 | The Counter containing all n-grams upto max_order in segment 111 | with a count of how many times each n-gram occurred. 112 | """ 113 | ngram_counts = collections.Counter() 114 | for order in range(1, max_order + 1): 115 | for i in range(0, len(segment) - order + 1): 116 | ngram = tuple(segment[i:i + order]) 117 | ngram_counts[ngram] += 1 118 | return ngram_counts 119 | 120 | 121 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 122 | smooth=False): 123 | """Computes BLEU score of translated segments against one or more references. 124 | 125 | Args: 126 | reference_corpus: list of lists of references for each translation. Each 127 | reference should be tokenized into a list of tokens. 128 | translation_corpus: list of translations to score. Each translation 129 | should be tokenized into a list of tokens. 130 | max_order: Maximum n-gram order to use when computing BLEU score. 131 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 132 | 133 | Returns: 134 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 135 | precisions and brevity penalty. 136 | """ 137 | matches_by_order = [0] * max_order 138 | possible_matches_by_order = [0] * max_order 139 | reference_length = 0 140 | translation_length = 0 141 | for (references, translation) in zip(reference_corpus, 142 | translation_corpus): 143 | reference_length += min(len(r) for r in references) 144 | translation_length += len(translation) 145 | 146 | merged_ref_ngram_counts = collections.Counter() 147 | for reference in references: 148 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 149 | translation_ngram_counts = _get_ngrams(translation, max_order) 150 | overlap = translation_ngram_counts & merged_ref_ngram_counts 151 | for ngram in overlap: 152 | matches_by_order[len(ngram) - 1] += overlap[ngram] 153 | for order in range(1, max_order + 1): 154 | possible_matches = len(translation) - order + 1 155 | if possible_matches > 0: 156 | possible_matches_by_order[order - 1] += possible_matches 157 | 158 | precisions = [0] * max_order 159 | for i in range(0, max_order): 160 | if smooth: 161 | precisions[i] = ((matches_by_order[i] + 1.) / 162 | (possible_matches_by_order[i] + 1.)) 163 | else: 164 | if possible_matches_by_order[i] > 0: 165 | precisions[i] = (float(matches_by_order[i]) / 166 | possible_matches_by_order[i]) 167 | else: 168 | precisions[i] = 0.0 169 | 170 | if min(precisions) > 0: 171 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 172 | geo_mean = math.exp(p_log_sum) 173 | else: 174 | geo_mean = 0 175 | 176 | ratio = float(translation_length) / reference_length 177 | 178 | if ratio > 1.0: 179 | bp = 1. 180 | else: 181 | bp = math.exp(1 - 1. / ratio) 182 | 183 | bleu = geo_mean * bp 184 | 185 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 186 | 187 | 188 | def evaluate(gold_file, prediction_file, mode='follow_ups'): 189 | assert mode in ['', 'combined', 'follow_ups', 'classification'], "Mode not recognised" 190 | 191 | with open(gold_file, 'r') as f: 192 | ground_truths = json.load(f) 193 | 194 | with open(prediction_file, 'r') as f: 195 | predictions = json.load(f) 196 | 197 | # Check if all IDs are aligned 198 | # assert len(ground_truths) == len(predictions), "Predictions and ground truths have different sample sizes" 199 | 200 | ground_truth_map = {g["utterance_id"]: g for g in ground_truths} 201 | predictions_map = {p["utterance_id"]: p for p in predictions} 202 | for k in ground_truth_map: 203 | if k not in predictions_map: 204 | predictions_map[k] = {'utterance_id': k, 'answer': 'missing'} 205 | 206 | for gid in ground_truth_map: 207 | assert gid in predictions_map 208 | 209 | # Extract answers and prepro 210 | 211 | ground_truths = [] 212 | predictions = [] 213 | 214 | for uid in ground_truth_map.keys(): 215 | ground_truths.append(prepro(ground_truth_map[uid]['answer'])) 216 | predictions.append(prepro(predictions_map[uid]['answer'])) 217 | 218 | if mode == 'follow_ups': 219 | evaluator = MoreEvaluator() 220 | results = evaluator.evaluate(ground_truths, predictions) 221 | 222 | elif mode == 'classification': 223 | evaluator = ClassificationEvaluator(labels=['yes', 'no', 'more', 'irrelevant']) 224 | results = evaluator.evaluate(ground_truths, predictions) 225 | 226 | else: 227 | evaluator = CombinedEvaluator(labels=['yes', 'no', 'more', 'irrelevant']) 228 | results = evaluator.evaluate(ground_truths, predictions) 229 | 230 | return results 231 | 232 | 233 | if __name__ == '__main__': 234 | mode = 'combined' 235 | 236 | prediction_file = sys.argv[1] 237 | gold_file = sys.argv[2] 238 | 239 | results = evaluate(gold_file, prediction_file, mode=mode) 240 | print(results) 241 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from pprint import pprint 5 | from argparse import ArgumentParser 6 | from model.base import Module 7 | from preprocess_sharc import tokenize, make_tag, convert_to_ids, MAX_LEN, compute_metrics 8 | from editor_model.base import Module as EditorModule 9 | from preprocess_editor import trim_span 10 | 11 | 12 | def preprocess(data): 13 | for ex in data: 14 | ex['ann'] = a = { 15 | 'snippet': tokenize(ex['snippet']), 16 | 'question': tokenize(ex['question']), 17 | 'scenario': tokenize(ex['scenario']), 18 | 'hanswer': [{'yes': 1, 'no': 0}[h['follow_up_answer'].lower()] for h in ex['history']], 19 | 'hquestion': [tokenize(h['follow_up_question']) for h in ex['history']], 20 | } 21 | inp = [make_tag('[CLS]')] + a['question'] 22 | type_ids = [0] * len(inp) 23 | sep = make_tag('[SEP]') 24 | pointer_mask = [0] * len(inp) 25 | inp += [ 26 | sep, 27 | make_tag('classes'), 28 | make_tag('yes'), 29 | make_tag('no'), 30 | make_tag('irrelevant'), 31 | sep, 32 | make_tag('document'), 33 | ] 34 | pointer_mask += [0, 0, 1, 1, 1, 0, 0] 35 | snippet_start = len(inp) 36 | offset = len(inp) 37 | inp += a['snippet'] 38 | snippet_end = len(inp) 39 | pointer_mask += [1] * len(a['snippet']) # where can the answer pointer land 40 | inp += [sep] 41 | start = len(inp) 42 | inp += [make_tag('scenario')] + a['scenario'] + [sep] 43 | end = len(inp) 44 | scen_offsets = start, end 45 | inp += [make_tag('history')] 46 | hist_offsets = [] 47 | for hq, ha in zip(a['hquestion'], a['hanswer']): 48 | start = len(inp) 49 | inp += [make_tag('question')] + hq + [make_tag('answer'), [make_tag('yes'), make_tag('no')][ha]] 50 | end = len(inp) 51 | hist_offsets.append((start, end)) 52 | inp += [sep] 53 | type_ids += [1] * (len(inp) - len(type_ids)) 54 | input_ids = convert_to_ids(inp) 55 | input_mask = [1] * len(inp) 56 | pointer_mask += [0] * (len(inp) - len(pointer_mask)) 57 | 58 | if len(inp) > MAX_LEN: 59 | inp = inp[:MAX_LEN] 60 | input_mask = input_mask[:MAX_LEN] 61 | type_ids = type_ids[:MAX_LEN] 62 | input_ids = input_ids[:MAX_LEN] 63 | pointer_mask = pointer_mask[:MAX_LEN] 64 | pad = make_tag('pad') 65 | while len(inp) < MAX_LEN: 66 | inp.append(pad) 67 | input_mask.append(0) 68 | type_ids.append(0) 69 | input_ids.append(0) 70 | pointer_mask.append(0) 71 | 72 | assert len(inp) == len(input_mask) == len(type_ids) == len(input_ids) 73 | 74 | ex['feat'] = { 75 | 'inp': inp, 76 | 'snippet_start': snippet_start, 77 | 'snippet_end': snippet_end, 78 | 'input_ids': torch.LongTensor(input_ids), 79 | 'type_ids': torch.LongTensor(type_ids), 80 | 'input_mask': torch.LongTensor(input_mask), 81 | 'pointer_mask': torch.LongTensor(pointer_mask), 82 | 'hanswer': a['hanswer'], 83 | 'snippet_offset': offset, 84 | 'scen_offsets': scen_offsets, 85 | 'hist_offsets': hist_offsets, 86 | } 87 | return data 88 | 89 | 90 | def preprocess_editor(orig_data, preds): 91 | data = [] 92 | for orig_ex, pred in zip(orig_data, preds): 93 | if pred['answer'].lower() not in {'yes', 'no', 'irrelevant'}: 94 | s, e = pred['spans'][pred['retrieve_span']] 95 | sstart = orig_ex['feat']['snippet_start'] 96 | send = orig_ex['feat']['snippet_end'] 97 | s -= sstart 98 | e -= sstart 99 | if s < 0 or e < 0: 100 | continue 101 | snippet = orig_ex['feat']['inp'][sstart:send] 102 | s, e = trim_span(snippet, (s, e)) 103 | if e >= s: 104 | inp = [make_tag('[CLS]')] + snippet[s:e+1] + [make_tag('[SEP]')] 105 | # account for prepended tokens 106 | new_s, new_e = s + len(inp), e + len(inp) 107 | inp += snippet + [make_tag('[SEP]')] 108 | type_ids = [0] + [0] * (e+1-s) + [1] * (len(snippet) + 2) 109 | inp_ids = convert_to_ids(inp) 110 | inp_mask = [1] * len(inp) 111 | 112 | assert len(type_ids) == len(inp) == len(inp_ids) 113 | 114 | while len(inp_ids) < MAX_LEN: 115 | inp.append(make_tag('pad')) 116 | inp_ids.append(0) 117 | inp_mask.append(0) 118 | type_ids.append(0) 119 | 120 | if len(inp_ids) > MAX_LEN: 121 | inp = inp[:MAX_LEN] 122 | inp_ids = inp_ids[:MAX_LEN] 123 | inp_mask = inp_mask[:MAX_LEN] 124 | inp_mask[-1] = make_tag('[SEP]') 125 | type_ids = type_ids[:MAX_LEN] 126 | 127 | ex = { 128 | 'utterance_id': orig_ex['utterance_id'], 129 | 'span': (new_s, new_e), 130 | 'inp': inp, 131 | 'type_ids': torch.tensor(type_ids, dtype=torch.long), 132 | 'inp_ids': torch.tensor(inp_ids, dtype=torch.long), 133 | 'inp_mask': torch.tensor(inp_mask, dtype=torch.long), 134 | } 135 | data.append(ex) 136 | return data 137 | 138 | 139 | def merge_edits(preds, edits): 140 | # note: this happens in place 141 | edits = {p['utterance_id']: p for p in edits} 142 | for p in preds: 143 | p['orig_answer'] = p['answer'] 144 | if p['utterance_id'] in edits: 145 | p['answer'] = p['edit_answer'] = edits[p['utterance_id']]['answer'] 146 | return preds 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = ArgumentParser() 151 | parser.add_argument('--retrieval', required=True, help='retrieval model to use') 152 | parser.add_argument('--editor', help='editor model to use (optional)') 153 | parser.add_argument('--fin', default='sharc/json/sharc_dev.json', help='input data file') 154 | parser.add_argument('--dout', default=os.getcwd(), help='directory to store output files') 155 | parser.add_argument('--data', default='sharc/editor_disjoint', help='editor data') 156 | parser.add_argument('--verify', action='store_true', help='run evaluation') 157 | parser.add_argument('--force', action='store_true', help='overwrite retrieval predictions') 158 | args = parser.parse_args() 159 | 160 | if not os.path.isdir(args.dout): 161 | os.makedirs(args.dout) 162 | 163 | with open(args.fin) as f: 164 | raw = json.load(f) 165 | 166 | print('preprocessing data') 167 | data = preprocess(raw) 168 | 169 | fretrieval = os.path.join(args.dout, 'retrieval_preds.json') 170 | if os.path.isfile(fretrieval) and not args.force: 171 | print('loading {}'.format(fretrieval)) 172 | with open(fretrieval) as f: 173 | retrieval_preds = json.load(f) 174 | else: 175 | print('resuming retrieval from ' + args.retrieval) 176 | retrieval = Module.load(args.retrieval) 177 | retrieval.to(retrieval.device) 178 | retrieval_preds = retrieval.run_pred(data) 179 | with open(fretrieval, 'wt') as f: 180 | json.dump(retrieval_preds, f, indent=2) 181 | 182 | if args.verify: 183 | pprint(compute_metrics(retrieval_preds, raw)) 184 | 185 | if args.editor: 186 | editor_data = preprocess_editor(data, retrieval_preds) 187 | editor = EditorModule.load(args.editor, override_args={'data': args.data}) 188 | editor.to(editor.device) 189 | raw_editor_preds = editor.run_pred(editor_data) 190 | editor_preds = merge_edits(retrieval_preds, raw_editor_preds) 191 | 192 | with open(os.path.join(args.dout, 'editor_preds.json'), 'wt') as f: 193 | json.dump(editor_preds, f, indent=2) 194 | 195 | if args.verify: 196 | pprint(compute_metrics(editor_preds, raw)) 197 | -------------------------------------------------------------------------------- /list_exp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | import os 4 | import torch 5 | import tabulate 6 | from argparse import ArgumentParser 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument('--editor', action='store_true') 12 | parser.add_argument('--dsave', default='save') 13 | parser.add_argument('--force', '-f', action='store_true') 14 | args = parser.parse_args() 15 | 16 | rows = [] 17 | keys = ['epoch', 'dev_combined', 'dev_macro_accuracy', 'dev_micro_accuracy', 'dev_bleu_1', 'dev_bleu_4', 'dev_span_f1'] 18 | columns = ['name', 'epoch', 'combined', 'macro', 'micro', 'bleu1', 'bleu4', 'span_f1'] 19 | early = 'combined' 20 | if args.editor: 21 | keys = ['epoch', 'dev_f1'] 22 | columns = ['name', 'epoch', 'f1'] 23 | early = 'f1' 24 | for root, dirs, files in os.walk(args.dsave): 25 | if 'best.pt' in files: 26 | fbest = os.path.join(root, 'best.pt') 27 | fbest_json = os.path.join(root, 'best.json') 28 | if not os.path.isfile(fbest_json) or args.force: 29 | with open(fbest_json, 'wt') as f: 30 | metrics = torch.load(fbest, map_location='cpu')['metrics'] 31 | json.dump(metrics, f, indent=2) 32 | with open(fbest_json) as f: 33 | metrics = json.load(f) 34 | 35 | rows.append([root] + [metrics.get(k, 0) for k in keys]) 36 | rows.sort(key=lambda r: r[columns.index(early)]) 37 | print(tabulate.tabulate(rows, headers=columns)) 38 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import collections 4 | 5 | 6 | def normalize_answer(s): 7 | """Lower text and remove punctuation, articles and extra whitespace.""" 8 | 9 | def remove_articles(text): 10 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 11 | return re.sub(regex, ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def get_tokens(s): 27 | if not s: 28 | return [] 29 | return normalize_answer(s).split() 30 | 31 | 32 | def compute_exact(a_gold, a_pred): 33 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 34 | 35 | 36 | def compute_f1(a_gold, a_pred): 37 | gold_toks = get_tokens(a_gold) 38 | pred_toks = get_tokens(a_pred) 39 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 40 | num_same = sum(common.values()) 41 | if len(gold_toks) == 0 or len(pred_toks) == 0: 42 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 43 | return int(gold_toks == pred_toks) 44 | if num_same == 0: 45 | return 0 46 | precision = 1.0 * num_same / len(pred_toks) 47 | recall = 1.0 * num_same / len(gold_toks) 48 | f1 = (2 * precision * recall) / (precision + recall) 49 | return f1 50 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vzhong/e3/0c6b771b27463427db274802c4417355ddd90ed7/model/__init__.py -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import logging 5 | import importlib 6 | import numpy as np 7 | import json 8 | from tqdm import trange 9 | from pprint import pformat 10 | from collections import defaultdict 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from preprocess_sharc import detokenize, compute_metrics, BERT_MODEL 14 | from pytorch_pretrained_bert import BertModel, BertAdam 15 | from argparse import Namespace 16 | 17 | 18 | def warmup_linear(x, warmup=0.002): 19 | if x < warmup: 20 | return x/warmup 21 | return 1.0 - x 22 | 23 | 24 | DEVICE = torch.device('cpu') 25 | if torch.cuda.is_available() and torch.cuda.device_count(): 26 | DEVICE = torch.device('cuda') 27 | torch.cuda.manual_seed_all(0) 28 | 29 | 30 | class Module(nn.Module): 31 | 32 | def __init__(self, args, device=DEVICE): 33 | super().__init__() 34 | self.args = args 35 | self.device = device 36 | self.bert = BertModel.from_pretrained(BERT_MODEL, cache_dir=None) 37 | self.dropout = nn.Dropout(self.args.dropout) 38 | self.ans_scorer = nn.Linear(self.args.bert_hidden_size, 2) 39 | self.epoch = 0 40 | 41 | @classmethod 42 | def load_module(cls, name): 43 | return importlib.import_module('model.{}'.format(name)).Module 44 | 45 | @classmethod 46 | def load(cls, fname, override_args=None): 47 | load = torch.load(fname, map_location=lambda storage, loc: storage) 48 | args = vars(load['args']) 49 | if override_args: 50 | args.update(override_args) 51 | args = Namespace(**args) 52 | model = cls.load_module(args.model)(args) 53 | model.load_state_dict(load['state']) 54 | return model 55 | 56 | def save(self, metrics, dsave, early_stop): 57 | files = [os.path.join(dsave, f) for f in os.listdir(dsave) if f.endswith('.pt') and f != 'best.pt'] 58 | files.sort(key=lambda x: os.path.getmtime(x), reverse=True) 59 | if len(files) > self.args.keep-1: 60 | for f in files[self.args.keep-1:]: 61 | os.remove(f) 62 | 63 | fsave = os.path.join(dsave, 'ep{}-{}.pt'.format(metrics['epoch'], metrics[early_stop])) 64 | torch.save({ 65 | 'args': self.args, 66 | 'state': self.state_dict(), 67 | 'metrics': metrics, 68 | }, fsave) 69 | fbest = os.path.join(dsave, 'best.pt') 70 | if os.path.isfile(fbest): 71 | os.remove(fbest) 72 | shutil.copy(fsave, fbest) 73 | 74 | def create_input_tensors(self, batch): 75 | feat = { 76 | k: torch.stack([e['feat'][k] for e in batch], dim=0).to(self.device) 77 | for k in ['input_ids', 'type_ids', 'input_mask', 'pointer_mask'] 78 | } 79 | # for ex in batch: 80 | # s = ex['feat']['answer_start'] 81 | # e = ex['feat']['answer_end'] 82 | # print(s, ex['feat']['pointer_mask'][s]) 83 | # print(e, ex['feat']['pointer_mask'][e]) 84 | # import pdb; pdb.set_trace() 85 | return feat 86 | 87 | def score(self, enc): 88 | return self.ans_scorer(enc) 89 | 90 | def forward(self, batch): 91 | out = self.create_input_tensors(batch) 92 | out['bert_enc'], _ = bert_enc, _ = self.bert(out['input_ids'], out['type_ids'], out['input_mask'], output_all_encoded_layers=False) 93 | scores = self.score(self.dropout(bert_enc)) 94 | out['scores'] = self.mask_scores(scores, out['pointer_mask']) 95 | return out 96 | 97 | def mask_scores(self, scores, mask): 98 | invalid = 1 - mask 99 | scores -= invalid.unsqueeze(2).expand_as(scores).float().mul(1e20) 100 | return scores 101 | 102 | def get_top_k(self, probs, k): 103 | p = list(enumerate(probs.tolist())) 104 | p.sort(key=lambda tup: tup[1], reverse=True) 105 | return p[:k] 106 | 107 | def extract_preds(self, out, batch, top_k=20): 108 | scores = out['scores'] 109 | ystart, yend = scores.split(1, dim=-1) 110 | pstart = F.softmax(ystart.squeeze(-1), dim=1) 111 | pend = F.softmax(yend.squeeze(-1), dim=1) 112 | 113 | preds = [] 114 | for pstart_i, pend_i, ex in zip(pstart, pend, batch): 115 | top_start = self.get_top_k(pstart_i, top_k) 116 | top_end = self.get_top_k(pend_i, top_k) 117 | top_preds = [] 118 | for s, ps in top_start: 119 | for e, pe in top_end: 120 | if e >= s: 121 | top_preds.append((s, e, ps*pe)) 122 | top_preds = sorted(top_preds, key=lambda tup: tup[-1], reverse=True)[:top_k] 123 | top_answers = [(detokenize(ex['feat']['inp'][s:e+1]), s, e, p) for s, e, p in top_preds] 124 | top_ans, top_s, top_e, top_p = top_answers[0] 125 | preds.append({ 126 | 'utterance_id': ex['utterance_id'], 127 | 'top_k': top_answers, 128 | 'answer': top_ans, 129 | 'spans': [(top_s, top_e)], 130 | 'retrieve_span': 0, 131 | }) 132 | return preds 133 | 134 | def compute_loss(self, out, batch): 135 | scores = out['scores'] 136 | ystart, yend = scores.split(1, dim=-1) 137 | 138 | gstart = torch.tensor([e['feat']['answer_start'] for e in batch], dtype=torch.long, device=self.device) 139 | lstart = F.cross_entropy(ystart.squeeze(-1), gstart) 140 | 141 | gend = torch.tensor([e['feat']['answer_end'] for e in batch], dtype=torch.long, device=self.device) 142 | lend = F.cross_entropy(yend.squeeze(-1), gend) 143 | return {'start': lstart, 'end': lend} 144 | 145 | def compute_metrics(self, preds, data): 146 | preds = [{'utterance_id': p['utterance_id'], 'answer': p['top_k'][0][0]} for p in preds] 147 | return compute_metrics(preds, data) 148 | 149 | def run_pred(self, dev): 150 | preds = [] 151 | self.eval() 152 | for i in trange(0, len(dev), self.args.dev_batch, desc='batch'): 153 | batch = dev[i:i+self.args.dev_batch] 154 | out = self(batch) 155 | preds += self.extract_preds(out, batch) 156 | return preds 157 | 158 | def run_train(self, train, dev): 159 | if not os.path.isdir(self.args.dsave): 160 | os.makedirs(self.args.dsave) 161 | 162 | logger = logging.getLogger(self.__class__.__name__) 163 | logger.setLevel(logging.DEBUG) 164 | fh = logging.FileHandler(os.path.join(self.args.dsave, 'train.log')) 165 | fh.setLevel(logging.CRITICAL) 166 | logger.addHandler(fh) 167 | ch = logging.StreamHandler() 168 | ch.setLevel(logging.CRITICAL) 169 | logger.addHandler(ch) 170 | 171 | num_train_steps = int(len(train) / self.args.train_batch * self.args.epoch) 172 | 173 | # remove pooler 174 | param_optimizer = list(self.named_parameters()) 175 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 176 | 177 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 178 | optimizer_grouped_parameters = [ 179 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 180 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 181 | ] 182 | 183 | optimizer = BertAdam(optimizer_grouped_parameters, lr=self.args.learning_rate, warmup=self.args.warmup, t_total=num_train_steps) 184 | 185 | print('num_train', len(train)) 186 | print('num_dev', len(dev)) 187 | 188 | global_step = 0 189 | best_metrics = {self.args.early_stop: -float('inf')} 190 | for epoch in trange(self.args.epoch, desc='epoch'): 191 | self.epoch = epoch 192 | train = train[:] 193 | np.random.shuffle(train) 194 | 195 | stats = defaultdict(list) 196 | preds = [] 197 | self.train() 198 | for i in trange(0, len(train), self.args.train_batch, desc='batch'): 199 | batch = train[i:i+self.args.train_batch] 200 | out = self(batch) 201 | pred = self.extract_preds(out, batch) 202 | loss = self.compute_loss(out, batch) 203 | 204 | sum(loss.values()).backward() 205 | lr_this_step = self.args.learning_rate * warmup_linear(global_step/num_train_steps, self.args.warmup) 206 | for param_group in optimizer.param_groups: 207 | param_group['lr'] = lr_this_step 208 | optimizer.step() 209 | optimizer.zero_grad() 210 | global_step += 1 211 | 212 | for k, v in loss.items(): 213 | stats['loss_' + k].append(v.item()) 214 | preds += pred 215 | train_metrics = {k: sum(v) / len(v) for k, v in stats.items()} 216 | train_metrics.update(self.compute_metrics(preds, train)) 217 | 218 | stats = defaultdict(list) 219 | preds = self.run_pred(dev) 220 | dev_metrics = {k: sum(v) / len(v) for k, v in stats.items()} 221 | dev_metrics.update(self.compute_metrics(preds, dev)) 222 | 223 | metrics = {'epoch': epoch} 224 | metrics.update({'train_' + k: v for k, v in train_metrics.items()}) 225 | metrics.update({'dev_' + k: v for k, v in dev_metrics.items()}) 226 | logger.critical(pformat(metrics)) 227 | 228 | if metrics[self.args.early_stop] > best_metrics[self.args.early_stop]: 229 | logger.critical('Found new best! Saving to ' + self.args.dsave) 230 | best_metrics = metrics 231 | self.save(best_metrics, self.args.dsave, self.args.early_stop) 232 | with open(os.path.join(self.args.dsave, 'dev.preds.json'), 'wt') as f: 233 | json.dump(preds, f, indent=2) 234 | 235 | logger.critical('Best dev') 236 | logger.critical(pformat(best_metrics)) 237 | -------------------------------------------------------------------------------- /model/editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from model.entail import Module as Base 5 | from torch.nn.utils.rnn import pad_sequence 6 | from preprocess_sharc import CLASSES, detokenize 7 | 8 | 9 | class Decoder(nn.Module): 10 | 11 | def __init__(self, denc, emb, dropout=0): 12 | super().__init__() 13 | dhid = denc 14 | self.demb = emb.size(1) 15 | self.vocab_size = emb.size(0) 16 | self.emb = nn.Embedding(self.vocab_size, self.demb) 17 | self.emb.weight.data = emb 18 | 19 | self.dropout = nn.Dropout(dropout) 20 | 21 | self.attn_scorer = nn.Linear(denc, 1) 22 | 23 | self.rnn = nn.LSTMCell(denc+self.demb, dhid) 24 | 25 | self.proj = nn.Linear(denc+dhid, self.demb) 26 | 27 | self.emb0 = nn.Parameter(torch.Tensor(self.demb)) 28 | self.h0 = nn.Parameter(torch.Tensor(dhid)) 29 | self.c0 = nn.Parameter(torch.Tensor(dhid)) 30 | 31 | for p in [self.emb0, self.h0, self.c0]: 32 | nn.init.uniform_(p, -0.1, 0.1) 33 | 34 | def forward(self, enc, inp_mask, label, max_decode_len=30): 35 | max_t = label.size(1) if self.training else max_decode_len 36 | batch = enc.size(0) 37 | h_t = self.h0.repeat(batch, 1) 38 | c_t = self.c0.repeat(batch, 1) 39 | emb_t = self.emb0.repeat(batch, 1) 40 | 41 | outs = [] 42 | for t in range(max_t): 43 | h_t = self.dropout(h_t) 44 | # attend to input 45 | inp_score = enc.bmm(h_t.unsqueeze(2)).squeeze(2) - (1-inp_mask) * 1e20 46 | inp_score_norm = F.softmax(inp_score, dim=1) 47 | inp_attn = inp_score_norm.unsqueeze(2).expand_as(enc).mul(enc).sum(1) 48 | 49 | rnn_inp = self.dropout(torch.cat([inp_attn, emb_t], dim=1)) 50 | 51 | h_t, c_t = self.rnn(rnn_inp, (h_t, c_t)) 52 | 53 | proj_inp = self.dropout(torch.cat([inp_attn, h_t], dim=1)) 54 | proj = self.proj(proj_inp) 55 | 56 | out_t = proj.mm(self.emb.weight.t().detach()) 57 | outs.append(out_t) 58 | word_t = label[:, t] if self.training else out_t.max(1)[1] 59 | # get rid of -1's from unchosen spans 60 | word_t = torch.clamp(word_t, 0, self.vocab_size) 61 | emb_t = self.emb(word_t) 62 | outs = torch.stack(outs, dim=1) 63 | return outs 64 | 65 | 66 | class Module(Base): 67 | 68 | def __init__(self, args): 69 | super().__init__(args) 70 | vocab = torch.load('{}/vocab.pt'.format(args.data)) 71 | self.vocab = vocab['vocab'] 72 | self.decoder = Decoder( 73 | emb=vocab['emb'], 74 | denc=self.args.bert_hidden_size, 75 | dropout=self.args.dropout, 76 | ) 77 | 78 | def forward(self, batch): 79 | out = super().forward(batch) 80 | 81 | out['edit_scores'] = decs = [] 82 | out['edit_labels'] = labels = [] 83 | for ex, enc, spans in zip(batch, out['bert_enc'], out['spans']): 84 | inp = [enc[s:e+1] for s, e in spans] 85 | lens = [t.size(0) for t in inp] 86 | max_len = max(lens) 87 | mask = torch.tensor([[1] * l + [0] * (max_len-l) for l in lens], device=self.device, dtype=torch.float) 88 | inp = pad_sequence(inp, batch_first=True, padding_value=0) 89 | label = pad_sequence([torch.tensor(o, dtype=torch.long) for o in ex['edit_num']['out_vocab_id']], batch_first=True, padding_value=-1).to(self.device) if self.training else None 90 | dec = self.decoder(inp, mask, label) 91 | decs.append(dec) 92 | labels.append(label) 93 | return out 94 | 95 | def compute_loss(self, out, batch): 96 | loss = super().compute_loss(out, batch) 97 | edit_loss = 0 98 | for ex, dec in zip(batch, out['edit_scores']): 99 | label = pad_sequence([torch.tensor(o, dtype=torch.long) for o in ex['edit_num']['out_vocab_id']], batch_first=True, padding_value=-1).to(self.device) 100 | edit_loss += F.cross_entropy(dec.view(-1, dec.size(-1)), label.view(-1), ignore_index=-1) 101 | loss['edit'] = edit_loss / len(batch) * self.args.loss_editor_weight 102 | return loss 103 | 104 | def extract_preds(self, out, batch, top_k=20): 105 | preds = [] 106 | for ex, clf_i, retrieve_i, spans_i, edit_scores_i in zip(batch, out['clf_scores'].max(1)[1].tolist(), out['retrieve_scores'].max(1)[1].tolist(), out['spans'], out['edit_scores']): 107 | a = CLASSES[clf_i] 108 | edit_ids = edit_scores_i.max(2)[1].tolist() 109 | edits = [] 110 | for ids in edit_ids: 111 | words = self.vocab.index2word(ids) 112 | if 'eos' in words: 113 | words = words[:words.index('eos')] 114 | edits.append(' '.join(words)) 115 | r = None 116 | if a == 'more': 117 | s, e = spans_i[retrieve_i] 118 | r = detokenize(ex['feat']['inp'][s:e+1]) 119 | a = edits[retrieve_i] 120 | preds.append({ 121 | 'utterance_id': ex['utterance_id'], 122 | 'retrieval': r, 123 | 'answer': a, 124 | 'spans': spans_i, 125 | }) 126 | return preds 127 | -------------------------------------------------------------------------------- /model/entail.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.retrieve import Module as Base 3 | from model.span import Module as SpanModule 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn.utils.rnn import pad_sequence 7 | from preprocess_sharc import detokenize, CLASSES 8 | from metric import compute_f1 9 | from tqdm import trange 10 | 11 | 12 | class Module(Base): 13 | 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.span_attn_scorer = nn.Linear(self.args.bert_hidden_size, 1) 17 | self.span_retrieval_scorer = nn.Linear(self.args.bert_hidden_size+2, 1) 18 | self.inp_attn_scorer = nn.Linear(self.args.bert_hidden_size, 1) 19 | self.class_clf = nn.Linear(self.args.bert_hidden_size, len(CLASSES)) 20 | 21 | def compute_entailment(self, spans, ex): 22 | chunks = [detokenize(ex['feat']['inp'][s:e+1]) for s, e in spans] 23 | history = [0] * len(chunks) 24 | scenario = [0] * len(chunks) 25 | # history 26 | for i, c in enumerate(chunks): 27 | for q in ex['ann']['hquestion']: 28 | history[i] = max(history[i], compute_f1(c, detokenize(q))) 29 | scenario[i] = max(scenario[i], compute_f1(c, detokenize(ex['ann']['scenario']))) 30 | entail = torch.tensor([history, scenario], dtype=torch.float, device=self.device).t() 31 | return entail 32 | 33 | def forward(self, batch): 34 | out = SpanModule.forward(self, batch) 35 | if self.training: 36 | spans = out['spans'] = [ex['feat']['spans'] for ex in batch] 37 | else: 38 | spans = [[span[:2] for span in spans_i] for spans_i in self.extract_spans(out['span_scores'], batch)] 39 | spans = out['spans'] = [self.extract_bullets(s, ex) for s, ex in zip(spans, batch)] 40 | 41 | span_enc = [] 42 | out['entail'] = [] 43 | for h_i, spans_i, ex_i in zip(out['bert_enc'], spans, batch): 44 | span_h = [h_i[s:e+1] for s, e in spans_i] 45 | max_len = max([h.size(0) for h in span_h]) 46 | span_mask = torch.tensor([[1] * h.size(0) + [0] * (max_len-h.size(0)) for h in span_h], device=self.device, dtype=torch.float) 47 | span_h = pad_sequence(span_h, batch_first=True, padding_value=0) 48 | span_attn_mask = pad_sequence(span_mask, batch_first=True, padding_value=0) 49 | span_attn_score = self.span_attn_scorer(self.dropout(span_h)).squeeze(2) - (1-span_attn_mask).mul(1e20) 50 | span_attn = F.softmax(span_attn_score, dim=1).unsqueeze(2).expand_as(span_h).mul(self.dropout(span_h)).sum(1) 51 | span_entail = self.compute_entailment(spans_i, ex_i) 52 | out['entail'].append(span_entail) 53 | span_enc.append(torch.cat([span_attn, span_entail], dim=1)) 54 | max_len = max([h.size(0) for h in span_enc]) 55 | span_mask = torch.tensor([[1] * h.size(0) + [0] * (max_len-h.size(0)) for h in span_enc], device=self.device, dtype=torch.float) 56 | span_enc = pad_sequence(span_enc, batch_first=True, padding_value=0) 57 | out['retrieve_scores'] = self.span_retrieval_scorer(self.dropout(span_enc)).squeeze(2) - (1-span_mask).mul(1e20) 58 | 59 | inp_attn_score = self.inp_attn_scorer(self.dropout(out['bert_enc'])).squeeze(2) - (1-out['input_mask'].float()).mul(1e20) 60 | inp_attn = F.softmax(inp_attn_score, dim=1).unsqueeze(2).expand_as(out['bert_enc']).mul(self.dropout(out['bert_enc'])).sum(1) 61 | out['clf_scores'] = self.class_clf(self.dropout(inp_attn)) 62 | return out 63 | 64 | def extract_preds(self, out, batch, top_k=20): 65 | preds = super().extract_preds(out, batch, top_k=top_k) 66 | for ex, p, span_i, clf_i, retrieve_i, entail_i in zip(batch, preds, out['span_scores'], out['clf_scores'], out['retrieve_scores'], out['entail']): 67 | p['clf_scores'] = dict(list(zip(CLASSES, F.softmax(clf_i, dim=0).tolist()))) 68 | spans = [detokenize(ex['feat']['inp'][s:e+1]) for s, e in p['spans']] 69 | p['span_scores'] = dict(list(zip(spans, F.softmax(retrieve_i, dim=0).tolist()))) 70 | p['words'] = [w['sub'] for w in ex['feat']['inp'] if w['orig'] != 'pad'] 71 | p['og'] = {k: v for k, v in ex.items() if k in ['snippet', 'scenario', 'question', 'history', 'answer']} 72 | p['start_scores'] = span_i[:, 0].tolist() 73 | p['end_scores'] = span_i[:, 1].tolist() 74 | p['entail_hist_scores'] = dict(list(zip(spans, entail_i[:, 0].tolist()))) 75 | p['entail_scen_scores'] = dict(list(zip(spans, entail_i[:, 1].tolist()))) 76 | return preds 77 | -------------------------------------------------------------------------------- /model/retrieve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.span import Module as Base 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn.utils.rnn import pad_sequence 6 | from preprocess_sharc import detokenize, CLASSES, compute_metrics 7 | from metric import compute_f1 8 | 9 | 10 | class Module(Base): 11 | 12 | def __init__(self, args): 13 | super().__init__(args) 14 | self.span_attn_scorer = nn.Linear(self.args.bert_hidden_size, 1) 15 | self.span_retrieval_scorer = nn.Linear(self.args.bert_hidden_size, 1) 16 | self.inp_attn_scorer = nn.Linear(self.args.bert_hidden_size, 1) 17 | self.class_clf = nn.Linear(self.args.bert_hidden_size, len(CLASSES)) 18 | 19 | def extract_bullets(self, spans, ex): 20 | mask = ex['feat']['pointer_mask'].tolist() 21 | classes_start = mask.index(1) 22 | snippet_start = classes_start + 5 23 | snippet_end = snippet_start + mask[snippet_start:].index(0) 24 | bullet_inds = [i for i in range(snippet_start, snippet_end) if ex['feat']['inp'][i]['sub'] == '*'] 25 | if bullet_inds: 26 | bullets = [(s+1, e-1) for s, e in zip(bullet_inds, bullet_inds[1:] + [snippet_end]) if e-1 >= s+1] 27 | non_bullet_spans = [] 28 | for s, e in spans: 29 | gloss = detokenize(ex['feat']['inp']) 30 | if '*' not in gloss and '\n' not in gloss: 31 | non_bullet_spans.append((s, e)) 32 | all_spans = bullets + non_bullet_spans 33 | all_spans.sort(key=lambda tup: tup[1]-tup[0], reverse=True) 34 | covered = [False] * len(ex['feat']['inp']) 35 | keep = [] 36 | for s, e in all_spans: 37 | if not all(covered[s:e+1]): 38 | for i in range(s, e+1): 39 | covered[i] = True 40 | keep.append((s, e)) 41 | return keep 42 | else: 43 | return spans 44 | 45 | def forward(self, batch): 46 | out = super().forward(batch) 47 | if self.training: 48 | spans = out['spans'] = [ex['feat']['spans'] for ex in batch] 49 | else: 50 | spans = [[span[:2] for span in spans_i] for spans_i in self.extract_spans(out['span_scores'], batch)] 51 | spans = out['spans'] = [self.extract_bullets(s, ex) for s, ex in zip(spans, batch)] 52 | 53 | span_enc = [] 54 | for h_i, spans_i in zip(out['bert_enc'], spans): 55 | span_h = [h_i[s:e+1] for s, e in spans_i] 56 | max_len = max([h.size(0) for h in span_h]) 57 | span_mask = torch.tensor([[1] * h.size(0) + [0] * (max_len-h.size(0)) for h in span_h], device=self.device, dtype=torch.float) 58 | span_h = pad_sequence(span_h, batch_first=True, padding_value=0) 59 | span_attn_mask = pad_sequence(span_mask, batch_first=True, padding_value=0) 60 | span_attn_score = self.span_attn_scorer(self.dropout(span_h)).squeeze(2) - (1-span_attn_mask).mul(1e20) 61 | span_attn = F.softmax(span_attn_score, dim=1).unsqueeze(2).expand_as(span_h).mul(self.dropout(span_h)).sum(1) 62 | span_enc.append(span_attn) 63 | max_len = max([h.size(0) for h in span_enc]) 64 | span_mask = torch.tensor([[1] * h.size(0) + [0] * (max_len-h.size(0)) for h in span_enc], device=self.device, dtype=torch.float) 65 | span_enc = pad_sequence(span_enc, batch_first=True, padding_value=0) 66 | out['retrieve_scores'] = self.span_retrieval_scorer(self.dropout(span_enc)).squeeze(2) - (1-span_mask).mul(1e20) 67 | 68 | inp_attn_score = self.inp_attn_scorer(self.dropout(out['bert_enc'])).squeeze(2) - (1-out['input_mask'].float()).mul(1e20) 69 | inp_attn = F.softmax(inp_attn_score, dim=1).unsqueeze(2).expand_as(out['bert_enc']).mul(self.dropout(out['bert_enc'])).sum(1) 70 | out['clf_scores'] = self.class_clf(self.dropout(inp_attn)) 71 | return out 72 | 73 | def extract_preds(self, out, batch, top_k=20): 74 | preds = [] 75 | for ex, clf_i, retrieve_i, span_i in zip(batch, out['clf_scores'].max(1)[1].tolist(), out['retrieve_scores'].max(1)[1].tolist(), out['spans']): 76 | a = CLASSES[clf_i] 77 | if a == 'more': 78 | s, e = span_i[retrieve_i] 79 | a = detokenize(ex['feat']['inp'][s:e+1]) 80 | preds.append({ 81 | 'utterance_id': ex['utterance_id'], 82 | 'answer': a, 83 | 'spans': span_i, 84 | 'retrieve_span': retrieve_i, 85 | }) 86 | return preds 87 | 88 | def compute_metrics(self, preds, data): 89 | metrics = compute_metrics(preds, data) 90 | f1s = [] 91 | for p, ex in zip(preds, data): 92 | pspans = [detokenize(ex['feat']['inp'][s:e+1]) for s, e in p['spans']] 93 | gspans = [detokenize(ex['feat']['inp'][s:e+1]) for s, e in ex['feat']['spans']] 94 | f1s.append(compute_f1('\n'.join(gspans), '\n'.join(pspans))) 95 | metrics['span_f1'] = sum(f1s) / len(f1s) 96 | return metrics 97 | 98 | def compute_loss(self, out, batch): 99 | gclf = torch.tensor([ex['feat']['answer_class'] for ex in batch], device=self.device, dtype=torch.long) 100 | gretrieve = torch.tensor([ex['feat']['answer_span'] for ex in batch], device=self.device, dtype=torch.long) 101 | loss = { 102 | 'clf': F.cross_entropy(out['clf_scores'], gclf), 103 | 'retrieve': F.cross_entropy(out['retrieve_scores'], gretrieve, ignore_index=-1), 104 | } 105 | loss['span_start'], loss['span_end'] = self.get_span_loss(out, batch) 106 | loss['span_start'] *= self.args.loss_span_weight 107 | loss['span_end'] *= self.args.loss_span_weight 108 | return loss 109 | -------------------------------------------------------------------------------- /model/span.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.base import Module as Base 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from preprocess_sharc import detokenize 6 | from metric import compute_f1 7 | 8 | 9 | class Module(Base): 10 | 11 | def __init__(self, args): 12 | super().__init__(args) 13 | self.span_scorer = nn.Linear(self.args.bert_hidden_size, 2) 14 | 15 | def forward(self, batch): 16 | out = super().forward(batch) 17 | span_scores = self.span_scorer(self.dropout(out['bert_enc'])) 18 | out['span_scores'] = self.mask_scores(span_scores, out['pointer_mask']).sigmoid() 19 | return out 20 | 21 | def extract_spans(self, span_scores, batch): 22 | pstart, pend = span_scores.split(1, dim=-1) 23 | spans = [] 24 | for pstart_i, pend_i, ex in zip(pstart.squeeze(-1), pend.squeeze(-1), batch): 25 | spans_i = [] 26 | sthresh = min(pstart_i.max(), self.args.thresh) 27 | start = pstart_i.ge(sthresh).tolist() 28 | for si, strig in enumerate(start): 29 | if strig: 30 | ethresh = min(pend_i[si:].max(), self.args.thresh) 31 | end = pend_i[si:].ge(ethresh).tolist() 32 | for ei, etrig in enumerate(end): 33 | ei += si 34 | if etrig: 35 | spans_i.append((si, ei, detokenize(ex['feat']['inp'][si:ei+1]), pstart_i[si].item(), pend_i[ei].item())) 36 | break 37 | spans.append(spans_i) 38 | return spans 39 | 40 | def extract_preds(self, out, batch, top_k=20): 41 | preds = super().extract_preds(out, batch, top_k=top_k) 42 | for p, s in zip(preds, self.extract_spans(out['span_scores'], batch)): 43 | p['spans'] = s 44 | return preds 45 | 46 | def compute_metrics(self, preds, data): 47 | metrics = super().compute_metrics(preds, data) 48 | f1s = [] 49 | for p, ex in zip(preds, data): 50 | pspans = [gloss for s, e, gloss, ps, pe in p['spans']] 51 | gspans = [detokenize(ex['feat']['inp'][s:e+1]) for s, e in ex['feat']['spans']] 52 | f1s.append(compute_f1('\n'.join(gspans), '\n'.join(pspans))) 53 | metrics['span_f1'] = sum(f1s) / len(f1s) 54 | return metrics 55 | 56 | def get_span_loss(self, out, batch): 57 | span_scores = out['span_scores'] 58 | ystart, yend = span_scores.split(1, dim=-1) 59 | 60 | gstart = [] 61 | gend = [] 62 | for ex in batch: 63 | gstart_i = [0] * len(ex['feat']['inp']) 64 | gend_i = [0] * len(ex['feat']['inp']) 65 | for s, e in ex['feat']['spans']: 66 | gstart_i[s] = 1 67 | gend_i[e] = 1 68 | gstart.append(gstart_i) 69 | gend.append(gend_i) 70 | gstart = torch.tensor(gstart, dtype=torch.float, device=self.device) 71 | gend = torch.tensor(gend, dtype=torch.float, device=self.device) 72 | 73 | lstart = F.binary_cross_entropy(ystart.squeeze(-1), gstart) 74 | lend = F.binary_cross_entropy(yend.squeeze(-1), gend) 75 | return lstart, lend 76 | 77 | def compute_loss(self, out, batch): 78 | loss = super().compute_loss(out, batch) 79 | loss['span_start'], loss['span_end'] = self.get_span_loss(out, batch) 80 | loss['span_start'] *= self.args.loss_span_weight 81 | loss['span_end'] *= self.args.loss_span_weight 82 | return loss 83 | -------------------------------------------------------------------------------- /preprocess_editor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import json 4 | import torch 5 | import embeddings 6 | import stanfordnlp 7 | from metric import compute_f1 8 | from vocab import Vocab 9 | from tqdm import tqdm 10 | from preprocess_sharc import detokenize, tokenizer, make_tag 11 | 12 | 13 | def get_orig(tokens): 14 | words = [] 15 | for i, t in enumerate(tokens): 16 | if t['orig_id'] is None or (i and t['orig_id'] == tokens[i-1]['orig_id']): 17 | continue 18 | else: 19 | words.append(t['orig'].strip().lower()) 20 | return words 21 | 22 | 23 | nlp = None 24 | 25 | 26 | def trim_span(snippet, span): 27 | global nlp 28 | if nlp is None: 29 | nlp = stanfordnlp.Pipeline(processors='tokenize,pos,lemma', models_dir='cache') 30 | bad_pos = {'DET', 'ADP', '#', 'AUX', 'SCONJ', 'CCONJ', 'PUNCT'} 31 | s, e = span 32 | words = nlp(' '.join([t['orig'] for t in snippet[s:e+1]])).sentences[0].words 33 | while words and words[0].upos in bad_pos: 34 | words = words[1:] 35 | s += 1 36 | while words and words[-1].upos in bad_pos: 37 | words.pop() 38 | e -= 1 39 | return s, e 40 | 41 | 42 | def create_split(trees, vocab, max_len=300, train=True): 43 | split = [] 44 | keys = sorted(list(trees.keys())) 45 | for k in tqdm(keys): 46 | v = trees[k] 47 | snippet = v['t_snippet'] 48 | for q_str, q_tok in v['questions'].items(): 49 | span = v['spans'][v['match'][q_str]] 50 | # trim the span a bit to facilitate editing 51 | s, e = trim_span(snippet, span) 52 | if e >= s: 53 | inp = [make_tag('[CLS]')] + snippet[s:e+1] + [make_tag('[SEP]')] 54 | # account for prepended tokens 55 | new_s, new_e = s + len(inp), e + len(inp) 56 | inp += snippet + [make_tag('[SEP]')] 57 | type_ids = [0] + [0] * (e+1-s) + [1] * (len(snippet) + 2) 58 | inp_ids = tokenizer.convert_tokens_to_ids([t['sub'] for t in inp]) 59 | inp_mask = [1] * len(inp) 60 | 61 | assert len(type_ids) == len(inp) == len(inp_ids) 62 | 63 | while len(inp_ids) < max_len: 64 | inp.append(make_tag('pad')) 65 | inp_ids.append(0) 66 | inp_mask.append(0) 67 | type_ids.append(0) 68 | 69 | if len(inp_ids) > max_len: 70 | inp = inp[:max_len] 71 | inp_ids = inp_ids[:max_len] 72 | inp_mask = inp_mask[:max_len] 73 | inp_mask[-1] = make_tag('[SEP]') 74 | type_ids = type_ids[:max_len] 75 | 76 | out = get_orig(q_tok) 77 | if train: 78 | out_vids = torch.tensor(vocab.word2index(out + ['eos'], train=train), dtype=torch.long) 79 | else: 80 | out_vids = None 81 | 82 | ex = { 83 | 'utterance_id': len(split), 84 | 'question': q_tok, 85 | 'span': (new_s, new_e), 86 | 'inp': inp, 87 | 'type_ids': torch.tensor(type_ids, dtype=torch.long), 88 | 'inp_ids': torch.tensor(inp_ids, dtype=torch.long), 89 | 'inp_mask': torch.tensor(inp_mask, dtype=torch.long), 90 | 'out': out, 91 | 'out_vids': out_vids, 92 | } 93 | split.append(ex) 94 | return split 95 | 96 | 97 | def segment(ex, vocab, threshold=0.25): 98 | s, e = ex['span'] 99 | span = ex['inp'][s:e+1] 100 | span_str = detokenize(span) 101 | ques = ex['question'] 102 | 103 | best_i, best_j, best_score = None, None, -1 104 | for i in range(len(ques)): 105 | for j in range(i, len(ques)): 106 | chunk = detokenize(ques[i:j+1]) 107 | score = compute_f1(span_str, chunk) 108 | if score > best_score: 109 | best_score, best_i, best_j = score, i, j 110 | if best_score > threshold: 111 | before = ex['question'][:best_i] 112 | after = ex['question'][best_j+1:] 113 | ret = { 114 | 'before': get_orig(before), 115 | 'after': get_orig(after), 116 | } 117 | ret.update({ 118 | k + '_vids': torch.tensor(vocab.word2index(v + ['eos']), dtype=torch.long) 119 | for k, v in ret.items() 120 | }) 121 | return ret 122 | else: 123 | return None 124 | 125 | 126 | if __name__ == '__main__': 127 | import joblib 128 | vocab = Vocab() 129 | with open('sharc/trees_train.json') as f: 130 | train_trees = json.load(f) 131 | with open('sharc/trees_dev.json') as f: 132 | dev_trees = json.load(f) 133 | dout = 'sharc/editor_disjoint' 134 | if not os.path.isdir(dout): 135 | os.makedirs(dout) 136 | 137 | print('Flattening train') 138 | train = create_split(train_trees, vocab) 139 | print('Flattening dev') 140 | dev = create_split(dev_trees, vocab) 141 | 142 | par = joblib.Parallel(12) 143 | print('Segmenting train') 144 | train_ba = par(joblib.delayed(segment)(ex, vocab) for ex in tqdm(train)) 145 | 146 | train_filtered = [] 147 | for ex, ba in zip(train, train_ba): 148 | if ba: 149 | ex.update(ba) 150 | train_filtered.append(ex) 151 | 152 | print('filtered train from {} to {}'.format(len(train), len(train_filtered))) 153 | print('vocab size {}'.format(len(vocab))) 154 | 155 | emb = embeddings.ConcatEmbedding([embeddings.GloveEmbedding(), embeddings.KazumaCharEmbedding()], default='zero') 156 | mat = torch.Tensor([emb.emb(w) for w in vocab._index2word]) 157 | torch.save({'vocab': vocab, 'emb': mat}, dout + '/vocab.pt') 158 | torch.save(train_filtered, dout + '/proc_train.pt') 159 | torch.save(dev, dout + '/proc_dev.pt') 160 | -------------------------------------------------------------------------------- /preprocess_sharc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import editdistance 4 | import torch 5 | import string 6 | import revtok 7 | import json 8 | from tempfile import NamedTemporaryFile 9 | from tqdm import tqdm 10 | from pprint import pprint 11 | from collections import defaultdict 12 | from pytorch_pretrained_bert.tokenization import BertTokenizer 13 | 14 | 15 | FORCE = True 16 | MAX_LEN = 300 17 | BERT_MODEL = 'cache/bert-base-uncased.tar.gz' 18 | BERT_VOCAB = 'cache/bert-base-uncased-vocab.txt' 19 | LOWERCASE = True 20 | tokenizer = BertTokenizer.from_pretrained(BERT_VOCAB, do_lower_case=LOWERCASE, cache_dir=None) 21 | MATCH_IGNORE = {'do', 'have', '?'} 22 | SPAN_IGNORE = set(string.punctuation) 23 | CLASSES = ['yes', 'no', 'irrelevant', 'more'] 24 | 25 | 26 | nlp = None 27 | 28 | 29 | def tokenize(doc): 30 | if not doc.strip(): 31 | return [] 32 | tokens = [] 33 | for i, t in enumerate(revtok.tokenize(doc)): 34 | subtokens = tokenizer.tokenize(t.strip()) 35 | for st in subtokens: 36 | tokens.append({ 37 | 'orig': t, 38 | 'sub': st, 39 | 'orig_id': i, 40 | }) 41 | return tokens 42 | 43 | 44 | def convert_to_ids(tokens): 45 | return tokenizer.convert_tokens_to_ids([t['sub'] for t in tokens]) 46 | 47 | 48 | def filter_answer(answer): 49 | return detokenize([a for a in answer if a['orig'] not in MATCH_IGNORE]) 50 | 51 | 52 | def filter_chunk(answer): 53 | return detokenize([a for a in answer if a['orig'] not in MATCH_IGNORE]) 54 | 55 | 56 | def detokenize(tokens): 57 | words = [] 58 | for i, t in enumerate(tokens): 59 | if t['orig_id'] is None or (i and t['orig_id'] == tokens[i-1]['orig_id']): 60 | continue 61 | else: 62 | words.append(t['orig']) 63 | return revtok.detokenize(words) 64 | 65 | 66 | def make_tag(tag): 67 | return {'orig': tag, 'sub': tag, 'orig_id': tag} 68 | 69 | 70 | def compute_metrics(preds, data): 71 | import evaluator 72 | with NamedTemporaryFile('w') as fp, NamedTemporaryFile('w') as fg: 73 | json.dump(preds, fp) 74 | fp.flush() 75 | json.dump([{'utterance_id': e['utterance_id'], 'answer': e['answer']} for e in data], fg) 76 | fg.flush() 77 | results = evaluator.evaluate(fg.name, fp.name, mode='combined') 78 | results['combined'] = results['macro_accuracy'] * results['bleu_4'] 79 | return results 80 | 81 | 82 | def get_span(context, answer): 83 | answer = filter_answer(answer) 84 | best, best_score = None, float('inf') 85 | stop = False 86 | for i in range(len(context)): 87 | if stop: 88 | break 89 | for j in range(i, len(context)): 90 | chunk = filter_chunk(context[i:j+1]) 91 | if '\n' in chunk or '*' in chunk: 92 | continue 93 | score = editdistance.eval(answer, chunk) 94 | if score < best_score or (score == best_score and j-i < best[1]-best[0]): 95 | best, best_score = (i, j), score 96 | if chunk == answer: 97 | stop = True 98 | break 99 | s, e = best 100 | while not context[s]['orig'].strip() or context[s]['orig'] in SPAN_IGNORE: 101 | s += 1 102 | while not context[e]['orig'].strip() or context[s]['orig'] in SPAN_IGNORE: 103 | e -= 1 104 | return s, e 105 | 106 | 107 | def get_bullets(context): 108 | indices = [i for i, c in enumerate(context) if c['sub'] == '*'] 109 | pairs = list(zip(indices, indices[1:] + [len(context)])) 110 | cleaned = [] 111 | for s, e in pairs: 112 | while not context[e-1]['sub'].strip(): 113 | e -= 1 114 | while not context[s]['sub'].strip() or context[s]['sub'] == '*': 115 | s += 1 116 | if e - s > 2 and e - 2 < 45: 117 | cleaned.append((s, e-1)) 118 | return cleaned 119 | 120 | 121 | def extract_clauses(data, tokenizer): 122 | snippet = data['snippet'] 123 | t_snippet = tokenize(snippet) 124 | questions = data['questions'] 125 | t_questions = [tokenize(q) for q in questions] 126 | 127 | spans = [get_span(t_snippet, q) for q in t_questions] 128 | bullets = get_bullets(t_snippet) 129 | all_spans = spans + bullets 130 | coverage = [False] * len(t_snippet) 131 | sorted_by_len = sorted(all_spans, key=lambda tup: tup[1] - tup[0], reverse=True) 132 | 133 | ok = [] 134 | for s, e in sorted_by_len: 135 | if not all(coverage[s:e+1]): 136 | for i in range(s, e+1): 137 | coverage[i] = True 138 | ok.append((s, e)) 139 | ok.sort(key=lambda tup: tup[0]) 140 | 141 | match = {} 142 | match_text = {} 143 | clauses = [None] * len(ok) 144 | for q, tq in zip(questions, t_questions): 145 | best_score = float('inf') 146 | best = None 147 | for i, (s, e) in enumerate(ok): 148 | score = editdistance.eval(detokenize(tq), detokenize(t_snippet[s:e+1])) 149 | if score < best_score: 150 | best_score, best = score, i 151 | clauses[i] = tq 152 | match[q] = best 153 | s, e = ok[best] 154 | match_text[q] = detokenize(t_snippet[s:e+1]) 155 | 156 | return {'questions': {q: tq for q, tq in zip(questions, t_questions)}, 'snippet': snippet, 't_snippet': t_snippet, 'spans': ok, 'match': match, 'match_text': match_text, 'clauses': clauses} 157 | 158 | 159 | if __name__ == '__main__': 160 | for split in ['dev', 'train']: 161 | fsplit = 'sharc_train' if split == 'train' else 'sharc_dev' 162 | with open('sharc/json/{}.json'.format(fsplit)) as f: 163 | data = json.load(f) 164 | ftree = 'sharc/trees_{}.json'.format(split) 165 | if not os.path.isfile(ftree) or FORCE: 166 | tasks = {} 167 | for ex in data: 168 | for h in ex['evidence']: 169 | if 'followup_question' in h: 170 | h['follow_up_question'] = h['followup_question'] 171 | h['follow_up_answer'] = h['followup_answer'] 172 | del h['followup_question'] 173 | del h['followup_answer'] 174 | if ex['tree_id'] in tasks: 175 | task = tasks[ex['tree_id']] 176 | else: 177 | task = tasks[ex['tree_id']] = {'snippet': ex['snippet'], 'questions': set()} 178 | for h in ex['history'] + ex['evidence']: 179 | task['questions'].add(h['follow_up_question']) 180 | if ex['answer'].lower() not in {'yes', 'no', 'irrelevant'}: 181 | task['questions'].add(ex['answer']) 182 | keys = sorted(list(tasks.keys())) 183 | vals = [extract_clauses(tasks[k], tokenizer) for k in tqdm(keys)] 184 | mapping = {k: v for k, v in zip(keys, vals)} 185 | with open(ftree, 'wt') as f: 186 | json.dump(mapping, f, indent=2) 187 | else: 188 | with open(ftree) as f: 189 | mapping = json.load(f) 190 | fproc = 'sharc/proc_{}.pt'.format(split) 191 | if not os.path.isfile(fproc) or FORCE: 192 | stats = defaultdict(list) 193 | for ex in data: 194 | ex_answer = ex['answer'].lower() 195 | m = mapping[ex['tree_id']] 196 | ex['ann'] = a = { 197 | 'snippet': m['t_snippet'], 198 | 'clauses': m['clauses'], 199 | 'question': tokenize(ex['question']), 200 | 'scenario': tokenize(ex['scenario']), 201 | 'answer': tokenize(ex['answer']), 202 | 'hanswer': [{'yes': 1, 'no': 0}[h['follow_up_answer'].lower()] for h in ex['history']], 203 | 'hquestion': [m['questions'][h['follow_up_question']] for h in ex['history']], 204 | 'hquestion_span': [m['match'][h['follow_up_question']] for h in ex['history']], 205 | 'hquestion_span_text': [m['match_text'][h['follow_up_question']] for h in ex['history']], 206 | 'sentailed': [m['questions'][h['follow_up_question']] for h in ex['evidence']], 207 | 'sentailed_span': [m['match'][h['follow_up_question']] for h in ex['evidence']], 208 | 'sentailed_span_text': [m['match_text'][h['follow_up_question']] for h in ex['evidence']], 209 | 'spans': m['spans'], 210 | } 211 | if ex_answer not in CLASSES: 212 | a['answer_span'] = m['match'][ex['answer']] 213 | a['answer_span_text'] = m['match_text'][ex['answer']] 214 | else: 215 | a['answer_span'] = None 216 | a['answer_span_text'] = None 217 | 218 | inp = [make_tag('[CLS]')] + a['question'] 219 | type_ids = [0] * len(inp) 220 | clf_indices = { 221 | 'yes': len(inp) + 2, 222 | 'no': len(inp) + 3, 223 | 'irrelevant': len(inp) + 4, 224 | } 225 | sep = make_tag('[SEP]') 226 | pointer_mask = [0] * len(inp) 227 | inp += [ 228 | sep, 229 | make_tag('classes'), 230 | make_tag('yes'), 231 | make_tag('no'), 232 | make_tag('irrelevant'), 233 | sep, 234 | make_tag('document'), 235 | ] 236 | pointer_mask += [0, 0, 1, 1, 1, 0, 0] 237 | offset = len(inp) 238 | spans = [(s+offset, e+offset) for s, e in a['spans']] 239 | inp += a['snippet'] 240 | pointer_mask += [1] * len(a['snippet']) # where can the answer pointer land 241 | inp += [sep] 242 | start = len(inp) 243 | inp += [make_tag('scenario')] + a['scenario'] + [sep] 244 | end = len(inp) 245 | scen_offsets = start, end 246 | inp += [make_tag('history')] 247 | hist_offsets = [] 248 | for hq, ha in zip(a['hquestion'], a['hanswer']): 249 | start = len(inp) 250 | inp += [make_tag('question')] + hq + [make_tag('answer'), [make_tag('yes'), make_tag('no')][ha]] 251 | end = len(inp) 252 | hist_offsets.append((start, end)) 253 | inp += [sep] 254 | type_ids += [1] * (len(inp) - len(type_ids)) 255 | input_ids = convert_to_ids(inp) 256 | input_mask = [1] * len(inp) 257 | pointer_mask += [0] * (len(inp) - len(pointer_mask)) 258 | 259 | if ex_answer in CLASSES: 260 | start = clf_indices[ex_answer] 261 | end = start 262 | clf = CLASSES.index(ex_answer) 263 | answer_span = -1 264 | else: 265 | answer_span = a['answer_span'] 266 | start, end = spans[answer_span] 267 | clf = CLASSES.index('more') 268 | 269 | # for s, e in spans: 270 | # print(detokenize(inp[s:e+1])) 271 | # print(detokenize(inp[start:end+1])) 272 | # print(ex_answer) 273 | # import pdb; pdb.set_trace() 274 | 275 | if len(inp) > MAX_LEN: 276 | inp = inp[:MAX_LEN] 277 | input_mask = input_mask[:MAX_LEN] 278 | type_ids = type_ids[:MAX_LEN] 279 | input_ids = input_ids[:MAX_LEN] 280 | pointer_mask = pointer_mask[:MAX_LEN] 281 | pad = make_tag('pad') 282 | while len(inp) < MAX_LEN: 283 | inp.append(pad) 284 | input_mask.append(0) 285 | type_ids.append(0) 286 | input_ids.append(0) 287 | pointer_mask.append(0) 288 | 289 | assert len(inp) == len(input_mask) == len(type_ids) == len(input_ids) 290 | 291 | ex['feat'] = { 292 | 'inp': inp, 293 | 'input_ids': torch.LongTensor(input_ids), 294 | 'type_ids': torch.LongTensor(type_ids), 295 | 'input_mask': torch.LongTensor(input_mask), 296 | 'pointer_mask': torch.LongTensor(pointer_mask), 297 | 'spans': spans, 298 | 'hanswer': a['hanswer'], 299 | 'hquestion_span': torch.LongTensor(a['hquestion_span']), 300 | 'sentailed_span': torch.LongTensor(a['sentailed_span']), 301 | 'answer_start': start, 302 | 'answer_end': end, 303 | 'answer_class': clf, 304 | 'answer_span': answer_span, 305 | 'snippet_offset': offset, 306 | 'scen_offsets': scen_offsets, 307 | 'hist_offsets': hist_offsets, 308 | } 309 | 310 | stats['snippet_len'].append(len(ex['ann']['snippet'])) 311 | stats['scenario_len'].append(len(ex['ann']['scenario'])) 312 | stats['history_len'].append(sum([len(q) + 3 for q in ex['ann']['hquestion']])) 313 | stats['question_len'].append(len(ex['ann']['question'])) 314 | stats['inp_len'].append(sum(input_mask)) 315 | for k, v in sorted(list(stats.items()), key=lambda tup: tup[0]): 316 | print(k) 317 | print('mean: {}'.format(sum(v) / len(v))) 318 | print('min: {}'.format(min(v))) 319 | print('max: {}'.format(max(v))) 320 | preds = [{'utterance_id': e['utterance_id'], 'answer': detokenize(e['feat']['inp'][e['feat']['answer_start']:e['feat']['answer_end']+1])} for e in data] 321 | pprint(compute_metrics(preds, data)) 322 | torch.save(data, fproc) 323 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.14.0 2 | stanfordnlp==0.1.1 3 | pytorch-pretrained-bert==0.4.0 4 | torch==1.0.0 5 | numpy==1.14.2 6 | ujson==1.35 7 | embeddings==0.0.6 8 | vocab==0.0.4 9 | -------------------------------------------------------------------------------- /train_editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from argparse import ArgumentParser 5 | from editor_model.base import Module 6 | from pprint import pprint 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument('--train_batch', default=10, type=int) 12 | parser.add_argument('--dev_batch', default=5, type=int) 13 | parser.add_argument('--epoch', default=20, type=int) 14 | parser.add_argument('--keep', default=2, type=int) 15 | parser.add_argument('--seed', default=3, type=int) 16 | parser.add_argument('--learning_rate', default=5e-5, type=float) 17 | parser.add_argument('--dropout', default=0.4, type=float) 18 | parser.add_argument('--warmup', default=0.1, type=float) 19 | parser.add_argument('--thresh', default=0.5, type=float) 20 | parser.add_argument('--debug', action='store_true') 21 | parser.add_argument('--dsave', default='editor_save/{}') 22 | parser.add_argument('--model', default='double') 23 | parser.add_argument('--prefix', default='default') 24 | parser.add_argument('--early_stop', default='dev_f1') 25 | parser.add_argument('--bert_hidden_size', default=768, type=int) 26 | parser.add_argument('--bert_model', default='bert-base-uncased') 27 | parser.add_argument('--data', default='sharc/editor_disjoint') 28 | parser.add_argument('--resume', default='') 29 | parser.add_argument('--test', action='store_true') 30 | 31 | args = parser.parse_args() 32 | args.dsave = args.dsave.format(args.prefix + '-' + args.model) 33 | # if args.model != 'base': 34 | # args.dsave += '/{}/{}'.format(args.loss_span_weight, args.loss_editor_weight) 35 | 36 | random.seed(args.seed) 37 | np.random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | if torch.cuda.is_available(): 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | limit = 10 if args.debug else None 44 | data = {k: torch.load('{}/proc_{}.pt'.format(args.data, k))[:limit] for k in ['dev', 'train']} 45 | 46 | if args.resume: 47 | print('resuming model from ' + args.resume) 48 | model = Module.load(args.resume) 49 | else: 50 | print('instanting model') 51 | model = Module.load_module(args.model)(args) 52 | 53 | model.to(model.device) 54 | 55 | if args.test: 56 | preds = model.run_pred(data['dev']) 57 | metrics = model.compute_metrics(preds, data['dev']) 58 | pprint(metrics) 59 | else: 60 | model.run_train(data['train'], data['dev']) 61 | -------------------------------------------------------------------------------- /train_sharc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from argparse import ArgumentParser 5 | from model.base import Module 6 | from pprint import pprint 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = ArgumentParser() 11 | parser.add_argument('--train_batch', default=10, type=int, help='training batch size') 12 | parser.add_argument('--dev_batch', default=5, type=int, help='dev batch size') 13 | parser.add_argument('--epoch', default=5, type=int, help='number of epochs') 14 | parser.add_argument('--keep', default=2, type=int, help='number of model saves to keep') 15 | parser.add_argument('--seed', default=3, type=int, help='random seed') 16 | parser.add_argument('--learning_rate', default=5e-5, type=float, help='learning rate') 17 | parser.add_argument('--dropout', default=0.35, type=float, help='dropout rate') 18 | parser.add_argument('--warmup', default=0.1, type=float, help='optimizer warmup') 19 | parser.add_argument('--thresh', default=0.5, type=float, help='rule extraction threshold') 20 | parser.add_argument('--loss_span_weight', default=400., type=float, help='span loss weight') 21 | parser.add_argument('--loss_editor_weight', default=1., type=float, help='editor loss weight') 22 | parser.add_argument('--debug', action='store_true', help='debug flag to load less data') 23 | parser.add_argument('--dsave', default='save/{}', help='save directory') 24 | parser.add_argument('--model', default='entail', help='model to use') 25 | parser.add_argument('--early_stop', default='dev_combined', help='early stopping metric') 26 | parser.add_argument('--bert_hidden_size', default=768, type=int, help='hidden size for the bert model') 27 | parser.add_argument('--data', default='sharc', help='directory for data') 28 | parser.add_argument('--prefix', default='default', help='prefix for experiment name') 29 | parser.add_argument('--resume', default='', help='model .pt file') 30 | parser.add_argument('--test', action='store_true', help='only run evaluation') 31 | 32 | args = parser.parse_args() 33 | args.dsave = args.dsave.format(args.prefix+'-'+args.model) 34 | 35 | random.seed(args.seed) 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | 42 | limit = 10 if args.debug else None 43 | data = {k: torch.load('{}/proc_{}.pt'.format(args.data, k))[:limit] for k in ['dev', 'train']} 44 | 45 | if args.resume: 46 | print('resuming model from ' + args.resume) 47 | model = Module.load(args.resume) 48 | else: 49 | print('instanting model') 50 | model = Module.load_module(args.model)(args) 51 | 52 | model.to(model.device) 53 | 54 | if args.test: 55 | preds = model.run_pred(data['dev']) 56 | metrics = model.compute_metrics(preds, data['dev']) 57 | pprint(metrics) 58 | else: 59 | model.run_train(data['train'], data['dev']) 60 | --------------------------------------------------------------------------------