├── .gitignore ├── LICENSE ├── README.md ├── cache_elmo.py ├── coref_kernels.cc ├── coref_ops.py ├── data ├── test.vispro.1.1.jsonlines ├── train.vispro.1.1.jsonlines └── val.vispro.1.1.jsonlines ├── evaluate.py ├── experiments.conf ├── fig ├── case_study1.png ├── data_example.PNG └── dialog_example.PNG ├── filter_embeddings.py ├── get_char_vocab.py ├── get_im_fc.py ├── metrics.py ├── model.py ├── predict.py ├── requirements.txt ├── setup_all.sh ├── setup_training.sh ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | char_vocab*.txt 4 | glove*.txt 5 | glove*.txt.filtered 6 | *.hdf5 7 | logs 8 | output 9 | venv 10 | *.tgz 11 | .idea 12 | sftp-config.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 HKUST-KnowComp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Pronoun Coreference Resolution in Dialogues 2 | 3 | ## Introduction 4 | This is the data and the source code for EMNLP 2019 paper "What You See is What You Get: Visual Pronoun Coreference Resolution in Dialogues". [[PAPER](https://www.aclweb.org/anthology/D19-1516.pdf)][[PPT](https://drive.google.com/open?id=1T5911qE1XrToNcMTOhKoAFEiqEgco2Sv)] 5 | 6 | ### Abstract 7 | Grounding pronouns to a visual object it refers to requires complex reasoning from various information sources, especially in conversational scenarios. 8 | For example, when people in a conversation talk about something all speakers can see (e.g., the statue), they often directly use pronouns (e.g., it) to refer it without previous introduction. 9 | This fact brings a huge challenge for modern natural language understanding systems, particularly conventional context-based pronoun coreference models. 10 | To tackle this challenge, in this paper, we formally define the task of visual-aware pronoun coreference resolution (PCR), and introduce VisPro, a large-scale dialogue PCR dataset, to investigate whether and how the visual information can help resolve pronouns in dialogues. 11 | We then propose a novel visual-aware PCR model, VisCoref, for this task and conduct comprehensive experiments and case studies on our dataset. 12 | Results demonstrate the importance of the visual information in this PCR case and show the effectiveness of the proposed model. 13 | 14 |
15 | 16 |
17 | 18 | The readers are welcome to star/fork this repository and use it to train your own model, reproduce our experiment, and follow our future work. Please kindly cite our paper: 19 | ``` 20 | @inproceedings{DBLP:conf/emnlp/YuZSSZ19, 21 | author = {Xintong Yu and 22 | Hongming Zhang and 23 | Yangqiu Song and 24 | Yan Song and 25 | Changshui Zhang}, 26 | title = {What You See is What You Get: Visual Pronoun Coreference Resolution 27 | in Dialogues}, 28 | booktitle = {Proceedings of {EMNLP-IJCNLP} 2019}, 29 | pages = {5122--5131}, 30 | publisher = {Association for Computational Linguistics}, 31 | year = {2019}, 32 | url = {https://doi.org/10.18653/v1/D19-1516}, 33 | doi = {10.18653/v1/D19-1516}, 34 | } 35 | ``` 36 | 37 | 38 | 39 | ## VisPro Dataset 40 | VisPro dataset contains coreference annotation of 29,722 pronouns from 5,000 dialogues. 41 | 42 | The train, validation, and test split of VisPro dataset are in `data` directory. 43 | 44 | ### An example of VisPro 45 |
46 | 47 |
48 | Mentions in the same coreference cluster are in the same color. 49 | 50 | ### Annotation Format 51 | Each line contains the annotation of one dialog. 52 | ``` 53 | { 54 | "doc_key": str, # e.g. in "dl:train:0", "dl" indicates "dialog" genre to be compatible with the CoNLL format, and it is the same for all dialogs in VisPro; "train" means that it is from the train split of VisDial (note that the split of VisPro is not the same as VisDial); "0" is the original index in the randomly selected 5,000 VisDial dialogs; basically this key serves as an index of the dialog 55 | "image_file": str, # the image filename of the dialog 56 | "object_detection": list, # the ids of object labels from 80 categories of MSCOCO object detection challenge 57 | "sentences": list, 58 | "speakers": list, 59 | "cluster": list, # each element is a cluster, and each element within a cluster is a mention 60 | "correct_caption_NPs": list, # the noun phrases in the caption 61 | "pronoun_info": list 62 | } 63 | ``` 64 | 65 | Each element of `"pronoun_info"` contains the annotation of one pronoun. 66 | ``` 67 | { 68 | "current_pronoun": list, 69 | "reference_type": int, 70 | "not_discussed": bool, 71 | "candidate_NPs": list, 72 | "correct_NPs": list 73 | } 74 | ``` 75 | Text spans are denoted as [index_start, index_end] of their positions in the whole dialogue, and the indices is counted by concatenating all sentences of the dialogue together. 76 | 77 | `"current_pronoun"`, `"candidate_NPs"`, and `"correct_NPs"` are positions of the pronouns, the candidate noun phrases and the correct noun phrases of antecedents respectively. 78 | 79 | `"reference_type"` has 3 values. 0 for pronouns which refers to noun phrases in the text, 1 for pronouns whose antecedents are not in the candidate list, 2 for non-referential pronouns. 80 | 81 | `"not_discussed"` indicates whether the antecedents of the pronoun is discussed in the dialogue text. 82 | 83 | Take the first dialog in the test split of VisPro as example: 84 | ``` 85 | { 86 | "pronoun_info": [{"current_pronoun": [15, 15], "candidate_NPs": [[0, 1], [3, 4], [6, 8], [10, 10], [12, 12]], "reference_type": 0, "correct_NPs": [[0, 1], [10, 10]], "not_discussed": false}], 87 | "sentences": [["A", "firefighter", "rests", "his", "arm", "on", "a", "parking", "meter", "as", "another", "walks", "past", "."], ["Is", "he", "in", "his", "gear", "?"]], 88 | "doc_key": "dl:train:152" 89 | } 90 | ``` 91 | Here [0, 1] indicates the phrase of "a firefighter", [3, 4] indicates "his arms", [6, 8] indicates "a parking meter", [10, 10] indicates "another", [12, 12] indicates "past", and [15, 15] indicates "he." 92 | For the current pronoun "he", "candidate_NPs" means that "a firefighter", "his arms", "a parking meter", "another", "past" all serve as candidates for antecedents, while "correct_caption_NPs" means that only "a firefighter" and "another" are correct antecedents. 93 | 94 | The "doc_key" means that it is the 152th selected dialog from the train split of VisDial. 95 | 96 | 97 | ## Usage of VisCoref 98 | 99 | ### An Example of VisCoref Prediction 100 |
101 | 102 |
103 | 104 | The figure shows an example of a VisCoref prediction with the image, the relevant part of the dialogue, the prediction result, and the heatmap of the text-object similarity. We indicate the target pronoun with the *underlined italics* font and the candidate mentions with bold font. The row of the heatmap represents the mention in the context and the column means the detected object labels from the image. 105 | 106 | ### Getting Started 107 | * Install python 3.7 and the following requirements: `pip install -r requirements.txt`. Set default python under your system to python 3.7. 108 | * Download supplementary data for training VisCoref and the pretrained model from [Data](https://drive.google.com/open?id=1dSeGz5k57bU2GXCt7sY9krykLvmnbiVx) and extract: `tar -xzvf VisCoref.tar.gz`. 109 | * Move VisPro data and supplementary data end with `.jsonlines` to `data` directory and move the pretrained model to `logs` directory. 110 | * Download GloVe embeddings and build custom kernels by running `setup_all.sh`. 111 | * There are 3 platform-dependent ways to build custom TensorFlow kernels. Please comment/uncomment the appropriate lines in the script. 112 | * Setup training files by running `setup_training.sh`. 113 | 114 | ### Traning Instructions 115 | 116 | * Experiment configurations are found in `experiments.conf` 117 | * Choose an experiment that you would like to run, e.g. `best` 118 | * For training and prediction, set the `GPU` environment variable, which the code treats as shorthand for `CUDA_VISIBLE_DEVICES`. 119 | * (optional) For the "End-to-end + Visual" baseline, first download images from [VisDial](https://visualdialog.org/data) to the `data/images` folder, then run `python get_im_fc.py` to get image features. 120 | * Training: `python train.py ` 121 | * Results are stored in the `logs` directory and can be viewed via TensorBoard. 122 | * Prediction: `python predict.py ` 123 | * Evaluation: `python evaluate.py ` 124 | 125 | ## Acknowledgment 126 | VisPro dataset is based on [VisDial v1.0](https://visualdialog.org/). 127 | 128 | We built the training framework based on the original [End-to-end Coreference Resolution](https://github.com/kentonl/e2e-coref). 129 | 130 | ## Others 131 | If you have questions about the data or the code, you are welcome to open an issue or send me an email, I will respond to that as soon as possible. -------------------------------------------------------------------------------- /cache_elmo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow_hub as hub 7 | import h5py 8 | import json 9 | import argparse 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='cache elmo embedding') 13 | 14 | parser.add_argument('--dataset', type=str, default='vispro', 15 | help='dataset: vispro, vispro_cdd, vispro_mscoco') 16 | 17 | args = parser.parse_args() 18 | return args 19 | 20 | def build_elmo(): 21 | token_ph = tf.placeholder(tf.string, [None, None]) 22 | len_ph = tf.placeholder(tf.int32, [None]) 23 | elmo_module = hub.Module("https://tfhub.dev/google/elmo/2") 24 | lm_embeddings = elmo_module( 25 | inputs={"tokens": token_ph, "sequence_len": len_ph}, 26 | signature="tokens", as_dict=True) 27 | word_emb = lm_embeddings["word_emb"] 28 | lm_emb = tf.stack([tf.concat([word_emb, word_emb], -1), 29 | lm_embeddings["lstm_outputs1"], 30 | lm_embeddings["lstm_outputs2"]], -1) 31 | return token_ph, len_ph, lm_emb 32 | 33 | def cache_dataset(data_path, session, dataset, token_ph, len_ph, lm_emb, out_file): 34 | with open(data_path) as in_file: 35 | for doc_num, line in enumerate(in_file.readlines()): 36 | example = json.loads(line) 37 | sentences = example["sentences"] 38 | 39 | if dataset == 'vispro': 40 | caption = sentences.pop(0) 41 | 42 | max_sentence_length = max(len(s) for s in sentences) 43 | tokens = [[""] * max_sentence_length for _ in sentences] 44 | text_len = np.array([len(s) for s in sentences]) 45 | 46 | for i, sentence in enumerate(sentences): 47 | for j, word in enumerate(sentence): 48 | tokens[i][j] = word 49 | tokens = np.array(tokens) 50 | 51 | if dataset == 'vispro': 52 | # extract dialog 53 | tf_lm_emb_dial = session.run(lm_emb, feed_dict={ 54 | token_ph: tokens, 55 | len_ph: text_len 56 | }) 57 | file_key = example["doc_key"].replace("/", ":") 58 | group = out_file.create_group(file_key) 59 | for i, (e, l) in enumerate(zip(tf_lm_emb_dial, text_len)): 60 | e = e[:l, :, :] 61 | group[str(i + 1)] = e 62 | 63 | # extract caption alone 64 | # extract spans from caption 65 | caption_NPs = example['correct_caption_NPs'] 66 | file_key = file_key + ':cap' 67 | group = out_file.create_group(file_key) 68 | # caption_NPs might be empty 69 | if len(caption_NPs) == 0: 70 | continue 71 | # extract elmo feature for all spans 72 | span_len = [c[1] - c[0] + 1 for c in caption_NPs] 73 | span_list = [[""] * max(span_len) for _ in caption_NPs] 74 | for i, (span_start, span_end) in enumerate(caption_NPs): 75 | for j, index in enumerate(range(span_start, span_end + 1)): 76 | span_list[i][j] = caption[index].lower() 77 | span_list = np.array(span_list) 78 | tf_lm_emb_cap = session.run(lm_emb, feed_dict={ 79 | token_ph: span_list, 80 | len_ph: span_len 81 | }) 82 | for i, (e, l) in enumerate(zip(tf_lm_emb_cap, span_len)): 83 | e = e[:l, :, :] 84 | group[str(i)] = e 85 | 86 | else: 87 | tf_lm_emb = session.run(lm_emb, feed_dict={ 88 | token_ph: tokens, 89 | len_ph: text_len 90 | }) 91 | file_key = example["doc_key"].replace("/", ":") 92 | group = out_file.create_group(file_key) 93 | for i, (e, l) in enumerate(zip(tf_lm_emb, text_len)): 94 | e = e[:l, :, :] 95 | group[str(i)] = e 96 | 97 | if doc_num % 10 == 0: 98 | print(f"Cached {doc_num + 1} documents in {data_path}") 99 | 100 | if __name__ == "__main__": 101 | token_ph, len_ph, lm_emb = build_elmo() 102 | 103 | args = parse_args() 104 | if args.dataset == 'vispro': 105 | json_filenames = ['data/' + s + '.vispro.1.1.jsonlines' 106 | for s in ['train', 'val', 'test']] 107 | elif args.dataset == 'vispro_cdd': 108 | json_filenames = ['data/cdd_np.vispro.1.1.jsonlines'] 109 | elif args.dataset == 'vispro_mscoco': 110 | json_filenames = ['data/mscoco_label.jsonlines'] 111 | config = tf.ConfigProto() 112 | config.gpu_options.allow_growth = True 113 | with tf.Session(config=config) as session: 114 | session.run(tf.global_variables_initializer()) 115 | h5_filename = "data/elmo_cache.%s.hdf5" % args.dataset 116 | out_file = h5py.File(h5_filename, "w") 117 | for json_filename in json_filenames: 118 | cache_dataset(json_filename, session, args.dataset, token_ph, len_ph, lm_emb, out_file) 119 | out_file.close() 120 | -------------------------------------------------------------------------------- /coref_kernels.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "tensorflow/core/framework/op.h" 4 | #include "tensorflow/core/framework/shape_inference.h" 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | 7 | using namespace tensorflow; 8 | 9 | REGISTER_OP("ExtractSpans") 10 | .Input("span_scores: float32") 11 | .Input("candidate_starts: int32") 12 | .Input("candidate_ends: int32") 13 | .Input("num_output_spans: int32") 14 | .Input("max_sentence_length: int32") 15 | .Attr("sort_spans: bool") 16 | .Output("output_span_indices: int32"); 17 | 18 | class ExtractSpansOp : public OpKernel { 19 | public: 20 | explicit ExtractSpansOp(OpKernelConstruction* context) : OpKernel(context) { 21 | OP_REQUIRES_OK(context, context->GetAttr("sort_spans", &_sort_spans)); 22 | } 23 | 24 | void Compute(OpKernelContext* context) override { 25 | TTypes::ConstMatrix span_scores = context->input(0).matrix(); 26 | TTypes::ConstMatrix candidate_starts = context->input(1).matrix(); 27 | TTypes::ConstMatrix candidate_ends = context->input(2).matrix(); 28 | TTypes::ConstVec num_output_spans = context->input(3).vec(); 29 | int max_sentence_length = context->input(4).scalar()(); 30 | 31 | int num_sentences = span_scores.dimension(0); 32 | int num_input_spans = span_scores.dimension(1); 33 | int max_num_output_spans = 0; 34 | for (int i = 0; i < num_sentences; i++) { 35 | if (num_output_spans(i) > max_num_output_spans) { 36 | max_num_output_spans = num_output_spans(i); 37 | } 38 | } 39 | 40 | Tensor* output_span_indices_tensor = nullptr; 41 | TensorShape output_span_indices_shape({num_sentences, max_num_output_spans}); 42 | OP_REQUIRES_OK(context, context->allocate_output(0, output_span_indices_shape, 43 | &output_span_indices_tensor)); 44 | TTypes::Matrix output_span_indices = output_span_indices_tensor->matrix(); 45 | 46 | std::vector> sorted_input_span_indices(num_sentences, 47 | std::vector(num_input_spans)); 48 | for (int i = 0; i < num_sentences; i++) { 49 | std::iota(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 0); 50 | std::sort(sorted_input_span_indices[i].begin(), sorted_input_span_indices[i].end(), 51 | [&span_scores, &i](int j1, int j2) { 52 | return span_scores(i, j2) < span_scores(i, j1); 53 | }); 54 | } 55 | 56 | for (int l = 0; l < num_sentences; l++) { 57 | std::vector top_span_indices; 58 | std::unordered_map end_to_earliest_start; 59 | std::unordered_map start_to_latest_end; 60 | 61 | int current_span_index = 0, 62 | num_selected_spans = 0; 63 | while (num_selected_spans < num_output_spans(l) && current_span_index < num_input_spans) { 64 | int i = sorted_input_span_indices[l][current_span_index]; 65 | bool any_crossing = false; 66 | const int start = candidate_starts(l, i); 67 | const int end = candidate_ends(l, i); 68 | for (int j = start; j <= end; ++j) { 69 | auto latest_end_iter = start_to_latest_end.find(j); 70 | if (latest_end_iter != start_to_latest_end.end() && j > start && latest_end_iter->second > end) { 71 | // Given (), exists [], such that ( [ ) ] 72 | any_crossing = true; 73 | break; 74 | } 75 | auto earliest_start_iter = end_to_earliest_start.find(j); 76 | if (earliest_start_iter != end_to_earliest_start.end() && j < end && earliest_start_iter->second < start) { 77 | // Given (), exists [], such that [ ( ] ) 78 | any_crossing = true; 79 | break; 80 | } 81 | } 82 | if (!any_crossing) { 83 | if (_sort_spans) { 84 | top_span_indices.push_back(i); 85 | } else { 86 | output_span_indices(l, num_selected_spans) = i; 87 | } 88 | ++num_selected_spans; 89 | // Update data struct. 90 | auto latest_end_iter = start_to_latest_end.find(start); 91 | if (latest_end_iter == start_to_latest_end.end() || end > latest_end_iter->second) { 92 | start_to_latest_end[start] = end; 93 | } 94 | auto earliest_start_iter = end_to_earliest_start.find(end); 95 | if (earliest_start_iter == end_to_earliest_start.end() || start < earliest_start_iter->second) { 96 | end_to_earliest_start[end] = start; 97 | } 98 | } 99 | ++current_span_index; 100 | } 101 | // Sort and populate selected span indices. 102 | if (_sort_spans) { 103 | std::sort(top_span_indices.begin(), top_span_indices.end(), 104 | [&candidate_starts, &candidate_ends, &l] (int i1, int i2) { 105 | if (candidate_starts(l, i1) < candidate_starts(l, i2)) { 106 | return true; 107 | } else if (candidate_starts(l, i1) > candidate_starts(l, i2)) { 108 | return false; 109 | } else if (candidate_ends(l, i1) < candidate_ends(l, i2)) { 110 | return true; 111 | } else if (candidate_ends(l, i1) > candidate_ends(l, i2)) { 112 | return false; 113 | } else { 114 | return i1 < i2; 115 | } 116 | }); 117 | for (int i = 0; i < num_output_spans(l); ++i) { 118 | output_span_indices(l, i) = top_span_indices[i]; 119 | } 120 | } 121 | // Pad with the first span index. 122 | for (int i = num_selected_spans; i < max_num_output_spans; ++i) { 123 | output_span_indices(l, i) = output_span_indices(l, 0); 124 | } 125 | } 126 | } 127 | private: 128 | bool _sort_spans; 129 | }; 130 | 131 | REGISTER_KERNEL_BUILDER(Name("ExtractSpans").Device(DEVICE_CPU), ExtractSpansOp); 132 | -------------------------------------------------------------------------------- /coref_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python import pywrap_tensorflow 7 | 8 | coref_op_library = tf.load_op_library("./coref_kernels.so") 9 | 10 | extract_spans = coref_op_library.extract_spans 11 | tf.NotDifferentiable("ExtractSpans") 12 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import json 7 | import util 8 | import numpy as np 9 | import argparse 10 | import os.path as osp 11 | 12 | parser = argparse.ArgumentParser(description='evaluate pronoun resolution on trained model') 13 | parser.add_argument('model', type=str, 14 | help='model name to evaluate') 15 | parser.add_argument('--split', type=str, default='test', 16 | help='split to evaluate, test or val') 17 | parser.add_argument('--output_dir', type=str, default='output', 18 | help='output dir') 19 | 20 | pronoun_list = ['she', 'her', 'he', 'him', 'them', 'they', 'She', 'Her', 'He', 'Him', 'Them', 'They', 'it', 'It', 'his', 'hers', 'its', 'their', 'theirs', 'His', 'Hers', 'Its', 'Their', 'Theirs'] 21 | 22 | 23 | def main(args): 24 | # load data 25 | evaluate_file = f'{args.split}.vispro.1.1.prediction.jsonlines' 26 | evaluate_file = osp.join(args.output_dir, args.model, evaluate_file) 27 | test_data = list() 28 | with open(evaluate_file, 'r') as f: 29 | for line in f: 30 | tmp_example = json.loads(line) 31 | test_data.append(tmp_example) 32 | print(f'Evaluate prediction of {evaluate_file}') 33 | 34 | # initialize variables 35 | eval_types = ['all', 'not_discussed', 'discussed'] 36 | all_antecedent = {key: 0 for key in eval_types} 37 | predict_antecedent = {key: 0 for key in eval_types} 38 | correct_antecedent = {key: 0 for key in eval_types} 39 | 40 | # deal with each dialog 41 | for i, tmp_example in enumerate(test_data): 42 | all_sentence = list() 43 | caption_len = len(tmp_example['sentences'][0]) 44 | predicted_clusters = [tuple(pc) for pc in tmp_example['predicted_clusters']] 45 | 46 | for s in tmp_example['sentences']: 47 | all_sentence += s 48 | 49 | tokens_cdd = tmp_example['cdd_sentences'] 50 | cdd_tokens_len = [len(t) for t in tokens_cdd] 51 | cdd_tokens_end = np.cumsum(cdd_tokens_len) + len(all_sentence) 52 | cdd_nps = [] 53 | for cdd_end, cdd_len in zip(cdd_tokens_end, cdd_tokens_len): 54 | cdd_nps.append([cdd_end - cdd_len, cdd_end - 1]) 55 | 56 | for pronoun_example in tmp_example['pronoun_info']: 57 | tmp_pronoun = all_sentence[pronoun_example['current_pronoun'][0]] 58 | tmp_pronoun_index = pronoun_example['current_pronoun'][0] 59 | 60 | tmp_candidate_NPs = pronoun_example['candidate_NPs'] 61 | tmp_candidate_NPs += cdd_nps 62 | tmp_candidate_NPs = tuple(tmp_candidate_NPs) 63 | tmp_correct_candidate_NPs = tuple(pronoun_example['correct_NPs']) 64 | not_discussed_flag = pronoun_example['not_discussed'] 65 | if not_discussed_flag: 66 | all_antecedent['not_discussed'] += len(tmp_correct_candidate_NPs) 67 | else: 68 | all_antecedent['discussed'] += len(tmp_correct_candidate_NPs) 69 | all_antecedent['all'] += len(tmp_correct_candidate_NPs) 70 | 71 | find_pronoun = False 72 | for cluster_id, coref_cluster in enumerate(predicted_clusters): 73 | for mention in coref_cluster: 74 | if mention[0] == tmp_pronoun_index: 75 | find_pronoun = True 76 | find_cluster_id = cluster_id 77 | if find_pronoun: 78 | break 79 | if find_pronoun and pronoun_example['reference_type'] == 0: 80 | coref_cluster = predicted_clusters[find_cluster_id] 81 | matched_cdd_np_ids = [] 82 | matched_crr_np_ids = [] 83 | for mention in coref_cluster: 84 | mention_start_index = mention[0] 85 | tmp_mention_span = ( 86 | mention_start_index, 87 | mention[1]) 88 | matched_np_id = util.verify_correct_NP_match(tmp_mention_span, tmp_candidate_NPs, 'cover', matched_cdd_np_ids) 89 | if matched_np_id is not None: 90 | # exclude such scenario: predict 'its' and overlap with candidate 'its eyes' 91 | # predict +1 but correct +0 92 | if tmp_mention_span[0] < len(all_sentence) and\ 93 | tmp_mention_span[0] == tmp_mention_span[1] and\ 94 | all_sentence[tmp_mention_span[0]] in pronoun_list and\ 95 | len(tmp_candidate_NPs[matched_np_id]) > 1: 96 | continue 97 | matched_cdd_np_ids.append(matched_np_id) 98 | predict_antecedent['all'] += 1 99 | if not_discussed_flag: 100 | predict_antecedent['not_discussed'] += 1 101 | else: 102 | predict_antecedent['discussed'] += 1 103 | matched_np_id = util.verify_correct_NP_match(tmp_mention_span, tmp_correct_candidate_NPs, 'cover', matched_crr_np_ids) 104 | if matched_np_id is not None: 105 | matched_crr_np_ids.append(matched_np_id) 106 | correct_antecedent['all'] += 1 107 | if not_discussed_flag: 108 | correct_antecedent['not_discussed'] += 1 109 | else: 110 | correct_antecedent['discussed'] += 1 111 | 112 | print('Pronoun resolution') 113 | results = [] 114 | for key in ['discussed', 'not_discussed', 'all']: 115 | p = 0 if predict_antecedent[key] == 0 else correct_antecedent[key] / predict_antecedent[key] 116 | r = 0 if all_antecedent[key] == 0 else correct_antecedent[key] / all_antecedent[key] 117 | f1 = 0 if p + r == 0 else 2 * p * r / (p + r) 118 | results.extend([p, r, f1]) 119 | print(key) 120 | print(f'\tP: {p * 100:.2f}, R: {r * 100:.2f}, F1: {f1 * 100:.2f}') 121 | print(f'\tall: {all_antecedent[key]}, predict: {predict_antecedent[key]}, correct: {correct_antecedent[key]}') 122 | 123 | return results 124 | 125 | 126 | if __name__ == "__main__": 127 | args = parser.parse_args() 128 | main(args) 129 | -------------------------------------------------------------------------------- /experiments.conf: -------------------------------------------------------------------------------- 1 | # Word embeddings. 2 | glove_300d { 3 | path = data/glove.840B.300d.txt 4 | size = 300 5 | } 6 | glove_300d_filtered { 7 | path = data/glove.840B.300d.txt.filtered 8 | size = 300 9 | } 10 | glove_300d_2w { 11 | path = data/glove_50_300_2.txt 12 | size = 300 13 | } 14 | glove_300d_2w_filtered { 15 | path = data/glove_50_300_2.txt.filtered 16 | size = 300 17 | } 18 | 19 | # Main configuration. 20 | best { 21 | # Computation limits. 22 | max_top_antecedents = 50 23 | max_training_sentences = 50 24 | top_span_ratio = 0.4 25 | 26 | # Model hyperparameters. 27 | filter_widths = [3, 4, 5] 28 | filter_size = 50 29 | char_embedding_size = 8 30 | char_vocab_path = "data/char_vocab.txt" 31 | context_embeddings = ${glove_300d_filtered} 32 | head_embeddings = ${glove_300d_2w_filtered} 33 | contextualization_size = 200 34 | contextualization_layers = 3 35 | ffnn_size = 150 36 | ffnn_depth = 2 37 | feature_size = 20 38 | max_span_width = 20 39 | use_metadata = true 40 | use_features = true 41 | model_heads = true 42 | coref_depth = 2 43 | lm_layers = 3 44 | lm_size = 1024 45 | 46 | num_cdd_pool = 30 47 | use_im = true 48 | im_emb_size = 512 49 | vis_weight = 0.4 50 | ffnn_size_im = 100 51 | ffnn_depth_im = 1 52 | 53 | # End-to-End + Visual baseline 54 | use_im_fc = false 55 | im_fc_feat_path = data/resnet152_feat.hdf5 56 | im_fc_feat_size = 2048 57 | im_layer = 0 58 | im_fc_emb_size = 512 59 | im_dropout_rate = 0 60 | 61 | # Learning hyperparameters. 62 | max_gradient_norm = 5.0 63 | lstm_dropout_rate = 0.4 64 | lexical_dropout_rate = 0.5 65 | dropout_rate = 0.2 66 | optimizer = adam 67 | learning_rate = 0.001 68 | decay_rate = 0.999 69 | decay_frequency = 100 70 | random_seed = 2019 71 | max_step = 50000 72 | 73 | # Other. 74 | train_path = data/train.vispro.1.1.jsonlines 75 | eval_path = data/val.vispro.1.1.jsonlines 76 | lm_path = data/elmo_cache.vispro.hdf5 77 | cdd_path = data/cdd_np.vispro.1.1.jsonlines 78 | lm_cdd_path = data/elmo_cache.vispro_cdd.hdf5 79 | im_obj_label_path = data/mscoco_label.jsonlines 80 | lm_obj_path = data/elmo_cache.vispro_mscoco.hdf5 81 | eval_frequency = 5000 82 | report_frequency = 100 83 | log_root = /home/yuxintong/pr4vd/Visual_PCR/logs 84 | } 85 | 86 | best_predict = ${best} { 87 | context_embeddings = ${glove_300d} 88 | head_embeddings = ${glove_300d_2w} 89 | } 90 | 91 | e2e_baseline = ${best} { 92 | use_im = false 93 | } 94 | 95 | e2e_visual_baseline = ${best} { 96 | use_im_fc = true 97 | } -------------------------------------------------------------------------------- /fig/case_study1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/Visual_PCR/19417967514c0c9a3da2a510f291b7482326e4b7/fig/case_study1.png -------------------------------------------------------------------------------- /fig/data_example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/Visual_PCR/19417967514c0c9a3da2a510f291b7482326e4b7/fig/data_example.PNG -------------------------------------------------------------------------------- /fig/dialog_example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/Visual_PCR/19417967514c0c9a3da2a510f291b7482326e4b7/fig/dialog_example.PNG -------------------------------------------------------------------------------- /filter_embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import json 5 | import argparse 6 | import os.path as osp 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='filter golve embedding') 10 | 11 | parser.add_argument('--embedding', type=str, default='glove.840B.300d.txt', 12 | help='glove embedding file') 13 | 14 | args = parser.parse_args() 15 | return args 16 | 17 | if __name__ == "__main__": 18 | args = parse_args() 19 | 20 | json_filenames = ['data/' + s + '.vispro.1.1.jsonlines' 21 | for s in ['train', 'val', 'test']] 22 | 23 | words_to_keep = set() 24 | for json_filename in json_filenames: 25 | print(f'Open {json_filename}') 26 | with open(json_filename) as json_file: 27 | for line in json_file.readlines(): 28 | for sentence in json.loads(line)["sentences"]: 29 | words_to_keep.update(sentence) 30 | 31 | print(f"Found {len(words_to_keep)} words in {len(json_filenames)} dataset(s).") 32 | 33 | total_lines = 0 34 | kept_lines = 0 35 | out_filename = "{}.filtered".format(args.embedding) 36 | with open(osp.join('data', args.embedding)) as in_file: 37 | with open(osp.join('data', out_filename), "w") as out_file: 38 | for line in in_file.readlines(): 39 | total_lines += 1 40 | word = line.split()[0] 41 | if word in words_to_keep: 42 | kept_lines += 1 43 | out_file.write(line) 44 | 45 | print(f"Kept {kept_lines} out of {total_lines} lines.") 46 | print(f"Wrote result to {out_filename}.") 47 | -------------------------------------------------------------------------------- /get_char_vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import json 5 | 6 | 7 | def get_char_vocab(input_filenames, output_filename): 8 | vocab = set() 9 | for filename in input_filenames: 10 | with open(filename) as f: 11 | for line in f.readlines(): 12 | for sentence in json.loads(line)["sentences"]: 13 | for word in sentence: 14 | vocab.update(word) 15 | vocab = sorted(list(vocab)) 16 | with open(output_filename, "w") as f: 17 | for char in vocab: 18 | f.write(u"{}\n".format(char)) 19 | print(f"Wrote {len(vocab)} characters to {output_filename}") 20 | 21 | if __name__ == "__main__": 22 | json_filenames = ['data/' + s + '.vispro.1.1.jsonlines' 23 | for s in ['train', 'val', 'test']] 24 | get_char_vocab(json_filenames, 'data/char_vocab.txt') 25 | -------------------------------------------------------------------------------- /get_im_fc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import random 5 | import time 6 | import json 7 | import os.path as osp 8 | from PIL import Image 9 | import torch 10 | import torch.nn as nn 11 | from torchvision import models 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | import h5py 15 | from collections import OrderedDict 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='extract features of Resnet') 19 | 20 | parser.add_argument('--image_dir', type=str, default='data/images', 21 | help='image dir') 22 | parser.add_argument('--resnet', type=int, default=152, 23 | help='use resnet 101 or 152') 24 | parser.add_argument('--append_hdf5', action='store_true', 25 | help='append to existing hdf5, allow feature extraction after abortion') 26 | 27 | 28 | args = parser.parse_args() 29 | 30 | class myResnet(nn.Module): 31 | def __init__(self, resnet): 32 | super(myResnet, self).__init__() 33 | self.resnet = resnet 34 | 35 | def forward(self, img): 36 | x = img.unsqueeze(0) 37 | 38 | x = self.resnet.conv1(x) 39 | x = self.resnet.bn1(x) 40 | x = self.resnet.relu(x) 41 | x = self.resnet.maxpool(x) 42 | 43 | x = self.resnet.layer1(x) 44 | x = self.resnet.layer2(x) 45 | x = self.resnet.layer3(x) 46 | x = self.resnet.layer4(x) 47 | x = self.resnet.avgpool(x) 48 | x = x.view(x.size(0), -1) 49 | 50 | return x 51 | 52 | # prepare resnet 53 | if args.resnet == 101: 54 | resnet = models.resnet101(pretrained=True) 55 | elif args.resnet == 152: 56 | resnet = models.resnet152(pretrained=True) 57 | net = myResnet(resnet).cuda().eval() 58 | trans = transforms.Compose([ 59 | transforms.Resize((448,448)), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 62 | ]) 63 | 64 | # create hdf5 65 | h5_filename = 'data/resnet{}_feat.hdf5'.format(args.resnet) 66 | if args.append_hdf5: 67 | f = h5py.File(h5_filename, 'a') 68 | written_keys = f.keys() 69 | else: 70 | f = h5py.File(h5_filename, 'w') 71 | 72 | for split in ['train', 'val', 'test']: 73 | # load data 74 | data = [json.loads(line) for line in open('data/{}.vispro.1.1.jsonlines'.format(split))] 75 | 76 | # for each image 77 | for dialog in tqdm(data): 78 | filename = dialog['image_file'] 79 | 80 | # skip images already extracted 81 | if args.append_hdf5 and 'dl:%s:%d' % (split, dialog_id) in written_keys: 82 | continue 83 | 84 | # extract feature and write to hdf5 85 | filename = osp.join(args.image_dir, filename) 86 | img = Image.open(filename) 87 | if len(np.array(img).shape) < 3: 88 | img = Image.merge('RGB', (img,) * 3) 89 | with torch.no_grad(): 90 | feat = net(trans(img).cuda()) 91 | feat = feat.squeeze(0).cpu().data.numpy() 92 | f.create_dataset(dialog['doc_key'], data=feat) 93 | 94 | # save result 95 | f.close() 96 | print('Results saved to ' + h5_filename) 97 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import util 5 | import numpy as np 6 | 7 | 8 | class PrCorefEvaluator(object): 9 | def __init__(self): 10 | self.all_coreference = 0 11 | self.predict_coreference = 0 12 | self.correct_predict_coreference = 0 13 | self.all_coref_discussed = 0 14 | self.predict_coref_discussed = 0 15 | self.correct_predict_coref_discussed = 0 16 | self.all_coref_not_discussed = 0 17 | self.predict_coref_not_discussed = 0 18 | self.correct_predict_coref_not_discussed = 0 19 | 20 | self.pronoun_list = ['she', 'her', 'he', 'him', 'them', 'they', 'She', 'Her', 'He', 'Him', 'Them', 'They', 'it', 'It', 'his', 'hers', 'its', 'their', 'theirs', 'His', 'Hers', 'Its', 'Their', 'Theirs'] 21 | 22 | def get_prf(self): 23 | results = {} 24 | results['p'] = 0 if self.predict_coreference == 0 else self.correct_predict_coreference / self.predict_coreference 25 | results['r'] = 0 if self.all_coreference == 0 else self.correct_predict_coreference / self.all_coreference 26 | results['f'] = 0 if results['p'] + results['r'] == 0 else 2 * results['p'] * results['r'] / (results['p'] + results['r']) 27 | results['p_discussed'] = 0 if self.predict_coref_discussed == 0 else self.correct_predict_coref_discussed / self.predict_coref_discussed 28 | results['r_discussed'] = 0 if self.all_coref_discussed == 0 else self.correct_predict_coref_discussed / self.all_coref_discussed 29 | results['f_discussed'] = 0 if results['p_discussed'] + results['r_discussed'] == 0 else 2 * results['p_discussed'] * results['r_discussed'] / (results['p_discussed'] + results['r_discussed']) 30 | results['p_not_discussed'] = 0 if self.predict_coref_not_discussed == 0 else self.correct_predict_coref_not_discussed / self.predict_coref_not_discussed 31 | results['r_not_discussed'] = 0 if self.all_coref_not_discussed == 0 else self.correct_predict_coref_not_discussed / self.all_coref_not_discussed 32 | results['f_not_discussed'] = 0 if results['p_not_discussed'] + results['r_not_discussed'] == 0 else 2 * results['p_not_discussed'] * results['r_not_discussed'] / (results['p_not_discussed'] + results['r_not_discussed']) 33 | 34 | return results 35 | 36 | 37 | def update(self, predicted_clusters, pronoun_info, sentences, tokens_cdd): 38 | all_sentence = list() 39 | caption_len = len(sentences[0]) 40 | predicted_clusters = [tuple(pc) for pc in predicted_clusters] 41 | 42 | for s in sentences: 43 | all_sentence += s 44 | 45 | cdd_tokens_len = np.sum(tokens_cdd != '', axis=1) 46 | cdd_tokens_end = np.cumsum(cdd_tokens_len) + len(all_sentence) 47 | cdd_nps = [] 48 | for cdd_end, cdd_len in zip(cdd_tokens_end, cdd_tokens_len): 49 | cdd_nps.append([cdd_end - cdd_len, cdd_end - 1]) 50 | 51 | for pronoun_example in pronoun_info: 52 | tmp_pronoun_index = pronoun_example['current_pronoun'][0] 53 | 54 | tmp_candidate_NPs = pronoun_example['candidate_NPs'] 55 | tmp_candidate_NPs += cdd_nps 56 | tmp_correct_candidate_NPs = pronoun_example['correct_NPs'] 57 | 58 | if pronoun_example['not_discussed']: 59 | self.all_coref_not_discussed += len(tmp_correct_candidate_NPs) 60 | else: 61 | self.all_coref_discussed += len(tmp_correct_candidate_NPs) 62 | 63 | find_pronoun = False 64 | for coref_cluster in predicted_clusters: 65 | for mention in coref_cluster: 66 | mention_start_index = mention[0] 67 | if mention_start_index == tmp_pronoun_index: 68 | find_pronoun = True 69 | if find_pronoun and pronoun_example['reference_type'] == 0: 70 | matched_cdd_np_ids = [] 71 | matched_crr_np_ids = [] 72 | for mention in coref_cluster: 73 | mention_start_index = mention[0] 74 | tmp_mention_span = ( 75 | mention_start_index, 76 | mention[1]) 77 | matched_np_id = util.verify_correct_NP_match(tmp_mention_span, tmp_candidate_NPs, 'cover', matched_cdd_np_ids) 78 | if matched_np_id is not None: 79 | # exclude such scenario: predict 'its' and overlap with candidate 'its eyes' 80 | if tmp_mention_span[0] < len(all_sentence) and\ 81 | tmp_mention_span[0] == tmp_mention_span[1] and\ 82 | all_sentence[tmp_mention_span[0]] in self.pronoun_list and\ 83 | len(tmp_candidate_NPs[matched_np_id]) > 1: 84 | continue 85 | matched_cdd_np_ids.append(matched_np_id) 86 | self.predict_coreference += 1 87 | if pronoun_example['not_discussed']: 88 | self.predict_coref_not_discussed += 1 89 | else: 90 | self.predict_coref_discussed += 1 91 | matched_np_id = util.verify_correct_NP_match(tmp_mention_span, tmp_correct_candidate_NPs, 'cover', matched_crr_np_ids) 92 | if matched_np_id is not None: 93 | matched_crr_np_ids.append(matched_np_id) 94 | self.correct_predict_coreference += 1 95 | if pronoun_example['not_discussed']: 96 | self.correct_predict_coref_not_discussed += 1 97 | else: 98 | self.correct_predict_coref_discussed += 1 99 | break 100 | 101 | self.all_coreference += len(tmp_correct_candidate_NPs) 102 | 103 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import os 5 | import math 6 | import json 7 | import threading 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | import h5py 12 | import random 13 | 14 | import util 15 | import coref_ops 16 | import metrics 17 | 18 | class VisCoref(object): 19 | def __init__(self, config): 20 | self.config = config 21 | self.context_embeddings = util.EmbeddingDictionary(config["context_embeddings"]) 22 | self.head_embeddings = util.EmbeddingDictionary(config["head_embeddings"], maybe_cache=self.context_embeddings) 23 | self.char_embedding_size = config["char_embedding_size"] 24 | self.char_dict = util.load_char_dict(config["char_vocab_path"]) 25 | self.max_span_width = config["max_span_width"] 26 | 27 | self.lm_layers = self.config["lm_layers"] 28 | self.lm_size = self.config["lm_size"] 29 | 30 | self.use_im = self.config["use_im"] 31 | im_obj_labels = [json.loads(line) for line in open(self.config["im_obj_label_path"], "r")] 32 | self.id2cat = {int(d["doc_key"]):d["sentences"][0] for d in im_obj_labels} 33 | if self.use_im: 34 | self.lm_obj_file = h5py.File(self.config["lm_obj_path"], "r") 35 | self.im_emb_size = self.config["im_emb_size"] 36 | 37 | # visual baseline 38 | self.use_im_fc = self.config["use_im_fc"] 39 | if self.use_im_fc: 40 | self.im_fc_file = h5py.File(self.config["im_fc_feat_path"], "r") 41 | self.im_fc_feat_size = self.config["im_fc_feat_size"] 42 | self.im_fc_emb_size = self.config["im_fc_emb_size"] 43 | 44 | self.vis_weight = self.config["vis_weight"] 45 | self.num_cdd_pool = self.config["num_cdd_pool"] 46 | self.lm_cdd_file = h5py.File(self.config["lm_cdd_path"], "r") 47 | with open(self.config["cdd_path"]) as f: 48 | self.cdd_nps = [json.loads(jsonline) for jsonline in f.readlines()] 49 | 50 | self.eval_data = None # Load eval data lazily. 51 | self.lm_file = h5py.File(self.config["lm_path"], "r") 52 | print(f'Loading elmo cache from {self.config["lm_path"]}') 53 | 54 | input_props = [] 55 | input_props.append((tf.string, [None, None])) # Tokens. 56 | input_props.append((tf.float32, [None, None, self.context_embeddings.size])) # Context embeddings. 57 | input_props.append((tf.float32, [None, None, self.head_embeddings.size])) # Head embeddings. 58 | input_props.append((tf.float32, [None, None, self.lm_size, self.lm_layers])) # LM embeddings for cap. 59 | input_props.append((tf.float32, [None, None, self.lm_size, self.lm_layers])) # LM embeddings for dial. 60 | input_props.append((tf.int32, [None, None, None])) # Character indices. 61 | input_props.append((tf.int32, [None])) # Text lengths. 62 | input_props.append((tf.int32, [None])) # Speaker IDs. 63 | input_props.append((tf.bool, [])) # Is training. 64 | input_props.append((tf.int32, [None])) # Gold starts. 65 | input_props.append((tf.int32, [None])) # Gold ends. 66 | input_props.append((tf.int32, [None])) # Cluster ids. 67 | input_props.append((tf.int32, [None])) # caption candidate starts. 68 | input_props.append((tf.int32, [None])) # caption candidate ends. 69 | input_props.append((tf.int32, [None])) # Text lengths cdd. 70 | input_props.append((tf.float32, [None, None, self.context_embeddings.size])) # Context embeddings cdd. 71 | input_props.append((tf.float32, [None, None, self.head_embeddings.size])) # Head embeddings cdd. 72 | input_props.append((tf.int32, [None, None, None])) # Character indices cdd. 73 | input_props.append((tf.float32, [None, None, self.lm_size, self.lm_layers])) # LM embeddings for cdd. 74 | input_props.append((tf.string, [None, None])) # Tokens cdd. 75 | input_props.append((tf.int32, [None])) # Text lengths obj. 76 | input_props.append((tf.float32, [None, None, self.context_embeddings.size])) # Context embeddings obj. 77 | input_props.append((tf.float32, [None, None, self.head_embeddings.size])) # Head embeddings obj. 78 | input_props.append((tf.int32, [None, None, None])) # Character indices obj. 79 | input_props.append((tf.float32, [None, None, self.lm_size, self.lm_layers])) # LM embeddings for obj. 80 | input_props.append((tf.string, [None, None])) # Tokens obj. 81 | input_props.append((tf.bool, [])) # Has object. 82 | input_props.append((tf.float32, [self.im_fc_feat_size])) # Image features. 83 | 84 | self.queue_input_tensors = [tf.placeholder(dtype, shape) for dtype, shape in input_props] 85 | dtypes, shapes = zip(*input_props) 86 | queue = tf.PaddingFIFOQueue(capacity=10, dtypes=dtypes, shapes=shapes) 87 | self.enqueue_op = queue.enqueue(self.queue_input_tensors) 88 | self.input_tensors = queue.dequeue() 89 | 90 | self.predictions, self.loss = self.get_predictions_and_loss(*self.input_tensors) 91 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 92 | self.reset_global_step = tf.assign(self.global_step, 0) 93 | self.max_eval_f1 = tf.Variable(0.0, name="max_eval_f1", trainable=False) 94 | learning_rate = tf.train.exponential_decay(self.config["learning_rate"], self.global_step, 95 | self.config["decay_frequency"], self.config["decay_rate"], staircase=True) 96 | trainable_params = tf.trainable_variables() 97 | gradients = tf.gradients(self.loss, trainable_params) 98 | gradients, _ = tf.clip_by_global_norm(gradients, self.config["max_gradient_norm"]) 99 | optimizers = { 100 | "adam" : tf.train.AdamOptimizer, 101 | "sgd" : tf.train.GradientDescentOptimizer 102 | } 103 | optimizer = optimizers[self.config["optimizer"]](learning_rate) 104 | self.train_op = optimizer.apply_gradients(zip(gradients, trainable_params), global_step=self.global_step) 105 | 106 | def start_enqueue_thread(self, session): 107 | with open(self.config["train_path"]) as f: 108 | train_examples = [json.loads(jsonline) for jsonline in f.readlines()] 109 | def _enqueue_loop(): 110 | while True: 111 | global_step = session.run(self.global_step) 112 | random.seed(self.config["random_seed"] + global_step) 113 | random.shuffle(train_examples) 114 | for example in train_examples: 115 | tensorized_example = self.tensorize_example(example, is_training=True) 116 | feed_dict = dict(zip(self.queue_input_tensors, tensorized_example)) 117 | session.run(self.enqueue_op, feed_dict=feed_dict) 118 | enqueue_thread = threading.Thread(target=_enqueue_loop) 119 | enqueue_thread.daemon = True 120 | enqueue_thread.start() 121 | 122 | def restore(self, session, step='max'): 123 | # Don't try to restore unused variables from the TF-Hub ELMo module. 124 | vars_to_restore = [v for v in tf.global_variables() if "module/" not in v.name] 125 | saver = tf.train.Saver(vars_to_restore) 126 | if step == 'max': 127 | path = "model.max.ckpt" 128 | else: 129 | path = "model-" + step 130 | checkpoint_path = os.path.join(self.config["log_dir"], path) 131 | print(f"Restoring from {checkpoint_path}") 132 | session.run(tf.global_variables_initializer()) 133 | saver.restore(session, checkpoint_path) 134 | 135 | def load_lm_embeddings(self, doc_key): 136 | if self.lm_file is None: 137 | return np.zeros([0, 0, self.lm_size, self.lm_layers]) 138 | file_key = doc_key.replace("/", ":") 139 | 140 | group_cap = self.lm_file[file_key + ':cap'] 141 | num_candidates = len(list(group_cap.keys())) 142 | candidates = [group_cap[str(i)][...] for i in range(num_candidates)] 143 | if len(candidates) > 0: 144 | lm_emb_cap = np.zeros([len(candidates), max(c.shape[0] for c in candidates), self.lm_size, self.lm_layers]) 145 | for i, c in enumerate(candidates): 146 | lm_emb_cap[i, :c.shape[0], :, :] = c 147 | else: 148 | # to avoid empty lm_emb_cap 149 | lm_emb_cap = np.zeros([1, 1, self.lm_size, self.lm_layers]) 150 | 151 | group = self.lm_file[file_key] 152 | num_sentences = len(list(group.keys())) 153 | sentences = [group[str(i)][...] for i in range(1, num_sentences + 1)] 154 | lm_emb_dial = np.zeros([len(sentences), max(s.shape[0] for s in sentences), self.lm_size, self.lm_layers]) 155 | for i, s in enumerate(sentences): 156 | lm_emb_dial[i, :s.shape[0], :, :] = s 157 | 158 | return [lm_emb_cap, lm_emb_dial] 159 | 160 | def load_lm_embeddings_cdd(self, examples): 161 | candidates = [self.lm_cdd_file[e['doc_key']]['0'][...] for e in examples] 162 | lm_emb_cdd = np.zeros([len(candidates), max(c.shape[0] for c in candidates), self.lm_size, self.lm_layers]) 163 | for i, c in enumerate(candidates): 164 | lm_emb_cdd[i, :c.shape[0], :, :] = c 165 | 166 | return lm_emb_cdd 167 | 168 | def load_lm_embeddings_obj(self, objs): 169 | objs = [self.lm_obj_file[str(obj)]['0'][...] for obj in objs] 170 | lm_emb_objs = np.zeros([len(objs), max(c.shape[0] for c in objs), self.lm_size, self.lm_layers]) 171 | for i, c in enumerate(objs): 172 | lm_emb_objs[i, :c.shape[0], :, :] = c 173 | 174 | return lm_emb_objs 175 | 176 | def load_im_feat(self, doc_key): 177 | if self.im_fc_file is None: 178 | return np.zeros(self.im_fc_feat_size) 179 | file_key = doc_key.replace("/", ":") 180 | im_feat = self.im_fc_file[file_key][:] 181 | return im_feat 182 | 183 | def tensorize_mentions(self, mentions): 184 | if len(mentions) > 0: 185 | starts, ends = zip(*mentions) 186 | else: 187 | starts, ends = [], [] 188 | return np.array(starts), np.array(ends) 189 | 190 | def tensorize_example(self, example, is_training): 191 | clusters = example["clusters"] 192 | 193 | gold_mentions = sorted(tuple(m) for m in util.flatten(clusters)) 194 | gold_mention_map = {m:i for i,m in enumerate(gold_mentions)} 195 | cluster_ids = np.zeros(len(gold_mentions)) 196 | for cluster_id, cluster in enumerate(clusters): 197 | for mention in cluster: 198 | cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1 199 | 200 | sentences = example["sentences"] 201 | 202 | max_sentence_length = max(len(s) for s in sentences) 203 | max_word_length = max(max(max(len(w) for w in s) for s in sentences), max(self.config["filter_widths"])) 204 | text_len = np.array([len(s) for s in sentences]) 205 | tokens = [[""] * max_sentence_length for _ in sentences] 206 | context_word_emb = np.zeros([len(sentences), max_sentence_length, self.context_embeddings.size]) 207 | head_word_emb = np.zeros([len(sentences), max_sentence_length, self.head_embeddings.size]) 208 | char_index = np.zeros([len(sentences), max_sentence_length, max_word_length]) 209 | for i, sentence in enumerate(sentences): 210 | for j, word in enumerate(sentence): 211 | if i == 0: 212 | word = word.lower() 213 | tokens[i][j] = word 214 | context_word_emb[i, j] = self.context_embeddings[word] 215 | head_word_emb[i, j] = self.head_embeddings[word] 216 | char_index[i, j, :len(word)] = [self.char_dict[c] for c in word] 217 | tokens = np.array(tokens) 218 | 219 | if self.num_cdd_pool > 0: 220 | # random pick samples to a candidate pool of fixed number 221 | num_cdd_pick = self.num_cdd_pool - len(example["correct_caption_NPs"]) 222 | num_cdd_pick = max(1, num_cdd_pick) 223 | cdd_examples = [] 224 | all_sentences = list() 225 | for sent in sentences: 226 | all_sentences += sent 227 | candidate_cur = example["pronoun_info"][-1]["candidate_NPs"] 228 | candidate_cur = [' '.join(all_sentences[c[0]:c[1]+1]) for c in candidate_cur] 229 | if not is_training: 230 | sample_times = 0 231 | while len(cdd_examples) < num_cdd_pick: 232 | if not is_training: 233 | random.seed(example["doc_key"] + str(sample_times)) 234 | cdd_cur = random.choice(self.cdd_nps) 235 | sample_times += 1 236 | else: 237 | cdd_cur = random.choice(self.cdd_nps) 238 | # samples in candidate pools should not be the same as candidate nps 239 | cdd_text = ' '.join(cdd_cur["sentences"][0]).lower() 240 | repeat_flag = False 241 | for cdd in candidate_cur: 242 | if cdd.lower() == cdd_text: 243 | repeat_flag = True 244 | break 245 | if not repeat_flag: 246 | cdd_examples.append(cdd_cur) 247 | 248 | sentences_cdd = [s["sentences"][0] for s in cdd_examples] 249 | max_sentence_length_cdd = max(len(s) for s in sentences_cdd) 250 | max_word_length_cdd = max(max(max(len(w) for w in s) for s in sentences_cdd), max(self.config["filter_widths"])) 251 | text_len_cdd = np.array([len(s) for s in sentences_cdd]) 252 | context_word_emb_cdd = np.zeros([len(sentences_cdd), max_sentence_length_cdd, self.context_embeddings.size]) 253 | head_word_emb_cdd= np.zeros([len(sentences_cdd), max_sentence_length_cdd, self.head_embeddings.size]) 254 | char_index_cdd = np.zeros([len(sentences_cdd), max_sentence_length_cdd, max_word_length_cdd]) 255 | tokens_cdd = [[""] * max_sentence_length_cdd for _ in sentences_cdd] 256 | for i, sentence_cdd in enumerate(sentences_cdd): 257 | for j, word_cdd in enumerate(sentence_cdd): 258 | tokens_cdd[i][j] = word_cdd 259 | context_word_emb_cdd[i, j] = self.context_embeddings[word_cdd] 260 | head_word_emb_cdd[i, j] = self.head_embeddings[word_cdd] 261 | char_index_cdd[i, j, :len(word_cdd)] = [self.char_dict[c] for c in word_cdd] 262 | tokens_cdd = np.array(tokens_cdd) 263 | 264 | lm_emb_cdd = self.load_lm_embeddings_cdd(cdd_examples) 265 | 266 | for len_cdd in text_len_cdd: 267 | example["speakers"].append(['caption',] * len_cdd) 268 | 269 | doc_key = example["doc_key"] 270 | if self.use_im_fc: 271 | im_feat = self.load_im_feat(doc_key) 272 | 273 | if self.use_im: 274 | detections = example["object_detection"] 275 | has_obj = len(detections) > 0 276 | detections.append(0) 277 | sentences_obj = [self.id2cat[i] for i in detections] 278 | max_sentence_length_obj = max(len(s) for s in sentences_obj) 279 | max_word_length_obj = max(max(max(len(w) for w in s) for s in sentences_obj), max(self.config["filter_widths"])) 280 | text_len_obj = np.array([len(s) for s in sentences_obj]) 281 | context_word_emb_obj = np.zeros([len(sentences_obj), max_sentence_length_obj, self.context_embeddings.size]) 282 | head_word_emb_obj= np.zeros([len(sentences_obj), max_sentence_length_obj, self.head_embeddings.size]) 283 | char_index_obj = np.zeros([len(sentences_obj), max_sentence_length_obj, max_word_length_obj]) 284 | tokens_obj = [[""] * max_sentence_length_obj for _ in sentences_obj] 285 | for i, sentence_obj in enumerate(sentences_obj): 286 | for j, word_obj in enumerate(sentence_obj): 287 | tokens_obj[i][j] = word_obj 288 | context_word_emb_obj[i, j] = self.context_embeddings[word_obj] 289 | head_word_emb_obj[i, j] = self.head_embeddings[word_obj] 290 | char_index_obj[i, j, :len(word_obj)] = [self.char_dict[c] for c in word_obj] 291 | lm_emb_obj = self.load_lm_embeddings_obj(example["object_detection"]) 292 | 293 | speakers = util.flatten(example["speakers"]) 294 | speaker_dict = { s:i for i,s in enumerate(set(speakers)) } 295 | speaker_ids = np.array([speaker_dict[s] for s in speakers]) 296 | 297 | caption_candidates = example["correct_caption_NPs"] 298 | if len(caption_candidates) == 0: 299 | # add 1 NP to avoid empty candidates 300 | caption_candidates = [[0, 0]] 301 | candidate_starts_caption, candidate_ends_caption = self.tensorize_mentions(caption_candidates) 302 | 303 | gold_starts, gold_ends = self.tensorize_mentions(gold_mentions) 304 | 305 | lm_emb_cap, lm_emb_dial = self.load_lm_embeddings(doc_key) 306 | 307 | example_tensors = [tokens, context_word_emb, head_word_emb, lm_emb_cap, lm_emb_dial, char_index, text_len, speaker_ids, is_training, gold_starts, gold_ends, cluster_ids, candidate_starts_caption, candidate_ends_caption] 308 | example_tensors.extend([text_len_cdd, context_word_emb_cdd, head_word_emb_cdd, char_index_cdd, lm_emb_cdd, tokens_cdd]) 309 | if self.use_im: 310 | example_tensors.extend([text_len_obj, context_word_emb_obj, head_word_emb_obj, char_index_obj, lm_emb_obj, tokens_obj, has_obj]) 311 | else: 312 | example_tensors.extend([[0], np.zeros([0, 0, self.context_embeddings.size]), 313 | np.zeros([0, 0, self.head_embeddings.size]), 314 | np.zeros([0, 0, 1]), np.zeros([0, 0, self.lm_size, self.lm_layers]), 315 | np.zeros([0, 1]), False]) 316 | if self.use_im_fc: 317 | example_tensors.append(im_feat) 318 | else: 319 | example_tensors.append(np.zeros(self.config["im_fc_feat_size"])) 320 | 321 | 322 | return example_tensors 323 | 324 | def get_candidate_labels(self, candidate_starts, candidate_ends, labeled_starts, labeled_ends, labels): 325 | same_start = tf.equal(tf.expand_dims(labeled_starts, 1), tf.expand_dims(candidate_starts, 0)) # [num_labeled, num_candidates] 326 | same_end = tf.equal(tf.expand_dims(labeled_ends, 1), tf.expand_dims(candidate_ends, 0)) # [num_labeled, num_candidates] 327 | same_span = tf.logical_and(same_start, same_end) # [num_labeled, num_candidates] 328 | candidate_labels = tf.matmul(tf.expand_dims(labels, 0), tf.to_int32(same_span)) # [1, num_candidates] 329 | candidate_labels = tf.squeeze(candidate_labels, 0) # [num_candidates] 330 | return candidate_labels 331 | 332 | def get_dropout(self, dropout_rate, is_training): 333 | return 1 - (tf.to_float(is_training) * dropout_rate) 334 | 335 | def coarse_to_fine_pruning(self, top_span_emb, top_span_mention_scores, c, top_span_cdd_pool_flag=None): 336 | k = util.shape(top_span_emb, 0) 337 | top_span_range = tf.range(k) # [k] 338 | 339 | num_cdd_in_pool = tf.reduce_sum(tf.cast(top_span_cdd_pool_flag, tf.int32)) 340 | num_cdd_in_dial = k - num_cdd_in_pool 341 | top_span_range_cdd = tf.concat([tf.zeros(num_cdd_in_pool, tf.int32), tf.range(1, num_cdd_in_dial + 1)], 0) 342 | antecedent_offsets = tf.expand_dims(top_span_range_cdd, 1) - tf.expand_dims(top_span_range_cdd, 0) # [k, k] 343 | 344 | antecedents_mask = antecedent_offsets >= 1 # [k, k] 345 | fast_antecedent_scores = tf.expand_dims(top_span_mention_scores, 1) + tf.expand_dims(top_span_mention_scores, 0) # [k, k] 346 | fast_antecedent_scores += tf.log(tf.to_float(antecedents_mask)) # [k, k] 347 | fast_antecedent_scores += self.get_fast_antecedent_scores(top_span_emb) # [k, k] 348 | 349 | _, top_antecedents = tf.nn.top_k(fast_antecedent_scores, c, sorted=False) # [k, c] 350 | top_antecedents_mask = util.batch_gather(antecedents_mask, top_antecedents) # [k, c] 351 | top_fast_antecedent_scores = util.batch_gather(fast_antecedent_scores, top_antecedents) # [k, c] 352 | top_antecedent_offsets = util.batch_gather(antecedent_offsets, top_antecedents) # [k, c] 353 | return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets 354 | 355 | def distance_pruning(self, top_span_emb, top_span_mention_scores, c): 356 | k = util.shape(top_span_emb, 0) 357 | top_antecedent_offsets = tf.tile(tf.expand_dims(tf.range(c) + 1, 0), [k, 1]) # [k, c] 358 | raw_top_antecedents = tf.expand_dims(tf.range(k), 1) - top_antecedent_offsets # [k, c] 359 | top_antecedents_mask = raw_top_antecedents >= 0 # [k, c] 360 | top_antecedents = tf.maximum(raw_top_antecedents, 0) # [k, c] 361 | 362 | top_fast_antecedent_scores = tf.expand_dims(top_span_mention_scores, 1) + tf.gather(top_span_mention_scores, top_antecedents) # [k, c] 363 | top_fast_antecedent_scores += tf.log(tf.to_float(top_antecedents_mask)) # [k, c] 364 | return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets 365 | 366 | def get_predictions_and_loss(self, tokens, context_word_emb, head_word_emb, lm_emb_cap, lm_emb_dial, char_index, text_len, speaker_ids, is_training, gold_starts, gold_ends, cluster_ids, candidate_starts_caption, candidate_ends_caption, text_len_cdd, context_word_emb_cdd, head_word_emb_cdd, char_index_cdd, lm_emb_cdd, tokens_cdd, text_len_obj, context_word_emb_obj, head_word_emb_obj, char_index_obj, lm_emb_obj, tokens_obj, has_obj, im_feat): 367 | self.dropout = self.get_dropout(self.config["dropout_rate"], is_training) 368 | self.lexical_dropout = self.get_dropout(self.config["lexical_dropout_rate"], is_training) 369 | self.lstm_dropout = self.get_dropout(self.config["lstm_dropout_rate"], is_training) 370 | if self.use_im_fc: 371 | self.im_dropout = self.get_dropout(self.config["im_dropout_rate"], is_training) 372 | 373 | # for all sentences including caption 374 | num_sentences = tf.shape(context_word_emb)[0] 375 | max_sentence_length = tf.shape(context_word_emb)[1] 376 | 377 | context_emb_list = [context_word_emb] 378 | head_emb_list = [head_word_emb] 379 | 380 | # get char embedding by conv1d on char embeddings of each word 381 | if self.config["char_embedding_size"] > 0: 382 | char_emb_all = tf.get_variable("char_embeddings", [len(self.char_dict), self.config["char_embedding_size"]]) 383 | char_emb = tf.gather(char_emb_all, char_index) # [num_sentences, max_sentence_length, max_word_length, emb] 384 | flattened_char_emb = tf.reshape(char_emb, [num_sentences * max_sentence_length, util.shape(char_emb, 2), util.shape(char_emb, 3)]) # [num_sentences * max_sentence_length, max_word_length, emb] 385 | flattened_aggregated_char_emb = util.cnn(flattened_char_emb, self.config["filter_widths"], self.config["filter_size"]) # [num_sentences * max_sentence_length, emb] 386 | aggregated_char_emb = tf.reshape(flattened_aggregated_char_emb, [num_sentences, max_sentence_length, util.shape(flattened_aggregated_char_emb, 1)]) # [num_sentences, max_sentence_length, emb] 387 | context_emb_list.append(aggregated_char_emb) 388 | head_emb_list.append(aggregated_char_emb) 389 | 390 | # for candidate pool 391 | num_sentences_cdd = tf.shape(context_word_emb_cdd)[0] 392 | max_sentence_length_cdd = tf.shape(context_word_emb_cdd)[1] 393 | 394 | context_emb_list_cdd = [context_word_emb_cdd] 395 | head_emb_list_cdd = [head_word_emb_cdd] 396 | 397 | # get char embedding by conv1d on char embeddings of each word 398 | if self.config["char_embedding_size"] > 0: 399 | char_emb_cdd = tf.gather(char_emb_all, char_index_cdd) # [num_sentences, max_sentence_length, max_word_length, emb] 400 | flattened_char_emb_cdd = tf.reshape(char_emb_cdd, [num_sentences_cdd * max_sentence_length_cdd, util.shape(char_emb_cdd, 2), util.shape(char_emb_cdd, 3)]) # [num_sentences * max_sentence_length, max_word_length, emb] 401 | flattened_aggregated_char_emb_cdd = util.cnn(flattened_char_emb_cdd, self.config["filter_widths"], self.config["filter_size"]) # [num_sentences * max_sentence_length, emb] 402 | aggregated_char_emb_cdd = tf.reshape(flattened_aggregated_char_emb_cdd, [num_sentences_cdd, max_sentence_length_cdd, util.shape(flattened_aggregated_char_emb_cdd, 1)]) # [num_sentences, max_sentence_length, emb] 403 | context_emb_list_cdd.append(aggregated_char_emb_cdd) 404 | head_emb_list_cdd.append(aggregated_char_emb_cdd) 405 | 406 | context_emb_cdd = tf.concat(context_emb_list_cdd, 2) # [num_sentences, max_sentence_length, emb] 407 | head_emb_cdd = tf.concat(head_emb_list_cdd, 2) # [num_sentences, max_sentence_length, emb] 408 | 409 | # extract embedding for NPs in caption here 410 | context_emb = tf.concat(context_emb_list, 2) # [num_sentences, max_sentence_length, emb] 411 | head_emb = tf.concat(head_emb_list, 2) # [num_sentences, max_sentence_length, emb] 412 | text_len_cap = candidate_ends_caption - candidate_starts_caption + 1 413 | max_span_width_cap = tf.math.reduce_max(text_len_cap) 414 | span_indices_cap = tf.expand_dims(tf.range(max_span_width_cap), 0) + tf.expand_dims(candidate_starts_caption, 1) # [num_candidates_cap, max_span_width_cap] 415 | span_indices_cap = tf.minimum(text_len[0] - 1, span_indices_cap) # [num_candidates_cap, max_span_width_cap] 416 | context_emb_cap = tf.gather(context_emb[0], span_indices_cap) # [num_candidates_cap, max_span_width_cap, emb] 417 | head_emb_cap = tf.gather(head_emb[0], span_indices_cap) # [num_candidates_cap, max_span_width_cap, emb] 418 | 419 | # project lm_num_layer to 1 and scale 420 | lm_emb_size = util.shape(lm_emb_dial, 2) 421 | lm_num_layers = util.shape(lm_emb_dial, 3) 422 | # for sentences in dialog only 423 | num_sentences_dial = util.shape(lm_emb_dial, 0) 424 | max_sentence_length_dial = util.shape(lm_emb_dial, 1) 425 | # for caption 426 | num_candidates_cap = util.shape(lm_emb_cap, 0) 427 | max_candidate_length_cap = util.shape(lm_emb_cap, 1) 428 | # get projection and scaling parameter 429 | with tf.variable_scope("lm_aggregation"): 430 | self.lm_weights = tf.nn.softmax(tf.get_variable("lm_scores", [lm_num_layers], initializer=tf.constant_initializer(0.0))) 431 | self.lm_scaling = tf.get_variable("lm_scaling", [], initializer=tf.constant_initializer(1.0)) 432 | # for lm emb of cap 433 | flattened_lm_emb_cap = tf.reshape(lm_emb_cap, [num_candidates_cap * max_candidate_length_cap * lm_emb_size, lm_num_layers]) 434 | flattened_aggregated_lm_emb_cap = tf.matmul(flattened_lm_emb_cap, tf.expand_dims(self.lm_weights, 1)) # [num_candidates_cap * max_candidate_length_cap * emb, 1] 435 | aggregated_lm_emb_cap = tf.reshape(flattened_aggregated_lm_emb_cap, [num_candidates_cap, max_candidate_length_cap, lm_emb_size]) 436 | aggregated_lm_emb_cap *= self.lm_scaling 437 | # for lm emb of dial 438 | flattened_lm_emb_dial = tf.reshape(lm_emb_dial, [num_sentences_dial * max_sentence_length_dial * lm_emb_size, lm_num_layers]) 439 | flattened_aggregated_lm_emb_dial = tf.matmul(flattened_lm_emb_dial, tf.expand_dims(self.lm_weights, 1)) # [num_sentences_dial * max_sentence_length_dial * emb, 1] 440 | aggregated_lm_emb_dial = tf.reshape(flattened_aggregated_lm_emb_dial, [num_sentences_dial, max_sentence_length_dial, lm_emb_size]) 441 | aggregated_lm_emb_dial *= self.lm_scaling 442 | # for lm emb of cdd 443 | num_candidates_cdd = util.shape(lm_emb_cdd, 0) 444 | max_candidate_length_cdd = util.shape(lm_emb_cdd, 1) 445 | flattened_lm_emb_cdd = tf.reshape(lm_emb_cdd, [num_candidates_cdd * max_candidate_length_cdd * lm_emb_size, lm_num_layers]) 446 | flattened_aggregated_lm_emb_cdd = tf.matmul(flattened_lm_emb_cdd, tf.expand_dims(self.lm_weights, 1)) # [num_candidates_cdd * max_candidate_length_cdd * emb, 1] 447 | aggregated_lm_emb_cdd = tf.reshape(flattened_aggregated_lm_emb_cdd, [num_candidates_cdd, max_candidate_length_cdd, lm_emb_size]) 448 | aggregated_lm_emb_cdd *= self.lm_scaling 449 | 450 | context_emb_dial = tf.concat([context_emb[1:, :max_sentence_length_dial], aggregated_lm_emb_dial], 2) # [num_sentences_dial, max_sentence_length_dial, emb] 451 | context_emb_cap = tf.concat([context_emb_cap, aggregated_lm_emb_cap], 2) # [num_candidates_cap, max_candidate_length_cap, emb] 452 | 453 | context_emb_dial = tf.nn.dropout(context_emb_dial, self.lexical_dropout) # [num_sentences_dial, max_sentence_length_dial, emb] 454 | context_emb_cap = tf.nn.dropout(context_emb_cap, self.lexical_dropout) # [num_candidates_cap, max_candidate_length_cap, emb] 455 | head_emb_cap = tf.nn.dropout(head_emb_cap, self.lexical_dropout) # [num_candidates_cap, max_candidate_length_cap, emb] 456 | 457 | context_emb_cdd = tf.concat([context_emb_cdd, aggregated_lm_emb_cdd], 2) # [num_candidates_cdd, max_candidate_length_cdd, emb] 458 | context_emb_cdd = tf.nn.dropout(context_emb_cdd, self.lexical_dropout) # [num_candidates_cdd, max_candidate_length_cdd, emb] 459 | head_emb_cdd = tf.nn.dropout(head_emb_cdd, self.lexical_dropout) # [num_candidates_cdd, max_candidate_length_cdd, emb] 460 | 461 | # len mask for caption and dialog 462 | text_len_dial = text_len[1:] 463 | text_len_mask_dial = tf.sequence_mask(text_len_dial, maxlen=max_sentence_length_dial) # [num_sentence_dial, max_sentence_length_dial] 464 | 465 | # extract lstm feature for cap and dial, and flatten to only valid words for dial 466 | context_outputs_cap = self.lstm_contextualize(context_emb_cap, text_len_cap) # [num_candidates_cap, max_candidate_length_cap, emb] 467 | context_outputs_dial = self.lstm_contextualize(context_emb_dial, text_len_dial, text_len_mask_dial) # [num_words_dial, emb] 468 | num_words_dial = util.shape(context_outputs_dial, 0) 469 | num_words = tf.reduce_sum(text_len) 470 | context_outputs = tf.concat([tf.zeros([num_words - num_words_dial, util.shape(context_outputs_dial, 1)]), context_outputs_dial], 0) # [num_words, emb] 471 | context_outputs_cdd = self.lstm_contextualize(context_emb_cdd, text_len_cdd) # [num_candidates_cdd, max_candidate_length_cdd, emb] 472 | 473 | # flatten head embedding of only valid words 474 | sentence_indices = tf.tile(tf.expand_dims(tf.range(num_sentences), 1), [1, max_sentence_length]) # [num_sentences, max_sentence_length] 475 | text_len_mask = tf.sequence_mask(text_len, maxlen=max_sentence_length) # [num_sentence, max_sentence_length] 476 | flattened_sentence_indices = self.flatten_emb_by_sentence(sentence_indices, text_len_mask) # [num_words] 477 | flattened_head_emb = self.flatten_emb_by_sentence(head_emb, text_len_mask) # [num_words] 478 | 479 | candidate_starts = tf.tile(tf.expand_dims(tf.range(num_words), 1), [1, self.max_span_width]) # [num_words, max_span_width] 480 | candidate_ends = candidate_starts + tf.expand_dims(tf.range(self.max_span_width), 0) # [num_words, max_span_width] 481 | candidate_start_sentence_indices = tf.gather(flattened_sentence_indices, candidate_starts) # [num_words, max_span_width] 482 | candidate_end_sentence_indices = tf.gather(flattened_sentence_indices, tf.minimum(candidate_ends, num_words - 1)) # [num_words, max_span_width] 483 | candidate_mask = tf.logical_and(candidate_ends < num_words, tf.equal(candidate_start_sentence_indices, candidate_end_sentence_indices)) # [num_words, max_span_width] 484 | # keep candidates in dialog, exclude those in caption 485 | candidate_mask_dial = tf.logical_and(candidate_mask, candidate_starts >= text_len[0]) # [num_words, max_span_width] 486 | flattened_candidate_mask_dial = tf.reshape(candidate_mask_dial, [-1]) # [num_words * max_span_width] 487 | 488 | candidate_starts_dial = tf.boolean_mask(tf.reshape(candidate_starts, [-1]), flattened_candidate_mask_dial) # [num_candidates_dial] 489 | candidate_ends_dial = tf.boolean_mask(tf.reshape(candidate_ends, [-1]), flattened_candidate_mask_dial) # [num_candidates_dial] 490 | 491 | candidate_span_emb_dial = self.get_span_emb_dial(flattened_head_emb, context_outputs, candidate_starts_dial, candidate_ends_dial) # [num_candidates, emb] 492 | 493 | # get span emb of candidates in caption 494 | candidate_span_emb_cap = self.get_span_emb_phrases(head_emb_cap, context_outputs_cap, candidate_starts_caption, candidate_ends_caption) # [num_candidates, emb] 495 | 496 | candidate_ends_cdd = tf.cumsum(text_len_cdd) + num_words - 1 497 | candidate_starts_cdd = candidate_ends_cdd - text_len_cdd + 1 498 | candidate_span_emb_cdd = self.get_span_emb_phrases(head_emb_cdd, context_outputs_cdd, candidate_starts_cdd, candidate_ends_cdd) # [num_candidates, emb] 499 | 500 | if self.use_im: 501 | num_sentences_obj = tf.shape(context_word_emb_obj)[0] 502 | max_sentence_length_obj = tf.shape(context_word_emb_obj)[1] 503 | 504 | context_emb_list_obj = [context_word_emb_obj] 505 | head_emb_list_obj = [head_word_emb_obj] 506 | 507 | # get char embedding by conv1d on char embeddings of each word 508 | if self.config["char_embedding_size"] > 0: 509 | char_emb_obj = tf.gather(char_emb_all, char_index_obj) # [num_sentences, max_sentence_length, max_word_length, emb] 510 | flattened_char_emb_obj = tf.reshape(char_emb_obj, [num_sentences_obj * max_sentence_length_obj, util.shape(char_emb_obj, 2), util.shape(char_emb_obj, 3)]) # [num_sentences * max_sentence_length, max_word_length, emb] 511 | flattened_aggregated_char_emb_obj = util.cnn(flattened_char_emb_obj, self.config["filter_widths"], self.config["filter_size"]) # [num_sentences * max_sentence_length, emb] 512 | aggregated_char_emb_obj = tf.reshape(flattened_aggregated_char_emb_obj, [num_sentences_obj, max_sentence_length_obj, util.shape(flattened_aggregated_char_emb_obj, 1)]) # [num_sentences, max_sentence_length, emb] 513 | context_emb_list_obj.append(aggregated_char_emb_obj) 514 | head_emb_list_obj.append(aggregated_char_emb_obj) 515 | 516 | context_emb_obj = tf.concat(context_emb_list_obj, 2) # [num_sentences, max_sentence_length, emb] 517 | head_emb_obj = tf.concat(head_emb_list_obj, 2) # [num_sentences, max_sentence_length, emb] 518 | 519 | num_candidates_obj = util.shape(lm_emb_obj, 0) 520 | max_candidate_length_obj = util.shape(lm_emb_obj, 1) 521 | flattened_lm_emb_obj = tf.reshape(lm_emb_obj, [num_candidates_obj * max_candidate_length_obj * lm_emb_size, lm_num_layers]) 522 | flattened_aggregated_lm_emb_obj = tf.matmul(flattened_lm_emb_obj, tf.expand_dims(self.lm_weights, 1)) # [num_candidates_obj * max_candidate_length_obj * emb, 1] 523 | aggregated_lm_emb_obj = tf.reshape(flattened_aggregated_lm_emb_obj, [num_candidates_obj, max_candidate_length_obj, lm_emb_size]) 524 | aggregated_lm_emb_obj *= self.lm_scaling 525 | 526 | context_emb_obj = tf.concat([context_emb_obj, aggregated_lm_emb_obj], 2) # [num_candidates_obj, max_candidate_length_obj, emb] 527 | context_emb_obj = tf.nn.dropout(context_emb_obj, self.lexical_dropout) # [num_candidates_obj, max_candidate_length_obj, emb] 528 | head_emb_obj = tf.nn.dropout(head_emb_obj, self.lexical_dropout) # [num_candidates_obj, max_candidate_length_obj, emb] 529 | 530 | context_outputs_obj = self.lstm_contextualize(context_emb_obj, text_len_obj) # [num_candidates_obj, max_candidate_length_obj, emb] 531 | 532 | candidate_ends_obj = tf.cumsum(text_len_obj) - 1 533 | candidate_starts_obj = candidate_ends_obj - text_len_obj + 1 534 | obj_span_emb = self.get_span_emb_phrases(head_emb_obj, context_outputs_obj, candidate_starts_obj, candidate_ends_obj) # [num_candidates, emb] 535 | 536 | # concat candidates in caption here 537 | candidate_starts = tf.concat([candidate_starts_cdd, candidate_starts_caption, candidate_starts_dial], 0) 538 | candidate_ends = tf.concat([candidate_ends_cdd, candidate_ends_caption, candidate_ends_dial], 0) 539 | candidate_span_emb = tf.concat([candidate_span_emb_cdd, candidate_span_emb_cap, candidate_span_emb_dial], 0) # [num_candidates, emb] 540 | candidate_cluster_ids_cap = self.get_candidate_labels(candidate_starts_caption, candidate_ends_caption, gold_starts, gold_ends, cluster_ids) 541 | candidate_cluster_ids_dial = self.get_candidate_labels(candidate_starts_dial, candidate_ends_dial, gold_starts, gold_ends, cluster_ids) 542 | candidate_cluster_ids = tf.concat([tf.zeros([util.shape(candidate_starts_cdd, 0)], tf.int32), candidate_cluster_ids_cap, candidate_cluster_ids_dial], 0) # [num_candidates] 543 | candidate_pool_flag = tf.cast(tf.concat([tf.ones(util.shape(candidate_starts_cdd, 0) + util.shape(candidate_starts_caption, 0), tf.int32), tf.zeros(util.shape(candidate_starts_dial, 0), tf.int32)], 0), tf.bool) 544 | 545 | candidate_mention_scores = self.get_mention_scores(candidate_span_emb) # [k, 1] 546 | candidate_mention_scores = tf.squeeze(candidate_mention_scores, 1) # [k] 547 | 548 | k = tf.minimum(tf.to_int32(tf.floor(tf.to_float(util.shape(candidate_starts, 0)) * self.config["top_span_ratio"])), tf.shape(candidate_mention_scores)[0]) 549 | top_span_indices = coref_ops.extract_spans(tf.expand_dims(candidate_mention_scores, 0), 550 | tf.expand_dims(candidate_starts, 0), 551 | tf.expand_dims(candidate_ends, 0), 552 | tf.expand_dims(k, 0), 553 | util.shape(candidate_mention_scores, 0), 554 | True) # [1, k] 555 | top_span_indices.set_shape([1, None]) 556 | top_span_indices = tf.squeeze(top_span_indices, 0) # [k] 557 | # coref_ops add extra 0 to top_span_indices, have to remove it here 558 | first_index = tf.gather(top_span_indices, tf.constant([0])) 559 | valid_indices = tf.boolean_mask(top_span_indices, tf.logical_not(tf.equal(top_span_indices, first_index))) 560 | top_span_indices = tf.concat([first_index, valid_indices], 0) 561 | k = util.shape(top_span_indices, 0) 562 | 563 | # rearrange top_span to put cdd and cap first 564 | top_span_cdd_pool_flag = tf.gather(candidate_pool_flag, top_span_indices) # [k] 565 | top_span_indices_cdd_cap = tf.boolean_mask(top_span_indices, top_span_cdd_pool_flag) 566 | top_span_indices_dial = tf.boolean_mask(top_span_indices, tf.logical_not(top_span_cdd_pool_flag)) 567 | top_span_indices = tf.concat([top_span_indices_cdd_cap, top_span_indices_dial], 0) 568 | 569 | top_span_starts = tf.gather(candidate_starts, top_span_indices) # [k] 570 | top_span_ends = tf.gather(candidate_ends, top_span_indices) # [k] 571 | top_span_emb = tf.gather(candidate_span_emb, top_span_indices) # [k, emb] 572 | top_span_cluster_ids = tf.gather(candidate_cluster_ids, top_span_indices) # [k] 573 | top_span_mention_scores = tf.gather(candidate_mention_scores, top_span_indices) # [k] 574 | top_span_speaker_ids = tf.gather(speaker_ids, top_span_starts) # [k] 575 | 576 | top_span_cdd_pool_flag = tf.gather(candidate_pool_flag, top_span_indices) # [k] 577 | 578 | c = tf.minimum(self.config["max_top_antecedents"], k) 579 | 580 | top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets = self.coarse_to_fine_pruning(top_span_emb, top_span_mention_scores, c, top_span_cdd_pool_flag) 581 | 582 | dummy_scores = tf.zeros([k, 1]) # [k, 1] 583 | for i in range(self.config["coref_depth"]): 584 | if self.use_im: 585 | att_grid = self.get_span_im_emb(top_span_emb, obj_span_emb) # [k, emb], [k, emb] 586 | with tf.variable_scope("coref_layer", reuse=(i > 0)): 587 | top_antecedent_emb = tf.gather(top_span_emb, top_antecedents) # [k, c, emb] 588 | if self.use_im: 589 | top_antecedent_scores_text, top_antecedent_scores_im = self.get_slow_antecedent_scores(top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, im_feat, att_grid, has_obj) # [k, c] 590 | top_antecedent_scores = top_fast_antecedent_scores + (1 - self.vis_weight) * top_antecedent_scores_text + self.vis_weight * top_antecedent_scores_im 591 | else: 592 | top_antecedent_scores = top_fast_antecedent_scores + self.get_slow_antecedent_scores(top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, im_feat) # [k, c] 593 | top_antecedent_weights = tf.nn.softmax(tf.concat([dummy_scores, top_antecedent_scores], 1)) # [k, c + 1] 594 | top_antecedent_emb = tf.concat([tf.expand_dims(top_span_emb, 1), top_antecedent_emb], 1) # [k, c + 1, emb] 595 | attended_span_emb = tf.reduce_sum(tf.expand_dims(top_antecedent_weights, 2) * top_antecedent_emb, 1) # [k, emb] 596 | with tf.variable_scope("f"): 597 | f = tf.sigmoid(util.projection(tf.concat([top_span_emb, attended_span_emb], 1), util.shape(top_span_emb, -1))) # [k, emb] 598 | top_span_emb = f * attended_span_emb + (1 - f) * top_span_emb # [k, emb] 599 | 600 | top_antecedent_scores = tf.concat([dummy_scores, top_antecedent_scores], 1) # [k, c + 1] 601 | 602 | top_antecedent_cluster_ids = tf.gather(top_span_cluster_ids, top_antecedents) # [k, c] 603 | top_antecedent_cluster_ids += tf.to_int32(tf.log(tf.to_float(top_antecedents_mask))) # [k, c] 604 | same_cluster_indicator = tf.equal(top_antecedent_cluster_ids, tf.expand_dims(top_span_cluster_ids, 1)) # [k, c] 605 | non_dummy_indicator = tf.expand_dims(top_span_cluster_ids > 0, 1) # [k, 1] 606 | pairwise_labels = tf.logical_and(same_cluster_indicator, non_dummy_indicator) # [k, c] 607 | dummy_labels = tf.logical_not(tf.reduce_any(pairwise_labels, 1, keepdims=True)) # [k, 1] 608 | top_antecedent_labels = tf.concat([dummy_labels, pairwise_labels], 1) # [k, c + 1] 609 | 610 | loss = self.softmax_loss(top_antecedent_scores, top_antecedent_labels) # [k] 611 | loss = tf.reduce_sum(loss) # [] 612 | 613 | outputs = [candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores, 614 | tokens_cdd, tokens_obj] 615 | if self.use_im: 616 | outputs.append(att_grid) 617 | else: 618 | outputs.append(tf.zeros([1, 1])) 619 | 620 | return outputs, loss 621 | 622 | def get_span_emb_dial(self, head_emb, context_outputs, span_starts, span_ends): 623 | span_emb_list = [] 624 | 625 | span_start_emb = tf.gather(context_outputs, span_starts) # [k, emb] 626 | span_emb_list.append(span_start_emb) 627 | 628 | span_end_emb = tf.gather(context_outputs, span_ends) # [k, emb] 629 | span_emb_list.append(span_end_emb) 630 | 631 | span_width = 1 + span_ends - span_starts # [k] 632 | 633 | if self.config["use_features"]: 634 | span_width_index = span_width - 1 # [k] 635 | with tf.variable_scope("use_feature", reuse=tf.AUTO_REUSE): 636 | span_width_emb = tf.gather(tf.get_variable("span_width_embeddings", [self.config["max_span_width"], self.config["feature_size"]]), span_width_index) # [k, emb] 637 | span_width_emb = tf.nn.dropout(span_width_emb, self.dropout) 638 | span_emb_list.append(span_width_emb) 639 | 640 | if self.config["model_heads"]: 641 | span_indices = tf.expand_dims(tf.range(self.config["max_span_width"]), 0) + tf.expand_dims(span_starts, 1) # [k, max_span_width] 642 | span_indices = tf.minimum(util.shape(context_outputs, 0) - 1, span_indices) # [k, max_span_width] 643 | span_text_emb = tf.gather(head_emb, span_indices) # [k, max_span_width, emb] 644 | with tf.variable_scope("head_scores", reuse=tf.AUTO_REUSE): 645 | self.head_scores = util.projection(context_outputs, 1) # [num_words, 1] 646 | span_head_scores = tf.gather(self.head_scores, span_indices) # [k, max_span_width, 1] 647 | span_mask = tf.expand_dims(tf.sequence_mask(span_width, self.config["max_span_width"], dtype=tf.float32), 2) # [k, max_span_width, 1] 648 | span_head_scores += tf.log(span_mask) # [k, max_span_width, 1] 649 | span_attention = tf.nn.softmax(span_head_scores, 1) # [k, max_span_width, 1] 650 | span_head_emb = tf.reduce_sum(span_attention * span_text_emb, 1) # [k, emb] 651 | span_emb_list.append(span_head_emb) 652 | 653 | span_emb = tf.concat(span_emb_list, 1) # [k, emb] 654 | return span_emb # [k, emb] 655 | 656 | def get_span_emb_phrases(self, head_emb, context_outputs, span_starts, span_ends): 657 | # context_outputs: [num_candidates_cap, max_candidate_length_cap, emb] 658 | # head_emb [num_candidates_cap, max_span_width_cap, emb] 659 | span_emb_list = [] 660 | num_candidates = util.shape(context_outputs, 0) 661 | 662 | span_width = 1 + span_ends - span_starts # [num_candidates_cap] 663 | max_span_width = util.shape(context_outputs, 1) 664 | context_emb_size = util.shape(context_outputs, 2) 665 | 666 | context_outputs = tf.reshape(context_outputs, [-1, context_emb_size]) # [num_candidates_cap * max_candidate_length_cap, emb] 667 | span_start_indices = tf.range(num_candidates) * max_span_width # [num_candidates_cap] 668 | span_start_emb = tf.gather(context_outputs, span_start_indices) # [num_candidates_cap, emb] 669 | span_emb_list.append(span_start_emb) 670 | 671 | span_end_indices = span_start_indices + span_width - 1 # [num_candidates_cap] 672 | span_end_emb = tf.gather(context_outputs, span_end_indices) # [num_candidates_cap, emb] 673 | span_emb_list.append(span_end_emb) 674 | 675 | if self.config["use_features"]: 676 | span_width_index = span_width - 1 # [k] 677 | with tf.variable_scope("use_feature", reuse=tf.AUTO_REUSE): 678 | span_width_emb = tf.gather(tf.get_variable("span_width_embeddings", [self.config["max_span_width"], self.config["feature_size"]]), span_width_index) # [k, emb] 679 | span_width_emb = tf.nn.dropout(span_width_emb, self.dropout) 680 | span_emb_list.append(span_width_emb) 681 | 682 | if self.config["model_heads"]: 683 | with tf.variable_scope("head_scores", reuse=tf.AUTO_REUSE): 684 | span_head_scores = util.projection(context_outputs, 1) # [num_candidates_cap * max_span_width, 1] 685 | span_head_scores = tf.reshape(span_head_scores, [num_candidates, max_span_width, 1]) 686 | span_mask = tf.expand_dims(tf.sequence_mask(span_width, max_span_width, dtype=tf.float32), 2) # [k, max_span_width, 1] 687 | span_head_scores += tf.log(span_mask) # [k, max_span_width, 1] 688 | span_attention = tf.nn.softmax(span_head_scores, 1) # [k, max_span_width, 1] 689 | span_head_emb = tf.reduce_sum(span_attention * head_emb, 1) # [k, emb] 690 | span_emb_list.append(span_head_emb) 691 | 692 | span_emb = tf.concat(span_emb_list, 1) # [k, emb] 693 | return span_emb # [k, emb] 694 | 695 | def get_span_im_emb(self, span_emb, obj_span_emb): 696 | k = util.shape(span_emb, 0) 697 | n = util.shape(obj_span_emb, 0) 698 | with tf.variable_scope("image_attention", reuse=tf.AUTO_REUSE): 699 | # span_emb: [k, emb] 700 | map_dim = self.im_emb_size 701 | with tf.variable_scope("att_projection0"): 702 | text_map = util.projection(span_emb, map_dim) # [k, map_dim] 703 | obj_map = util.projection(obj_span_emb, map_dim) # [k, map_dim] 704 | text_map = tf.nn.relu(text_map) 705 | obj_map = tf.nn.relu(obj_map) 706 | 707 | text_map = tf.tile(tf.expand_dims(text_map, 1), [1, n, 1]) # [k, n, map_dim] 708 | obj_map = tf.tile(tf.expand_dims(obj_map, 0), [k, 1, 1]) # [k, n, map_dim] 709 | 710 | # interact via element wise map 711 | text_obj_combine = tf.nn.l2_normalize(text_map * obj_map, 2) # [k, n, map_dim] 712 | with tf.variable_scope("get_attention"): 713 | w_att = tf.get_variable('w_att', [map_dim, 1], initializer=tf.contrib.layers.xavier_initializer()) 714 | att_grid = tf.reshape(tf.matmul(tf.reshape(text_obj_combine, [-1, map_dim]), w_att), [k, n]) # [k, n] 715 | 716 | # softmax 717 | att_grid_soft = tf.nn.softmax(att_grid) # [k, n] 718 | 719 | return att_grid_soft # [k, n] 720 | 721 | def get_mention_scores(self, span_emb): 722 | with tf.variable_scope("mention_scores"): 723 | return util.ffnn(span_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [k, 1] 724 | 725 | def softmax_loss(self, antecedent_scores, antecedent_labels): 726 | gold_scores = antecedent_scores + tf.log(tf.to_float(antecedent_labels)) # [k, max_ant + 1] 727 | marginalized_gold_scores = tf.reduce_logsumexp(gold_scores, [1]) # [k] 728 | log_norm = tf.reduce_logsumexp(antecedent_scores, [1]) # [k] 729 | return log_norm - marginalized_gold_scores # [k] 730 | 731 | def bucket_distance(self, distances): 732 | """ 733 | Places the given values (designed for distances) into 10 semi-logscale buckets: 734 | [0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32-63, 64+]. 735 | """ 736 | logspace_idx = tf.to_int32(tf.floor(tf.log(tf.to_float(distances))/math.log(2))) + 3 737 | use_identity = tf.to_int32(distances <= 4) 738 | combined_idx = use_identity * distances + (1 - use_identity) * logspace_idx 739 | return tf.clip_by_value(combined_idx, 0, 9) 740 | 741 | def get_slow_antecedent_scores(self, top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, im_feat, att_grid=None, has_obj=None): 742 | k = util.shape(top_span_emb, 0) 743 | c = util.shape(top_antecedents, 1) 744 | 745 | feature_emb_list = [] 746 | 747 | if self.config["use_metadata"]: 748 | top_antecedent_speaker_ids = tf.gather(top_span_speaker_ids, top_antecedents) # [k, c] 749 | same_speaker = tf.equal(tf.expand_dims(top_span_speaker_ids, 1), top_antecedent_speaker_ids) # [k, c] 750 | speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [k, c, emb] 751 | feature_emb_list.append(speaker_pair_emb) 752 | 753 | if self.use_im_fc: 754 | im_emb = tf.expand_dims(im_feat, 0) 755 | im_emb = tf.nn.dropout(im_emb, self.im_dropout) 756 | if self.config["im_layer"] > 0: 757 | for i in range(self.config["im_layer"]): 758 | im_weights = tf.get_variable("im_weights_{}".format(i), [util.shape(im_emb, 1), self.im_fc_emb_size], initializer=None) 759 | im_bias = tf.get_variable("im_bias_{}".format(i), [self.im_fc_emb_size], initializer=None) 760 | im_emb = tf.nn.xw_plus_b(im_emb, im_weights, im_bias) 761 | tiled_im_emb = tf.tile(tf.expand_dims(im_emb, 0), [k, c, 1]) # [k, c, emb] 762 | feature_emb_list.append(tiled_im_emb) 763 | 764 | if self.config["use_features"]: 765 | antecedent_distance_buckets = self.bucket_distance(top_antecedent_offsets) # [k, c] 766 | antecedent_distance_emb = tf.gather(tf.get_variable("antecedent_distance_emb", [10, self.config["feature_size"]]), antecedent_distance_buckets) # [k, c] 767 | feature_emb_list.append(antecedent_distance_emb) 768 | 769 | feature_emb = tf.concat(feature_emb_list, 2) # [k, c, emb] 770 | feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [k, c, emb] 771 | 772 | target_emb = tf.expand_dims(top_span_emb, 1) # [k, 1, emb=1270] 773 | similarity_emb = top_antecedent_emb * target_emb # [k, c, emb] 774 | target_emb = tf.tile(target_emb, [1, c, 1]) # [k, c, emb] 775 | 776 | pair_emb = tf.concat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) # [k, c, emb=3850] 777 | 778 | with tf.variable_scope("slow_antecedent_scores"): 779 | slow_antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [k, c, 1] 780 | slow_antecedent_scores = tf.squeeze(slow_antecedent_scores, 2) # [k, c] 781 | 782 | if self.use_im: 783 | # att max 784 | def zero_att_max(k): 785 | return tf.zeros([k, 1]) 786 | 787 | def obj_att_max(att_grid): 788 | return tf.reduce_max(att_grid[:, :-1], axis=1, keepdims=True) # [k, 1] 789 | 790 | top_span_att_max = tf.cond(has_obj, lambda: obj_att_max(att_grid), lambda: zero_att_max(k)) # [k, 1] 791 | top_antecedent_att_max = tf.gather(top_span_att_max, top_antecedents) # [k, c, 1] 792 | target_att_max = tf.expand_dims(top_span_att_max, 2) # [k, 1, 1] 793 | similarity_emb_att = top_antecedent_att_max * target_att_max # [k, c, 1] 794 | target_emb_att = tf.tile(target_att_max, [1, c, 1]) # [k, c, 1] 795 | 796 | # att similarity 797 | top_antecedent_att = tf.gather(att_grid, top_antecedents) # [k, c, n] 798 | top_span_att = tf.expand_dims(att_grid, 1) # [k, 1, n] 799 | 800 | def zero_ant_att_max(k, c): 801 | return tf.zeros([k, c, 1]) 802 | 803 | def obj_ant_att_max(att_grid): 804 | return tf.reduce_max(att_grid[:, :, :-1], axis=2, keepdims=True) # [k, c, 1] 805 | 806 | top_span_antecedent_att_max = tf.cond(has_obj, lambda: obj_ant_att_max(top_antecedent_att * top_span_att), lambda: zero_ant_att_max(k, c)) # [k, c, 1] 807 | 808 | similarity_emb_att = tf.concat([similarity_emb_att, top_span_antecedent_att_max], 2) # [k, c, 2] 809 | 810 | pair_emb_im = tf.concat([target_emb_att, top_antecedent_att_max, similarity_emb_att], 2) # [k, c, 4 (+3n)] 811 | 812 | with tf.variable_scope("slow_antecedent_scores_im"): 813 | slow_antecedent_scores_im = util.ffnn(pair_emb_im, self.config["ffnn_depth_im"], self.config["ffnn_size_im"], 1, self.dropout) # [k, c, 1] 814 | slow_antecedent_scores_im = tf.squeeze(slow_antecedent_scores_im, 2) # [k, c] 815 | 816 | return slow_antecedent_scores, slow_antecedent_scores_im # [k, c] 817 | 818 | return slow_antecedent_scores # [k, c] 819 | 820 | def get_fast_antecedent_scores(self, top_span_emb): 821 | with tf.variable_scope("src_projection"): 822 | source_top_span_emb = tf.nn.dropout(util.projection(top_span_emb, util.shape(top_span_emb, -1)), self.dropout) # [k, emb] 823 | target_top_span_emb = tf.nn.dropout(top_span_emb, self.dropout) # [k, emb] 824 | return tf.matmul(source_top_span_emb, target_top_span_emb, transpose_b=True) # [k, k] 825 | 826 | def flatten_emb_by_sentence(self, emb, text_len_mask): 827 | num_sentences = tf.shape(emb)[0] 828 | max_sentence_length = tf.shape(emb)[1] 829 | 830 | emb_rank = len(emb.get_shape()) 831 | if emb_rank == 2: 832 | flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length]) 833 | elif emb_rank == 3: 834 | flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, util.shape(emb, 2)]) 835 | else: 836 | raise ValueError("Unsupported rank: {}".format(emb_rank)) 837 | return tf.boolean_mask(flattened_emb, tf.reshape(text_len_mask, [num_sentences * max_sentence_length])) 838 | 839 | def lstm_contextualize(self, text_emb, text_len, text_len_mask=None): 840 | num_sentences = tf.shape(text_emb)[0] 841 | current_inputs = text_emb # [num_sentences, max_sentence_length, emb] 842 | 843 | for layer in range(self.config["contextualization_layers"]): 844 | with tf.variable_scope("layer_{}".format(layer), reuse=tf.AUTO_REUSE): 845 | with tf.variable_scope("fw_cell", reuse=tf.AUTO_REUSE): 846 | cell_fw = util.CustomLSTMCell(self.config["contextualization_size"], num_sentences, self.lstm_dropout) 847 | with tf.variable_scope("bw_cell", reuse=tf.AUTO_REUSE): 848 | cell_bw = util.CustomLSTMCell(self.config["contextualization_size"], num_sentences, self.lstm_dropout) 849 | state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), tf.tile(cell_fw.initial_state.h, [num_sentences, 1])) 850 | state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), tf.tile(cell_bw.initial_state.h, [num_sentences, 1])) 851 | 852 | (fw_outputs, bw_outputs), _ = tf.nn.bidirectional_dynamic_rnn( 853 | cell_fw=cell_fw, 854 | cell_bw=cell_bw, 855 | inputs=current_inputs, 856 | sequence_length=text_len, 857 | initial_state_fw=state_fw, 858 | initial_state_bw=state_bw) 859 | 860 | text_outputs = tf.concat([fw_outputs, bw_outputs], 2) # [num_sentences, max_sentence_length, emb] 861 | text_outputs = tf.nn.dropout(text_outputs, self.lstm_dropout) 862 | if layer > 0: 863 | highway_gates = tf.sigmoid(util.projection(text_outputs, util.shape(text_outputs, 2))) # [num_sentences, max_sentence_length, emb] 864 | text_outputs = highway_gates * text_outputs + (1 - highway_gates) * current_inputs 865 | current_inputs = text_outputs # [num_sentences, max_sentence_length, emb] 866 | 867 | if text_len_mask is None: 868 | return text_outputs 869 | else: 870 | return self.flatten_emb_by_sentence(text_outputs, text_len_mask) 871 | 872 | def get_predicted_antecedents(self, antecedents, antecedent_scores): 873 | predicted_antecedents = [] 874 | for i, index in enumerate(np.argmax(antecedent_scores, axis=1) - 1): 875 | if index < 0: 876 | predicted_antecedents.append(-1) 877 | else: 878 | predicted_antecedents.append(antecedents[i, index]) 879 | return predicted_antecedents 880 | 881 | def get_predicted_clusters(self, top_span_starts, top_span_ends, predicted_antecedents): 882 | mention_to_predicted = {} 883 | predicted_clusters = [] 884 | for i, predicted_index in enumerate(predicted_antecedents): 885 | if predicted_index < 0: 886 | continue 887 | assert i > predicted_index 888 | predicted_antecedent = (int(top_span_starts[predicted_index]), int(top_span_ends[predicted_index])) 889 | if predicted_antecedent in mention_to_predicted: 890 | predicted_cluster = mention_to_predicted[predicted_antecedent] 891 | else: 892 | predicted_cluster = len(predicted_clusters) 893 | predicted_clusters.append([predicted_antecedent]) 894 | mention_to_predicted[predicted_antecedent] = predicted_cluster 895 | 896 | mention = (int(top_span_starts[i]), int(top_span_ends[i])) 897 | predicted_clusters[predicted_cluster].append(mention) 898 | mention_to_predicted[mention] = predicted_cluster 899 | 900 | predicted_clusters = [tuple(pc) for pc in predicted_clusters] 901 | mention_to_predicted = { m:predicted_clusters[i] for m,i in mention_to_predicted.items() } 902 | 903 | return predicted_clusters, mention_to_predicted 904 | 905 | def get_predicted_clusters_attention(self, top_span_starts, top_span_ends, att_grid, predicted_antecedents): 906 | mention_to_predicted = {} 907 | predicted_clusters = [] 908 | for i, predicted_index in enumerate(predicted_antecedents): 909 | if predicted_index < 0: 910 | continue 911 | assert i > predicted_index 912 | predicted_antecedent = (int(top_span_starts[predicted_index]), int(top_span_ends[predicted_index])) 913 | 914 | if predicted_antecedent in mention_to_predicted: 915 | predicted_cluster = mention_to_predicted[predicted_antecedent] 916 | else: 917 | predicted_cluster = len(predicted_clusters) 918 | predicted_clusters.append([predicted_antecedent]) 919 | mention_to_predicted[predicted_antecedent] = predicted_cluster 920 | 921 | mention = (int(top_span_starts[i]), int(top_span_ends[i])) 922 | predicted_clusters[predicted_cluster].append(mention) 923 | mention_to_predicted[mention] = predicted_cluster 924 | 925 | predicted_clusters = [tuple(pc) for pc in predicted_clusters] 926 | mention_to_predicted = { m:predicted_clusters[i] for m,i in mention_to_predicted.items() } 927 | 928 | # att_grid is the same order as top_span, extract them for each mention in predicted_clusters 929 | predicted_att_grids = [] 930 | for cluster in predicted_clusters: 931 | att_grid_cluster = [] 932 | for mention in cluster: 933 | find_mention = False 934 | for index, (start, end) in enumerate(zip(top_span_starts, top_span_ends)): 935 | if mention[0] == start and mention[1] == end: 936 | att_grid_cluster.append(att_grid[index]) 937 | find_mention = True 938 | break 939 | if not find_mention: 940 | raise ValueError('antecedent not found in top spans') 941 | predicted_att_grids.append(att_grid_cluster) 942 | 943 | return predicted_clusters, predicted_att_grids, mention_to_predicted 944 | 945 | def evaluate_coref(self, top_span_starts, top_span_ends, predicted_antecedents, gold_clusters): 946 | gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters] 947 | mention_to_gold = {} 948 | for gc in gold_clusters: 949 | for mention in gc: 950 | mention_to_gold[mention] = gc 951 | 952 | predicted_clusters, mention_to_predicted = self.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents) 953 | return predicted_clusters 954 | 955 | def load_eval_data(self): 956 | if self.eval_data is None: 957 | def load_line(line): 958 | example = json.loads(line) 959 | return self.tensorize_example(example, is_training=False), example 960 | with open(self.config["eval_path"]) as f: 961 | self.eval_data = [load_line(l) for l in f.readlines()] 962 | num_words = sum(tensorized_example[2].sum() for tensorized_example, _ in self.eval_data) 963 | print(f"Loaded {len(self.eval_data)} eval examples.") 964 | 965 | def evaluate(self, session, official_stdout=False): 966 | self.load_eval_data() 967 | 968 | coref_predictions = {} 969 | pr_coref_evaluator = metrics.PrCorefEvaluator() 970 | 971 | for example_num, (tensorized_example, example) in enumerate(self.eval_data): 972 | feed_dict = {i:t for i,t in zip(self.input_tensors, tensorized_example)} 973 | 974 | outputs = session.run(self.predictions, feed_dict=feed_dict) 975 | candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores, tokens_cdd, tokens_obj, att_grid = outputs 976 | 977 | predicted_antecedents = self.get_predicted_antecedents(top_antecedents, top_antecedent_scores) 978 | coref_predictions[example["doc_key"]] = self.evaluate_coref(top_span_starts, top_span_ends, predicted_antecedents, example["clusters"]) 979 | pr_coref_evaluator.update(coref_predictions[example["doc_key"]], example["pronoun_info"], example["sentences"], tokens_cdd) 980 | if example_num % 50 == 0: 981 | print(f"Evaluated {example_num + 1}/{len(self.eval_data)} examples.") 982 | 983 | summary_dict = {} 984 | pr_coref_results = pr_coref_evaluator.get_prf() 985 | 986 | summary_dict["Pronoun Coref average F1 (py)"] = pr_coref_results['f'] 987 | print(f"Pronoun Coref average F1 (py): {pr_coref_results['f'] * 100:.2f}%") 988 | summary_dict["Pronoun Coref average precision (py)"] = pr_coref_results['p'] 989 | print(f"Pronoun Coref average precision (py): {pr_coref_results['p'] * 100:.2f}%") 990 | summary_dict["Pronoun Coref average recall (py)"] = pr_coref_results['r'] 991 | print(f"Pronoun Coref average recall (py): {pr_coref_results['r'] * 100:.2f}%") 992 | 993 | summary_dict["Discussed Pronoun Coref average F1 (py)"] = pr_coref_results['f_discussed'] 994 | print(f"Discussed Pronoun Coref average F1 (py): {pr_coref_results['f_discussed'] * 100:.2f}%") 995 | summary_dict["Discussed Pronoun Coref average precision (py)"] = pr_coref_results['p_discussed'] 996 | print(f"Discussed Pronoun Coref average precision (py): {pr_coref_results['p_discussed'] * 100:.2f}%") 997 | summary_dict["Discussed Pronoun Coref average recall (py)"] = pr_coref_results['r_discussed'] 998 | print(f"Discussed Pronoun Coref average recall (py): {pr_coref_results['r_discussed'] * 100:.2f}%") 999 | 1000 | summary_dict["Not Discussed Pronoun Coref average F1 (py)"] = pr_coref_results['f_not_discussed'] 1001 | print(f"Not Discussed Pronoun Coref average F1 (py): {pr_coref_results['f_not_discussed'] * 100:.2f}%") 1002 | summary_dict["Not Discussed Pronoun Coref average precision (py)"] = pr_coref_results['p_not_discussed'] 1003 | print(f"Not Discussed Pronoun Coref average precision (py): {pr_coref_results['p_not_discussed'] * 100:.2f}%") 1004 | summary_dict["Not Discussed Pronoun Coref average recall (py)"] = pr_coref_results['r_not_discussed'] 1005 | print(f"Not Discussed Pronoun Coref average recall (py): {pr_coref_results['r_not_discussed'] * 100:.2f}%") 1006 | 1007 | average_f1 = pr_coref_results['f'] 1008 | max_eval_f1 = tf.maximum(self.max_eval_f1, average_f1) 1009 | self.update_max_f1 = tf.assign(self.max_eval_f1, max_eval_f1) 1010 | 1011 | return util.make_summary(summary_dict), average_f1 1012 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import json 5 | import os 6 | import os.path as osp 7 | import argparse 8 | import sys 9 | import numpy as np 10 | 11 | import tensorflow as tf 12 | from model import VisCoref 13 | import util 14 | 15 | parser = argparse.ArgumentParser(description='predict coreference cluster on trained model') 16 | parser.add_argument('model', type=str, 17 | help='model name to evaluate') 18 | parser.add_argument('--step', type=str, default='max', 19 | help='global steps to restore from') 20 | parser.add_argument('--split', type=str, default='test', 21 | help='split to evaluate, test or val') 22 | parser.add_argument('--input_dir', type=str, default='data', 23 | help='input dir') 24 | parser.add_argument('--output_dir', type=str, default='output', 25 | help='output dir') 26 | 27 | 28 | class MyEncoder(json.JSONEncoder): 29 | def default(self, obj): 30 | if isinstance(obj, np.integer): 31 | return int(obj) 32 | elif isinstance(obj, np.floating): 33 | return float(obj) 34 | elif isinstance(obj, np.ndarray): 35 | return obj.tolist() 36 | else: 37 | return super(MyEncoder, self).default(obj) 38 | 39 | if __name__ == "__main__": 40 | args = parser.parse_args() 41 | if len(sys.argv) == 1: 42 | sys.argv.append(args.model) 43 | else: 44 | sys.argv[1] = args.model 45 | config = util.initialize_from_env() 46 | input_filename = args.split + '.vispro.1.1.jsonlines' 47 | output_filename = args.split + '.vispro.1.1.prediction.jsonlines' 48 | input_filename = osp.join(args.input_dir, input_filename) 49 | output_filename = osp.join(args.output_dir, args.model, output_filename) 50 | 51 | model = VisCoref(config) 52 | 53 | # Create output dir 54 | output_dir = osp.split(output_filename)[0] 55 | if not osp.exists(output_dir): 56 | os.makedirs(output_dir) 57 | 58 | configtf = tf.ConfigProto() 59 | configtf.gpu_options.allow_growth = True 60 | with tf.Session(config=configtf) as session: 61 | model.restore(session, args.step) 62 | 63 | if config["use_im"]: 64 | predicted_att_grids = {} 65 | with open(output_filename, "w") as output_file: 66 | with open(input_filename) as input_file: 67 | for example_num, line in enumerate(input_file.readlines()): 68 | example = json.loads(line) 69 | tensorized_example = model.tensorize_example(example, is_training=False) 70 | feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)} 71 | 72 | outputs = session.run(model.predictions, feed_dict=feed_dict) 73 | candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores, tokens_cdd, tokens_obj, att_grid = outputs 74 | 75 | tokens_cdd_list = [] 76 | for i in range(tokens_cdd.shape[0]): 77 | cdd_np = [] 78 | for j in range(tokens_cdd.shape[1]): 79 | if tokens_cdd[i][j] != '': 80 | cdd_np.append(tokens_cdd[i][j]) 81 | tokens_cdd_list.append(cdd_np) 82 | example["cdd_sentences"] = tokens_cdd_list 83 | 84 | predicted_antecedents = model.get_predicted_antecedents(top_antecedents, top_antecedent_scores) 85 | if config["use_im"]: 86 | example["predicted_clusters"], predicted_att_grids[example["doc_key"]], _ = model.get_predicted_clusters_attention(top_span_starts, top_span_ends, att_grid, predicted_antecedents) 87 | else: 88 | example["predicted_clusters"], _ = model.get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents) 89 | 90 | output_file.write(json.dumps(example, cls=MyEncoder)) 91 | output_file.write("\n") 92 | if example_num % 100 == 0: 93 | print(f"Decoded {example_num + 1} examples.") 94 | 95 | print(f"Output saved to {output_filename}") 96 | if config["use_im"]: 97 | output_filename = output_filename.replace('.jsonlines', '.att.npz') 98 | np.savez(output_filename, att=predicted_att_grids) 99 | print(f"Attention grids saved to {output_filename}") 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>1.13.1 2 | tensorflow-hub>0.4.0 3 | h5py 4 | nltk 5 | pyhocon 6 | scipy 7 | sklearn 8 | -------------------------------------------------------------------------------- /setup_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download pretrained embeddings. 4 | curl -O https://nlp.stanford.edu/data/glove.840B.300d.zip 5 | unzip glove.840B.300d.zip -d data/ 6 | rm glove.840B.300d.zip 7 | 8 | # Build custom kernels. 9 | TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 10 | TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) 11 | 12 | # Linux (pip) 13 | g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0 14 | 15 | # Linux (build from source) 16 | # g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 17 | 18 | # Mac 19 | #g++ -std=c++11 -shared coref_kernels.cc -o coref_kernels.so -I -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -undefined dynamic_lookup 20 | -------------------------------------------------------------------------------- /setup_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python get_char_vocab.py 4 | python filter_embeddings.py 5 | python filter_embeddings.py --embedding glove_50_300_2.txt 6 | python cache_elmo.py --dataset vispro 7 | python cache_elmo.py --dataset vispro_cdd 8 | python cache_elmo.py --dataset vispro_mscoco -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import os 6 | import time 7 | import subprocess 8 | import sys 9 | 10 | import tensorflow as tf 11 | from model import VisCoref 12 | import util 13 | 14 | def set_log_file(fname): 15 | tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE) 16 | os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) 17 | os.dup2(tee.stdin.fileno(), sys.stderr.fileno()) 18 | 19 | if __name__ == "__main__": 20 | config = util.initialize_from_env() 21 | 22 | log_dir = config["log_dir"] 23 | writer = tf.summary.FileWriter(log_dir, flush_secs=20) 24 | log_file = os.path.join(log_dir, 'train.log') 25 | set_log_file(log_file) 26 | 27 | report_frequency = config["report_frequency"] 28 | eval_frequency = config["eval_frequency"] 29 | 30 | tf.set_random_seed(config['random_seed']) 31 | 32 | model = VisCoref(config) 33 | saver = tf.train.Saver() 34 | 35 | max_f1 = 0 36 | 37 | config_tf = tf.ConfigProto() 38 | config_tf.gpu_options.allow_growth = True 39 | with tf.Session(config=config_tf) as session: 40 | session.run(tf.global_variables_initializer()) 41 | model.start_enqueue_thread(session) 42 | accumulated_loss = 0.0 43 | 44 | ckpt = tf.train.get_checkpoint_state(log_dir) 45 | if ckpt and ckpt.model_checkpoint_path: 46 | print(f"Restoring from: {ckpt.model_checkpoint_path}") 47 | saver.restore(session, ckpt.model_checkpoint_path) 48 | max_f1 = session.run(model.max_eval_f1) 49 | print(f'Restoring from max f1 of {max_f1:.2f}') 50 | 51 | initial_time = time.time() 52 | 53 | while True: 54 | tf_loss, tf_global_step, _ = session.run([model.loss, model.global_step, model.train_op]) 55 | accumulated_loss += tf_loss 56 | 57 | if tf_global_step == 1 or tf_global_step % report_frequency == 0: 58 | total_time = time.time() - initial_time 59 | steps_per_second = tf_global_step / total_time 60 | 61 | average_loss = accumulated_loss / report_frequency 62 | print(f"[{tf_global_step}] loss={average_loss:.4f}, steps/s={steps_per_second:.2f}") 63 | writer.add_summary(util.make_summary({"loss": average_loss}), tf_global_step) 64 | accumulated_loss = 0.0 65 | 66 | if tf_global_step == 1 or tf_global_step % eval_frequency == 0: 67 | eval_summary, eval_f1 = model.evaluate(session) 68 | _ = session.run(model.update_max_f1) 69 | saver.save(session, os.path.join(log_dir, "model"), global_step=tf_global_step) 70 | 71 | if eval_f1 > max_f1: 72 | max_f1 = eval_f1 73 | util.copy_checkpoint(os.path.join(log_dir, "model-{}".format(tf_global_step)), os.path.join(log_dir, "model.max.ckpt")) 74 | 75 | writer.add_summary(eval_summary, tf_global_step) 76 | writer.add_summary(util.make_summary({"max_eval_f1": max_f1}), tf_global_step) 77 | 78 | print(f"[{tf_global_step}] evaL_f1={eval_f1:.2f}, max_f1={max_f1:.2f}") 79 | 80 | if tf_global_step >= config['max_step']: 81 | print('Training finishes due to reaching max steps') 82 | break 83 | 84 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import os 5 | import errno 6 | import codecs 7 | import collections 8 | import shutil 9 | import sys 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | import pyhocon 14 | 15 | 16 | def initialize_from_env(): 17 | if "GPU" in os.environ: 18 | set_gpus(int(os.environ["GPU"])) 19 | else: 20 | set_gpus() 21 | 22 | name = sys.argv[1] 23 | print(f"Running experiment: {name}") 24 | 25 | config = pyhocon.ConfigFactory.parse_file("experiments.conf")[name] 26 | config["log_dir"] = mkdirs(os.path.join(config["log_root"], name)) 27 | 28 | print(pyhocon.HOCONConverter.convert(config, "hocon")) 29 | return config 30 | 31 | def copy_checkpoint(source, target): 32 | for ext in (".index", ".data-00000-of-00001"): 33 | shutil.copyfile(source + ext, target + ext) 34 | 35 | def make_summary(value_dict): 36 | return tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=v) for k,v in value_dict.items()]) 37 | 38 | def flatten(l): 39 | return [item for sublist in l for item in sublist] 40 | 41 | def set_gpus(*gpus): 42 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpus) 43 | print(f"Setting CUDA_VISIBLE_DEVICES to: {os.environ['CUDA_VISIBLE_DEVICES']}") 44 | 45 | def mkdirs(path): 46 | try: 47 | os.makedirs(path) 48 | except OSError as exception: 49 | if exception.errno != errno.EEXIST: 50 | raise 51 | return path 52 | 53 | def load_char_dict(char_vocab_path): 54 | vocab = [u""] 55 | with codecs.open(char_vocab_path, encoding="utf-8") as f: 56 | vocab.extend(l.strip() for l in f.readlines()) 57 | char_dict = collections.defaultdict(int) 58 | char_dict.update({c:i for i, c in enumerate(vocab)}) 59 | return char_dict 60 | 61 | def maybe_divide(x, y): 62 | return 0 if y == 0 else x / float(y) 63 | 64 | def verify_correct_NP_match(predicted_NP, gold_NPs, model, matched_gold_ids): 65 | if model == 'exact': 66 | for gold_id, tmp_gold_NP in enumerate(gold_NPs): 67 | if gold_id in matched_gold_ids: 68 | continue 69 | if tmp_gold_NP[0] == predicted_NP[0] and tmp_gold_NP[1] == predicted_NP[1]: 70 | return gold_id 71 | elif model == 'cover': 72 | for gold_id, tmp_gold_NP in enumerate(gold_NPs): 73 | if gold_id in matched_gold_ids: 74 | continue 75 | if tmp_gold_NP[0] <= predicted_NP[0] and tmp_gold_NP[1] >= predicted_NP[1]: 76 | return gold_id 77 | if tmp_gold_NP[0] >= predicted_NP[0] and tmp_gold_NP[1] <= predicted_NP[1]: 78 | return gold_id 79 | return None 80 | 81 | def projection(inputs, output_size, initializer=None): 82 | return ffnn(inputs, 0, -1, output_size, dropout=None, output_weights_initializer=initializer) 83 | 84 | def highway(inputs, num_layers, dropout): 85 | for i in range(num_layers): 86 | with tf.variable_scope("highway_{}".format(i)): 87 | j, f = tf.split(projection(inputs, 2 * shape(inputs, -1)), 2, -1) 88 | f = tf.sigmoid(f) 89 | j = tf.nn.relu(j) 90 | if dropout is not None: 91 | j = tf.nn.dropout(j, dropout) 92 | inputs = f * j + (1 - f) * inputs 93 | return inputs 94 | 95 | def shape(x, dim): 96 | return x.get_shape()[dim].value or tf.shape(x)[dim] 97 | 98 | def ffnn(inputs, num_hidden_layers, hidden_size, output_size, dropout, output_weights_initializer=None): 99 | if len(inputs.get_shape()) > 3: 100 | raise ValueError("FFNN with rank {} not supported".format(len(inputs.get_shape()))) 101 | 102 | if len(inputs.get_shape()) == 3: 103 | batch_size = shape(inputs, 0) 104 | seqlen = shape(inputs, 1) 105 | emb_size = shape(inputs, 2) 106 | current_inputs = tf.reshape(inputs, [batch_size * seqlen, emb_size]) 107 | else: 108 | current_inputs = inputs 109 | 110 | for i in range(num_hidden_layers): 111 | hidden_weights = tf.get_variable("hidden_weights_{}".format(i), [shape(current_inputs, 1), hidden_size]) 112 | hidden_bias = tf.get_variable("hidden_bias_{}".format(i), [hidden_size]) 113 | current_outputs = tf.nn.relu(tf.nn.xw_plus_b(current_inputs, hidden_weights, hidden_bias)) 114 | 115 | if dropout is not None: 116 | current_outputs = tf.nn.dropout(current_outputs, dropout) 117 | current_inputs = current_outputs 118 | 119 | output_weights = tf.get_variable("output_weights", [shape(current_inputs, 1), output_size], initializer=output_weights_initializer) 120 | output_bias = tf.get_variable("output_bias", [output_size]) 121 | outputs = tf.nn.xw_plus_b(current_inputs, output_weights, output_bias) 122 | 123 | if len(inputs.get_shape()) == 3: 124 | outputs = tf.reshape(outputs, [batch_size, seqlen, output_size]) 125 | return outputs 126 | 127 | def cnn(inputs, filter_sizes, num_filters): 128 | num_words = shape(inputs, 0) 129 | num_chars = shape(inputs, 1) 130 | input_size = shape(inputs, 2) 131 | outputs = [] 132 | for i, filter_size in enumerate(filter_sizes): 133 | with tf.variable_scope("conv_{}".format(i), reuse=tf.AUTO_REUSE): 134 | w = tf.get_variable("w", [filter_size, input_size, num_filters]) 135 | b = tf.get_variable("b", [num_filters]) 136 | conv = tf.nn.conv1d(inputs, w, stride=1, padding="VALID") # [num_words, num_chars - filter_size, num_filters] 137 | h = tf.nn.relu(tf.nn.bias_add(conv, b)) # [num_words, num_chars - filter_size, num_filters] 138 | pooled = tf.reduce_max(h, 1) # [num_words, num_filters] 139 | outputs.append(pooled) 140 | return tf.concat(outputs, 1) # [num_words, num_filters * len(filter_sizes)] 141 | 142 | def batch_gather(emb, indices): 143 | batch_size = shape(emb, 0) 144 | seqlen = shape(emb, 1) 145 | if len(emb.get_shape()) > 2: 146 | emb_size = shape(emb, 2) 147 | else: 148 | emb_size = 1 149 | flattened_emb = tf.reshape(emb, [batch_size * seqlen, emb_size]) # [batch_size * seqlen, emb] 150 | offset = tf.expand_dims(tf.range(batch_size) * seqlen, 1) # [batch_size, 1] 151 | gathered = tf.gather(flattened_emb, indices + offset) # [batch_size, num_indices, emb] 152 | if len(emb.get_shape()) == 2: 153 | gathered = tf.squeeze(gathered, 2) # [batch_size, num_indices] 154 | return gathered 155 | 156 | class EmbeddingDictionary(object): 157 | def __init__(self, info, normalize=True, maybe_cache=None): 158 | self._size = info["size"] 159 | self._normalize = normalize 160 | self._path = info["path"] 161 | if maybe_cache is not None and maybe_cache._path == self._path: 162 | assert self._size == maybe_cache._size 163 | self._embeddings = maybe_cache._embeddings 164 | else: 165 | self._embeddings = self.load_embedding_dict(self._path) 166 | 167 | @property 168 | def size(self): 169 | return self._size 170 | 171 | def load_embedding_dict(self, path): 172 | print(f"Loading word embeddings from {path}...") 173 | default_embedding = np.zeros(self.size) 174 | embedding_dict = collections.defaultdict(lambda:default_embedding) 175 | if len(path) > 0: 176 | vocab_size = None 177 | with open(path) as f: 178 | for i, line in enumerate(f.readlines()): 179 | word_end = line.find(" ") 180 | word = line[:word_end] 181 | embedding = np.fromstring(line[word_end + 1:], np.float32, sep=" ") 182 | assert len(embedding) == self.size 183 | embedding_dict[word] = embedding 184 | if vocab_size is not None: 185 | assert vocab_size == len(embedding_dict) 186 | print(f"Done loading word embeddings.") 187 | return embedding_dict 188 | 189 | def __getitem__(self, key): 190 | embedding = self._embeddings[key] 191 | if self._normalize: 192 | embedding = self.normalize(embedding) 193 | return embedding 194 | 195 | def normalize(self, v): 196 | norm = np.linalg.norm(v) 197 | if norm > 0: 198 | return v / norm 199 | else: 200 | return v 201 | 202 | class CustomLSTMCell(tf.contrib.rnn.RNNCell): 203 | def __init__(self, num_units, batch_size, dropout): 204 | self._num_units = num_units 205 | self._dropout = dropout 206 | self._dropout_mask = tf.nn.dropout(tf.ones([batch_size, self.output_size]), dropout) 207 | self._initializer = self._block_orthonormal_initializer([self.output_size] * 3) 208 | initial_cell_state = tf.get_variable("lstm_initial_cell_state", [1, self.output_size]) 209 | initial_hidden_state = tf.get_variable("lstm_initial_hidden_state", [1, self.output_size]) 210 | self._initial_state = tf.contrib.rnn.LSTMStateTuple(initial_cell_state, initial_hidden_state) 211 | 212 | @property 213 | def state_size(self): 214 | return tf.contrib.rnn.LSTMStateTuple(self.output_size, self.output_size) 215 | 216 | @property 217 | def output_size(self): 218 | return self._num_units 219 | 220 | @property 221 | def initial_state(self): 222 | return self._initial_state 223 | 224 | def __call__(self, inputs, state, scope=None): 225 | """Long short-term memory cell (LSTM).""" 226 | with tf.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE): # "CustomLSTMCell" 227 | c, h = state 228 | h *= self._dropout_mask 229 | concat = projection(tf.concat([inputs, h], 1), 3 * self.output_size, initializer=self._initializer) 230 | i, j, o = tf.split(concat, num_or_size_splits=3, axis=1) 231 | i = tf.sigmoid(i) 232 | new_c = (1 - i) * c + i * tf.tanh(j) 233 | new_h = tf.tanh(new_c) * tf.sigmoid(o) 234 | new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 235 | return new_h, new_state 236 | 237 | def _orthonormal_initializer(self, scale=1.0): 238 | def _initializer(shape, dtype=tf.float32, partition_info=None): 239 | M1 = np.random.randn(shape[0], shape[0]).astype(np.float32) 240 | M2 = np.random.randn(shape[1], shape[1]).astype(np.float32) 241 | Q1, R1 = np.linalg.qr(M1) 242 | Q2, R2 = np.linalg.qr(M2) 243 | Q1 = Q1 * np.sign(np.diag(R1)) 244 | Q2 = Q2 * np.sign(np.diag(R2)) 245 | n_min = min(shape[0], shape[1]) 246 | params = np.dot(Q1[:, :n_min], Q2[:n_min, :]) * scale 247 | return params 248 | return _initializer 249 | 250 | def _block_orthonormal_initializer(self, output_sizes): 251 | def _initializer(shape, dtype=np.float32, partition_info=None): 252 | assert len(shape) == 2 253 | assert sum(output_sizes) == shape[1] 254 | initializer = self._orthonormal_initializer() 255 | params = np.concatenate([initializer([shape[0], o], dtype, partition_info) for o in output_sizes], 1) 256 | return params 257 | return _initializer 258 | --------------------------------------------------------------------------------