├── .gitignore ├── ACL 2022 Video talk.pdf ├── LICENSE ├── README.md ├── __init__.py ├── coreference_resolution.py ├── data ├── Annotation Guideline for _evaluating information-seeking conversations_.pdf ├── Qualification test for evaluation task.pdf └── human_annotation_data.json ├── figs ├── autorewrite.png └── example.png ├── interface.py ├── models ├── __init__.py ├── bert │ ├── __init__.py │ ├── interface.py │ ├── modeling.py │ ├── run_quac_dataset_utils.py │ └── run_quac_train.py ├── excord │ ├── __init__.py │ ├── interface.py │ ├── modeling_auto.py │ ├── modeling_roberta.py │ ├── quac.py │ └── quac_metrics.py ├── graphflow │ ├── __init__.py │ ├── interface.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── common.py │ │ └── graphs.py │ ├── model.py │ ├── model_handler.py │ ├── models │ │ ├── __init__.py │ │ └── graphflow.py │ ├── utils │ │ ├── __init__.py │ │ ├── bert_utils.py │ │ ├── constants.py │ │ ├── coqa │ │ │ ├── __init__.py │ │ │ └── eval_utils.py │ │ ├── data_utils.py │ │ ├── doqa │ │ │ ├── __init__.py │ │ │ └── eval_utils.py │ │ ├── eval_utils.py │ │ ├── generic_utils.py │ │ ├── io_utils.py │ │ ├── logger.py │ │ ├── process_utils.py │ │ ├── quac │ │ │ ├── __init__.py │ │ │ └── eval_utils.py │ │ ├── radam.py │ │ └── timer.py │ └── word_model.py └── ham │ ├── __init__.py │ ├── cqa_gen_batches.py │ ├── cqa_model.py │ ├── cqa_rl_supports.py │ ├── cqa_supports.py │ ├── interface.py │ ├── modeling.py │ ├── optimization.py │ ├── reindent.py │ ├── scorer.py │ └── tokenization.py ├── requirements.txt ├── run.sh ├── run_quac_eval.py └── run_quac_eval_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ACL 2022 Video talk.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/ACL 2022 Video talk.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ditch the Gold Standard: Re-evaluating Conversational Question Answering 2 | This is the repository for our ACL'2022 paper [Ditch the Gold Standard: Re-evaluating Conversational Question Answering](https://arxiv.org/pdf/2112.08812.pdf). The slides for our ACL presentation can be found [here](https://github.com/princeton-nlp/EvalConvQA/blob/main/ACL%202022%20Video%20talk.pdf). 3 | 4 | ## Quick links 5 | * [Overview](#Overview) 6 | * [Human Evaluation Dataset](#Human-Evaluation-Dataset) 7 | * [Automatic model evaluation interface](#Automatic-model-evaluation-interface) 8 | * [Setup](#Setup) 9 | * [Install dependencies](#Install-dependencies) 10 | * [Download the datasets](#Download-the-datasets) 11 | * [Evaluating existing models](#Evaluating-existing-models) 12 | * [BERT](#BERT) 13 | * [GraphFlow](#GraphFlow) 14 | * [HAM](#HAM) 15 | * [ExCorD](#ExCorD) 16 | * [Evaluating your own model](#Evaluating-your-own-model) 17 | * [Citation](#Citation) 18 | 19 | ## Overview 20 | 21 | In this work, we conduct the first large-scale human evaluation of state-of-the-art conversational QA systems. In our evaluation, human annotators chat with conversational QA models about passages from the [QuAC](https://quac.ai) development set, and after that the annotators judge the correctness of model answers. We release the human annotated dataset in the following section. 22 | 23 | We also identify a critical issue with the current automatic evaluation, which pre-collectes human-human conversations and uses ground-truth answers as conversational history (differences between different evaluations are shown in the following figure). By comparison, we find that the automatic evaluation does not always agree with the human evaluation. We propose a new evaluation protocol that is based on predicted history and question rewriting. Our experiments show that the new protocol better reflects real-world performance compared to the original automatic evaluation. We also provide the new evaluation protocol code in the following. 24 | 25 | ![Different evaluation protocols](figs/example.png) 26 | 27 | ## Human Evaluation Dataset 28 | You can download the human annotation dataset from `data/human_annotation_data.json`. The json file is structured as follows: 29 | 30 | ``` 31 | {"data": 32 | [{ 33 | # The model evaluated. One of `bert4quac`, `graphflow`, `ham`, `excord` 34 | "model_name": "graphflow", 35 | 36 | # The passage used in this conversation. 37 | "context": "Azaria wrote and directed the 2004 short film Nobody's Perfect, ...", 38 | 39 | # The ID from the original QuAC dataset. 40 | "dialog_id": "C_f0555dd820d84564a189474bbfffd4a1_1_0", 41 | 42 | # The conversation, which contains a list of QA pairs. 43 | "qas": [{ 44 | 45 | # The number of the turn 46 | "turn_id": 0, 47 | 48 | # The question from the human annotator 49 | "question": "What is some voice work he's done?", 50 | 51 | # The answer from the model 52 | "answer": "Azaria wrote and directed the 2004 short film Nobody's Perfect,", 53 | 54 | # Whether the question is valid (annotated by our human annotator) 55 | "valid": "y", 56 | 57 | # Whether the question is answerable (annotated by our human annotator) 58 | "answerable": "y", 59 | 60 | # Whether the model's answer is correct (annotated by our human annotator) 61 | "correct": "y", 62 | 63 | # Human annotator selects an answer, ONLY IF they marked the answer as incorrect 64 | "gold_anno": ["Azaria wrote and directed ..."] 65 | }, 66 | ... 67 | ] 68 | }, 69 | ... 70 | ] 71 | ``` 72 | 73 | ## Automatic model evaluation interface 74 | 75 | We provide a convenient interface to test model performance on a few evaluation protocols compared in our paper, including `Auto-Pred`, `Auto-Replace` and our proposed evaluation protocol, `Auto-Rewrite`, which better demonstrates models' performance in human-model conversations. Please refer to our paper for more details. Following is a figure describing how Auto-Rewrite works. 76 | 77 | ![Auto-rewrite](figs/autorewrite.png) 78 | 79 | ## Setup 80 | 81 | ### Install dependencies 82 | 83 | Please install all dependency packages using the following command: 84 | ```bash 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | ### Download the datasets 89 | 90 | Our experiments use [QuAC dataset](https://quac.ai) for passages and conversations, and the test set of [CANARD dataset](https://sites.google.com/view/qanta/projects/canard) for context-independent questions in `Auto-Replace`. 91 | 92 | ## Evaluating existing models 93 | 94 | We provide our implementations for the four models that we used in our paper: BERT, [GraphFlow](https://www.ijcai.org/Proceedings/2020/171), [HAM](https://dl.acm.org/doi/abs/10.1145/3357384.3357905), [ExCorD](https://aclanthology.org/2021.acl-long.478/). We modified exisiting implementation online to use model predictions as conversation history. Below are the instructions to run evaluation script on each of these models. 95 | 96 | ### BERT 97 | We implemented and trained our own BERT model. 98 | ```bash 99 | # Run Training 100 | python run_quac_train.py \ 101 | --type bert \ 102 | --model_name_or_path bert-base-uncased \ 103 | --do_train \ 104 | --output_dir ${directory_to_save_model} \ 105 | --overwrite_output_dir \ 106 | --train_file ${path_to_quac_train_file} \ 107 | --train_batch_size 8 \ 108 | --gradient_accumulation_steps 4 \ 109 | --max_seq_length 512 \ 110 | --learning_rate 3e-5 \ 111 | --history_len 2 \ 112 | --warmup_proportion 0.1 \ 113 | --max_grad_norm -1 \ 114 | --weight_decay 0.01 \ 115 | --rationale_beta 0 \ # important for BERT 116 | 117 | # Run Evaluation (Auto-Rewrite as example) 118 | python run_quac_eval.py \ 119 | --type bert \ 120 | --output_dir ${directory-to-model-checkpoint} \ 121 | --write_dir ${directory-to-write-evaluation-result} \ 122 | --predict_file val_v0.2.json \ 123 | --max_seq_length 512 \ 124 | --doc_stride 128 \ 125 | --max_query_length 64 \ 126 | --match_metric f1 \ 127 | --add_background \ 128 | --skip_entity \ 129 | --rewrite \ 130 | --start_i ${index_of_first_passage_to_eval} \ 131 | --end_i ${index_of_last_passage_to_eval_exclusive} \ 132 | ``` 133 | 134 | 135 | ### GraphFlow 136 | We did not find an uploaded model checkpoint so we trained our own using [their training script](https://github.com/hugochan/GraphFlow). 137 | ```bash 138 | 139 | # Download Stanford CoreNLP package 140 | wget https://nlp.stanford.edu/software/stanford-corenlp-latest.zip 141 | unzip stanford-corenlp-latest.zip 142 | rm -f stanford-corenlp-latest.zip 143 | 144 | # Start StanfordCoreNLP server 145 | java -mx4g -cp "${directory_to_standford_corenlp_package}" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 & 146 | 147 | # Run Evaluation (Auto-Rewrite as example) 148 | python run_quac_eval.py \ 149 | --type graphflow \ 150 | --predict_file ${path-to-annotated-dev-json-file} \ 151 | --output_dir ${directory-to-model-checkpoint} \ 152 | --saved_vocab_file ${directory-to-saved-model-vocab} \ 153 | --pretrained ${directory-to-model-checkpoint} \ 154 | --write_dir /n/fs/scratch/huihanl/unified/graphflow/write \ 155 | --match_metric f1 \ 156 | --add_background \ 157 | --skip_entity \ 158 | --rewrite \ 159 | --fix_vocab_embed \ 160 | --f_qem \ 161 | --f_pos \ 162 | --f_ner \ 163 | --use_ques_marker \ 164 | --use_gnn \ 165 | --temporal_gnn \ 166 | --use_bert \ 167 | --use_bert_weight \ 168 | --shuffle \ 169 | --out_predictions \ 170 | --predict_raw_text \ 171 | --out_pred_in_folder \ 172 | --optimizer adamax \ 173 | --start_i ${index_of_first_passage_to_eval} \ 174 | --end_i ${index_of_last_passage_to_eval_exclusive} \ 175 | ``` 176 | 177 | 178 | ### HAM 179 | The orgininal model checkpoint can be downloaded from [CodaLab](https://worksheets.codalab.org/rest/bundles/0x5c08cb0fb90c4afd8a2811bb63023cce/contents/blob/) 180 | 181 | ```bash 182 | # Run Evaluation (Auto-Rewrite as example) 183 | python run_quac_eval.py \ 184 | --type ham \ 185 | --output_dir ${directory-to-model-checkpoint} \ 186 | --write_dir ${directory-to-write-evaluation-result} \ 187 | --predict_file val_v0.2.json \ 188 | --max_seq_length 512 \ 189 | --doc_stride 128 \ 190 | --max_query_length 64 \ 191 | --do_lower_case \ 192 | --history_len 6 \ 193 | --match_metric f1 \ 194 | --add_background \ 195 | --skip_entity \ 196 | --replace \ 197 | --init_checkpoint ${directory-to-model-checkpoint}/model_52000.ckpt \ 198 | --bert_config_file ${directory-to-pretrained-bert-large-uncased}/bert_config.json \ 199 | --vocab_file ${directory-to-model-checkpoint}/vocab.txt \ 200 | --MTL_mu 0.8 \ 201 | --MTL_lambda 0.1 \ 202 | --mtl_input reduce_mean \ 203 | --max_answer_length 40 \ 204 | --max_considered_history_turns 4 \ 205 | --bert_hidden 1024 \ 206 | --fine_grained_attention \ 207 | --better_hae \ 208 | --MTL \ 209 | --use_history_answer_marker \ 210 | --start_i ${index_of_first_passage_to_eval} \ 211 | --end_i ${index_of_last_passage_to_eval_exclusive} \ 212 | ``` 213 | 214 | 215 | ### ExCorD 216 | The original model checkpoint can be downloaded from [their repo](https://drive.google.com/file/d/1Xf0-XUvGi7jgiAAdA5BQLk7p5ikc_wOl/view?usp=sharing) 217 | 218 | ```bash 219 | # Run Evaluation (Auto-Rewrite as example) 220 | python run_quac_eval.py \ 221 | --type excord \ 222 | --output_dir ${directory-to-model-checkpoint} \ 223 | --write_dir ${directory-to-write-evaluation-result} \ 224 | --predict_file val_v0.2.json \ 225 | --max_seq_length 512 \ 226 | --doc_stride 128 \ 227 | --max_query_length 64 \ 228 | --match_metric f1 \ 229 | --add_background \ 230 | --skip_entity \ 231 | --rewrite \ 232 | --start_i ${index_of_first_passage_to_eval} \ 233 | --end_i ${index_of_last_passage_to_eval_exclusive} \ 234 | ``` 235 | 236 | ## Evaluating your own model 237 | One can follow our existing implementations for the four models to implement evaluation for their own models. To do so, please add a directory under `models` and write a customized model class following the template `interface.py` and our example implementations. 238 | 239 | ## Citation 240 | 241 | ```bibtex 242 | @inproceedings{li2022ditch, 243 | title = "Ditch the Gold Standard: Re-evaluating Conversational Question Answering", 244 | author = "Li, Huihan and 245 | Gao, Tianyu and 246 | Goenka, Manan and 247 | Chen, Danqi", 248 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 249 | year = "2022", 250 | url = "https://aclanthology.org/2022.acl-long.555", 251 | pages = "8074--8085", 252 | } 253 | ``` 254 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/__init__.py -------------------------------------------------------------------------------- /coreference_resolution.py: -------------------------------------------------------------------------------- 1 | from allennlp.predictors.predictor import Predictor 2 | import allennlp_models.tagging 3 | from ncr.replace_corefs import resolve 4 | from collections import Counter 5 | import string 6 | import difflib 7 | 8 | from numpy import True_ 9 | 10 | predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz") 11 | PRONOUNS = { 12 | 'all', 'another', 'any', 'anybody', 'anyone', 'anything', 'as', 'aught', 'both', 'each other', 'each', 'either', 13 | 'enough', 'everybody', 'everyone', 'everything', 'few', 'he', 'her', 'hers', 'herself', 'him', 'himself', 'his', 14 | 'i', 'idem', 'it', 'its', 'itself', 'many', 'me', 'mine', 'most', 'my', 'myself', 'naught', 'neither', 'no one', 15 | 'nobody', 'none', 'nothing', 'nought', 'one another', 'one', 'other', 'others', 'ought', 'our', 'ours', 'ourself', 16 | 'ourselves', 'several', 'she', 'some', 'somebody', 'someone', 'something', 'somewhat', 'such', 'suchlike', 'that', 17 | 'thee', 'their', 'theirs', 'theirself', 'theirselves', 'them', 'themself', 'themselves', 'there', 'these', 'they', 18 | 'thine', 'this', 'those', 'thou', 'thy', 'thyself', 'us', 'we', 'what', 'whatever', 'whatnot', 'whatsoever', 19 | 'whence', 'where', 'whereby', 'wherefrom', 'wherein', 'whereinto', 'whereof', 'whereon', 'wheresoever', 'whereto', 20 | 'whereunto', 'wherever', 'wherewith', 'wherewithal', 'whether', 'which', 'whichever', 'whichsoever', 'who', 21 | 'whoever', 'whom', 'whomever', 'whomso', 'whomsoever', 'whose', 'whosesoever', 'whosever', 'whoso', 'whosoever', 22 | 'ye', 'yon', 'yonder', 'you', 'your', 'yours', 'yourself', 'yourselves' 23 | } 24 | ARTICLES = {'the','this','that','these','those','a','an'} 25 | 26 | def get_overlap(s1, s2): 27 | s = difflib.SequenceMatcher(None, s1, s2) 28 | pos_a, pos_b, size = s.find_longest_match(0, len(s1), 0, len(s2)) 29 | return s1[pos_a:pos_a+size], pos_a+size,pos_b+size 30 | 31 | def resolve_coreference(text = ""): 32 | 33 | result_dict = predictor.predict(document=text) 34 | 35 | text_formatted= { 36 | 'document': result_dict['document'], 37 | 'clusters': result_dict['clusters'] 38 | } 39 | 40 | resolved_toks = resolve(text_formatted['document'], text_formatted['clusters']) 41 | replaced_text = ' '.join(resolved_toks) 42 | return replaced_text 43 | 44 | def is_entity(word): 45 | 46 | tokens = word.split() 47 | tokens = [t for t in tokens if (t not in ARTICLES)] 48 | is_entity = True if all([t[0].isupper() for t in tokens]) else False 49 | return is_entity 50 | 51 | def find_coreference_f1s(text1="", text2="", skip_entity=True): 52 | 53 | result_dict1 = predictor.predict(document=text1) 54 | result_dict2 = predictor.predict(document=text2) 55 | text_formatted1 = { 56 | 'document': result_dict1['document'], 57 | 'clusters': result_dict1['clusters'] 58 | } 59 | text_formatted2 = { 60 | 'document': result_dict2['document'], 61 | 'clusters': result_dict2['clusters'] 62 | } 63 | 64 | print("Cluster1:",text_formatted1['clusters']) 65 | print("Cluster2:",text_formatted2['clusters']) 66 | q_start1 = max(idx for idx, val in enumerate( 67 | text_formatted1['document']) if val == '>')+1 68 | subs1 = [] 69 | for cluster in text_formatted1['clusters']: 70 | for r in cluster: 71 | if r[0] >= q_start1: 72 | subs1.append(cluster) 73 | break 74 | subs1.sort(key=lambda c: c[-1][0]) 75 | print("Subs1:",subs1) 76 | 77 | q_start2 = max(idx for idx, val in enumerate( 78 | text_formatted2['document']) if val == '>')+1 79 | subs2 = [] 80 | for cluster in text_formatted2['clusters']: 81 | for r in cluster: 82 | if r[0] >= q_start2: 83 | subs2.append(cluster) 84 | break 85 | 86 | subs2.sort(key=lambda c: c[-1][0]) 87 | print("Subs2:",subs2) 88 | 89 | nouns1 = [] 90 | for cluster in subs1: 91 | cluster_strings = list(map(lambda x: " ".join( 92 | text_formatted1['document'][x[0]:x[1]+1]), cluster)) 93 | if skip_entity and is_entity(cluster_strings[-1]): 94 | continue 95 | 96 | set_strings = set([s for s in cluster_strings if s.lower() not in PRONOUNS]) 97 | 98 | span_lens = list(map(len, set_strings)) 99 | head_span_idx = None 100 | for i, span_len in enumerate(span_lens): 101 | if span_len > 0: 102 | head_span_idx = i 103 | break 104 | if head_span_idx is None: 105 | nouns1.append(cluster_strings[0]) 106 | else: 107 | nouns1.append(list(set_strings)[head_span_idx]) 108 | 109 | nouns2 = [] 110 | for cluster in subs2: 111 | cluster_strings = list(map(lambda x: " ".join( 112 | text_formatted2['document'][x[0]:x[1]+1]), cluster)) 113 | if skip_entity and is_entity(cluster_strings[-1]): 114 | continue 115 | 116 | set_strings = set([s for s in cluster_strings if s.lower() not in PRONOUNS]) 117 | 118 | span_lens = list(map(len, set_strings)) 119 | head_span_idx = None 120 | for i, span_len in enumerate(span_lens): 121 | if span_len > 0: 122 | head_span_idx = i 123 | break 124 | if head_span_idx is None: 125 | nouns2.append(cluster_strings[0]) 126 | else: 127 | nouns2.append(list(set_strings)[head_span_idx]) 128 | 129 | resolved_toks1 = resolve( 130 | text_formatted1['document'], text_formatted1['clusters']) 131 | 132 | resolved_toks2 = resolve( 133 | text_formatted2['document'], text_formatted2['clusters']) 134 | 135 | f1s = [] 136 | 137 | def f1(lst1, lst2): 138 | common = Counter(lst1) & Counter(lst2) 139 | num_same = sum(common.values()) 140 | if num_same == 0: 141 | f1 = 0 142 | else: 143 | precision = 1.0 * num_same / len(lst2) 144 | recall = 1.0 * num_same / len(lst1) 145 | f1 = (2 * precision * recall) / (precision + recall) 146 | return f1 147 | 148 | if len(nouns1) == 0 or len(nouns2) == 0 or len(nouns1) != len(nouns2): 149 | f1s = [0] * max(len(nouns1), len(nouns2)) 150 | else: 151 | short_list, long_list = nouns1, nouns2 152 | for i in range(len(long_list)): 153 | max_f1 = 0 154 | for j in range(len(short_list)): 155 | f1_temp = f1(long_list[i].lower().split(), short_list[j].lower().split()) 156 | if f1_temp > max_f1: 157 | max_f1 = f1_temp 158 | f1s.append(max_f1) 159 | 160 | return f1s, " ".join(resolved_toks1), " ".join(resolved_toks2) 161 | 162 | if __name__ == "__main__": 163 | s1 = "Why did he fight the Dutch? Dutch colonial rule was becoming unpopular among local farmers because of tax rises, crop failures Is there any interesting information? Diponogoro was widely believed to be the Ratu Adil, the just ruler predicted in the Pralembang Jayabaya. What is the Pralembrang Jayabaya?" 164 | s2 = "Why did he fight the Dutch? due to their lack of coherent strategy and commitment in fighting Diponegoro's guerrilla warfare. Is there any interesting information? CANNOTANSWER What is the Pralembrang Jayabaya?" 165 | s3 = "How many shows did she do after her comeback? On 12 July 2012, Reddy returned to the musical stage at Croce's Jazz Bar in San Diego and for a benefit concert for the arts at St. Genevieve High School Did she perform anywhere after that? Reddy appeared in downtown Los Angeles at the 2017 Women's March on January 21. What did she sing at the Womens March?" 166 | s4 = "How many shows did she do after her comeback? CANNOTANSWER Did she perform anywhere after that? Reddy performed at the Paramount nightclub at The Crown & Anchor in Provincetown on 13 October 2013. What did she sing at the Womens March?" 167 | print(find_coreference_f1s(s3,s4)) 168 | 169 | 170 | -------------------------------------------------------------------------------- /data/Annotation Guideline for _evaluating information-seeking conversations_.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/data/Annotation Guideline for _evaluating information-seeking conversations_.pdf -------------------------------------------------------------------------------- /data/Qualification test for evaluation task.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/data/Qualification test for evaluation task.pdf -------------------------------------------------------------------------------- /figs/autorewrite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/figs/autorewrite.png -------------------------------------------------------------------------------- /figs/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/figs/example.png -------------------------------------------------------------------------------- /interface.py: -------------------------------------------------------------------------------- 1 | # import transformers dependencies 2 | 3 | # import model dependencies 4 | 5 | class QAModel(): 6 | def __init__(self, args): 7 | self.args = args 8 | 9 | # load model 10 | self.model = None 11 | 12 | # Initialize conversation history 13 | self.QA_history = [] 14 | 15 | def tokenizer(self): 16 | 17 | # load tokenizer 18 | self.tokenizer = None 19 | return self.tokenizer 20 | 21 | def load_partial_examples(self, file_name): 22 | 23 | # Pre-load partially filled examples from train/dev file 24 | # Loaded to a list of passages, each with a list of QAs 25 | # We will construct the augmented question later 26 | partial_examples = [] 27 | return partial_examples 28 | 29 | def predict_one_automatic_turn(self, partial_example, unique_id, example_idx): 30 | 31 | # Construct the augmented question here 32 | question = partial_example.question_text 33 | 34 | # Run prediction here. Your model might predict these fields. 35 | prediction_string = "" 36 | prediction_start = 0 37 | prediction_end = -1 38 | 39 | # Append predictions to QA history as your model will use it 40 | self.QA_history.append((example_idx, question, (prediction_string, prediction_start, prediction_end))) 41 | return prediction_string, unique_id + 1 42 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/models/__init__.py -------------------------------------------------------------------------------- /models/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/models/bert/__init__.py -------------------------------------------------------------------------------- /models/bert/interface.py: -------------------------------------------------------------------------------- 1 | from models.bert.modeling import BertForQuAC, RobertaForQuAC 2 | from transformers import AutoTokenizer 3 | from models.bert.run_quac_dataset_utils import read_partial_quac_examples_extern, read_one_quac_example_extern, convert_one_example_to_features, recover_predicted_answer, RawResult 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import CrossEntropyLoss 9 | from torch.utils.data import (DataLoader, RandomSampler, TensorDataset) 10 | 11 | class BertOrg(): 12 | def __init__(self, args): 13 | self.args = args 14 | self.model = BertForQuAC.from_pretrained(self.args.model_name_or_path) 15 | self.QA_history = [] 16 | torch.manual_seed(args.seed) 17 | self.device = torch.device("cuda" if torch.cuda.is_available() 18 | and not args.no_cuda else "cpu") 19 | self.model = self.model.to(self.device) 20 | 21 | def tokenizer(self): 22 | tokenizer = AutoTokenizer.from_pretrained( 23 | self.args.model_name_or_path, do_lower_case=self.args.do_lower_case) 24 | return tokenizer 25 | 26 | def load_partial_examples(self, partial_eval_examples_file): 27 | paragraphs = read_partial_quac_examples_extern(partial_eval_examples_file) 28 | return paragraphs 29 | 30 | def predict_one_automatic_turn(self, partial_example, unique_id, example_idx, tokenizer): 31 | question = partial_example.question_text 32 | turn = int(partial_example.qas_id.split("#")[1]) 33 | example = read_one_quac_example_extern(partial_example, self.QA_history, history_len=2, add_QA_tag=False) 34 | 35 | curr_eval_features, next_unique_id= convert_one_example_to_features(example=example, unique_id=unique_id, example_index=example_idx, tokenizer=tokenizer, max_seq_length=self.args.max_seq_length, 36 | doc_stride=self.args.doc_stride, max_query_length=self.args.max_query_length) 37 | all_input_ids = torch.tensor([f.input_ids for f in curr_eval_features], 38 | dtype=torch.long) 39 | all_input_mask = torch.tensor([f.input_mask for f in curr_eval_features], 40 | dtype=torch.long) 41 | all_segment_ids = torch.tensor([f.segment_ids for f in curr_eval_features], 42 | dtype=torch.long) 43 | all_feature_index = torch.arange(all_input_ids.size(0), 44 | dtype=torch.long) 45 | eval_data = TensorDataset(all_input_ids, all_input_mask, 46 | all_segment_ids, all_feature_index) 47 | # Run prediction for full data 48 | 49 | eval_dataloader = DataLoader(eval_data, 50 | sampler=None, 51 | batch_size=1) 52 | curr_results = [] 53 | # Run prediction for current example 54 | for input_ids, input_mask, segment_ids, feature_indices in eval_dataloader: 55 | 56 | input_ids = input_ids.to(self.device) 57 | input_mask = input_mask.to(self.device) 58 | segment_ids = segment_ids.to(self.device) 59 | print(type(input_ids[0]), type(input_mask[0]), type(segment_ids[0])) 60 | # Assume the logits are a list of one item 61 | with torch.no_grad(): 62 | batch_start_logits, batch_end_logits, batch_yes_logits, batch_no_logits, batch_unk_logits = self.model( 63 | input_ids, segment_ids, input_mask) 64 | for i, feature_index in enumerate(feature_indices): 65 | start_logits = batch_start_logits[i].detach().cpu().tolist() 66 | end_logits = batch_end_logits[i].detach().cpu().tolist() 67 | yes_logits = batch_yes_logits[i].detach().cpu().tolist() 68 | no_logits = batch_no_logits[i].detach().cpu().tolist() 69 | unk_logits = batch_unk_logits[i].detach().cpu().tolist() 70 | eval_feature = curr_eval_features[feature_index.item()] 71 | unique_id = int(eval_feature.unique_id) 72 | curr_results.append( 73 | RawResult(unique_id=unique_id, 74 | start_logits=start_logits, 75 | end_logits=end_logits, 76 | yes_logits=yes_logits, 77 | no_logits=no_logits, 78 | unk_logits=unk_logits)) 79 | predicted_answer = recover_predicted_answer( 80 | example=example, features=curr_eval_features, results=curr_results, tokenizer=tokenizer, n_best_size=self.args.n_best_size, max_answer_length=self.args.max_answer_length, 81 | do_lower_case=self.args.do_lower_case, verbose_logging=self.args.verbose_logging) 82 | self.QA_history.append((turn, question, (predicted_answer, None, None))) 83 | return predicted_answer, next_unique_id 84 | 85 | class RobertaOrg(): 86 | def __init__(self, args, device): 87 | self.args = args 88 | self.model = RobertaForQuAC.from_pretrained(self.args.model_name_or_path) 89 | self.QA_history = [] 90 | self.device = device 91 | 92 | def tokenizer(self): 93 | tokenizer = AutoTokenizer.from_pretrained( 94 | self.args.model_name_or_path, do_lower_case=self.args.do_lower_case) 95 | return tokenizer 96 | 97 | def load_partial_examples(self, cached_partial_eval_examples_file): 98 | paragraphs = read_partial_quac_examples_extern(cached_partial_eval_examples_file) 99 | return paragraphs 100 | 101 | def predict_one_human_turn(self, paragraph, question): 102 | return 103 | 104 | def predict_one_automatic_turn(self, partial_example, unique_id, example_idx, tokenizer): 105 | question = partial_example.question_text 106 | turn = int(partial_example.qas_id.split("#")[1]) 107 | example = read_one_quac_example_extern(partial_example, self.QA_history, history_len=2, add_QA_tag=False) 108 | 109 | curr_eval_features, next_unique_id= convert_one_example_to_features(example=example, unique_id=unique_id, example_index=example_idx, tokenizer=tokenizer, max_seq_length=self.args.max_seq_length, 110 | doc_stride=self.args.doc_stride, max_query_length=self.args.max_query_length) 111 | all_input_ids = torch.tensor([f.input_ids for f in curr_eval_features], 112 | dtype=torch.long) 113 | all_input_mask = torch.tensor([f.input_mask for f in curr_eval_features], 114 | dtype=torch.long) 115 | all_segment_ids = torch.tensor([f.segment_ids for f in curr_eval_features], 116 | dtype=torch.long) 117 | all_feature_index = torch.arange(all_input_ids.size(0), 118 | dtype=torch.long) 119 | eval_data = TensorDataset(all_input_ids, all_input_mask, 120 | all_segment_ids, all_feature_index) 121 | # Run prediction for full data 122 | 123 | eval_dataloader = DataLoader(eval_data, 124 | sampler=None, 125 | batch_size=1) 126 | curr_results = [] 127 | # Run prediction for current example 128 | for input_ids, input_mask, segment_ids, feature_indices in eval_dataloader: 129 | 130 | input_ids = input_ids.to(self.device) 131 | input_mask = input_mask.to(self.device) 132 | segment_ids = segment_ids.to(self.device) 133 | # Assume the logits are a list of one item 134 | with torch.no_grad(): 135 | batch_start_logits, batch_end_logits, batch_yes_logits, batch_no_logits, batch_unk_logits = self.model( 136 | input_ids, segment_ids, input_mask) 137 | for i, feature_index in enumerate(feature_indices): 138 | start_logits = batch_start_logits[i].detach().cpu().tolist() 139 | end_logits = batch_end_logits[i].detach().cpu().tolist() 140 | yes_logits = batch_yes_logits[i].detach().cpu().tolist() 141 | no_logits = batch_no_logits[i].detach().cpu().tolist() 142 | unk_logits = batch_unk_logits[i].detach().cpu().tolist() 143 | eval_feature = curr_eval_features[feature_index.item()] 144 | unique_id = int(eval_feature.unique_id) 145 | curr_results.append( 146 | RawResult(unique_id=unique_id, 147 | start_logits=start_logits, 148 | end_logits=end_logits, 149 | yes_logits=yes_logits, 150 | no_logits=no_logits, 151 | unk_logits=unk_logits)) 152 | predicted_answer = recover_predicted_answer( 153 | example=example, features=curr_eval_features, results=curr_results, n_best_size=self.args.n_best_size, max_answer_length=self.args.max_answer_length, 154 | do_lower_case=self.args.do_lower_case, verbose_logging=self.args.verbose_logging) 155 | self.QA_history.append((turn, question, (predicted_answer, None, None))) 156 | return predicted_answer, next_unique_id -------------------------------------------------------------------------------- /models/bert/modeling.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertPreTrainedModel, RobertaForQuestionAnswering 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import CrossEntropyLoss 5 | import random 6 | import torch 7 | 8 | class Multi_linear_layer(nn.Module): 9 | def __init__(self, 10 | n_layers, 11 | input_size, 12 | hidden_size, 13 | output_size, 14 | activation=None): 15 | super(Multi_linear_layer, self).__init__() 16 | self.linears = nn.ModuleList() 17 | self.linears.append(nn.Linear(input_size, hidden_size)) 18 | for _ in range(1, n_layers - 1): 19 | self.linears.append(nn.Linear(hidden_size, hidden_size)) 20 | self.linears.append(nn.Linear(hidden_size, output_size)) 21 | self.activation = getattr(F, activation) 22 | 23 | def forward(self, x): 24 | for linear in self.linears[:-1]: 25 | x = self.activation(linear(x)) 26 | linear = self.linears[-1] 27 | x = linear(x) 28 | return x 29 | 30 | class BertForQuAC(BertPreTrainedModel): 31 | def __init__( 32 | self, 33 | config, 34 | output_attentions=False, 35 | keep_multihead_output=False, 36 | n_layers=2, 37 | activation='relu', 38 | beta=100, 39 | ): 40 | super(BertForQuAC, self).__init__(config) 41 | self.output_attentions = output_attentions 42 | self.bert = BertModel(config) 43 | hidden_size = config.hidden_size 44 | self.rational_l = Multi_linear_layer(n_layers, hidden_size, 45 | hidden_size, 1, activation) 46 | self.logits_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 47 | 2, activation) 48 | self.unk_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 1, 49 | activation) 50 | self.attention_l = Multi_linear_layer(n_layers, hidden_size, 51 | hidden_size, 1, activation) 52 | self.yn_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 2, 53 | activation) 54 | self.beta = beta 55 | 56 | self.init_weights() 57 | 58 | def forward( 59 | self, 60 | input_ids, 61 | token_type_ids=None, 62 | attention_mask=None, 63 | start_positions=None, 64 | end_positions=None, 65 | rational_mask=None, 66 | cls_idx = None, 67 | head_mask=None, 68 | ): 69 | # mask some words on inputs_ids 70 | # if self.training and self.mask_p > 0: 71 | # batch_size = input_ids.size(0) 72 | # for i in range(batch_size): 73 | # len_c, len_qc = token_type_ids[i].sum( 74 | # dim=0).detach().item(), attention_mask[i].sum( 75 | # dim=0).detach().item() 76 | # masked_idx = random.sample(range(len_qc - len_c, len_qc), 77 | # int(len_c * self.mask_p)) 78 | # input_ids[i, masked_idx] = 100 79 | 80 | outputs = self.bert( 81 | input_ids, 82 | token_type_ids=token_type_ids, 83 | attention_mask=attention_mask, 84 | # output_all_encoded_layers=False, 85 | head_mask=head_mask, 86 | ) 87 | # print(outputs) 88 | if self.output_attentions: 89 | all_attentions, sequence_output, cls_outputs = outputs 90 | else: 91 | final_hidden=outputs.last_hidden_state 92 | pooled_output =outputs.pooler_output 93 | # print("Final_hidden:",final_hidden) 94 | rational_logits = self.rational_l(final_hidden) 95 | rational_logits = torch.sigmoid(rational_logits) 96 | 97 | final_hidden = final_hidden * rational_logits 98 | 99 | logits = self.logits_l(final_hidden) 100 | 101 | start_logits, end_logits = logits.split(1, dim=-1) 102 | 103 | start_logits, end_logits = start_logits.squeeze( 104 | -1), end_logits.squeeze(-1) 105 | 106 | segment_mask = token_type_ids.type(final_hidden.dtype) 107 | 108 | rational_logits = rational_logits.squeeze(-1) * segment_mask 109 | 110 | start_logits = start_logits * rational_logits 111 | 112 | end_logits = end_logits * rational_logits 113 | 114 | unk_logits = self.unk_l(pooled_output) 115 | 116 | attention = self.attention_l(final_hidden).squeeze(-1) 117 | 118 | attention.data.masked_fill_(attention_mask.eq(0), -float('inf')) 119 | 120 | attention = F.softmax(attention, dim=-1) 121 | 122 | attention_pooled_output = (attention.unsqueeze(-1) * 123 | final_hidden).sum(dim=-2) 124 | 125 | yn_logits = self.yn_l(attention_pooled_output) 126 | 127 | yes_logits, no_logits = yn_logits.split(1, dim=-1) 128 | 129 | start_logits.data.masked_fill_(attention_mask.eq(0), -float('inf')) 130 | end_logits.data.masked_fill_(attention_mask.eq(0), -float('inf')) 131 | 132 | new_start_logits = torch.cat( 133 | (yes_logits, no_logits, unk_logits, start_logits), dim=-1) 134 | new_end_logits = torch.cat( 135 | (yes_logits, no_logits, unk_logits, end_logits), dim=-1) 136 | 137 | if start_positions is not None and end_positions is not None: 138 | 139 | start_positions, end_positions = start_positions + cls_idx, end_positions + cls_idx 140 | 141 | # If we are on multi-GPU, split add a dimension 142 | if len(start_positions.size()) > 1: 143 | start_positions = start_positions.squeeze(-1) 144 | if len(end_positions.size()) > 1: 145 | end_positions = end_positions.squeeze(-1) 146 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 147 | ignored_index = new_start_logits.size(1) 148 | start_positions.clamp_(0, ignored_index) 149 | end_positions.clamp_(0, ignored_index) 150 | 151 | span_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 152 | 153 | start_loss = span_loss_fct(new_start_logits, start_positions) 154 | end_loss = span_loss_fct(new_end_logits, end_positions) 155 | 156 | # rational part 157 | alpha = 0.25 158 | gamma = 2. 159 | rational_mask = rational_mask.type(final_hidden.dtype) 160 | 161 | rational_loss = -alpha * ( 162 | (1 - rational_logits)**gamma 163 | ) * rational_mask * torch.log(rational_logits + 1e-7) - ( 164 | 1 - alpha) * (rational_logits**gamma) * ( 165 | 1 - rational_mask) * torch.log(1 - rational_logits + 1e-7) 166 | 167 | rational_loss = (rational_loss * 168 | segment_mask).sum() / segment_mask.sum() 169 | # end 170 | 171 | assert not torch.isnan(rational_loss) 172 | 173 | total_loss = (start_loss + 174 | end_loss) / 2 + rational_loss * self.beta 175 | return total_loss 176 | 177 | return start_logits, end_logits, yes_logits, no_logits, unk_logits 178 | 179 | 180 | class RobertaForQuAC(RobertaForQuestionAnswering): 181 | def __init__( 182 | self, 183 | config, 184 | output_attentions=False, 185 | keep_multihead_output=False, 186 | n_layers=2, 187 | activation='relu', 188 | beta=100, 189 | ): 190 | super(RobertaForQuAC, self).__init__(config) 191 | self.output_attentions = output_attentions 192 | hidden_size = config.hidden_size 193 | self.rational_l = Multi_linear_layer(n_layers, hidden_size, 194 | hidden_size, 1, activation) 195 | self.logits_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 196 | 2, activation) 197 | self.unk_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 1, 198 | activation) 199 | self.attention_l = Multi_linear_layer(n_layers, hidden_size, 200 | hidden_size, 1, activation) 201 | self.yn_l = Multi_linear_layer(n_layers, hidden_size, hidden_size, 2, 202 | activation) 203 | self.beta = beta 204 | 205 | self.init_weights() 206 | 207 | def forward( 208 | self, 209 | input_ids, 210 | token_type_ids=None, 211 | attention_mask=None, 212 | start_positions=None, 213 | end_positions=None, 214 | rational_mask=None, 215 | cls_idx = None, 216 | head_mask=None, 217 | ): 218 | # mask some words on inputs_ids 219 | # if self.training and self.mask_p > 0: 220 | # batch_size = input_ids.size(0) 221 | # for i in range(batch_size): 222 | # len_c, len_qc = token_type_ids[i].sum( 223 | # dim=0).detach().item(), attention_mask[i].sum( 224 | # dim=0).detach().item() 225 | # masked_idx = random.sample(range(len_qc - len_c, len_qc), 226 | # int(len_c * self.mask_p)) 227 | # input_ids[i, masked_idx] = 100 228 | 229 | outputs = self.roberta( 230 | input_ids, 231 | token_type_ids=None, # warning: should we use token_type_ids in roberta? 232 | attention_mask=attention_mask, 233 | # output_all_encoded_layers=False, 234 | head_mask=head_mask, 235 | ) 236 | if self.output_attentions: 237 | all_attentions, sequence_output, cls_outputs = outputs 238 | else: 239 | final_hidden, pooled_output = outputs 240 | 241 | rational_logits = self.rational_l(final_hidden) 242 | rational_logits = torch.sigmoid(rational_logits) 243 | 244 | final_hidden = final_hidden * rational_logits 245 | 246 | logits = self.logits_l(final_hidden) 247 | 248 | start_logits, end_logits = logits.split(1, dim=-1) 249 | 250 | start_logits, end_logits = start_logits.squeeze( 251 | -1), end_logits.squeeze(-1) 252 | 253 | segment_mask = token_type_ids.type(final_hidden.dtype) 254 | 255 | rational_logits = rational_logits.squeeze(-1) * segment_mask 256 | 257 | start_logits = start_logits * rational_logits 258 | 259 | end_logits = end_logits * rational_logits 260 | 261 | unk_logits = self.unk_l(pooled_output) 262 | 263 | attention = self.attention_l(final_hidden).squeeze(-1) 264 | 265 | attention.data.masked_fill_(attention_mask.eq(0), -float('inf')) 266 | 267 | attention = F.softmax(attention, dim=-1) 268 | 269 | attention_pooled_output = (attention.unsqueeze(-1) * 270 | final_hidden).sum(dim=-2) 271 | 272 | yn_logits = self.yn_l(attention_pooled_output) 273 | 274 | yes_logits, no_logits = yn_logits.split(1, dim=-1) 275 | 276 | start_logits.data.masked_fill_(attention_mask.eq(0), -float('inf')) 277 | end_logits.data.masked_fill_(attention_mask.eq(0), -float('inf')) 278 | 279 | new_start_logits = torch.cat( 280 | (yes_logits, no_logits, unk_logits, start_logits), dim=-1) 281 | new_end_logits = torch.cat( 282 | (yes_logits, no_logits, unk_logits, end_logits), dim=-1) 283 | 284 | if start_positions is not None and end_positions is not None: 285 | 286 | start_positions, end_positions = start_positions + cls_idx, end_positions + cls_idx 287 | 288 | # If we are on multi-GPU, split add a dimension 289 | if len(start_positions.size()) > 1: 290 | start_positions = start_positions.squeeze(-1) 291 | if len(end_positions.size()) > 1: 292 | end_positions = end_positions.squeeze(-1) 293 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 294 | ignored_index = new_start_logits.size(1) 295 | start_positions.clamp_(0, ignored_index) 296 | end_positions.clamp_(0, ignored_index) 297 | 298 | span_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 299 | 300 | start_loss = span_loss_fct(new_start_logits, start_positions) 301 | end_loss = span_loss_fct(new_end_logits, end_positions) 302 | 303 | # rational part 304 | alpha = 0.25 305 | gamma = 2. 306 | rational_mask = rational_mask.type(final_hidden.dtype) 307 | 308 | rational_loss = -alpha * ( 309 | (1 - rational_logits)**gamma 310 | ) * rational_mask * torch.log(rational_logits + 1e-7) - ( 311 | 1 - alpha) * (rational_logits**gamma) * ( 312 | 1 - rational_mask) * torch.log(1 - rational_logits + 1e-7) 313 | 314 | rational_loss = (rational_loss * 315 | segment_mask).sum() / segment_mask.sum() 316 | # end 317 | 318 | assert not torch.isnan(rational_loss) 319 | 320 | total_loss = (start_loss + 321 | end_loss) / 2 + rational_loss * self.beta 322 | return total_loss 323 | 324 | return start_logits, end_logits, yes_logits, no_logits, unk_logits 325 | -------------------------------------------------------------------------------- /models/excord/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/models/excord/__init__.py -------------------------------------------------------------------------------- /models/excord/interface.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 2 | 3 | from transformers import ( 4 | MODEL_FOR_QUESTION_ANSWERING_MAPPING, 5 | WEIGHTS_NAME, 6 | AdamW, 7 | AutoConfig, 8 | AutoTokenizer, 9 | get_linear_schedule_with_warmup, 10 | RobertaTokenizer 11 | ) 12 | from models.excord.quac import ( 13 | QuacProcessor, 14 | quac_convert_example_to_features_pt, 15 | QuacResult, 16 | ) 17 | from models.excord.quac_metrics import ( 18 | compute_one_prediction_logits, 19 | ) 20 | from models.excord.modeling_auto import AutoModelForQuestionAnswering 21 | 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn as nn 26 | from torch.nn import CrossEntropyLoss 27 | from torch.utils.data import (DataLoader, RandomSampler, TensorDataset) 28 | 29 | class Excord(): 30 | def __init__(self, args): 31 | self.args = args 32 | config = AutoConfig.from_pretrained(args.model_name_or_path) 33 | self.model = AutoModelForQuestionAnswering.from_pretrained( 34 | self.args.model_name_or_path, config=config) 35 | self.QA_history = [] 36 | torch.manual_seed(args.seed) 37 | self.device = torch.device("cuda" if torch.cuda.is_available() 38 | and not args.no_cuda else "cpu") 39 | self.model = self.model.to(self.device) 40 | 41 | def tokenizer(self): 42 | self.tokenizer = RobertaTokenizer.from_pretrained( 43 | self.args.model_name_or_path, do_lower_case=self.args.do_lower_case) 44 | return self.tokenizer 45 | 46 | def load_partial_examples(self, predict_file): 47 | processor = QuacProcessor( 48 | tokenizer=self.tokenizer, orig_history=False) 49 | partial_examples = processor.get_partial_dev_examples(data_dir=None, filename=predict_file) 50 | return partial_examples 51 | 52 | def to_list(tensor): 53 | return tensor.detach().cpu().tolist() 54 | 55 | def predict_one_automatic_turn(self, partial_example, unique_id, example_idx, tokenizer): 56 | processor = QuacProcessor(tokenizer=tokenizer, orig_history=False) 57 | question = partial_example.question_text 58 | 59 | example = processor.process_one_dev_example(self.QA_history, example_idx, partial_example) 60 | dataset, features = quac_convert_example_to_features_pt( 61 | example, tokenizer, self.args.max_seq_length, self.args.doc_stride, self.args.max_query_length) 62 | new_features = [] 63 | next_unique_id = unique_id 64 | for example_feature in features: 65 | example_feature.unique_id = next_unique_id 66 | new_features.append(example_feature) 67 | next_unique_id += 1 68 | features = new_features 69 | del new_features 70 | # Run prediction for full data 71 | 72 | eval_sampler = SequentialSampler(dataset) 73 | eval_dataloader = DataLoader( 74 | dataset, sampler=eval_sampler, batch_size=1) 75 | curr_results = [] 76 | # Run prediction for current example 77 | for batch in eval_dataloader: 78 | 79 | batch = tuple(t.to(self.device) for t in batch) 80 | # Assume the logits are a list of one item 81 | with torch.no_grad(): 82 | inputs = { 83 | "input_ids": batch[0], 84 | "attention_mask": batch[1], 85 | } 86 | feature_indices = batch[3] 87 | 88 | temp_outputs = self.model(**inputs) 89 | batch_start_logits = temp_outputs.start_logits 90 | batch_end_logts = temp_outputs.end_logits 91 | outputs = [batch_start_logits, batch_end_logts] 92 | for i, feature_index in enumerate(feature_indices): 93 | eval_feature = features[feature_index.item()] 94 | example_unique_id = int(eval_feature.unique_id) 95 | 96 | output = [Excord.to_list(output[i]) for output in outputs] 97 | 98 | start_logits, end_logits = output 99 | result = QuacResult(example_unique_id, start_logits, 100 | end_logits,0) 101 | 102 | curr_results.append(result) 103 | prediction, nbest_prediction = compute_one_prediction_logits( 104 | example, 105 | features, 106 | curr_results, 107 | self.args.n_best_size, 108 | self.args.max_answer_length, 109 | self.args.do_lower_case, 110 | self.args.verbose_logging, 111 | self.args.null_score_diff_threshold, 112 | tokenizer 113 | ) 114 | self.QA_history.append((example_idx, question, (prediction, None, None))) 115 | return prediction, next_unique_id 116 | -------------------------------------------------------------------------------- /models/excord/modeling_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | 18 | import warnings 19 | from collections import OrderedDict 20 | 21 | from transformers import ( 22 | AutoConfig, 23 | RobertaConfig, 24 | ) 25 | from transformers.models.auto.configuration_auto import replace_list_option_in_docstrings 26 | 27 | from transformers.file_utils import( 28 | add_start_docstrings, 29 | ) 30 | 31 | from transformers.configuration_utils import PretrainedConfig 32 | 33 | from transformers.utils import logging 34 | 35 | from models.excord.modeling_roberta import ( 36 | RobertaForQuestionAnswering 37 | ) 38 | 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | 43 | MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( 44 | [ 45 | (RobertaConfig, RobertaForQuestionAnswering), 46 | ] 47 | ) 48 | 49 | AUTO_MODEL_PRETRAINED_DOCSTRING = r""" 50 | The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either 51 | passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, 52 | by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: 53 | List options 54 | The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are 55 | deactivated). To train the model, you should first set it back in training mode with ``model.train()`` 56 | Args: 57 | pretrained_model_name_or_path: 58 | Can be either: 59 | - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., 60 | ``bert-base-uncased``. 61 | - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., 62 | ``dbmdz/bert-base-german-cased``. 63 | - A path to a `directory` containing model weights saved using 64 | :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. 65 | - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In 66 | this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided 67 | as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in 68 | a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. 69 | model_args (additional positional arguments, `optional`): 70 | Will be passed along to the underlying model ``__init__()`` method. 71 | config (:class:`~transformers.PretrainedConfig`, `optional`): 72 | Configuration for the model to use instead of an automatically loaded configuration. Configuration can 73 | be automatically loaded when: 74 | - The model is a model provided by the library (loaded with the `shortcut name` string of a 75 | pretrained model). 76 | - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded 77 | by supplying the save directory. 78 | - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a 79 | configuration JSON file named `config.json` is found in the directory. 80 | state_dict (`Dict[str, torch.Tensor]`, `optional`): 81 | A state dictionary to use instead of a state dictionary loaded from saved weights file. 82 | This option can be used if you want to create a model from a pretrained configuration but load your own 83 | weights. In this case though, you should check if using 84 | :func:`~transformers.PreTrainedModel.save_pretrained` and 85 | :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. 86 | cache_dir (:obj:`str`, `optional`): 87 | Path to a directory in which a downloaded pretrained model configuration should be cached if the 88 | standard cache should not be used. 89 | from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): 90 | Load the model weights from a TensorFlow checkpoint save file (see docstring of 91 | ``pretrained_model_name_or_path`` argument). 92 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 93 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the 94 | cached versions if they exist. 95 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 96 | Whether or not to delete incompletely received files. Will attempt to resume the download if such a 97 | file exists. 98 | proxies (:obj:`Dict[str, str], `optional`): 99 | A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 100 | 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 101 | output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): 102 | Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. 103 | local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): 104 | Whether or not to only look at local files (e.g., not try downloading the model). 105 | use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): 106 | Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on 107 | our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. 108 | kwargs (additional keyword arguments, `optional`): 109 | Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., 110 | :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or 111 | automatically loaded: 112 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the 113 | underlying model's ``__init__`` method (we assume all relevant updates to the configuration have 114 | already been done) 115 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class 116 | initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of 117 | ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute 118 | with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration 119 | attribute will be passed to the underlying model's ``__init__`` function. 120 | """ 121 | 122 | class AutoModelForQuestionAnswering: 123 | r""" 124 | This is a generic model class that will be instantiated as one of the model classes of the library---with a 125 | question answering head---when created with the when created with the 126 | :meth:`~transformers.AutoModeForQuestionAnswering.from_pretrained` class method or the 127 | :meth:`~transformers.AutoModelForQuestionAnswering.from_config` class method. 128 | This class cannot be instantiated directly using ``__init__()`` (throws an error). 129 | """ 130 | 131 | def __init__(self): 132 | raise EnvironmentError( 133 | "AutoModelForQuestionAnswering is designed to be instantiated " 134 | "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or " 135 | "`AutoModelForQuestionAnswering.from_config(config)` methods." 136 | ) 137 | 138 | @classmethod 139 | @replace_list_option_in_docstrings(MODEL_FOR_QUESTION_ANSWERING_MAPPING, use_model_types=False) 140 | def from_config(cls, config): 141 | r""" 142 | Instantiates one of the model classes of the library---with a question answering head---from a configuration. 143 | Note: 144 | Loading a model from its configuration file does **not** load the model weights. It only affects the 145 | model's configuration. Use :meth:`~transformers.AutoModelForQuestionAnswering.from_pretrained` to load the 146 | model weights. 147 | Args: 148 | config (:class:`~transformers.PretrainedConfig`): 149 | The model class to instantiate is selected based on the configuration class: 150 | List options 151 | Examples:: 152 | >>> from transformers import AutoConfig, AutoModelForQuestionAnswering 153 | >>> # Download configuration from S3 and cache. 154 | >>> config = AutoConfig.from_pretrained('bert-base-uncased') 155 | >>> model = AutoModelForQuestionAnswering.from_config(config) 156 | """ 157 | if type(config) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys(): 158 | return MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)](config) 159 | 160 | raise ValueError( 161 | "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" 162 | "Model type should be one of {}.".format( 163 | config.__class__, 164 | cls.__name__, 165 | ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()), 166 | ) 167 | ) 168 | 169 | @classmethod 170 | @replace_list_option_in_docstrings(MODEL_FOR_QUESTION_ANSWERING_MAPPING) 171 | @add_start_docstrings( 172 | "Instantiate one of the model classes of the library---with a question answering head---from a " 173 | "pretrained model.", 174 | AUTO_MODEL_PRETRAINED_DOCSTRING, 175 | ) 176 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 177 | r""" 178 | Examples:: 179 | >>> from transformers import AutoConfig, AutoModelForQuestionAnswering 180 | >>> # Download model and configuration from S3 and cache. 181 | >>> model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') 182 | >>> # Update configuration during loading 183 | >>> model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attentions=True) 184 | >>> model.config.output_attentions 185 | True 186 | >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) 187 | >>> config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') 188 | >>> model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) 189 | """ 190 | config = kwargs.pop("config", None) 191 | if not isinstance(config, PretrainedConfig): 192 | config, kwargs = AutoConfig.from_pretrained( 193 | pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs 194 | ) 195 | 196 | if type(config) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys(): 197 | return MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained( 198 | pretrained_model_name_or_path, *model_args, config=config, **kwargs 199 | ) 200 | 201 | raise ValueError( 202 | "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" 203 | "Model type should be one of {}.".format( 204 | config.__class__, 205 | cls.__name__, 206 | ", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()), 207 | ) 208 | ) 209 | -------------------------------------------------------------------------------- /models/excord/modeling_roberta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 4 | 5 | from transformers.models.roberta.modeling_roberta import ( 6 | RobertaPreTrainedModel, 7 | RobertaModel 8 | ) 9 | 10 | from transformers.file_utils import( 11 | add_start_docstrings, 12 | add_code_sample_docstrings, 13 | ) 14 | 15 | from transformers.modeling_outputs import ( 16 | QuestionAnsweringModelOutput, 17 | ) 18 | 19 | _CONFIG_FOR_DOC = "RobertaConfig" 20 | _TOKENIZER_FOR_DOC = "RobertaTokenizer" 21 | 22 | ROBERTA_START_DOCSTRING = r""" 23 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 24 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 25 | pruning heads etc.) 26 | This model is also a PyTorch `torch.nn.Module `__ 27 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 28 | general usage and behavior. 29 | Parameters: 30 | config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the 31 | model. Initializing with a config file does not load the weights associated with the model, only the 32 | configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model 33 | weights. 34 | """ 35 | 36 | class RobertaClassificationHead(nn.Module): 37 | """Head for sentence-level classification tasks.""" 38 | 39 | def __init__(self, config, class_num): 40 | super().__init__() 41 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 42 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 43 | self.out_proj = nn.Linear(config.hidden_size, class_num) 44 | 45 | def forward(self, features, **kwargs): 46 | x = features[:, 0, :] # take token (equiv. to [CLS]) 47 | x = self.dropout(x) 48 | x = self.dense(x) 49 | x = torch.tanh(x) 50 | x = self.dropout(x) 51 | x = self.out_proj(x) 52 | return x 53 | 54 | @add_start_docstrings( 55 | """ 56 | Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 57 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 58 | """, 59 | ROBERTA_START_DOCSTRING, 60 | ) 61 | class RobertaForQuestionAnswering(RobertaPreTrainedModel): 62 | authorized_unexpected_keys = [r"pooler"] 63 | authorized_missing_keys = [r"position_ids"] 64 | 65 | def __init__(self, config, class_num=1): 66 | super().__init__(config) 67 | self.class_num = class_num # quac: 1 class (answerable), 68 | 69 | self.roberta = RobertaModel(config, add_pooling_layer=False) 70 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 71 | self.classifier = RobertaClassificationHead(config, class_num) 72 | 73 | self.init_weights() 74 | 75 | @add_code_sample_docstrings( 76 | tokenizer_class=_TOKENIZER_FOR_DOC, 77 | checkpoint="roberta-base", 78 | output_type=QuestionAnsweringModelOutput, 79 | config_class=_CONFIG_FOR_DOC, 80 | ) 81 | def forward( 82 | self, 83 | input_ids=None, 84 | attention_mask=None, 85 | token_type_ids=None, 86 | position_ids=None, 87 | head_mask=None, 88 | inputs_embeds=None, 89 | start_positions=None, 90 | end_positions=None, 91 | is_impossible=None, 92 | output_attentions=None, 93 | output_hidden_states=None, 94 | return_dict=None, 95 | ): 96 | r""" 97 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 98 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 99 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 100 | sequence are not taken into account for computing the loss. 101 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 102 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 103 | Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the 104 | sequence are not taken into account for computing the loss. 105 | """ 106 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 107 | 108 | outputs = self.roberta( 109 | input_ids, 110 | attention_mask=attention_mask, 111 | token_type_ids=token_type_ids, 112 | position_ids=position_ids, 113 | head_mask=head_mask, 114 | inputs_embeds=inputs_embeds, 115 | output_attentions=output_attentions, 116 | output_hidden_states=output_hidden_states, 117 | return_dict=return_dict, 118 | ) 119 | 120 | sequence_output = outputs[0] 121 | 122 | logits = self.qa_outputs(sequence_output) 123 | start_logits, end_logits = logits.split(1, dim=-1) 124 | start_logits = start_logits.squeeze(-1) 125 | end_logits = end_logits.squeeze(-1) 126 | class_logits = self.classifier(sequence_output) 127 | 128 | total_loss = None 129 | if start_positions is not None and end_positions is not None: 130 | # If we are on multi-GPU, split add a dimension 131 | if len(start_positions.size()) > 1: 132 | start_positions = start_positions.squeeze(-1) 133 | if len(end_positions.size()) > 1: 134 | end_positions = end_positions.squeeze(-1) 135 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 136 | ignored_index = start_logits.size(1) 137 | start_positions.clamp_(0, ignored_index) 138 | end_positions.clamp_(0, ignored_index) 139 | 140 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 141 | start_loss = loss_fct(start_logits, start_positions) 142 | end_loss = loss_fct(end_logits, end_positions) 143 | 144 | if self.class_num < 2: # quac 145 | class_loss_fct = BCEWithLogitsLoss() 146 | class_loss = class_loss_fct(class_logits.squeeze(), is_impossible.squeeze()) 147 | else: # coqa 148 | class_loss_fct = CrossEntropyLoss(ignore_index=3) 149 | class_loss = class_loss_fct(class_logits, is_impossible) 150 | 151 | total_loss = (start_loss + end_loss + class_loss) / 3 152 | 153 | if not return_dict: 154 | output = (start_logits, end_logits, class_logits) 155 | return ((total_loss,) + output) if total_loss is not None else output 156 | 157 | return QuestionAnsweringModelOutput( 158 | loss=total_loss, 159 | start_logits=start_logits, 160 | end_logits=end_logits, 161 | hidden_states=outputs.hidden_states, 162 | attentions=outputs.attentions, 163 | ) 164 | -------------------------------------------------------------------------------- /models/graphflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/models/graphflow/__init__.py -------------------------------------------------------------------------------- /models/graphflow/interface.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.function_base import _quantile_is_valid 2 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 3 | from transformers import BertTokenizer, BertModel 4 | from models.graphflow.model import QuACModel 5 | from models.graphflow.utils.data_utils import QADataset, sanitize_input, vectorize_input_turn 6 | 7 | import spacy 8 | from spacy.tokens import Doc 9 | 10 | from pycorenlp import StanfordCoreNLP 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import CrossEntropyLoss 16 | from torch.utils.data import (DataLoader, RandomSampler, TensorDataset) 17 | 18 | class QuacExample: 19 | """ 20 | A single training/test example for the Squad dataset, as loaded from disk. 21 | 22 | Args: 23 | qas_id: The example's unique identifier 24 | question_text: The question string 25 | context_text: The context string 26 | answer_text: The answer string 27 | start_position_character: The character position of the start of the answer 28 | title: The title of the example 29 | answers: None by default, this is used during evaluation. Holds answers as well as their start positions. 30 | is_impossible: False by default, set to True if the example has no possible answer. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | qas_id, 36 | question_text, 37 | ): 38 | self.qas_id = qas_id 39 | self.question_text = question_text 40 | 41 | def __str__(self): 42 | return self.__repr__() 43 | 44 | def __repr__(self): 45 | s = "" 46 | s += "qas_id: %s" % (self.qas_id) 47 | s += "\nquestion_text: %s" % (self.question_text) 48 | 49 | return s 50 | 51 | def _str(s): 52 | """ Convert PTB tokens to normal tokens """ 53 | if (s.lower() == '-lrb-'): 54 | s = '(' 55 | elif (s.lower() == '-rrb-'): 56 | s = ')' 57 | elif (s.lower() == '-lsb-'): 58 | s = '[' 59 | elif (s.lower() == '-rsb-'): 60 | s = ']' 61 | elif (s.lower() == '-lcb-'): 62 | s = '{' 63 | elif (s.lower() == '-rcb-'): 64 | s = '}' 65 | return s 66 | 67 | class WhitespaceTokenizer(object): 68 | 69 | def __init__(self, vocab): 70 | self.vocab = vocab 71 | 72 | def __call__(self, text): 73 | words = text.split(' ') 74 | # All tokens 'own' a subsequent space character in this tokenizer 75 | spaces = [True] * len(words) 76 | return Doc(self.vocab, words=words, spaces=spaces) 77 | 78 | 79 | class ExampleProcessor(): 80 | def __init__(self) -> None: 81 | self.corenlp = StanfordCoreNLP('http://localhost:9000') 82 | self.parser = spacy.load("en_core_web_sm") 83 | self.parser.tokenizer = WhitespaceTokenizer(self.parser.vocab) 84 | 85 | def process(self, text): 86 | paragraph = self.corenlp.annotate(text, properties={ 87 | 'annotators': 'tokenize, ssplit, pos, ner', 88 | 'outputFormat': 'json', 89 | 'ssplit.newlineIsSentenceBreak': 'two'}) 90 | output = {'word': [], 91 | 'pos': [], 92 | 'ner': [], 93 | 'offsets': []} 94 | for sent in paragraph['sentences']: 95 | for token in sent['tokens']: 96 | output['word'].append(_str(token['word'])) 97 | output['pos'].append(token['pos']) 98 | output['ner'].append(token['ner']) 99 | output['offsets'].append( 100 | (token['characterOffsetBegin'], token['characterOffsetEnd'])) 101 | return output 102 | 103 | class GraphFlow(): 104 | def __init__(self, args): 105 | self.args = vars(args) 106 | self.device = torch.device("cuda" if torch.cuda.is_available() 107 | and not self.args['no_cuda'] else "cpu") 108 | self.args['device'] = self.device 109 | self.model = QuACModel(self.args) 110 | self.QA_history = [] 111 | self.history = [] 112 | torch.manual_seed(self.args['seed']) 113 | 114 | self.bert_model = BertModel.from_pretrained( 115 | self.args['bert_model']).to(self.device) 116 | self.bert_model.eval() 117 | self.model.init_saved_network(self.args['model_name_or_path']) 118 | self.model.network = self.model.network.to(self.device) 119 | self.question_processor = ExampleProcessor() 120 | 121 | 122 | def tokenizer(self): 123 | self.tokenizer = BertTokenizer.from_pretrained( 124 | self.args['bert_model'], do_lower_case=self.args['do_lower_case']) 125 | return self.tokenizer 126 | 127 | def load_partial_examples(self, predict_file): 128 | partial_test_set = QADataset(predict_file, self.args) 129 | return partial_test_set 130 | 131 | def predict_one_automatic_turn(self, partial_example, unique_id, example_idx, tokenizer): 132 | example = {'id': partial_example['cid'], 133 | 'evidence': partial_example['evidence'], 134 | 'raw_evidence': partial_example['raw_evidence']} 135 | 136 | qa = { 137 | 'turn_id': partial_example['turn_id'], 138 | 'question': partial_example['question'], 139 | 'answers': partial_example['answers'], 140 | 'targets': partial_example['targets'], 141 | 'span_mask': partial_example['span_mask'], 142 | 'unk_answer_targets': partial_example['unk_answer_targets'], 143 | 'yesno_targets': partial_example['yesno_targets'], 144 | 'followup_targets': partial_example['followup_targets'] 145 | } 146 | temp = [] 147 | marker = [] 148 | n_history = len(self.history) if self.args['n_history'] < 0 else min( 149 | self.args['n_history'], len(self.history)) 150 | if n_history > 0: 151 | count = sum([not self.args['no_pre_question'], 152 | not self.args['no_pre_answer']]) * len(self.history[-n_history:]) 153 | for q, a, a_s, a_e in self.history[-n_history:]: 154 | if not self.args['no_pre_question']: 155 | temp.extend(q) 156 | marker.extend([count] * len(q)) 157 | count -= 1 158 | if not self.args['no_pre_answer']: 159 | temp.extend(a) 160 | marker.extend([count] * len(a)) 161 | count -= 1 162 | original_question = qa['question']['word'] 163 | temp.extend(original_question) 164 | marker.extend([0] * len(original_question)) 165 | qa['question']['word'] = temp 166 | qa['question']['marker'] = marker 167 | 168 | example['turns'] = [qa] 169 | test_loader = DataLoader([example], batch_size=1, shuffle=False, 170 | collate_fn=lambda x: x, pin_memory=True) 171 | for step, input_batch in enumerate(test_loader): 172 | input_batch = sanitize_input(input_batch, self.args, self.model.word_dict, 173 | self.model.feature_dict, self.tokenizer, training=False) 174 | x_batch = vectorize_input_turn( 175 | input_batch, self.args, self.bert_model, self.history, len(self.history), device=self.device) 176 | if not x_batch: 177 | continue # When there are no target spans present in the batch 178 | 179 | res = self.model.predict( 180 | x_batch, step, update=False, out_predictions=True) 181 | prediction = res['predictions'][0][0] 182 | span = res['token_spans'][0][0] 183 | tokens = res['token_lists'][0][0] 184 | self.history.append((original_question, tokens, span[0], span[1])) 185 | self.QA_history.append((len(self.QA_history), " ".join(original_question), (" ".join(tokens), span[0], span[1]))) # turn, q_text, a_text 186 | 187 | return prediction, unique_id+1 188 | 189 | def convert_example(self, partial_example): 190 | example = QuacExample( 191 | qas_id=partial_example['turn_id'], 192 | question_text=" ".join(partial_example['question']['word']), 193 | ) 194 | return example 195 | 196 | -------------------------------------------------------------------------------- /models/graphflow/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/models/graphflow/layers/__init__.py -------------------------------------------------------------------------------- /models/graphflow/layers/attention.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov, 2018 3 | 4 | @author: hugo 5 | 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | 10 | import warnings 11 | # warnings.simplefilter("error") 12 | warnings.simplefilter("ignore", UserWarning) 13 | warnings.simplefilter("ignore", Warning) 14 | 15 | INF = 1e20 16 | class Context2QuestionAttention(nn.Module): 17 | def __init__(self, dim, hidden_size): 18 | super(Context2QuestionAttention, self).__init__() 19 | self.linear_sim = nn.Linear(dim, hidden_size, bias=False) 20 | 21 | def forward(self, context, questions, out_questions, ques_mask=None): 22 | """ 23 | Parameters 24 | :context, (batch_size, ?, ctx_size, dim) 25 | :questions, (batch_size, turn_size, ques_size, dim) 26 | :out_questions, (batch_size, turn_size, ques_size, ?) 27 | :ques_mask, (batch_size, turn_size, ques_size) 28 | 29 | Returns 30 | :ques_emb, (batch_size, turn_size, ctx_size, dim) 31 | """ 32 | # shape: (batch_size, ?, ctx_size, dim), ? equals 1 or turn_size 33 | context_fc = torch.relu(self.linear_sim(context)) 34 | # shape: (batch_size, turn_size, ques_size, dim) 35 | questions_fc = torch.relu(self.linear_sim(questions)) 36 | 37 | # shape: (batch_size, turn_size, ctx_size, ques_size) 38 | attention = torch.matmul(context_fc, questions_fc.transpose(-1, -2)) 39 | if ques_mask is not None: 40 | # print("Context2Question Attention") 41 | # print(1 - ques_mask.byte().unsqueeze(2)) 42 | # print((1 - ques_mask.byte().unsqueeze(2)).to(torch.bool)) 43 | # print((1 - ques_mask.byte().unsqueeze(2)).to(torch.bool).dtype) 44 | mask = (1 - ques_mask.byte().unsqueeze(2)).to(torch.bool) 45 | attention = attention.masked_fill_(mask, -INF) 46 | prob = torch.softmax(attention, dim=-1) 47 | # shape: (batch_size, turn_size, ctx_size, ?) 48 | ques_emb = torch.matmul(prob, out_questions) 49 | return ques_emb 50 | 51 | class SelfAttention(nn.Module): 52 | def __init__(self, input_size, hidden_size): 53 | super(SelfAttention, self).__init__() 54 | self.W1 = torch.Tensor(input_size, hidden_size) 55 | self.W1 = nn.Parameter(nn.init.xavier_uniform_(self.W1)) 56 | self.W2 = torch.Tensor(hidden_size, 1) 57 | self.W2 = nn.Parameter(nn.init.xavier_uniform_(self.W2)) 58 | 59 | def forward(self, x, attention_mask=None): 60 | attention = torch.mm(torch.tanh(torch.mm(x.view(-1, x.size(-1)), self.W1)), self.W2).view(x.size(0), -1) 61 | if attention_mask is not None: 62 | # Exclude masked elements from the softmax 63 | # print("Self Attention") 64 | # print(1-attention_mask.byte()) 65 | # print((1-attention_mask.byte()).to(torch.bool).dtype) 66 | # print(type((1-attention_mask.byte()).to(torch.bool))) 67 | mask = (1-attention_mask.byte()).to(torch.bool) 68 | attention = attention.masked_fill_(mask, -INF) 69 | 70 | probs = torch.softmax(attention, dim=-1).unsqueeze(1) 71 | weighted_x = torch.bmm(probs, x).squeeze(1) 72 | return weighted_x 73 | -------------------------------------------------------------------------------- /models/graphflow/layers/common.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov, 2018 3 | 4 | @author: hugo 5 | 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 10 | import torch.nn.functional as F 11 | from ..utils.generic_utils import to_cuda 12 | 13 | 14 | class GatedFusion(nn.Module): 15 | def __init__(self, hidden_size): 16 | super(GatedFusion, self).__init__() 17 | '''GatedFusion module''' 18 | self.fc_z = nn.Linear(4 * hidden_size, hidden_size, bias=True) 19 | 20 | def forward(self, h_state, input): 21 | z = torch.sigmoid(self.fc_z(torch.cat([h_state, input, h_state * input, h_state - input], -1))) 22 | h_state = (1 - z) * h_state + z * input 23 | return h_state 24 | 25 | class GRUStep(nn.Module): 26 | def __init__(self, hidden_size, input_size): 27 | super(GRUStep, self).__init__() 28 | '''GRU module''' 29 | self.linear_z = nn.Linear(hidden_size + input_size, hidden_size, bias=False) 30 | self.linear_r = nn.Linear(hidden_size + input_size, hidden_size, bias=False) 31 | self.linear_t = nn.Linear(hidden_size + input_size, hidden_size, bias=False) 32 | 33 | def forward(self, h_state, input): 34 | z = torch.sigmoid(self.linear_z(torch.cat([h_state, input], -1))) 35 | r = torch.sigmoid(self.linear_r(torch.cat([h_state, input], -1))) 36 | t = torch.tanh(self.linear_t(torch.cat([r * h_state, input], -1))) 37 | h_state = (1 - z) * h_state + z * t 38 | return h_state 39 | 40 | class EncoderRNN(nn.Module): 41 | def __init__(self, input_size, hidden_size, \ 42 | bidirectional=False, rnn_type='lstm', rnn_dropout=None, rnn_input_dropout=None, device=None): 43 | super(EncoderRNN, self).__init__() 44 | if not rnn_type in ('lstm', 'gru'): 45 | raise RuntimeError('rnn_type is expected to be lstm or gru, got {}'.format(rnn_type)) 46 | if bidirectional: 47 | print('[ Using bidirectional {} encoder ]'.format(rnn_type)) 48 | else: 49 | print('[ Using {} encoder ]'.format(rnn_type)) 50 | if bidirectional and hidden_size % 2 != 0: 51 | raise RuntimeError('hidden_size is expected to be even in the bidirectional mode!') 52 | self.rnn_type = rnn_type 53 | self.rnn_dropout = rnn_dropout 54 | self.rnn_input_dropout = rnn_input_dropout 55 | self.device = device 56 | self.hidden_size = hidden_size // 2 if bidirectional else hidden_size 57 | self.num_directions = 2 if bidirectional else 1 58 | model = nn.LSTM if rnn_type == 'lstm' else nn.GRU 59 | self.model = model(input_size, self.hidden_size, 1, batch_first=True, bidirectional=bidirectional) 60 | 61 | def forward(self, x, x_len): 62 | """x: [batch_size * max_length * emb_dim] 63 | x_len: [batch_size] 64 | """ 65 | x = dropout(x, self.rnn_input_dropout, shared_axes=[-2], training=self.training) 66 | sorted_x_len, indx = torch.sort(x_len, 0, descending=True) 67 | x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True) 68 | 69 | h0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.device) 70 | if self.rnn_type == 'lstm': 71 | c0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.device) 72 | packed_h, (packed_h_t, _) = self.model(x, (h0, c0)) 73 | packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1) 74 | else: 75 | packed_h, packed_h_t = self.model(x, h0) 76 | packed_h_t = packed_h_t.transpose(0, 1).contiguous().view(x_len.size(0), -1) 77 | 78 | hh, _ = pad_packed_sequence(packed_h, batch_first=True) 79 | 80 | # restore the sorting 81 | _, inverse_indx = torch.sort(indx, 0) 82 | restore_hh = hh[inverse_indx] 83 | restore_packed_h_t = packed_h_t[inverse_indx] 84 | restore_hh = dropout(restore_hh, self.rnn_dropout, shared_axes=[-2], training=self.training) 85 | restore_packed_h_t = dropout(restore_packed_h_t, self.rnn_dropout, training=self.training) 86 | return restore_hh, restore_packed_h_t 87 | 88 | 89 | def dropout(x, drop_prob, shared_axes=[], training=False): 90 | """ 91 | Apply dropout to input tensor. 92 | Parameters 93 | ---------- 94 | input_tensor: ``torch.FloatTensor`` 95 | A tensor of shape ``(batch_size, ..., num_timesteps, embedding_dim)`` 96 | Returns 97 | ------- 98 | output: ``torch.FloatTensor`` 99 | A tensor of shape ``(batch_size, ..., num_timesteps, embedding_dim)`` with dropout applied. 100 | """ 101 | if drop_prob == 0 or drop_prob == None or (not training): 102 | return x 103 | 104 | sz = list(x.size()) 105 | for i in shared_axes: 106 | sz[i] = 1 107 | mask = x.new(*sz).bernoulli_(1. - drop_prob).div_(1. - drop_prob) 108 | mask = mask.expand_as(x) 109 | return x * mask 110 | -------------------------------------------------------------------------------- /models/graphflow/layers/graphs.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov, 2018 3 | 4 | @author: hugo 5 | 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from ..utils.generic_utils import to_cuda, get_range_vector, get_sinusoid_encoding_table 12 | from ..layers.common import * 13 | import warnings 14 | warnings.simplefilter("ignore", UserWarning) 15 | warnings.simplefilter("ignore", Warning) 16 | 17 | INF = 1e20 18 | VERY_SMALL_NUMBER = 1e-12 19 | COMBINE_RATIO = 0.9 20 | 21 | class GraphLearner(nn.Module): 22 | def __init__(self, input_size, hidden_size, topk, epsilon, n_spatial_kernels, use_spatial_kernels=True, \ 23 | use_position_enc=False, position_emb_size=10, max_position_distance=160, num_pers=1, device=None): 24 | super(GraphLearner, self).__init__() 25 | self.device = device 26 | self.topk = topk 27 | self.epsilon = epsilon 28 | self.use_spatial_kernels = use_spatial_kernels 29 | self.use_position_enc = use_position_enc 30 | self.max_position_distance = max_position_distance 31 | # self.linear_sim = nn.Linear(input_size, hidden_size, bias=False) 32 | 33 | self.weight_tensor = torch.Tensor(num_pers, input_size) 34 | self.weight_tensor = nn.Parameter(nn.init.xavier_uniform_(self.weight_tensor)) 35 | print('[ Multi-perspective GraphLearner: {} ]'.format(num_pers)) 36 | 37 | 38 | if use_spatial_kernels: 39 | print('[ Using spatial Gaussian kernels ]') 40 | if use_position_enc: 41 | print('[ Using sinusoid position encoding ]') 42 | # Position encoding 43 | self.position_enc = nn.Embedding.from_pretrained( 44 | get_sinusoid_encoding_table(self.max_position_distance + 1, position_emb_size, padding_idx=0, device=device), 45 | freeze=True) 46 | 47 | # Parameters of the Gaussian kernels 48 | self.mean_dis = nn.Parameter(torch.Tensor(n_spatial_kernels, position_emb_size)) 49 | self.mean_dis.data.uniform_(-1, 1) 50 | self.precision_inv_dis = nn.Parameter(torch.Tensor(n_spatial_kernels, position_emb_size)) 51 | self.precision_inv_dis.data.uniform_(0.0, 1.0) 52 | else: 53 | # Parameters of the Gaussian kernels 54 | self.mean_dis = nn.Parameter(torch.Tensor(n_spatial_kernels, 1)) 55 | self.mean_dis.data.uniform_(0, 1) 56 | self.precision_inv_dis = nn.Parameter(torch.Tensor(n_spatial_kernels, 1)) 57 | self.precision_inv_dis.data.uniform_(0.0, 1.0) 58 | 59 | def forward(self, context, ctx_mask): 60 | """ 61 | Parameters 62 | :context, (batch_size, turn_size, ctx_size, dim) 63 | :ctx_mask, (batch_size, ctx_size) 64 | 65 | Returns 66 | :adjacency_matrix, (batch_size, turn_size, ctx_size, ctx_size) 67 | """ 68 | markoff_value = -INF 69 | 70 | # 1) 71 | # context_fc = torch.relu(self.linear_sim(context)) 72 | # attention = torch.matmul(context_fc, context_fc.transpose(-1, -2)) 73 | 74 | # # 2) 75 | # context_fc = context.unsqueeze(2) * self.weight_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-2) 76 | # attention = torch.mean(torch.matmul(context_fc, context_fc.transpose(-1, -2)), dim=2) 77 | 78 | 79 | # 3) Best attention mechanism 80 | context_fc = context.unsqueeze(2) * torch.relu(self.weight_tensor).unsqueeze(0).unsqueeze(0).unsqueeze(-2) 81 | attention = torch.mean(torch.matmul(context_fc, context.unsqueeze(2).transpose(-1, -2)), dim=2) 82 | 83 | 84 | # # 4)weighted cosine 85 | # context_fc = context.unsqueeze(2) * self.weight_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-2) 86 | # context_norm = F.normalize(context_fc, p=2, dim=-1) 87 | # attention = torch.matmul(context_norm, context_norm.transpose(-1, -2)).mean(2) 88 | # markoff_value = 0 89 | 90 | 91 | if ctx_mask is not None: 92 | # print("ctx mask") 93 | mask1 = (1 - ctx_mask.byte().unsqueeze(1).unsqueeze(-1)).to(torch.bool) 94 | # print(mask1.dtype) 95 | attention = attention.masked_fill_(mask1, markoff_value) 96 | mask2 = (1 - ctx_mask.byte().unsqueeze(1).unsqueeze(-2)).to(torch.bool) 97 | # print(mask2.dtype) 98 | attention = attention.masked_fill_(mask2, markoff_value) 99 | 100 | if self.use_spatial_kernels: 101 | # shape: (batch_size, turn_size, n_spatial_kernels, ctx_size, ctx_size) 102 | spatial_attention = self.get_spatial_attention(attention.shape[:3]) 103 | # joint_attention = COMBINE_RATIO * torch.softmax(attention, dim=-1).unsqueeze(2) + (1 - COMBINE_RATIO) * spatial_attention / torch.sum(spatial_attention, dim=-1, keepdim=True) 104 | weighted_adjacency_matrix = self.build_knn_neighbourhood(attention, self.topk, attention, spatial_attention) 105 | else: 106 | if self.topk is not None: 107 | weighted_adjacency_matrix = self.build_knn_neighbourhood(attention, self.topk) 108 | 109 | if self.epsilon is not None: 110 | weighted_adjacency_matrix = self.build_epsilon_neighbourhood(attention, self.epsilon, markoff_value) 111 | 112 | return weighted_adjacency_matrix 113 | 114 | 115 | def build_epsilon_neighbourhood(self, attention, epsilon, markoff_value): 116 | mask = (attention > epsilon).detach().float() 117 | weighted_adjacency_matrix = attention * mask + markoff_value * (1 - mask) 118 | return weighted_adjacency_matrix 119 | 120 | 121 | def build_knn_neighbourhood(self, attention, topk, semantic_attention=None, spatial_attention=None, markoff_value=-INF): 122 | knn_val, knn_ind = torch.topk(attention, topk, dim=-1) 123 | if self.use_spatial_kernels: 124 | # semantic_attention = semantic_attention.unsqueeze(2).expand(-1, -1, spatial_attention.size(2), -1, -1) 125 | semantic_attn_chosen = torch.gather(semantic_attention, dim=-1, index=knn_ind) 126 | semantic_attn_chosen = torch.softmax(semantic_attn_chosen, dim=-1) 127 | 128 | expand_knn_ind = knn_ind.unsqueeze(2).expand(-1, -1, spatial_attention.size(2), -1, -1) 129 | spatial_attn_chosen = torch.gather(spatial_attention, dim=-1, index=expand_knn_ind) 130 | spatial_attn_chosen = spatial_attn_chosen / torch.sum(spatial_attn_chosen, dim=-1, keepdim=True) 131 | 132 | attn_chosen = semantic_attn_chosen.unsqueeze(2) * spatial_attn_chosen 133 | weighted_adjacency_matrix = to_cuda(torch.zeros_like(spatial_attention).scatter_(-1, expand_knn_ind, attn_chosen), self.device) 134 | else: 135 | weighted_adjacency_matrix = to_cuda((markoff_value * torch.ones_like(attention)).scatter_(-1, knn_ind, knn_val), self.device) 136 | return weighted_adjacency_matrix 137 | 138 | def get_spatial_attention(self, shape): 139 | # Compute pseudo-coordinates for context words 140 | batch_size, turn_size, ctx_size = shape 141 | ctx_token_idx = get_range_vector(ctx_size, self.device) 142 | pseudo_coord = ctx_token_idx.unsqueeze(-1) - ctx_token_idx.unsqueeze(0) 143 | if self.use_position_enc: 144 | # Truncate 145 | pseudo_coord = torch.clamp(torch.abs(pseudo_coord) + 1, max=self.max_position_distance) 146 | pseudo_coord = self.position_enc(pseudo_coord) 147 | # Use Gaussian kernel to model attention over distance 148 | spatial_attention = self.get_multivariate_gaussian_weights(pseudo_coord) 149 | else: 150 | # Truncate & scale 151 | # pseudo_coord = torch.clamp(pseudo_coord, min=-self.max_position_distance, max=self.max_position_distance) 152 | pseudo_coord = torch.clamp(torch.abs(pseudo_coord.float()), max=self.max_position_distance) / self.max_position_distance 153 | # Use Gaussian kernel to model attention over distance 154 | spatial_attention = self.get_gaussian_weights(pseudo_coord) 155 | 156 | # shape: (batch_size, turn_size, n_spatial_kernels, ctx_size, ctx_size) 157 | spatial_attention = spatial_attention.unsqueeze(0).unsqueeze(0).expand(batch_size, turn_size, -1, -1, -1) 158 | return spatial_attention 159 | 160 | def get_gaussian_weights(self, pseudo_coord): 161 | ''' 162 | ## Inputs: 163 | - pseudo_coord (ctx_size, ctx_size) 164 | ## Returns: 165 | - weights (n_spatial_kernels, ctx_size, ctx_size) 166 | ''' 167 | # compute weights 168 | diff = (pseudo_coord.view(1, -1) - self.mean_dis)**2 169 | weights = torch.exp(-0.5 * diff * (self.precision_inv_dis**2)) 170 | 171 | # shape: (n_spatial_kernels, ctx_size, ctx_size) 172 | weights = weights.view((-1,) + pseudo_coord.shape) 173 | return weights 174 | 175 | def get_multivariate_gaussian_weights(self, pseudo_coord): 176 | ''' 177 | ## Inputs: 178 | - pseudo_coord (ctx_size, ctx_size, dim) 179 | ## Returns: 180 | - weights (n_spatial_kernels, ctx_size, ctx_size) 181 | ''' 182 | # compute weights 183 | diff = (pseudo_coord.view(1, -1, pseudo_coord.size(-1)) - self.mean_dis.view(-1, 1, self.mean_dis.size(-1)))**2 184 | weights = torch.exp(-0.5 * torch.sum(diff * (self.precision_inv_dis.unsqueeze(1))**2, dim=-1)) 185 | 186 | # shape: (n_spatial_kernels, ctx_size, ctx_size) 187 | weights = weights.view((-1,) + pseudo_coord.shape[:2]) 188 | return weights 189 | 190 | class ContextGraphNN(nn.Module): 191 | def __init__(self, hidden_size, n_spatial_kernels, use_spatial_kernels=True, graph_hops=1, bignn=False, device=None): 192 | super(ContextGraphNN, self).__init__() 193 | print('[ Using {}-hop ContextGraphNN ]'.format(graph_hops)) 194 | self.graph_hops = graph_hops 195 | self.use_spatial_kernels = use_spatial_kernels 196 | if self.use_spatial_kernels: 197 | self.linear_kernels = nn.ModuleList([nn.Linear(hidden_size, hidden_size // n_spatial_kernels, bias=False) for _ in range(n_spatial_kernels)]) 198 | else: 199 | n_spatial_kernels = 1 200 | self.gru_step = GRUStep(hidden_size, hidden_size // n_spatial_kernels * n_spatial_kernels) 201 | if bignn: 202 | self.gated_fusion = GatedFusion(hidden_size) 203 | self.update = self.bignn_update 204 | else: 205 | self.update = self.gnn_update 206 | 207 | print('[ Using graph type: dynamic ]') 208 | 209 | 210 | def forward(self, node_state, weighted_adjacency_matrix): 211 | node_state = self.update(node_state, weighted_adjacency_matrix) 212 | return node_state 213 | 214 | def bignn_update(self, node_state, weighted_adjacency_matrix): 215 | weighted_adjacency_matrix_in = torch.softmax(weighted_adjacency_matrix, dim=-1) 216 | weighted_adjacency_matrix_out = torch.softmax(weighted_adjacency_matrix.transpose(-1, -2), dim=-1) 217 | 218 | for _ in range(self.graph_hops): 219 | agg_state_in = self.aggregate_avgpool(node_state, weighted_adjacency_matrix_in) 220 | agg_state_out = self.aggregate_avgpool(node_state, weighted_adjacency_matrix_out) 221 | agg_state = self.gated_fusion(agg_state_in, agg_state_out) 222 | node_state = self.gru_step(node_state, agg_state) 223 | return node_state 224 | 225 | def gnn_update(self, node_state, weighted_adjacency_matrix): 226 | weighted_adjacency_matrix = torch.softmax(weighted_adjacency_matrix, dim=-1) 227 | 228 | 229 | for _ in range(self.graph_hops): 230 | agg_state = self.aggregate_avgpool(node_state, weighted_adjacency_matrix) 231 | node_state = self.gru_step(node_state, agg_state) 232 | return node_state 233 | 234 | def aggregate_avgpool(self, node_state, weighted_adjacency_matrix): 235 | # Information aggregation 236 | if self.use_spatial_kernels: 237 | # Joint aggregation 238 | agg_state = torch.cat([self.linear_kernels[i](torch.matmul(weighted_adjacency_matrix[:, i], node_state)) for i in range(weighted_adjacency_matrix.size(1))], -1) 239 | else: 240 | agg_state = torch.matmul(weighted_adjacency_matrix, node_state) 241 | return agg_state 242 | 243 | 244 | # Static GNN 245 | class StaticContextGraphNN(nn.Module): 246 | def __init__(self, hidden_size, graph_hops=1, device=None): 247 | super(StaticContextGraphNN, self).__init__() 248 | print('[ Using {}-hop GraphNN ]'.format(graph_hops)) 249 | self.device = device 250 | self.graph_hops = graph_hops 251 | self.linear_max = nn.Linear(hidden_size, hidden_size, bias=False) 252 | 253 | # Static graph 254 | self.static_graph_mp = GraphMessagePassing() 255 | self.static_gated_fusion = GatedFusion(hidden_size) 256 | self.static_gru_step = GRUStep(hidden_size, hidden_size) 257 | 258 | print('[ Using graph type: static ]') 259 | 260 | def forward(self, node_state, adj): 261 | '''Static graph update''' 262 | node2edge, edge2node = adj 263 | 264 | # Shape: (batch_size, num_edges, num_nodes) 265 | node2edge = to_cuda(torch.stack([torch.Tensor(x.A) for x in node2edge], dim=0), self.device) 266 | # Shape: (batch_size, num_nodes, num_edges) 267 | edge2node = to_cuda(torch.stack([torch.Tensor(x.A) for x in edge2node], dim=0), self.device) 268 | 269 | for _ in range(self.graph_hops): 270 | bw_agg_state = self.static_graph_mp(node_state, node2edge, edge2node) 271 | fw_agg_state = self.static_graph_mp(node_state, edge2node.transpose(1, 2), node2edge.transpose(1, 2)) 272 | agg_state = self.static_gated_fusion(fw_agg_state, bw_agg_state) 273 | node_state = self.static_gru_step(node_state, agg_state) 274 | return node_state 275 | 276 | 277 | class GraphMessagePassing(nn.Module): 278 | def __init__(self): 279 | super(GraphMessagePassing, self).__init__() 280 | 281 | def forward(self, node_state, node2edge, edge2node): 282 | node2edge_emb = torch.bmm(node2edge, node_state) # batch_size x num_edges x hidden_size 283 | 284 | # Add self-loop 285 | norm_ = torch.sum(edge2node, 2, keepdim=True) + 1 286 | agg_state = (torch.bmm(edge2node, node2edge_emb) + node_state) / norm_ 287 | return agg_state 288 | -------------------------------------------------------------------------------- /models/graphflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphflow import GraphFlow 2 | -------------------------------------------------------------------------------- /models/graphflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .timer import * 2 | from .logger import * 3 | from .data_utils import * 4 | from .io_utils import * 5 | from .eval_utils import * 6 | -------------------------------------------------------------------------------- /models/graphflow/utils/bert_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | 3 | import torch 4 | 5 | 6 | # When using the sliding window trick for long sequences, 7 | # we take the representation of each token with maximal context. 8 | # Take average of the BERT embeddings of these BPE sub-tokens 9 | # as the embedding for the word. 10 | # Take *weighted* average of the word embeddings through all layers. 11 | 12 | def extract_bert_ques_hidden_states(all_encoder_layers, max_doc_len, features, weighted_avg=False): 13 | num_layers, batch_size, turn_size, num_chunk, max_token_len, bert_dim = all_encoder_layers.shape 14 | out_features = torch.Tensor(num_layers, batch_size, turn_size, max_doc_len, bert_dim).fill_(0) 15 | device = all_encoder_layers.get_device() if all_encoder_layers.is_cuda else None 16 | if device is not None: 17 | out_features = out_features.to(device) 18 | 19 | token_count = [] 20 | # Map BERT tokens to doc words 21 | for i, ex_feature in enumerate(features): # Example 22 | ex_token_count = [] 23 | for t, para_feature in enumerate(ex_feature): # Turn 24 | para_token_count = defaultdict(int) 25 | for j, chunk_feature in enumerate(para_feature): # Chunk 26 | for k in chunk_feature.token_is_max_context: # Token 27 | if chunk_feature.token_is_max_context[k]: 28 | doc_word_idx = chunk_feature.token_to_orig_map[k] 29 | out_features[:, i, t, doc_word_idx] += all_encoder_layers[:, i, t, j, k] 30 | para_token_count[doc_word_idx] += 1 31 | ex_token_count.append(para_token_count) 32 | token_count.append(ex_token_count) 33 | 34 | for i, ex_token_count in enumerate(token_count): 35 | for t, para_token_count in enumerate(ex_token_count): 36 | for doc_word_idx, count in para_token_count.items(): 37 | out_features[:, i, t, doc_word_idx] /= count 38 | 39 | # Average through all layers 40 | if not weighted_avg: 41 | out_features = torch.mean(out_features, 0) 42 | return out_features 43 | 44 | def extract_bert_ctx_hidden_states(all_encoder_layers, max_doc_len, features, weighted_avg=False): 45 | num_layers, batch_size, num_chunk, max_token_len, bert_dim = all_encoder_layers.shape 46 | out_features = torch.Tensor(num_layers, batch_size, max_doc_len, bert_dim).fill_(0) 47 | device = all_encoder_layers.get_device() if all_encoder_layers.is_cuda else None 48 | if device is not None: 49 | out_features = out_features.to(device) 50 | 51 | token_count = [] 52 | # Map BERT tokens to doc words 53 | for i, ex_feature in enumerate(features): # Example 54 | ex_token_count = defaultdict(int) 55 | for j, chunk_feature in enumerate(ex_feature): # Chunk 56 | for k in chunk_feature.token_is_max_context: # Token 57 | if chunk_feature.token_is_max_context[k]: 58 | doc_word_idx = chunk_feature.token_to_orig_map[k] 59 | out_features[:, i, doc_word_idx] += all_encoder_layers[:, i, j, k] 60 | ex_token_count[doc_word_idx] += 1 61 | token_count.append(ex_token_count) 62 | 63 | for i, ex_token_count in enumerate(token_count): 64 | for doc_word_idx, count in ex_token_count.items(): 65 | out_features[:, i, doc_word_idx] /= count 66 | 67 | # Average through all layers 68 | if not weighted_avg: 69 | out_features = torch.mean(out_features, 0) 70 | return out_features 71 | 72 | def convert_text_to_bert_features(text, bert_tokenizer, max_seq_length, doc_stride): 73 | # The convention in BERT is: 74 | # (a) For sequence pairs: 75 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 76 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 77 | # (b) For single sequences: 78 | # tokens: [CLS] the dog is hairy . [SEP] 79 | # type_ids: 0 0 0 0 0 0 0 80 | 81 | tok_to_orig_index = [] 82 | all_doc_tokens = [] 83 | for (i, token) in enumerate(text): 84 | sub_tokens = bert_tokenizer.wordpiece_tokenizer.tokenize(token.lower()) 85 | for sub_ in sub_tokens: 86 | tok_to_orig_index.append(i) 87 | all_doc_tokens.append(sub_) 88 | 89 | # The -2 accounts for [CLS] and [SEP] 90 | max_tokens_for_doc = max_seq_length - 2 91 | 92 | # We can have documents that are longer than the maximum sequence length. 93 | # To deal with this we do a sliding window approach, where we take chunks 94 | # of the up to our max length with a stride of `doc_stride`. 95 | _DocSpan = namedtuple( # pylint: disable=invalid-name 96 | "DocSpan", ["start", "length"]) 97 | doc_spans = [] 98 | start_offset = 0 99 | while start_offset < len(all_doc_tokens): 100 | length = len(all_doc_tokens) - start_offset 101 | if length > max_tokens_for_doc: 102 | length = max_tokens_for_doc 103 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 104 | if start_offset + length == len(all_doc_tokens): 105 | break 106 | start_offset += min(length, doc_stride) 107 | 108 | out_features = [] 109 | for (doc_span_index, doc_span) in enumerate(doc_spans): 110 | tokens = [] 111 | token_to_orig_map = {} 112 | token_is_max_context = {} 113 | segment_ids = [] 114 | tokens.append("[CLS]") 115 | segment_ids.append(0) 116 | 117 | for i in range(doc_span.length): 118 | split_token_index = doc_span.start + i 119 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 120 | 121 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 122 | split_token_index) 123 | token_is_max_context[len(tokens)] = is_max_context 124 | tokens.append(all_doc_tokens[split_token_index]) 125 | segment_ids.append(0) 126 | tokens.append("[SEP]") 127 | segment_ids.append(0) 128 | 129 | input_ids = bert_tokenizer.convert_tokens_to_ids(tokens) 130 | 131 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 132 | # tokens are attended to. 133 | input_mask = [1] * len(input_ids) 134 | 135 | feature = BertInputFeatures( 136 | doc_span_index=doc_span_index, 137 | tokens=tokens, 138 | token_to_orig_map=token_to_orig_map, 139 | token_is_max_context=token_is_max_context, 140 | input_ids=input_ids, 141 | input_mask=input_mask, 142 | segment_ids=segment_ids) 143 | out_features.append(feature) 144 | return out_features 145 | 146 | 147 | def _check_is_max_context(doc_spans, cur_span_index, position): 148 | """Check if this is the 'max context' doc span for the token.""" 149 | 150 | # Because of the sliding window approach taken to scoring documents, a single 151 | # token can appear in multiple documents. E.g. 152 | # Doc: the man went to the store and bought a gallon of milk 153 | # Span A: the man went to the 154 | # Span B: to the store and bought 155 | # Span C: and bought a gallon of 156 | # ... 157 | # 158 | # Now the word 'bought' will have two scores from spans B and C. We only 159 | # want to consider the score with "maximum context", which we define as 160 | # the *minimum* of its left and right context (the *sum* of left and 161 | # right context will always be the same, of course). 162 | # 163 | # In the example the maximum context for 'bought' would be span C since 164 | # it has 1 left context and 3 right context, while span B has 4 left context 165 | # and 0 right context. 166 | best_score = None 167 | best_span_index = None 168 | for (span_index, doc_span) in enumerate(doc_spans): 169 | end = doc_span.start + doc_span.length - 1 170 | if position < doc_span.start: 171 | continue 172 | if position > end: 173 | continue 174 | num_left_context = position - doc_span.start 175 | num_right_context = end - position 176 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 177 | if best_score is None or score > best_score: 178 | best_score = score 179 | best_span_index = span_index 180 | 181 | return cur_span_index == best_span_index 182 | 183 | class BertInputFeatures(object): 184 | """A single set of BERT features of data.""" 185 | 186 | def __init__(self, 187 | doc_span_index, 188 | tokens, 189 | token_to_orig_map, 190 | token_is_max_context, 191 | input_ids, 192 | input_mask, 193 | segment_ids): 194 | self.doc_span_index = doc_span_index 195 | self.tokens = tokens 196 | self.token_to_orig_map = token_to_orig_map 197 | self.token_is_max_context = token_is_max_context 198 | self.input_ids = input_ids 199 | self.input_mask = input_mask 200 | self.segment_ids = segment_ids 201 | -------------------------------------------------------------------------------- /models/graphflow/utils/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to handle universal/general constants used across files. 3 | """ 4 | 5 | ################################################################################ 6 | # Constants # 7 | ################################################################################ 8 | 9 | # GENERAL CONSTANTS: 10 | VERY_SMALL_NUMBER = 1e-12 11 | 12 | _UNK_POS = 'unk_pos' 13 | _UNK_NER = 'unk_ner' 14 | 15 | _UNK_TOKEN = '<>' 16 | _QUESTION_SYMBOL = '' 17 | _ANSWER_SYMBOL = '' 18 | 19 | 20 | # CoQA CONSTANTS: 21 | CoQA_UNK_ANSWER = 'unknown' 22 | CoQA_YES_ANSWER = 'yes' 23 | CoQA_NO_ANSWER = 'no' 24 | 25 | CoQA_UNK_ANSWER_LABEL = 0 26 | CoQA_ANSWER_YES_LABEL = 1 27 | CoQA_ANSWER_NO_LABEL = 2 28 | CoQA_ANSWER_SPAN_LABEL = 3 29 | CoQA_ANSWER_CLASS_NUM = 4 30 | 31 | 32 | # QuAC 33 | QuAC_UNK_ANSWER = 'cannotanswer' 34 | 35 | QuAC_YESNO_YES = 'y' 36 | QuAC_YESNO_NO = 'n' 37 | QuAC_YESNO_OTHER = 'x' 38 | 39 | QuAC_YESNO_YES_LABEL = 0 40 | QuAC_YESNO_NO_LABEL = 1 41 | QuAC_YESNO_OTHER_LABEL = 2 42 | QuAC_YESNO_CLASS_NUM = 3 43 | 44 | QuAC_FOLLOWUP_YES = 'y' 45 | QuAC_FOLLOWUP_NO = 'n' 46 | QuAC_FOLLOWUP_OTHER = 'm' 47 | 48 | QuAC_FOLLOWUP_YES_LABEL = 0 49 | QuAC_FOLLOWUP_NO_LABEL = 1 50 | QuAC_FOLLOWUP_OTHER_LABEL = 2 51 | QuAC_FOLLOWUP_CLASS_NUM = 3 52 | 53 | # LOG FILES ## 54 | 55 | _CONFIG_FILE = "config.json" 56 | _SAVED_WEIGHTS_FILE = "params.saved" 57 | _GOLD_PREDICTION_FILE = "gold_pred.json" 58 | _PRED_PREDICTION_FILE = "pred_pred.json" 59 | _TEST_PREDICTION_FILE = "test_pred.json" 60 | -------------------------------------------------------------------------------- /models/graphflow/utils/coqa/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_utils import * 2 | -------------------------------------------------------------------------------- /models/graphflow/utils/coqa/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import string 4 | from collections import Counter 5 | 6 | 7 | ################################################################################ 8 | # Text Processing Helper Functions # 9 | ################################################################################ 10 | 11 | 12 | def normalize_text(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | def remove_articles(text): 15 | return re.sub(r'\b(a|an|the)\b', ' ', text) 16 | 17 | def white_space_fix(text): 18 | return ' '.join(text.split()) 19 | 20 | def remove_punc(text): 21 | exclude = set(string.punctuation) 22 | return ''.join(ch for ch in text if ch not in exclude) 23 | 24 | def lower(text): 25 | return text.lower() 26 | 27 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 28 | 29 | 30 | def compute_eval_metric(eval_metric, predictions, ground_truths, cross_eval=True): 31 | fns = {'f1': compute_f1_score, 32 | 'em': compute_em_score} 33 | 34 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 35 | scores_for_ground_truths = [] 36 | for ground_truth in ground_truths: 37 | score = metric_fn(normalize_text(prediction), normalize_text(ground_truth)) 38 | scores_for_ground_truths.append(score) 39 | return max(scores_for_ground_truths) 40 | 41 | values = [] 42 | for prediction, ground_truth_set in zip(predictions, ground_truths): 43 | if cross_eval and len(ground_truth_set) > 1: 44 | _scores = [] 45 | for i in range(len(ground_truth_set)): 46 | _ground_truth_set = [] 47 | for j in range(len(ground_truth_set)): 48 | if j != i: 49 | _ground_truth_set.append(ground_truth_set[j]) 50 | _scores.append(metric_max_over_ground_truths(fns[eval_metric], prediction, _ground_truth_set)) 51 | value = np.mean(_scores) 52 | else: 53 | value = metric_max_over_ground_truths(fns[eval_metric], prediction, ground_truth_set) 54 | values.append(value) 55 | return np.mean(values) 56 | 57 | 58 | def compute_f1_score(prediction, ground_truth): 59 | common = Counter(prediction.split()) & Counter(ground_truth.split()) 60 | num_same = sum(common.values()) 61 | if num_same == 0: 62 | return 0 63 | precision = 1.0 * num_same / len(prediction.split()) 64 | recall = 1.0 * num_same / len(ground_truth.split()) 65 | f1 = (2 * precision * recall) / (precision + recall) 66 | return f1 67 | 68 | 69 | def compute_em_score(prediction, ground_truth): 70 | return 1.0 if prediction == ground_truth else 0.0 71 | -------------------------------------------------------------------------------- /models/graphflow/utils/doqa/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_utils import * 2 | -------------------------------------------------------------------------------- /models/graphflow/utils/doqa/eval_utils.py: -------------------------------------------------------------------------------- 1 | import json, string, re 2 | from collections import Counter, defaultdict 3 | from argparse import ArgumentParser 4 | 5 | 6 | def is_overlapping(x1, x2, y1, y2): 7 | return max(x1, y1) <= min(x2, y2) 8 | 9 | def normalize_answer(s): 10 | """Lower text and remove punctuation, articles and extra whitespace.""" 11 | def remove_articles(text): 12 | return re.sub(r'\b(a|an|the)\b', ' ', text) 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | def remove_punc(text): 16 | exclude = set(string.punctuation) 17 | return ''.join(ch for ch in text if ch not in exclude) 18 | def lower(text): 19 | return text.lower() 20 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 21 | 22 | def f1_score(prediction, ground_truth): 23 | prediction_tokens = normalize_answer(prediction).split() 24 | ground_truth_tokens = normalize_answer(ground_truth).split() 25 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 26 | num_same = sum(common.values()) 27 | if num_same == 0: 28 | return 0 29 | precision = 1.0 * num_same / len(prediction_tokens) 30 | recall = 1.0 * num_same / len(ground_truth_tokens) 31 | f1 = (2 * precision * recall) / (precision + recall) 32 | return f1 33 | 34 | def exact_match_score(prediction, ground_truth): 35 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 36 | 37 | def leave_one_out_max(prediction, ground_truths, article): 38 | if len(ground_truths) == 1: 39 | return metric_max_over_ground_truths(prediction, ground_truths, article)[1] 40 | else: 41 | t_f1 = [] 42 | # leave out one ref every time 43 | for i in range(len(ground_truths)): 44 | idxes = list(range(len(ground_truths))) 45 | idxes.pop(i) 46 | refs = [ground_truths[z] for z in idxes] 47 | t_f1.append(metric_max_over_ground_truths(prediction, refs, article)[1]) 48 | return 1.0 * sum(t_f1) / len(t_f1) 49 | 50 | 51 | def metric_max_over_ground_truths(prediction, ground_truths, article): 52 | scores_for_ground_truths = [] 53 | for ground_truth in ground_truths: 54 | score = compute_span_overlap(prediction, ground_truth, article) 55 | scores_for_ground_truths.append(score) 56 | return max(scores_for_ground_truths, key=lambda x: x[1]) 57 | 58 | 59 | def handle_cannot(refs): 60 | num_cannot = 0 61 | num_spans = 0 62 | for ref in refs: 63 | if ref == 'CANNOTANSWER': 64 | num_cannot += 1 65 | else: 66 | num_spans += 1 67 | if num_cannot >= num_spans: 68 | refs = ['CANNOTANSWER'] 69 | else: 70 | refs = [x for x in refs if x != 'CANNOTANSWER'] 71 | return refs 72 | 73 | 74 | def leave_one_out(refs): 75 | if len(refs) == 1: 76 | return 1. 77 | splits = [] 78 | for r in refs: 79 | splits.append(r.split()) 80 | t_f1 = 0.0 81 | for i in range(len(refs)): 82 | m_f1 = 0 83 | for j in range(len(refs)): 84 | if i == j: 85 | continue 86 | f1_ij = f1_score(refs[i], refs[j]) 87 | if f1_ij > m_f1: 88 | m_f1 = f1_ij 89 | t_f1 += m_f1 90 | return t_f1 / len(refs) 91 | 92 | 93 | def compute_span_overlap(pred_span, gt_span, text): 94 | if gt_span == 'CANNOTANSWER': 95 | if pred_span == 'CANNOTANSWER': 96 | return 'Exact match', 1.0 97 | return 'No overlap', 0. 98 | fscore = f1_score(pred_span, gt_span) 99 | pred_start = text.find(pred_span) 100 | gt_start = text.find(gt_span) 101 | 102 | if pred_start == -1 or gt_start == -1: 103 | return 'Span indexing error', fscore 104 | 105 | pred_end = pred_start + len(pred_span) 106 | gt_end = gt_start + len(gt_span) 107 | 108 | fscore = f1_score(pred_span, gt_span) 109 | overlap = is_overlapping(pred_start, pred_end, gt_start, gt_end) 110 | 111 | if exact_match_score(pred_span, gt_span): 112 | return 'Exact match', fscore 113 | if overlap: 114 | return 'Partial overlap', fscore 115 | else: 116 | return 'No overlap', fscore 117 | 118 | 119 | def eval_fn(gold_results, pred_results, raw_context, min_f1=0.4): 120 | total_qs = 0. 121 | f1_stats = defaultdict(list) 122 | human_f1 = [] 123 | HEQ = 0. 124 | DHEQ = 0. 125 | total_dials = 0. 126 | for dial_idx, ex_results in enumerate(gold_results): 127 | good_dial = 1. 128 | for turn_idx, turn_results in enumerate(ex_results): 129 | gold_spans = handle_cannot(turn_results) 130 | hf1 = leave_one_out(gold_spans) 131 | 132 | pred_span = pred_results[dial_idx][turn_idx] 133 | 134 | max_overlap, _ = metric_max_over_ground_truths( \ 135 | pred_span, gold_spans, raw_context[dial_idx]) 136 | max_f1 = leave_one_out_max( \ 137 | pred_span, gold_spans, raw_context[dial_idx]) 138 | 139 | # dont eval on low agreement instances 140 | if hf1 < min_f1: 141 | continue 142 | 143 | human_f1.append(hf1) 144 | if max_f1 >= hf1: 145 | HEQ += 1. 146 | else: 147 | good_dial = 0. 148 | f1_stats[max_overlap].append(max_f1) 149 | total_qs += 1. 150 | DHEQ += good_dial 151 | total_dials += 1 152 | 153 | DHEQ_score = DHEQ / total_dials 154 | if total_qs == 0: 155 | HEQ_score = 0 156 | else: 157 | HEQ_score = HEQ / total_qs 158 | all_f1s = sum(f1_stats.values(), []) 159 | if len(all_f1s) == 0: 160 | overall_f1 = 0 161 | else: 162 | overall_f1 = sum(all_f1s) / len(all_f1s) 163 | metric_json = {"f1": overall_f1, "heq": HEQ_score, "dheq": DHEQ_score} 164 | return metric_json, total_qs, total_dials 165 | -------------------------------------------------------------------------------- /models/graphflow/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | 5 | ################################################################################ 6 | # Text Processing Helper Functions # 7 | ################################################################################ 8 | 9 | 10 | def normalize_text(s): 11 | """Lower text and remove punctuation, articles and extra whitespace.""" 12 | def remove_articles(text): 13 | return re.sub(r'\b(a|an|the)\b', ' ', text) 14 | 15 | def white_space_fix(text): 16 | return ' '.join(text.split()) 17 | 18 | def remove_punc(text): 19 | exclude = set(string.punctuation) 20 | return ''.join(ch for ch in text if ch not in exclude) 21 | 22 | def lower(text): 23 | return text.lower() 24 | 25 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 26 | 27 | 28 | class AverageMeter(object): 29 | """Computes and stores the average and current value.""" 30 | def __init__(self): 31 | self.history = [] 32 | self.last = None 33 | self.val = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def reset(self): 38 | self.last = self.mean() 39 | self.history.append(self.last) 40 | self.val = 0 41 | self.sum = 0 42 | self.count = 0 43 | 44 | def update(self, val, n=1): 45 | self.val = val 46 | self.sum += val * n 47 | self.count += n 48 | 49 | def mean(self): 50 | if self.count == 0: 51 | return 0. 52 | return self.sum / self.count 53 | -------------------------------------------------------------------------------- /models/graphflow/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov, 2018 3 | 4 | @author: hugo 5 | 6 | ''' 7 | import yaml 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None, device=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | sinusoid_table = torch.Tensor(sinusoid_table) 33 | return sinusoid_table.to(device) if device else sinusoid_table 34 | 35 | def get_range_vector(size, device): 36 | """ 37 | Returns a range vector with the desired size, starting at 0. The CUDA implementation 38 | is meant to avoid copy data from CPU to GPU. 39 | """ 40 | if device.type == 'cuda': 41 | return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1 42 | else: 43 | return torch.arange(0, size, dtype=torch.long) 44 | 45 | def to_cuda(x, device=None): 46 | if device: 47 | x = x.to(device) 48 | return x 49 | 50 | def batched_diag(x, device=None): 51 | # Input: a 2D tensor 52 | # Output: a 3D tensor 53 | x_diag = torch.zeros(x.size(0), x.size(1), x.size(1)) 54 | _ = x_diag.as_strided(x.size(), [x_diag.stride(0), x_diag.size(2) + 1]).copy_(x) 55 | return to_cuda(x_diag, device) 56 | 57 | def create_mask(x, N, device=None): 58 | x = x.data 59 | mask = np.zeros((x.size(0), N)) 60 | for i in range(x.size(0)): 61 | mask[i, :x[i]] = 1 62 | return to_cuda(torch.Tensor(mask), device) 63 | 64 | def get_config(config_path="config.yml"): 65 | with open(config_path, "r") as setting: 66 | config = yaml.load(setting) 67 | return config 68 | -------------------------------------------------------------------------------- /models/graphflow/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov, 2018 3 | 4 | @author: hugo 5 | 6 | ''' 7 | import json 8 | import numpy as np 9 | 10 | 11 | def dump_ndarray(data, path_to_file): 12 | try: 13 | with open(path_to_file, 'wb') as f: 14 | np.save(f, data) 15 | except Exception as e: 16 | raise e 17 | 18 | def load_ndarray(path_to_file): 19 | try: 20 | with open(path_to_file, 'rb') as f: 21 | data = np.load(f) 22 | except Exception as e: 23 | raise e 24 | 25 | return data 26 | 27 | def dump_ndjson(data, file): 28 | try: 29 | with open(file, 'w') as f: 30 | for each in data: 31 | f.write(json.dumps(each) + '\n') 32 | except Exception as e: 33 | raise e 34 | 35 | def load_ndjson(file, return_type='array'): 36 | if return_type == 'array': 37 | return load_ndjson_to_array(file) 38 | elif return_type == 'dict': 39 | return load_ndjson_to_dict(file) 40 | else: 41 | raise RuntimeError('Unknown return_type: %s' % return_type) 42 | 43 | def dump_json(data, file, indent=None): 44 | try: 45 | with open(file, 'w') as f: 46 | json.dump(data, f, indent=indent) 47 | except Exception as e: 48 | raise e 49 | 50 | def load_json(file): 51 | try: 52 | with open(file, 'r') as f: 53 | data = json.load(f) 54 | except Exception as e: 55 | raise e 56 | return data 57 | -------------------------------------------------------------------------------- /models/graphflow/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | from . import constants as Constants 5 | 6 | 7 | class DummyLogger(object): 8 | def __init__(self, config, dirname=None, pretrained=None): 9 | self.config = config 10 | if dirname is None: 11 | if pretrained is None: 12 | raise Exception('Either --dir or --pretrained needs to be specified.') 13 | self.dirname = pretrained 14 | else: 15 | self.dirname = dirname 16 | if not os.path.exists(dirname): 17 | # raise Exception('Directory already exists: {}'.format(dirname)) 18 | os.makedirs(dirname) 19 | os.mkdir(os.path.join(dirname, 'metrics')) 20 | self.log_json(config, os.path.join(self.dirname, Constants._CONFIG_FILE)) 21 | if config['logging']: 22 | self.f_metric = open(os.path.join(self.dirname, 'metrics', 'loss_f1_em.log'), 'a') 23 | 24 | def log_json(self, data, filename, mode='w'): 25 | with open(filename, mode) as outfile: 26 | outfile.write(json.dumps(data, indent=4, ensure_ascii=False)) 27 | 28 | def log(self, data, filename): 29 | print(data) 30 | 31 | def write_to_file(self, text): 32 | if self.config['logging']: 33 | self.f_metric.writelines(text + '\n') 34 | self.f_metric.flush() 35 | 36 | def close(self): 37 | if self.config['logging']: 38 | self.f_metric.close() 39 | 40 | class Logger(object): 41 | def __init__(self, log_file): 42 | self.terminal = sys.stdout 43 | self.log = open(log_file, "a") 44 | 45 | def write(self, message): 46 | self.terminal.write(message) 47 | self.log.write(message) 48 | self.log.flush() 49 | 50 | def flush(self): 51 | pass 52 | -------------------------------------------------------------------------------- /models/graphflow/utils/process_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file takes a QuAC data file as input and generates the input files for training a conversational reading comprehension model. 3 | """ 4 | 5 | 6 | import argparse 7 | import json 8 | import re 9 | import time 10 | import string 11 | from collections import Counter, defaultdict 12 | import spacy 13 | from spacy.tokens import Doc 14 | 15 | from pycorenlp import StanfordCoreNLP 16 | 17 | 18 | def _str(s): 19 | """ Convert PTB tokens to normal tokens """ 20 | if (s.lower() == '-lrb-'): 21 | s = '(' 22 | elif (s.lower() == '-rrb-'): 23 | s = ')' 24 | elif (s.lower() == '-lsb-'): 25 | s = '[' 26 | elif (s.lower() == '-rsb-'): 27 | s = ']' 28 | elif (s.lower() == '-lcb-'): 29 | s = '{' 30 | elif (s.lower() == '-rcb-'): 31 | s = '}' 32 | return s 33 | 34 | class WhitespaceTokenizer(object): 35 | 36 | def __init__(self, vocab): 37 | self.vocab = vocab 38 | 39 | def __call__(self, text): 40 | words = text.split(' ') 41 | # All tokens 'own' a subsequent space character in this tokenizer 42 | spaces = [True] * len(words) 43 | return Doc(self.vocab, words=words, spaces=spaces) 44 | 45 | class ExampleProcessor(): 46 | def __init__(self) -> None: 47 | self.corenlp = StanfordCoreNLP('http://localhost:9000') 48 | self.parser = spacy.load('en') 49 | self.parser.tokenizer = WhitespaceTokenizer(self.parser.vocab) 50 | 51 | def process(self, text): 52 | paragraph = self.corenlp.annotate(text, properties={ 53 | 'annotators': 'tokenize, ssplit, pos, ner', 54 | 'outputFormat': 'json', 55 | 'ssplit.newlineIsSentenceBreak': 'two'}) 56 | 57 | output = {'word': [], 58 | # 'lemma': [], 59 | 'pos': [], 60 | 'ner': [], 61 | 'offsets': []} 62 | 63 | for sent in paragraph['sentences']: 64 | for token in sent['tokens']: 65 | output['word'].append(_str(token['word'])) 66 | output['pos'].append(token['pos']) 67 | output['ner'].append(token['ner']) 68 | output['offsets'].append( 69 | (token['characterOffsetBegin'], token['characterOffsetEnd'])) 70 | return output 71 | 72 | def extract_sent_dep_tree(self, text): 73 | if len(text) == 0: 74 | return {'g_features': [], 'g_adj': {}, 'num_edges': 0} 75 | 76 | doc = self.parser(text) 77 | g_features = [] 78 | dep_tree = defaultdict(list) 79 | boundary_nodes = [] 80 | num_edges = 0 81 | for sent in doc.sents: 82 | boundary_nodes.append(sent[-1].i) 83 | for each in sent: 84 | g_features.append(each.text) 85 | if each.i != each.head.i: # Not a root 86 | dep_tree[each.head.i].append( 87 | {'node': each.i, 'edge': each.dep_}) 88 | num_edges += 1 89 | 90 | for i in range(len(boundary_nodes) - 1): 91 | # Add connection between neighboring dependency trees 92 | dep_tree[boundary_nodes[i]].append( 93 | {'node': boundary_nodes[i] + 1, 'edge': 'neigh'}) 94 | dep_tree[boundary_nodes[i] + 95 | 1].append({'node': boundary_nodes[i], 'edge': 'neigh'}) 96 | num_edges += 2 97 | 98 | info = {'g_features': g_features, 99 | 'g_adj': dep_tree, 100 | 'num_edges': num_edges 101 | } 102 | return info 103 | 104 | 105 | def extract_sent_dep_order_tree(self, text): 106 | '''Keep both dependency and ordering info''' 107 | if len(text) == 0: 108 | return {'g_features': [], 'g_adj': {}, 'num_edges': 0} 109 | 110 | doc = self.parser(text) 111 | g_features = [] 112 | dep_tree = defaultdict(list) 113 | 114 | num_edges = 0 115 | # Add word ordering info 116 | for i in range(len(doc) - 1): 117 | dep_tree[i].append({'node': i + 1, 'edge': 'neigh'}) 118 | dep_tree[i + 1].append({'node': i, 'edge': 'neigh'}) 119 | num_edges += 2 120 | 121 | # Add dependency info 122 | for sent in doc.sents: 123 | for each in sent: 124 | g_features.append(each.text) 125 | # Not a root 126 | if each.i != each.head.i and abs(each.head.i - each.i) != 1: 127 | dep_tree[each.head.i].append( 128 | {'node': each.i, 'edge': each.dep_}) 129 | num_edges += 1 130 | 131 | info = {'g_features': g_features, 132 | 'g_adj': dep_tree, 133 | 'num_edges': num_edges 134 | } 135 | return info 136 | -------------------------------------------------------------------------------- /models/graphflow/utils/quac/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_utils import * 2 | -------------------------------------------------------------------------------- /models/graphflow/utils/quac/eval_utils.py: -------------------------------------------------------------------------------- 1 | import json, string, re 2 | from collections import Counter, defaultdict 3 | from argparse import ArgumentParser 4 | 5 | 6 | def is_overlapping(x1, x2, y1, y2): 7 | return max(x1, y1) <= min(x2, y2) 8 | 9 | def normalize_answer(s): 10 | """Lower text and remove punctuation, articles and extra whitespace.""" 11 | def remove_articles(text): 12 | return re.sub(r'\b(a|an|the)\b', ' ', text) 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | def remove_punc(text): 16 | exclude = set(string.punctuation) 17 | return ''.join(ch for ch in text if ch not in exclude) 18 | def lower(text): 19 | return text.lower() 20 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 21 | 22 | def f1_score(prediction, ground_truth): 23 | prediction_tokens = normalize_answer(prediction).split() 24 | ground_truth_tokens = normalize_answer(ground_truth).split() 25 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 26 | num_same = sum(common.values()) 27 | if num_same == 0: 28 | return 0 29 | precision = 1.0 * num_same / len(prediction_tokens) 30 | recall = 1.0 * num_same / len(ground_truth_tokens) 31 | f1 = (2 * precision * recall) / (precision + recall) 32 | return f1 33 | 34 | def exact_match_score(prediction, ground_truth): 35 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 36 | 37 | def leave_one_out_max(prediction, ground_truths, article): 38 | if len(ground_truths) == 1: 39 | return metric_max_over_ground_truths(prediction, ground_truths, article)[1] 40 | else: 41 | t_f1 = [] 42 | # leave out one ref every time 43 | for i in range(len(ground_truths)): 44 | idxes = list(range(len(ground_truths))) 45 | idxes.pop(i) 46 | refs = [ground_truths[z] for z in idxes] 47 | t_f1.append(metric_max_over_ground_truths(prediction, refs, article)[1]) 48 | return 1.0 * sum(t_f1) / len(t_f1) 49 | 50 | 51 | def metric_max_over_ground_truths(prediction, ground_truths, article): 52 | scores_for_ground_truths = [] 53 | for ground_truth in ground_truths: 54 | score = compute_span_overlap(prediction, ground_truth, article) 55 | scores_for_ground_truths.append(score) 56 | return max(scores_for_ground_truths, key=lambda x: x[1]) 57 | 58 | 59 | def handle_cannot(refs): 60 | num_cannot = 0 61 | num_spans = 0 62 | for ref in refs: 63 | if ref == 'CANNOTANSWER': 64 | num_cannot += 1 65 | else: 66 | num_spans += 1 67 | if num_cannot >= num_spans: 68 | refs = ['CANNOTANSWER'] 69 | else: 70 | refs = [x for x in refs if x != 'CANNOTANSWER'] 71 | return refs 72 | 73 | 74 | def leave_one_out(refs): 75 | if len(refs) == 1: 76 | return 1. 77 | splits = [] 78 | for r in refs: 79 | splits.append(r.split()) 80 | t_f1 = 0.0 81 | for i in range(len(refs)): 82 | m_f1 = 0 83 | for j in range(len(refs)): 84 | if i == j: 85 | continue 86 | f1_ij = f1_score(refs[i], refs[j]) 87 | if f1_ij > m_f1: 88 | m_f1 = f1_ij 89 | t_f1 += m_f1 90 | return t_f1 / len(refs) 91 | 92 | 93 | def compute_span_overlap(pred_span, gt_span, text): 94 | if gt_span == 'CANNOTANSWER': 95 | if pred_span == 'CANNOTANSWER': 96 | return 'Exact match', 1.0 97 | return 'No overlap', 0. 98 | fscore = f1_score(pred_span, gt_span) 99 | pred_start = text.find(pred_span) 100 | gt_start = text.find(gt_span) 101 | 102 | if pred_start == -1 or gt_start == -1: 103 | return 'Span indexing error', fscore 104 | 105 | pred_end = pred_start + len(pred_span) 106 | gt_end = gt_start + len(gt_span) 107 | 108 | fscore = f1_score(pred_span, gt_span) 109 | overlap = is_overlapping(pred_start, pred_end, gt_start, gt_end) 110 | 111 | if exact_match_score(pred_span, gt_span): 112 | return 'Exact match', fscore 113 | if overlap: 114 | return 'Partial overlap', fscore 115 | else: 116 | return 'No overlap', fscore 117 | 118 | 119 | def eval_fn(gold_results, pred_results, raw_context, min_f1=0.4): 120 | total_qs = 0. 121 | f1_stats = defaultdict(list) 122 | human_f1 = [] 123 | HEQ = 0. 124 | DHEQ = 0. 125 | total_dials = 0. 126 | for dial_idx, ex_results in enumerate(gold_results): 127 | good_dial = 1. 128 | # print("Ex:",ex_results) 129 | for turn_idx, turn_results in enumerate(ex_results): 130 | # print("Turn:",turn_results) 131 | gold_spans = handle_cannot(turn_results) 132 | hf1 = leave_one_out(gold_spans) 133 | 134 | pred_span = pred_results[dial_idx][turn_idx] 135 | 136 | max_overlap, _ = metric_max_over_ground_truths( \ 137 | pred_span, gold_spans, raw_context[dial_idx]) 138 | max_f1 = leave_one_out_max( \ 139 | pred_span, gold_spans, raw_context[dial_idx]) 140 | 141 | # dont eval on low agreement instances 142 | # if hf1 < min_f1: 143 | # continue 144 | 145 | human_f1.append(hf1) 146 | if max_f1 >= hf1: 147 | HEQ += 1. 148 | else: 149 | good_dial = 0. 150 | f1_stats[max_overlap].append(max_f1) 151 | total_qs += 1. 152 | # print("Total_q:",total_qs) 153 | DHEQ += good_dial 154 | total_dials += 1 155 | # print("total_qs:",total_qs) 156 | DHEQ_score = DHEQ / total_dials 157 | HEQ_score = HEQ / total_qs 158 | all_f1s = sum(f1_stats.values(), []) 159 | overall_f1 = sum(all_f1s) / len(all_f1s) 160 | metric_json = {"f1": overall_f1, "heq": HEQ_score, "dheq": DHEQ_score} 161 | return metric_json, total_qs, total_dials 162 | -------------------------------------------------------------------------------- /models/graphflow/utils/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(RAdam, self).__init__(params, defaults) 11 | 12 | def __setstate__(self, state): 13 | super(RAdam, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('RAdam does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 46 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 47 | 48 | state['step'] += 1 49 | buffered = self.buffer[int(state['step'] % 10)] 50 | if state['step'] == buffered[0]: 51 | N_sma, step_size = buffered[1], buffered[2] 52 | else: 53 | buffered[0] = state['step'] 54 | beta2_t = beta2 ** state['step'] 55 | N_sma_max = 2 / (1 - beta2) - 1 56 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 57 | buffered[1] = N_sma 58 | 59 | # more conservative since it's an approximated value 60 | if N_sma >= 5: 61 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 62 | else: 63 | step_size = group['lr'] / (1 - beta1 ** state['step']) 64 | buffered[2] = step_size 65 | 66 | if group['weight_decay'] != 0: 67 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 68 | 69 | # more conservative since it's an approximated value 70 | if N_sma >= 5: 71 | denom = exp_avg_sq.sqrt().add_(group['eps']) 72 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 73 | else: 74 | p_data_fp32.add_(-step_size, exp_avg) 75 | 76 | p.data.copy_(p_data_fp32) 77 | 78 | return loss 79 | 80 | class PlainRAdam(Optimizer): 81 | 82 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 83 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 84 | 85 | super(RAdam, self).__init__(params, defaults) 86 | 87 | def __setstate__(self, state): 88 | super(RAdam, self).__setstate__(state) 89 | 90 | def step(self, closure=None): 91 | 92 | loss = None 93 | if closure is not None: 94 | loss = closure() 95 | 96 | for group in self.param_groups: 97 | 98 | for p in group['params']: 99 | if p.grad is None: 100 | continue 101 | grad = p.grad.data.float() 102 | if grad.is_sparse: 103 | raise RuntimeError('RAdam does not support sparse gradients') 104 | 105 | p_data_fp32 = p.data.float() 106 | 107 | state = self.state[p] 108 | 109 | if len(state) == 0: 110 | state['step'] = 0 111 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 112 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 113 | else: 114 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 115 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 116 | 117 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 118 | beta1, beta2 = group['betas'] 119 | 120 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 121 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 122 | 123 | state['step'] += 1 124 | beta2_t = beta2 ** state['step'] 125 | N_sma_max = 2 / (1 - beta2) - 1 126 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 127 | 128 | if group['weight_decay'] != 0: 129 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 130 | 131 | # more conservative since it's an approximated value 132 | if N_sma >= 5: 133 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 134 | denom = exp_avg_sq.sqrt().add_(group['eps']) 135 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 136 | else: 137 | step_size = group['lr'] / (1 - beta1 ** state['step']) 138 | p_data_fp32.add_(-step_size, exp_avg) 139 | 140 | p.data.copy_(p_data_fp32) 141 | 142 | return loss 143 | 144 | 145 | class AdamW(Optimizer): 146 | 147 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 148 | defaults = dict(lr=lr, betas=betas, eps=eps, 149 | weight_decay=weight_decay, amsgrad=amsgrad, use_variance=True, warmup = warmup) 150 | super(AdamW, self).__init__(params, defaults) 151 | 152 | def __setstate__(self, state): 153 | super(AdamW, self).__setstate__(state) 154 | 155 | def step(self, closure=None): 156 | loss = None 157 | if closure is not None: 158 | loss = closure() 159 | 160 | for group in self.param_groups: 161 | 162 | for p in group['params']: 163 | if p.grad is None: 164 | continue 165 | grad = p.grad.data.float() 166 | if grad.is_sparse: 167 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 168 | 169 | p_data_fp32 = p.data.float() 170 | 171 | state = self.state[p] 172 | 173 | if len(state) == 0: 174 | state['step'] = 0 175 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 176 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 177 | else: 178 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 179 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 180 | 181 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 182 | beta1, beta2 = group['betas'] 183 | 184 | state['step'] += 1 185 | 186 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 187 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 188 | 189 | denom = exp_avg_sq.sqrt().add_(group['eps']) 190 | bias_correction1 = 1 - beta1 ** state['step'] 191 | bias_correction2 = 1 - beta2 ** state['step'] 192 | 193 | if group['warmup'] > state['step']: 194 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 195 | else: 196 | scheduled_lr = group['lr'] 197 | 198 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 199 | 200 | if group['weight_decay'] != 0: 201 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 202 | 203 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 204 | 205 | p.data.copy_(p_data_fp32) 206 | 207 | return loss 208 | -------------------------------------------------------------------------------- /models/graphflow/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer(object): 5 | """Computes elapsed time.""" 6 | def __init__(self, name): 7 | self.name = name 8 | self.running = True 9 | self.total = 0 10 | self.start = round(time.time(), 2) 11 | self.intervalTime = round(time.time(), 2) 12 | print("<> <> <> Starting Timer [{}] <> <> <>".format(self.name)) 13 | 14 | def reset(self): 15 | self.running = True 16 | self.total = 0 17 | self.start = round(time.time(), 2) 18 | return self 19 | 20 | def interval(self, intervalName=''): 21 | intervalTime = self._to_hms(round(time.time() - self.intervalTime, 2)) 22 | print("<> <> Timer [{}] <> <> Interval [{}]: {} <> <>".format(self.name, intervalName, intervalTime)) 23 | self.intervalTime = round(time.time(), 2) 24 | return intervalTime 25 | 26 | def stop(self): 27 | if self.running: 28 | self.running = False 29 | self.total += round(time.time() - self.start, 2) 30 | return self 31 | 32 | def resume(self): 33 | if not self.running: 34 | self.running = True 35 | self.start = round(time.time(), 2) 36 | return self 37 | 38 | def time(self): 39 | if self.running: 40 | return round(self.total + time.time() - self.start, 2) 41 | return self.total 42 | 43 | def finish(self): 44 | if self.running: 45 | self.running = False 46 | self.total += round(time.time() - self.start, 2) 47 | elapsed = self._to_hms(self.total) 48 | print("<> <> <> Finished Timer [{}] <> <> <> Total time elapsed: {} <> <> <>".format(self.name, elapsed)) 49 | 50 | def _to_hms(self, seconds): 51 | m, s = divmod(seconds, 60) 52 | h, m = divmod(m, 60) 53 | return "%dh %02dm %02ds" % (h, m, s) 54 | -------------------------------------------------------------------------------- /models/graphflow/word_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Module to handle word vectors and initializing embeddings. 4 | """ 5 | import os 6 | import string 7 | from collections import Counter 8 | import numpy as np 9 | 10 | from .utils import constants as Constants 11 | from .utils import dump_ndarray, load_ndarray, dump_json, load_json, Timer 12 | 13 | 14 | ################################################################################ 15 | # WordModel Class # 16 | ################################################################################ 17 | 18 | class GloveModel(object): 19 | 20 | def __init__(self, filename): 21 | self.word_vecs = {} 22 | self.vocab = [] 23 | with open(filename, 'r') as input_file: 24 | for line in input_file.readlines(): 25 | splitLine = line.split(' ') 26 | w = splitLine[0] 27 | self.word_vecs[w] = np.array([float(val) for val in splitLine[1:]]) 28 | self.vocab.append(w) 29 | self.vector_size = len(self.word_vecs[w]) 30 | 31 | def word_vec(self, word): 32 | word_list = [word, word.lower(), word.upper(), word.title(), string.capwords(word, '_')] 33 | 34 | for w in word_list: 35 | if w in self.word_vecs: 36 | return self.word_vecs[w] 37 | return None 38 | 39 | 40 | class WordModel(object): 41 | """Class to get pretrained word vectors for a list of sentences. Can be used 42 | for any pretrained word vectors. 43 | """ 44 | 45 | def __init__(self, saved_vocab_file=None, embed_size=None, filename=None, embed_type='glove', top_n=None, additional_vocab=Counter()): 46 | vocab_path = saved_vocab_file + '.vocab' 47 | word_vec_path = saved_vocab_file + '.npy' 48 | if os.path.exists(vocab_path) and \ 49 | os.path.exists(word_vec_path): 50 | print('Loading pre-built vocabs stored in {}'.format(saved_vocab_file)) 51 | self.vocab = load_json(vocab_path) 52 | self.word_vecs = load_ndarray(word_vec_path) 53 | self.vocab_size = len(self.vocab) + 1 54 | self.embed_size = self.word_vecs.shape[1] 55 | assert self.embed_size == embed_size 56 | else: 57 | print('Building vocabs...') 58 | if filename is None: 59 | if embed_size is None: 60 | raise Exception('Either embed_file or embed_size needs to be specified.') 61 | self.embed_size = embed_size 62 | self._model = None 63 | else: 64 | self.set_model(filename, embed_type) 65 | self.embed_size = self._model.vector_size 66 | 67 | # padding: 0 68 | self.vocab = {Constants._UNK_TOKEN: 1, Constants._QUESTION_SYMBOL: 2, Constants._ANSWER_SYMBOL: 3} 69 | n_added = 0 70 | for w, count in additional_vocab.most_common(): 71 | if w not in self.vocab: 72 | self.vocab[w] = len(self.vocab) + 1 73 | n_added += 1 74 | # print('Added {} words to the vocab in total.'.format(n_added)) 75 | 76 | self.vocab_size = len(self.vocab) + 1 77 | print('Vocab size: {}'.format(self.vocab_size)) 78 | # self.word_vecs = np.random.rand(self.vocab_size, self.embed_size) * 0.2 - 0.1 79 | self.word_vecs = np.random.uniform(-0.08, 0.08, (self.vocab_size, self.embed_size)) 80 | i = 0. 81 | if self._model is not None: 82 | for word in self.vocab: 83 | emb = self._model.word_vec(word) 84 | if emb is not None: 85 | i += 1 86 | self.word_vecs[self.vocab[word]] = emb 87 | self.word_vecs[0] = 0 88 | print('Get_wordemb hit ratio: {}'.format(i / len(self.vocab))) 89 | dump_json(self.vocab, vocab_path) 90 | print('Saved vocab to {}'.format(vocab_path)) 91 | dump_ndarray(self.word_vecs, word_vec_path) 92 | print('Saved word_vecs to {}'.format(word_vec_path)) 93 | 94 | def set_model(self, filename, embed_type='glove'): 95 | timer = Timer('Load {}'.format(filename)) 96 | if embed_type == 'glove': 97 | self._model = GloveModel(filename) 98 | else: 99 | from gensim.models.keyedvectors import KeyedVectors 100 | self._model = KeyedVectors.load_word2vec_format(filename, binary=True 101 | if embed_type == 'word2vec' else False) 102 | print('Embeddings: vocab = {}, embed_size = {}'.format(len(self._model.vocab), self._model.vector_size)) 103 | timer.finish() 104 | 105 | def get_vocab(self): 106 | return self.vocab 107 | 108 | def get_word_vecs(self): 109 | return self.word_vecs 110 | -------------------------------------------------------------------------------- /models/ham/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/EvalConvQA/fbf34196b4d8e39d4ecfe36353c9e394101af5eb/models/ham/__init__.py -------------------------------------------------------------------------------- /models/ham/cqa_gen_batches.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import json 7 | import math 8 | import os 9 | import six 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | def cqa_gen_batches(features, batch_size, num_epoches, shuffle=False): 14 | num_examples = len(features) 15 | 16 | if shuffle: 17 | np.random.seed(0) 18 | idx = np.random.permutation(num_examples) 19 | features_shuffled = np.asarray(features)[idx] 20 | else: 21 | features_shuffled = np.asarray(features) 22 | 23 | num_steps = math.ceil(num_examples / batch_size) 24 | for _ in range(int(num_epoches)): 25 | i = 0 26 | for _ in range(num_steps): 27 | batch_features = features_shuffled[i: i + batch_size] 28 | 29 | batch_unique_ids = [] 30 | batch_input_ids = [] 31 | batch_input_mask = [] 32 | batch_segment_ids = [] 33 | batch_start_positions = [] 34 | batch_end_positions = [] 35 | batch_history_answer_marker = [] 36 | batch_metadata = [] 37 | for feature in batch_features: 38 | batch_unique_ids.append(feature.unique_id) 39 | batch_input_ids.append(feature.input_ids) 40 | batch_input_mask.append(feature.input_mask) 41 | batch_segment_ids.append(feature.segment_ids) 42 | batch_start_positions.append(feature.start_position) 43 | batch_end_positions.append(feature.end_position) 44 | batch_history_answer_marker.append(feature.history_answer_marker) 45 | batch_metadata.append(feature.metadata) 46 | 47 | i += batch_size 48 | 49 | yield (batch_unique_ids, batch_input_ids, batch_input_mask, batch_segment_ids, 50 | batch_start_positions, batch_end_positions, batch_history_answer_marker, batch_metadata) 51 | 52 | def cqa_gen_example_batches(examples, batch_size, num_epoches, shuffle=False): 53 | num_examples = len(examples) 54 | 55 | if shuffle: 56 | np.random.seed(0) 57 | idx = np.random.permutation(num_examples) 58 | examples_shuffled = np.asarray(examples)[idx] 59 | else: 60 | examples_shuffled = np.asarray(examples) 61 | 62 | num_steps = math.ceil(num_examples / batch_size) 63 | for _ in range(int(num_epoches)): 64 | i = 0 65 | for _ in range(num_steps): 66 | batch_examples = examples_shuffled[i: i + batch_size] 67 | i += batch_size 68 | yield batch_examples 69 | 70 | 71 | def cqa_gen_example_aware_batches(features, example_tracker, variation_tracker, example_features_nums, batch_size, num_epoches, shuffle=False): 72 | 73 | # the training examples have been shuffled before this function, so no need to shuffle here 74 | 75 | # num_examples = len(features) 76 | 77 | # if shuffle: 78 | # np.random.seed(0) 79 | # idx = np.random.permutation(num_examples) 80 | # features_shuffled = np.asarray(features)[idx] 81 | # else: 82 | # features_shuffled = np.asarray(features) 83 | 84 | # num_steps = math.ceil(num_examples / batch_size) 85 | 86 | for _ in range(int(num_epoches)): 87 | # we greedily select all the features that are generated by the next example, 88 | # as long as the sum of example_features does not exceed FLAGS.train_batch_size 89 | start_example_index, end_example_index = 0, 0 90 | while start_example_index in example_tracker: 91 | features_sum = example_features_nums[start_example_index] 92 | while features_sum <= batch_size: 93 | end_example_index += 1 94 | try: 95 | features_sum += example_features_nums[end_example_index] 96 | except: 97 | break 98 | 99 | start_index = example_tracker.index(start_example_index) 100 | # sometimes an example generates more features than a batch can handle 101 | if end_example_index == start_example_index: 102 | end_example_index += 1 103 | try: 104 | end_index = example_tracker.index(end_example_index) 105 | except: 106 | end_index = None 107 | 108 | batch_features = features[start_index: end_index] 109 | batch_example_tracker = example_tracker[start_index: end_index] 110 | batch_variation_tracker = variation_tracker[start_index: end_index] 111 | 112 | start_example_index = end_example_index 113 | assert len(batch_features) > 0 114 | yield batch_features, batch_example_tracker, batch_variation_tracker 115 | 116 | print('epoch finished!') 117 | 118 | def cqa_gen_example_aware_batch_single(features, example_tracker, variation_tracker, batch_size): 119 | # this is for history attention. suppose example 1 has 3 variations (e1.1, e1.2, e1.3), and each variation has two features 120 | # due to the sliding window approach. so example 1 has features (e1.1.1, e1.1.2, e1.2.1, e1.2.2, e1.3.1, e1.3.2) 121 | # we put (e1.1.1, e1.2.1, e1.3.1) in a batch, because we will compute history attention on them and get the weighted sum 122 | # as the representation for e1. We also include features from the same example or other example in this batch and provide a slide mask 123 | # to distinguish them. So the batch looks like (e1.1.1, e1.2.1, e1.3.1, e1.1.2, e1.2.2, e1.3.2, e2.1.1, e2.2.1), 124 | # and the slice mask looks like (3, 3, 2), with each elements denote the number of features for each (xample_index, feature_index) combo 125 | 126 | 127 | prev_e_tracker, prev_v_tracker = None, None 128 | f_tracker = 0 # feature tracker, denotes the feature index for each variation 129 | features_dict = {} 130 | for feature, e_tracker, v_tracker in zip(features, example_tracker, variation_tracker): 131 | # get the f_tracker 132 | if e_tracker == prev_e_tracker and v_tracker == prev_v_tracker: 133 | f_tracker += 1 134 | else: 135 | f_tracker = 0 136 | prev_e_tracker, prev_v_tracker = e_tracker, v_tracker 137 | 138 | key = (e_tracker, f_tracker) 139 | if key not in features_dict: 140 | features_dict[key] = [] 141 | features_dict[key].append(feature) 142 | 143 | feature_groups = list(features_dict.values()) 144 | 145 | # we greedily select all the features that belong the next feature group, 146 | # as long as the sum of example_features does not exceed FLAGS.train_batch_size 147 | batch_features = [] 148 | batch_slice_mask = [] 149 | batch_slice_num = None 150 | 151 | # after the weighted sum of history, we get a new representation for the example feature 152 | # this feature will be fed into the prediction layer 153 | # this feature share the input_ids, etc with the entire feature group 154 | # we use this feature to compute loss 155 | output_features = [] 156 | 157 | for feature_group in feature_groups: 158 | len_feature_group = len(feature_group) 159 | batch_features.extend(feature_group) 160 | batch_slice_mask.append(len_feature_group) 161 | batch_slice_num = len(batch_slice_mask) 162 | # batch_slice_mask += [1] * (len(batch_features) - len(batch_slice_mask)) 163 | output_features.append(feature_group[0]) 164 | if batch_size >= len(batch_slice_mask): 165 | batch_slice_mask += [1] * (batch_size - len(batch_slice_mask)) 166 | else: 167 | raise ValueError("Need larger batch size, current batch_slice_mask length is {}".format(len(batch_slice_mask))) 168 | return batch_features, batch_slice_mask, batch_slice_num, output_features 169 | 170 | def cqa_gen_example_aware_batches_v2(features, example_tracker, variation_tracker, example_features_nums, batch_size, num_epoches, shuffle=False): 171 | # this is for history attention. suppose example 1 has 3 variations (e1.1, e1.2, e1.3), and each variation has two features 172 | # due to the sliding window approach. so example 1 has features (e1.1.1, e1.1.2, e1.2.1, e1.2.2, e1.3.1, e1.3.2) 173 | # we put (e1.1.1, e1.2.1, e1.3.1) in a batch, because we will compute history attention on them and get the weighted sum 174 | # as the representation for e1. We also include features from the same example or other example in this batch and provide a slide mask 175 | # to distinguish them. So the batch looks like (e1.1.1, e1.2.1, e1.3.1, e1.1.2, e1.2.2, e1.3.2, e2.1.1, e2.2.1), 176 | # and the slice mask looks like (3, 3, 2), with each elements denote the number of features for each (xample_index, feature_index) combo 177 | 178 | # the training examples have been shuffled before this function, so no need to shuffle here 179 | 180 | # num_examples = len(features) 181 | 182 | # if shuffle: 183 | # np.random.seed(0) 184 | # idx = np.random.permutation(num_examples) 185 | # features_shuffled = np.asarray(features)[idx] 186 | # else: 187 | # features_shuffled = np.asarray(features) 188 | 189 | # num_steps = math.ceil(num_examples / batch_size) 190 | 191 | prev_e_tracker, prev_v_tracker = None, None 192 | f_tracker = 0 # feature tracker, denotes the feature index for each variation 193 | features_dict = {} 194 | for feature, e_tracker, v_tracker in zip(features, example_tracker, variation_tracker): 195 | # get the f_tracker 196 | if e_tracker == prev_e_tracker and v_tracker == prev_v_tracker: 197 | f_tracker += 1 198 | else: 199 | f_tracker = 0 200 | prev_e_tracker, prev_v_tracker = e_tracker, v_tracker 201 | 202 | key = (e_tracker, f_tracker) 203 | if key not in features_dict: 204 | features_dict[key] = [] 205 | features_dict[key].append(feature) 206 | 207 | feature_groups = list(features_dict.values()) 208 | 209 | if shuffle: 210 | np.random.seed(0) 211 | np.random.shuffle(feature_groups) 212 | # idx = np.random.permutation(len(feature_groups)) 213 | # feature_groups = np.asarray(feature_groups)[idx] 214 | 215 | for _ in range(int(num_epoches)): 216 | # we greedily select all the features that belong the next feature group, 217 | # as long as the sum of example_features does not exceed FLAGS.train_batch_size 218 | batch_features = [] 219 | batch_slice_mask = [] 220 | batch_slice_num = None 221 | 222 | # after the weighted sum of history, we get a new representation for the example feature 223 | # this feature will be fed into the prediction layer 224 | # this feature share the input_ids, etc with the entire feature group 225 | # we use this feature to compute loss 226 | output_features = [] 227 | 228 | for feature_group in feature_groups: 229 | len_feature_group = len(feature_group) 230 | if len(batch_features) + len_feature_group <= batch_size: 231 | batch_features.extend(feature_group) 232 | batch_slice_mask.append(len_feature_group) 233 | output_features.append(feature_group[0]) 234 | else: 235 | batch_slice_num = len(batch_slice_mask) 236 | batch_slice_mask += [1] * (batch_size - len(batch_slice_mask)) 237 | yield batch_features, batch_slice_mask, batch_slice_num, output_features 238 | 239 | batch_features = [] 240 | batch_slice_mask = [] 241 | batch_slice_num = None 242 | output_features = [] 243 | 244 | batch_features.extend(feature_group) 245 | batch_slice_mask.append(len_feature_group) 246 | output_features.append(feature_group[0]) 247 | 248 | if len(batch_features) > 0: 249 | batch_slice_num = len(batch_slice_mask) 250 | batch_slice_mask += [1] * (batch_size - len(batch_slice_mask)) 251 | yield batch_features, batch_slice_mask, batch_slice_num, output_features 252 | 253 | print('epoch finished!', 'shuffle={}'.format(shuffle)) 254 | 255 | 256 | 257 | # for _ in range(int(num_epoches)): 258 | # start_example_index = 0 259 | # end_example_index = start_example_index + example_batch_size # this is actually the first example index in the next batch 260 | 261 | # while start_example_index in example_tracker: 262 | # start_index = example_tracker.index(start_example_index) 263 | # try: 264 | # end_index = example_tracker.index(end_example_index) 265 | # except: 266 | # end_index = None 267 | # batch_features = features[start_index: end_index] 268 | # batch_example_tracker = example_tracker[start_index: end_index] 269 | # batch_variation_tracker = variation_tracker[start_index: end_index] 270 | 271 | # start_example_index += example_batch_size 272 | # end_example_index += example_batch_size 273 | 274 | # yield batch_features, batch_example_tracker, batch_variation_tracker 275 | 276 | # print('epoch finished!') 277 | -------------------------------------------------------------------------------- /models/ham/interface.py: -------------------------------------------------------------------------------- 1 | from models.ham.cqa_supports import * 2 | from models.ham.cqa_model import * 3 | from models.ham.cqa_gen_batches import * 4 | from models.ham.cqa_rl_supports import * 5 | import os 6 | import tensorflow as tf 7 | import models.ham.modeling as modeling 8 | import models.ham.tokenization as tokenization 9 | 10 | 11 | class BertHAM(): 12 | def __init__(self, args): 13 | self.QA_history = [] 14 | self.args=args 15 | 16 | bert_config = modeling.BertConfig.from_json_file(self.args.bert_config_file) 17 | 18 | # tf Graph input 19 | self.unique_ids = tf.placeholder(tf.int32, shape=[None], name='unique_ids') 20 | self.input_ids = tf.placeholder(tf.int32, shape=[None, self.args.max_seq_length], name='input_ids') 21 | self.input_mask = tf.placeholder(tf.int32, shape=[None, self.args.max_seq_length], name='input_mask') 22 | self.segment_ids = tf.placeholder(tf.int32, shape=[None, self.args.max_seq_length], name='segment_ids') 23 | self.start_positions = tf.placeholder(tf.int32, shape=[None], name='start_positions') 24 | self.end_positions = tf.placeholder(tf.int32, shape=[None], name='end_positions') 25 | self.history_answer_marker = tf.placeholder(tf.int32, shape=[None, self.args.max_seq_length], name='history_answer_marker') 26 | self.training = tf.placeholder(tf.bool, name='training') 27 | self.yesno_labels = tf.placeholder(tf.int32, shape=[None], name='yesno_labels') 28 | self.followup_labels = tf.placeholder(tf.int32, shape=[None], name='followup_labels') 29 | 30 | # a unique combo of (e_tracker, f_tracker) is called a slice 31 | self.slice_mask = tf.placeholder(tf.int32, shape=[self.args.predict_batch_size], name='slice_mask') 32 | self.slice_num = tf.placeholder(tf.int32, shape=None, name='slice_num') 33 | # for auxiliary loss 34 | self.aux_start_positions = tf.placeholder(tf.int32, shape=[None], name='aux_start_positions') 35 | self.aux_end_positions = tf.placeholder(tf.int32, shape=[None], name='aux_end_positions') 36 | 37 | bert_representation, cls_representation = bert_rep( 38 | bert_config=bert_config, 39 | is_training=self.training, 40 | input_ids=self.input_ids, 41 | input_mask=self.input_mask, 42 | segment_ids=self.segment_ids, 43 | history_answer_marker=self.history_answer_marker, 44 | use_one_hot_embeddings=False 45 | ) 46 | 47 | reduce_mean_representation = tf.reduce_mean(bert_representation, axis=1) 48 | reduce_max_representation = tf.reduce_max(bert_representation, axis=1) 49 | 50 | if self.args.history_attention_input == 'CLS': 51 | history_attention_input = cls_representation 52 | elif self.args.history_attention_input == 'reduce_mean': 53 | history_attention_input = reduce_mean_representation 54 | elif self.args.history_attention_input == 'reduce_max': 55 | history_attention_input = reduce_max_representation 56 | else: 57 | print('FLAGS.history_attention_input not specified') 58 | 59 | if self.args.mtl_input == 'CLS': 60 | mtl_input = cls_representation 61 | elif self.args.mtl_input == 'reduce_mean': 62 | mtl_input = reduce_mean_representation 63 | elif self.args.mtl_input == 'reduce_max': 64 | mtl_input = reduce_max_representation 65 | else: 66 | print('FLAGS.mtl_input not specified') 67 | 68 | if self.args.disable_attention: 69 | new_bert_representation, new_mtl_input, self.attention_weights = disable_history_attention_net(bert_representation, 70 | history_attention_input, mtl_input, 71 | self.slice_mask, 72 | self.slice_num, self.args) 73 | else: 74 | if self.args.fine_grained_attention: 75 | new_bert_representation, new_mtl_input, self.attention_weights = fine_grained_history_attention_net(bert_representation, 76 | mtl_input, 77 | self.slice_mask, 78 | self.slice_num, self.args) 79 | 80 | else: 81 | new_bert_representation, new_mtl_input, self.attention_weights = history_attention_net(bert_representation, 82 | history_attention_input, mtl_input, 83 | self.slice_mask, 84 | self.slice_num, self.args) 85 | 86 | (self.start_logits, self.end_logits) = cqa_model(new_bert_representation, self.args) 87 | self.yesno_logits = yesno_model(new_mtl_input) 88 | self.followup_logits = followup_model(new_mtl_input) 89 | 90 | tvars = tf.trainable_variables() 91 | 92 | initialized_variable_names = {} 93 | if self.args.init_checkpoint: 94 | (assignment_map, initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(tvars, self.args.init_checkpoint) 95 | tf.train.init_from_checkpoint(self.args.init_checkpoint, assignment_map) 96 | 97 | # compute loss 98 | seq_length = modeling.get_shape_list(self.input_ids)[1] 99 | def compute_loss(logits, positions): 100 | one_hot_positions = tf.one_hot( 101 | positions, depth=seq_length, dtype=tf.float32) 102 | log_probs = tf.nn.log_softmax(logits, axis=-1) 103 | loss = -tf.reduce_mean(tf.reduce_sum(one_hot_positions * log_probs, axis=-1)) 104 | return loss 105 | 106 | start_loss = compute_loss(self.start_logits, self.start_positions) 107 | end_loss = compute_loss(self.end_logits, self.end_positions) 108 | 109 | yesno_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.yesno_logits, labels=self.yesno_labels)) 110 | followup_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.followup_logits, labels=self.followup_labels)) 111 | 112 | if self.args.MTL: 113 | cqa_loss = (start_loss + end_loss) / 2.0 114 | if self.args.MTL_lambda < 1: 115 | self.total_loss = self.args.MTL_mu * cqa_loss * cqa_loss + self.args.MTL_lambda * yesno_loss + self.args.MTL_lambda * followup_loss 116 | else: 117 | self.total_loss = cqa_loss + yesno_loss + followup_loss 118 | 119 | else: 120 | self.total_loss = (start_loss + end_loss) / 2.0 121 | 122 | # Initializing the variables 123 | init = tf.global_variables_initializer() 124 | tf.get_default_graph().finalize() 125 | self.session = tf.Session() 126 | self.session.run(init) 127 | 128 | def tokenizer(self): 129 | tokenizer = tokenization.FullTokenizer(vocab_file=self.args.vocab_file, do_lower_case=self.args.do_lower_case) 130 | return tokenizer 131 | 132 | def load_partial_examples(self, partial_eval_examples_file): 133 | paragraphs = read_partial_quac_examples_extern(input_file=partial_eval_examples_file) 134 | return paragraphs 135 | 136 | def predict_one_automatic_turn(self, partial_example, unique_id, example_idx, tokenizer): 137 | question = partial_example.question_text 138 | turn = int(partial_example.qas_id.split("#")[1]) 139 | char_to_word_offset = partial_example.char_to_word_offset 140 | example = read_one_quac_example_extern(partial_example, self.QA_history, char_to_word_offset, self.args.history_len, self.args.use_history_answer_marker, self.args.only_history_answer) 141 | val_features, val_example_tracker, val_variation_tracker, val_example_features_nums, unique_id = convert_one_example_to_variations_and_then_features( 142 | example=example, example_index=example_idx, tokenizer=tokenizer, 143 | max_seq_length=self.args.max_seq_length, doc_stride=self.args.doc_stride, 144 | max_query_length=self.args.max_query_length, max_considered_history_turns=self.args.max_considered_history_turns, 145 | reformulate_question=self.args.reformulate_question, front_padding=self.args.front_padding, append_self=self.args.append_self,unique_id=unique_id) 146 | 147 | val_batch = cqa_gen_example_aware_batch_single(val_features, val_example_tracker, val_variation_tracker, self.args.predict_batch_size) 148 | 149 | batch_results = [] 150 | batch_features, batch_slice_mask, batch_slice_num, output_features = val_batch 151 | 152 | fd = convert_features_to_feed_dict(batch_features) # feed_dict 153 | fd_output = convert_features_to_feed_dict(output_features) 154 | 155 | if self.args.better_hae: 156 | turn_features = get_turn_features(fd['metadata']) 157 | fd['history_answer_marker'] = fix_history_answer_marker_for_bhae(fd['history_answer_marker'], turn_features) 158 | 159 | if self.args.history_ngram != 1: 160 | batch_slice_mask, group_batch_features = group_histories(batch_features, fd['history_answer_marker'], 161 | batch_slice_mask, batch_slice_num) 162 | fd = convert_features_to_feed_dict(group_batch_features) 163 | 164 | feed_dict={self.unique_ids: fd['unique_ids'], self.input_ids: fd['input_ids'], 165 | self.input_mask: fd['input_mask'], self.segment_ids: fd['segment_ids'], 166 | self.start_positions: fd_output['start_positions'], self.end_positions: fd_output['end_positions'], 167 | self.history_answer_marker: fd['history_answer_marker'], self.slice_mask: batch_slice_mask, 168 | self.slice_num: batch_slice_num, 169 | self.aux_start_positions: fd['start_positions'], self.aux_end_positions: fd['end_positions'], 170 | self.yesno_labels: fd_output['yesno'], self.followup_labels: fd_output['followup'], self.training: False} 171 | 172 | start_logits_res, end_logits_res, yesno_logits_res, followup_logits_res, batch_total_loss, attention_weights_res = self.session.run([self.start_logits, self.end_logits, self.yesno_logits, self.followup_logits, 173 | self.total_loss, self.attention_weights], feed_dict=feed_dict) 174 | 175 | for each_unique_id, each_start_logits, each_end_logits, each_yesno_logits, each_followup_logits in zip(fd_output['unique_ids'], start_logits_res, end_logits_res, yesno_logits_res, followup_logits_res): 176 | 177 | each_unique_id = int(each_unique_id) 178 | each_start_logits = [float(x) for x in each_start_logits.flat] 179 | each_end_logits = [float(x) for x in each_end_logits.flat] 180 | each_yesno_logits = [float(x) for x in each_yesno_logits.flat] 181 | each_followup_logits = [float(x) for x in each_followup_logits.flat] 182 | batch_results.append(RawResult(unique_id=each_unique_id, start_logits=each_start_logits, 183 | end_logits=each_end_logits, yesno_logits=each_yesno_logits, 184 | followup_logits=each_followup_logits)) 185 | 186 | pred_text, pred_answer_s, pred_answer_e, pred_yesno, pred_followup = get_last_prediction(example, example_idx, output_features, batch_results, self.args.n_best_size, self.args.max_answer_length, self.args.do_lower_case) 187 | 188 | self.QA_history.append((turn, question, (pred_text, pred_answer_s, pred_answer_e))) 189 | 190 | return pred_text, unique_id -------------------------------------------------------------------------------- /models/ham/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | # we only optimize the CQA model, we do not optimize the RL model 72 | # vars_to_optimize = [v for v in tvars if (v.name.startswith('bert') or v.name.startswith('cls'))] 73 | # print('vars_to_optimize', vars_to_optimize) 74 | # grads = tf.gradients(loss, vars_to_optimize) 75 | 76 | # if FLAGS.freeze_bert: 77 | # vars_to_optimize = [v for v in tvars if not v.name.startswith('bert')] 78 | # grads = tf.gradients(loss, vars_to_optimize) 79 | # (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 80 | # train_op = optimizer.apply_gradients(zip(grads, vars_to_optimize), global_step=global_step) 81 | # else: 82 | grads = tf.gradients(loss, tvars) 83 | 84 | # This is how the model was pre-trained. 85 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 86 | 87 | # train_op = optimizer.apply_gradients( 88 | # zip(grads, vars_to_optimize), global_step=global_step) 89 | train_op = optimizer.apply_gradients( 90 | zip(grads, tvars), global_step=global_step) 91 | 92 | new_global_step = global_step + 1 93 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 94 | 95 | tf.summary.scalar('gobal_step', global_step) 96 | tf.summary.scalar('learning_rate', learning_rate) 97 | return train_op 98 | 99 | 100 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 101 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 102 | 103 | def __init__(self, 104 | learning_rate, 105 | weight_decay_rate=0.0, 106 | beta_1=0.9, 107 | beta_2=0.999, 108 | epsilon=1e-6, 109 | exclude_from_weight_decay=None, 110 | name="AdamWeightDecayOptimizer"): 111 | """Constructs a AdamWeightDecayOptimizer.""" 112 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 113 | 114 | self.learning_rate = learning_rate 115 | self.weight_decay_rate = weight_decay_rate 116 | self.beta_1 = beta_1 117 | self.beta_2 = beta_2 118 | self.epsilon = epsilon 119 | self.exclude_from_weight_decay = exclude_from_weight_decay 120 | 121 | ################# 122 | self.m_and_v = [] 123 | ################# 124 | 125 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 126 | """See base class.""" 127 | assignments = [] 128 | for (grad, param) in grads_and_vars: 129 | if grad is None or param is None: 130 | continue 131 | 132 | ################ 133 | with tf.control_dependencies([grad]): 134 | ################ 135 | 136 | param_name = self._get_variable_name(param.name) 137 | 138 | m = tf.get_variable( 139 | name=param_name + "/adam_m", 140 | shape=param.shape.as_list(), 141 | dtype=tf.float32, 142 | trainable=False, 143 | initializer=tf.zeros_initializer()) 144 | v = tf.get_variable( 145 | name=param_name + "/adam_v", 146 | shape=param.shape.as_list(), 147 | dtype=tf.float32, 148 | trainable=False, 149 | initializer=tf.zeros_initializer()) 150 | 151 | self.m_and_v.append(m) 152 | self.m_and_v.append(v) 153 | tf.add_to_collection('OPT_VARS', m) 154 | tf.add_to_collection('OPT_VARS', v) 155 | 156 | # Standard Adam update. 157 | next_m = ( 158 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 159 | next_v = ( 160 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 161 | tf.square(grad))) 162 | 163 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 164 | 165 | # Just adding the square of the weights to the loss function is *not* 166 | # the correct way of using L2 regularization/weight decay with Adam, 167 | # since that will interact with the m and v parameters in strange ways. 168 | # 169 | # Instead we want ot decay the weights in a manner that doesn't interact 170 | # with the m/v parameters. This is equivalent to adding the square 171 | # of the weights to the loss with plain (non-momentum) SGD. 172 | if self._do_use_weight_decay(param_name): 173 | update += self.weight_decay_rate * param 174 | 175 | update_with_lr = self.learning_rate * update 176 | 177 | next_param = param - update_with_lr 178 | 179 | assignments.extend( 180 | [param.assign(next_param), 181 | m.assign(next_m), 182 | v.assign(next_v)]) 183 | return tf.group(*assignments, name=name) 184 | 185 | # def apply_gradients(self, grads_and_vars, global_step=None, name=None): 186 | # """See base class.""" 187 | # assignments = [] 188 | # for (grad, param) in grads_and_vars: 189 | # if grad is None or param is None: 190 | # continue 191 | 192 | # param_name = self._get_variable_name(param.name) 193 | 194 | # m = tf.get_variable( 195 | # name=param_name + "/adam_m", 196 | # shape=param.shape.as_list(), 197 | # dtype=tf.float32, 198 | # trainable=False, 199 | # initializer=tf.zeros_initializer()) 200 | # v = tf.get_variable( 201 | # name=param_name + "/adam_v", 202 | # shape=param.shape.as_list(), 203 | # dtype=tf.float32, 204 | # trainable=False, 205 | # initializer=tf.zeros_initializer()) 206 | 207 | # # Standard Adam update. 208 | # next_m = ( 209 | # tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 210 | # next_v = ( 211 | # tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 212 | # tf.square(grad))) 213 | 214 | # update = next_m / (tf.sqrt(next_v) + self.epsilon) 215 | 216 | # # Just adding the square of the weights to the loss function is *not* 217 | # # the correct way of using L2 regularization/weight decay with Adam, 218 | # # since that will interact with the m and v parameters in strange ways. 219 | # # 220 | # # Instead we want ot decay the weights in a manner that doesn't interact 221 | # # with the m/v parameters. This is equivalent to adding the square 222 | # # of the weights to the loss with plain (non-momentum) SGD. 223 | # if self._do_use_weight_decay(param_name): 224 | # update += self.weight_decay_rate * param 225 | 226 | # update_with_lr = self.learning_rate * update 227 | 228 | # next_param = param - update_with_lr 229 | 230 | # assignments.extend( 231 | # [param.assign(next_param), 232 | # m.assign(next_m), 233 | # v.assign(next_v)]) 234 | # return tf.group(*assignments, name=name) 235 | 236 | def _do_use_weight_decay(self, param_name): 237 | """Whether to use L2 weight decay for `param_name`.""" 238 | if not self.weight_decay_rate: 239 | return False 240 | if self.exclude_from_weight_decay: 241 | for r in self.exclude_from_weight_decay: 242 | if re.search(r, param_name) is not None: 243 | return False 244 | return True 245 | 246 | def _get_variable_name(self, param_name): 247 | """Get the variable name from the tensor name.""" 248 | m = re.match("^(.*):\\d+$", param_name) 249 | if m is not None: 250 | param_name = m.group(1) 251 | return param_name 252 | 253 | # def apply_gradients(self, grads_and_vars, global_step=None, name=None): 254 | # """See base class.""" 255 | # assignments = [] 256 | # for (grad, param) in grads_and_vars: 257 | # if grad is None or param is None: 258 | # continue 259 | 260 | # with tf.control_dependencies([grad]): 261 | # param_name = self._get_variable_name(param.name) 262 | 263 | # m = tf.get_variable( 264 | # name=param_name + "/adam_m", 265 | # shape=param.shape.as_list(), 266 | # dtype=tf.float32, 267 | # trainable=False, 268 | # initializer=tf.zeros_initializer()) 269 | # v = tf.get_variable( 270 | # name=param_name + "/adam_v", 271 | # shape=param.shape.as_list(), 272 | # dtype=tf.float32, 273 | # trainable=False, 274 | # initializer=tf.zeros_initializer()) 275 | # self.m_and_v.append(m) 276 | # self.m_and_v.append(v) 277 | # tf.add_to_collection('OPT_VARS', m) 278 | # tf.add_to_collection('OPT_VARS', v) 279 | 280 | # # Standard Adam update. 281 | # next_m = ( 282 | # tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 283 | # next_v = ( 284 | # tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 285 | # tf.square(grad))) 286 | 287 | # update = next_m / (tf.sqrt(next_v) + self.epsilon) 288 | 289 | # # Just adding the square of the weights to the loss function is *not* 290 | # # the correct way of using L2 regularization/weight decay with Adam, 291 | # # since that will interact with the m and v parameters in strange ways. 292 | # # 293 | # # Instead we want ot decay the weights in a manner that doesn't interact 294 | # # with the m/v parameters. This is equivalent to adding the square 295 | # # of the weights to the loss with plain (non-momentum) SGD. 296 | # if self._do_use_weight_decay(param_name): 297 | # update += self.weight_decay_rate * param 298 | 299 | # update_with_lr = self.learning_rate * update 300 | 301 | # next_param = param - update_with_lr 302 | 303 | # assignments.extend( 304 | # [param.assign(next_param), 305 | # m.assign(next_m), 306 | # v.assign(next_v)]) 307 | # return tf.group(*assignments, name=name) -------------------------------------------------------------------------------- /models/ham/reindent.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | # Released to the public domain, by Tim Peters, 03 October 2000. 4 | 5 | """reindent [-d][-r][-v] [ path ... ] 6 | 7 | -d (--dryrun) Dry run. Analyze, but don't make any changes to, files. 8 | -r (--recurse) Recurse. Search for all .py files in subdirectories too. 9 | -n (--nobackup) No backup. Does not make a ".bak" file before reindenting. 10 | -v (--verbose) Verbose. Print informative msgs; else no output. 11 | -h (--help) Help. Print this usage information and exit. 12 | 13 | Change Python (.py) files to use 4-space indents and no hard tab characters. 14 | Also trim excess spaces and tabs from ends of lines, and remove empty lines 15 | at the end of files. Also ensure the last line ends with a newline. 16 | 17 | If no paths are given on the command line, reindent operates as a filter, 18 | reading a single source file from standard input and writing the transformed 19 | source to standard output. In this case, the -d, -r and -v flags are 20 | ignored. 21 | 22 | You can pass one or more file and/or directory paths. When a directory 23 | path, all .py files within the directory will be examined, and, if the -r 24 | option is given, likewise recursively for subdirectories. 25 | 26 | If output is not to standard output, reindent overwrites files in place, 27 | renaming the originals with a .bak extension. If it finds nothing to 28 | change, the file is left alone. If reindent does change a file, the changed 29 | file is a fixed-point for future runs (i.e., running reindent on the 30 | resulting .py file won't change it again). 31 | 32 | The hard part of reindenting is figuring out what to do with comment 33 | lines. So long as the input files get a clean bill of health from 34 | tabnanny.py, reindent should do a good job. 35 | 36 | The backup file is a copy of the one that is being reindented. The ".bak" 37 | file is generated with shutil.copy(), but some corner cases regarding 38 | user/group and permissions could leave the backup file more readable than 39 | you'd prefer. You can always use the --nobackup option to prevent this. 40 | """ 41 | 42 | __version__ = "1" 43 | 44 | import tokenize 45 | import os 46 | import shutil 47 | import sys 48 | 49 | verbose = False 50 | recurse = False 51 | dryrun = False 52 | makebackup = True 53 | 54 | 55 | def usage(msg=None): 56 | if msg is None: 57 | msg = __doc__ 58 | print(msg, file=sys.stderr) 59 | 60 | 61 | def errprint(*args): 62 | sys.stderr.write(" ".join(str(arg) for arg in args)) 63 | sys.stderr.write("\n") 64 | 65 | 66 | def main(): 67 | import getopt 68 | global verbose, recurse, dryrun, makebackup 69 | try: 70 | opts, args = getopt.getopt(sys.argv[1:], "drnvh", 71 | ["dryrun", "recurse", "nobackup", "verbose", "help"]) 72 | except getopt.error as msg: 73 | usage(msg) 74 | return 75 | for o, a in opts: 76 | if o in ('-d', '--dryrun'): 77 | dryrun = True 78 | elif o in ('-r', '--recurse'): 79 | recurse = True 80 | elif o in ('-n', '--nobackup'): 81 | makebackup = False 82 | elif o in ('-v', '--verbose'): 83 | verbose = True 84 | elif o in ('-h', '--help'): 85 | usage() 86 | return 87 | if not args: 88 | r = Reindenter(sys.stdin) 89 | r.run() 90 | r.write(sys.stdout) 91 | return 92 | for arg in args: 93 | check(arg) 94 | 95 | 96 | def check(file): 97 | if os.path.isdir(file) and not os.path.islink(file): 98 | if verbose: 99 | print("listing directory", file) 100 | names = os.listdir(file) 101 | for name in names: 102 | fullname = os.path.join(file, name) 103 | if ((recurse and os.path.isdir(fullname) and 104 | not os.path.islink(fullname) and 105 | not os.path.split(fullname)[1].startswith(".")) 106 | or name.lower().endswith(".py")): 107 | check(fullname) 108 | return 109 | 110 | if verbose: 111 | print("checking", file, "...", end=' ') 112 | with open(file, 'rb') as f: 113 | encoding, _ = tokenize.detect_encoding(f.readline) 114 | try: 115 | with open(file, encoding=encoding) as f: 116 | r = Reindenter(f) 117 | except IOError as msg: 118 | errprint("%s: I/O Error: %s" % (file, str(msg))) 119 | return 120 | 121 | newline = r.newlines 122 | if isinstance(newline, tuple): 123 | errprint("%s: mixed newlines detected; cannot process file" % file) 124 | return 125 | 126 | if r.run(): 127 | if verbose: 128 | print("changed.") 129 | if dryrun: 130 | print("But this is a dry run, so leaving it alone.") 131 | if not dryrun: 132 | bak = file + ".bak" 133 | if makebackup: 134 | shutil.copyfile(file, bak) 135 | if verbose: 136 | print("backed up", file, "to", bak) 137 | with open(file, "w", encoding=encoding, newline=newline) as f: 138 | r.write(f) 139 | if verbose: 140 | print("wrote new", file) 141 | return True 142 | else: 143 | if verbose: 144 | print("unchanged.") 145 | return False 146 | 147 | 148 | def _rstrip(line, JUNK='\n \t'): 149 | """Return line stripped of trailing spaces, tabs, newlines. 150 | 151 | Note that line.rstrip() instead also strips sundry control characters, 152 | but at least one known Emacs user expects to keep junk like that, not 153 | mentioning Barry by name or anything . 154 | """ 155 | 156 | i = len(line) 157 | while i > 0 and line[i - 1] in JUNK: 158 | i -= 1 159 | return line[:i] 160 | 161 | 162 | class Reindenter: 163 | 164 | def __init__(self, f): 165 | self.find_stmt = 1 # next token begins a fresh stmt? 166 | self.level = 0 # current indent level 167 | 168 | # Raw file lines. 169 | self.raw = f.readlines() 170 | 171 | # File lines, rstripped & tab-expanded. Dummy at start is so 172 | # that we can use tokenize's 1-based line numbering easily. 173 | # Note that a line is all-blank iff it's "\n". 174 | self.lines = [_rstrip(line).expandtabs() + "\n" 175 | for line in self.raw] 176 | self.lines.insert(0, None) 177 | self.index = 1 # index into self.lines of next line 178 | 179 | # List of (lineno, indentlevel) pairs, one for each stmt and 180 | # comment line. indentlevel is -1 for comment lines, as a 181 | # signal that tokenize doesn't know what to do about them; 182 | # indeed, they're our headache! 183 | self.stats = [] 184 | 185 | # Save the newlines found in the file so they can be used to 186 | # create output without mutating the newlines. 187 | self.newlines = f.newlines 188 | 189 | def run(self): 190 | tokens = tokenize.generate_tokens(self.getline) 191 | for _token in tokens: 192 | self.tokeneater(*_token) 193 | # Remove trailing empty lines. 194 | lines = self.lines 195 | while lines and lines[-1] == "\n": 196 | lines.pop() 197 | # Sentinel. 198 | stats = self.stats 199 | stats.append((len(lines), 0)) 200 | # Map count of leading spaces to # we want. 201 | have2want = {} 202 | # Program after transformation. 203 | after = self.after = [] 204 | # Copy over initial empty lines -- there's nothing to do until 205 | # we see a line with *something* on it. 206 | i = stats[0][0] 207 | after.extend(lines[1:i]) 208 | for i in range(len(stats) - 1): 209 | thisstmt, thislevel = stats[i] 210 | nextstmt = stats[i + 1][0] 211 | have = getlspace(lines[thisstmt]) 212 | want = thislevel * 4 213 | if want < 0: 214 | # A comment line. 215 | if have: 216 | # An indented comment line. If we saw the same 217 | # indentation before, reuse what it most recently 218 | # mapped to. 219 | want = have2want.get(have, -1) 220 | if want < 0: 221 | # Then it probably belongs to the next real stmt. 222 | for j in range(i + 1, len(stats) - 1): 223 | jline, jlevel = stats[j] 224 | if jlevel >= 0: 225 | if have == getlspace(lines[jline]): 226 | want = jlevel * 4 227 | break 228 | if want < 0: # Maybe it's a hanging 229 | # comment like this one, 230 | # in which case we should shift it like its base 231 | # line got shifted. 232 | for j in range(i - 1, -1, -1): 233 | jline, jlevel = stats[j] 234 | if jlevel >= 0: 235 | want = have + (getlspace(after[jline - 1]) - 236 | getlspace(lines[jline])) 237 | break 238 | if want < 0: 239 | # Still no luck -- leave it alone. 240 | want = have 241 | else: 242 | want = 0 243 | assert want >= 0 244 | have2want[have] = want 245 | diff = want - have 246 | if diff == 0 or have == 0: 247 | after.extend(lines[thisstmt:nextstmt]) 248 | else: 249 | for line in lines[thisstmt:nextstmt]: 250 | if diff > 0: 251 | if line == "\n": 252 | after.append(line) 253 | else: 254 | after.append(" " * diff + line) 255 | else: 256 | remove = min(getlspace(line), -diff) 257 | after.append(line[remove:]) 258 | return self.raw != self.after 259 | 260 | def write(self, f): 261 | f.writelines(self.after) 262 | 263 | # Line-getter for tokenize. 264 | def getline(self): 265 | if self.index >= len(self.lines): 266 | line = "" 267 | else: 268 | line = self.lines[self.index] 269 | self.index += 1 270 | return line 271 | 272 | # Line-eater for tokenize. 273 | def tokeneater(self, type, token, slinecol, end, line, 274 | INDENT=tokenize.INDENT, 275 | DEDENT=tokenize.DEDENT, 276 | NEWLINE=tokenize.NEWLINE, 277 | COMMENT=tokenize.COMMENT, 278 | NL=tokenize.NL): 279 | 280 | if type == NEWLINE: 281 | # A program statement, or ENDMARKER, will eventually follow, 282 | # after some (possibly empty) run of tokens of the form 283 | # (NL | COMMENT)* (INDENT | DEDENT+)? 284 | self.find_stmt = 1 285 | 286 | elif type == INDENT: 287 | self.find_stmt = 1 288 | self.level += 1 289 | 290 | elif type == DEDENT: 291 | self.find_stmt = 1 292 | self.level -= 1 293 | 294 | elif type == COMMENT: 295 | if self.find_stmt: 296 | self.stats.append((slinecol[0], -1)) 297 | # but we're still looking for a new stmt, so leave 298 | # find_stmt alone 299 | 300 | elif type == NL: 301 | pass 302 | 303 | elif self.find_stmt: 304 | # This is the first "real token" following a NEWLINE, so it 305 | # must be the first token of the next program statement, or an 306 | # ENDMARKER. 307 | self.find_stmt = 0 308 | if line: # not endmarker 309 | self.stats.append((slinecol[0], self.level)) 310 | 311 | 312 | # Count number of leading blanks. 313 | def getlspace(line): 314 | i, n = 0, len(line) 315 | while i < n and line[i] == " ": 316 | i += 1 317 | return i 318 | 319 | 320 | if __name__ == '__main__': 321 | main() 322 | -------------------------------------------------------------------------------- /models/ham/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_tokens_to_ids(vocab, tokens): 86 | """Converts a sequence of tokens into ids using the vocab.""" 87 | ids = [] 88 | for token in tokens: 89 | ids.append(vocab[token]) 90 | return ids 91 | 92 | 93 | def whitespace_tokenize(text): 94 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 95 | text = text.strip() 96 | if not text: 97 | return [] 98 | tokens = text.split() 99 | return tokens 100 | 101 | 102 | class FullTokenizer(object): 103 | """Runs end-to-end tokenziation.""" 104 | 105 | def __init__(self, vocab_file, do_lower_case=True): 106 | self.vocab = load_vocab(vocab_file) 107 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 108 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 109 | 110 | def tokenize(self, text): 111 | split_tokens = [] 112 | for token in self.basic_tokenizer.tokenize(text): 113 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 114 | split_tokens.append(sub_token) 115 | 116 | return split_tokens 117 | 118 | def convert_tokens_to_ids(self, tokens): 119 | return convert_tokens_to_ids(self.vocab, tokens) 120 | 121 | 122 | class BasicTokenizer(object): 123 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 124 | 125 | def __init__(self, do_lower_case=True): 126 | """Constructs a BasicTokenizer. 127 | 128 | Args: 129 | do_lower_case: Whether to lower case the input. 130 | """ 131 | self.do_lower_case = do_lower_case 132 | 133 | def tokenize(self, text): 134 | """Tokenizes a piece of text.""" 135 | text = convert_to_unicode(text) 136 | text = self._clean_text(text) 137 | orig_tokens = whitespace_tokenize(text) 138 | split_tokens = [] 139 | for token in orig_tokens: 140 | if self.do_lower_case: 141 | token = token.lower() 142 | token = self._run_strip_accents(token) 143 | split_tokens.extend(self._run_split_on_punc(token)) 144 | 145 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 146 | return output_tokens 147 | 148 | def _run_strip_accents(self, text): 149 | """Strips accents from a piece of text.""" 150 | text = unicodedata.normalize("NFD", text) 151 | output = [] 152 | for char in text: 153 | cat = unicodedata.category(char) 154 | if cat == "Mn": 155 | continue 156 | output.append(char) 157 | return "".join(output) 158 | 159 | def _run_split_on_punc(self, text): 160 | """Splits punctuation on a piece of text.""" 161 | chars = list(text) 162 | i = 0 163 | start_new_word = True 164 | output = [] 165 | while i < len(chars): 166 | char = chars[i] 167 | if _is_punctuation(char): 168 | output.append([char]) 169 | start_new_word = True 170 | else: 171 | if start_new_word: 172 | output.append([]) 173 | start_new_word = False 174 | output[-1].append(char) 175 | i += 1 176 | 177 | return ["".join(x) for x in output] 178 | 179 | def _clean_text(self, text): 180 | """Performs invalid character removal and whitespace cleanup on text.""" 181 | output = [] 182 | for char in text: 183 | cp = ord(char) 184 | if cp == 0 or cp == 0xfffd or _is_control(char): 185 | continue 186 | if _is_whitespace(char): 187 | output.append(" ") 188 | else: 189 | output.append(char) 190 | return "".join(output) 191 | 192 | 193 | class WordpieceTokenizer(object): 194 | """Runs WordPiece tokenziation.""" 195 | 196 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 197 | self.vocab = vocab 198 | self.unk_token = unk_token 199 | self.max_input_chars_per_word = max_input_chars_per_word 200 | 201 | def tokenize(self, text): 202 | """Tokenizes a piece of text into its word pieces. 203 | 204 | This uses a greedy longest-match-first algorithm to perform tokenization 205 | using the given vocabulary. 206 | 207 | For example: 208 | input = "unaffable" 209 | output = ["un", "##aff", "##able"] 210 | 211 | Args: 212 | text: A single token or whitespace separated tokens. This should have 213 | already been passed through `BasicTokenizer. 214 | 215 | Returns: 216 | A list of wordpiece tokens. 217 | """ 218 | 219 | text = convert_to_unicode(text) 220 | 221 | output_tokens = [] 222 | for token in whitespace_tokenize(text): 223 | chars = list(token) 224 | if len(chars) > self.max_input_chars_per_word: 225 | output_tokens.append(self.unk_token) 226 | continue 227 | 228 | is_bad = False 229 | start = 0 230 | sub_tokens = [] 231 | while start < len(chars): 232 | end = len(chars) 233 | cur_substr = None 234 | while start < end: 235 | substr = "".join(chars[start:end]) 236 | if start > 0: 237 | substr = "##" + substr 238 | if substr in self.vocab: 239 | cur_substr = substr 240 | break 241 | end -= 1 242 | if cur_substr is None: 243 | is_bad = True 244 | break 245 | sub_tokens.append(cur_substr) 246 | start = end 247 | 248 | if is_bad: 249 | output_tokens.append(self.unk_token) 250 | else: 251 | output_tokens.extend(sub_tokens) 252 | return output_tokens 253 | 254 | 255 | def _is_whitespace(char): 256 | """Checks whether `chars` is a whitespace character.""" 257 | # \t, \n, and \r are technically contorl characters but we treat them 258 | # as whitespace since they are generally considered as such. 259 | if char == " " or char == "\t" or char == "\n" or char == "\r": 260 | return True 261 | cat = unicodedata.category(char) 262 | if cat == "Zs": 263 | return True 264 | return False 265 | 266 | 267 | def _is_control(char): 268 | """Checks whether `chars` is a control character.""" 269 | # These are technically control characters but we count them as whitespace 270 | # characters. 271 | if char == "\t" or char == "\n" or char == "\r": 272 | return False 273 | cat = unicodedata.category(char) 274 | if cat.startswith("C"): 275 | return True 276 | return False 277 | 278 | 279 | def _is_punctuation(char): 280 | """Checks whether `chars` is a punctuation character.""" 281 | cp = ord(char) 282 | # We treat all non-letter/number ASCII as punctuation. 283 | # Characters such as "^", "$", and "`" are not in the Unicode 284 | # Punctuation class but we treat them as punctuation anyways, for 285 | # consistency. 286 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 287 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 288 | return True 289 | cat = unicodedata.category(char) 290 | if cat.startswith("P"): 291 | return True 292 | return False 293 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==2.5.0 2 | allennlp-models==2.5.0 3 | nested-coref-resolver==0.0.2 4 | en-core-web-sm==3.0.0 5 | nltk>=3.6.4 6 | pycorenlp==0.3.0 7 | pytorch-pretrained-bert==0.6.2 8 | sentencepiece==0.1.96 9 | spacy==3.0.6 10 | torch==1.8.1 11 | torchtext==0.9.1 12 | torchvision==0.9.1 13 | transformers==4.6.1 14 | tqdm 15 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python run_quac_eval.py \ 2 | --type model_alias \ 3 | --output_dir dir_of_model \ 4 | --write_dir dir_to_write_result \ 5 | --predict_file path_to_quac_dev \ 6 | --match_metric f1 \ 7 | --add_background \ 8 | --skip_entity \ 9 | --rewrite \ 10 | --start_i 0 \ 11 | --end_i 1000 \ -------------------------------------------------------------------------------- /run_quac_eval_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from io import open 3 | import json 4 | 5 | from coreference_resolution import find_coreference_f1s 6 | 7 | def filter_with_coreference(partial_example, background, gold_answers, QA_history, history_len=2, match_metric='em', add_background=False, skip_entity=True): 8 | """Append the previous predicted answers to the context during evaluation""" 9 | qa_gold = "" 10 | qa_pred = "" 11 | i = 0 12 | total_length = len(QA_history) 13 | while i < history_len: 14 | index = total_length-1-i 15 | if index >= 0: 16 | qa_gold = gold_answers[QA_history[index][0]] + ' ' + qa_gold #using turn_id to access gold answers 17 | qa_gold = QA_history[index][1] + ' ' + qa_gold 18 | 19 | qa_pred = QA_history[index][2][0] + ' ' + qa_pred # answer text 20 | qa_pred = QA_history[index][1] + ' ' + qa_pred # question 21 | i+=1 22 | qa_gold += " "+ partial_example.question_text 23 | qa_pred += " "+ partial_example.question_text 24 | 25 | if add_background: 26 | qa_gold = background + " " + qa_gold 27 | qa_pred = background + " " + qa_pred 28 | 29 | f1s, resolved_gold, resolved_pred = find_coreference_f1s(qa_gold, qa_pred, skip_entity) 30 | 31 | if match_metric == 'em': 32 | modified_gold = resolved_gold.split("< Q >")[-1].strip() 33 | modified_pred = resolved_pred.split("< Q >")[-1].strip() 34 | skip = (modified_gold != modified_pred) 35 | elif match_metric == 'f1': 36 | skip = False if all([f1 > 0 for f1 in f1s]) else True 37 | 38 | return skip 39 | 40 | def rewrite_with_coreference(partial_example, background, gold_answers, QA_history, history_len=2, match_metric='f1', add_background=True, skip_entity=True): 41 | qa_gold = "" 42 | qa_pred = "" 43 | i = 0 44 | total_length = len(QA_history) 45 | while i < history_len: 46 | index = total_length-1-i 47 | if index >= 0: 48 | qa_gold = gold_answers[QA_history[index][0]] + ' ' + qa_gold #using turn_id to access gold answers 49 | qa_gold = QA_history[index][1] + ' ' + qa_gold 50 | 51 | qa_pred = QA_history[index][2][0] + ' ' + qa_pred # answer text 52 | qa_pred = QA_history[index][1] + ' ' + qa_pred # question 53 | i+=1 54 | qa_gold += " "+ partial_example.question_text 55 | qa_pred += " "+ partial_example.question_text 56 | 57 | if add_background: 58 | qa_gold = background + " " + qa_gold 59 | qa_pred = background + " " + qa_pred 60 | 61 | f1s, resolved_gold, resolved_pred = find_coreference_f1s(qa_gold, qa_pred, skip_entity) 62 | 63 | modified_gold = resolved_gold.split("< Q >")[-1].strip() 64 | modified_pred = resolved_pred.split("< Q >")[-1].strip() 65 | 66 | if match_metric == 'em': 67 | skip = (modified_gold != modified_pred) 68 | elif match_metric == 'f1': 69 | skip = False if all([f1 > 0 for f1 in f1s]) else True 70 | 71 | return skip, modified_gold 72 | 73 | def write_automatic_eval_result(json_file, evaluation_result): 74 | """evaluation_results = [{"CID": ..., 75 | "Predictions": [ 76 | (qa_id, span), 77 | ... 78 | ]}, ...]""" 79 | 80 | with open(json_file, 'w') as fout: 81 | for passage_index, predictions in evaluation_result.items(): 82 | output_dict = {'best_span_str': [], 'qid': [], 'yesno':[], 'followup': []} 83 | for qa_id, span in predictions["Predictions"]: 84 | output_dict['best_span_str'].append(span) 85 | output_dict['qid'].append(qa_id) 86 | output_dict['yesno'].append('y') 87 | output_dict['followup'].append('y') 88 | fout.write(json.dumps(output_dict) + '\n') 89 | 90 | def write_invalid_category(json_file, skip_dictionary): 91 | with open(json_file, 'w') as fout: 92 | fout.write(json.dumps(skip_dictionary, indent=2)) 93 | 94 | def load_context_indep_questions(canard_path): 95 | data = {} 96 | with open(canard_path, encoding="utf-8") as c: 97 | canard_data = json.load(c) 98 | for entry in canard_data: 99 | cid = entry["QuAC_dialog_id"] 100 | turn = entry["Question_no"]-1 101 | rewrite = entry["Rewrite"] 102 | qid = cid + "_q#" + str(turn) 103 | data[qid] = rewrite 104 | return data 105 | 106 | 107 | --------------------------------------------------------------------------------