├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── evaluation_utils ├── __init__.py ├── bleu.py ├── evaluators.py ├── nmt_evaluation_utils.py ├── rouge.py └── sentence_simplification.py ├── multitask ├── __init__.py ├── bandits.py ├── multitask_autoMR_model.py ├── multitask_base_model.py ├── multitask_utils.py ├── sharing_dicts_utils.py └── soft_sharing_utils.py ├── pointer_model ├── __init__.py ├── attention_decoder.py ├── attention_utils.py ├── batcher.py ├── beam_search.py ├── data.py ├── decode.py ├── model.py ├── pg_decoder.py ├── pg_decoder_test.py └── policy_gradient_utils.py ├── run.py └── utils ├── __init__.py ├── misc_utils.py ├── model_utils.py ├── modified_rnn_cell_wrappers.py └── rnn_cell_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # my own todos 107 | TODOs.md 108 | _*.sh 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Han Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the code for our COLING 2018 paper: 2 | 3 | *[Dynamic Multi-Level Multi-Task Learning for Sentence Simplification](https://arxiv.org/abs/1806.07304)*. 4 | 5 | # Data Preprocessing 6 | Please follow the instructions from [Zhang et al. 2017](https://github.com/XingxingZhang/dress) for downloading the pre-processed dataset. 7 | To build the .bin files please follow the instructions from [See et al. 2017](https://github.com/abisee/pointer-generator), or [here](https://github.com/abisee/cnn-dailymail). 8 | 9 | # Evaluation Set-Up 10 | * Please follow the instructions from [Zhang et al. 2017](https://github.com/XingxingZhang/dress) for setting up the evaluation system. 11 | * FKGL implementations can be found [in this repo](https://github.com/mmautner/readability). 12 | * Modify corresponding directories in `evaluation_utils/sentence_simplification.py`. 13 | * Please note that evaluation metrics are calculated on corpus level. 14 | 15 | 16 | # Dependencies 17 | python 2.7 18 | tensorflow 1.4 19 | 20 | # Usage 21 | ```bash 22 | CUDA_VISIBLE_DEVICES="GPU_ID" python run.py \ 23 | --mode "string" \ 24 | --vocab_path "/path/to/vocab/file" \ 25 | --train_data_dirs "/path/to/trainig/data_1,/path/to/trainig/data_2,/path/to/trainig/data_3" \ 26 | --val_data_dir "/path/to/validation/data_1" \ 27 | --decode_data_dir "/path/to/decode/data_1" \ 28 | --eval_source_dir "/path/to/validation/data_1.source" \ 29 | --eval_target_dir "/path/to/validation/data_1.target" \ 30 | --max_enc_steps "int" --max_dec_steps "int" --batch_size "int" --steps_per_eval "int" \ 31 | --log_root "/path/to/log/root/" --exp_name "string" [--autoMR] \ 32 | --lr "float" --beam_size "int" --soft_sharing_coef "float" --mixing_ratios "mr_1,mr_2"\ 33 | --decode_ckpt_file "/path/to/ckpt" --decode_output_file "/path/to/file" 34 | 35 | ``` 36 | Pretrained models can be found [here](https://drive.google.com/file/d/1MJ6kq8nGfPcQaTZMreavkMET-BlG93Ij/view?usp=sharing). 37 | 38 | # Citation 39 | ``` 40 | @inproceedings{guo2018dynamic, 41 | title = {Dynamic Multi-Level Multi-Task Learning for Sentence Simplification}, 42 | author = {Han Guo and Ramakanth Pasunuru and Mohit Bansal}, 43 | booktitle = {Proceedings of the 27th International Conference on Computational Linguistics (COLING 2018)}, 44 | year = {2018} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HanGuo97/MultitaskSimplification/2632e7bdb5fd53c32092468662fefd8ea6c1dc5d/__init__.py -------------------------------------------------------------------------------- /evaluation_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HanGuo97/MultitaskSimplification/2632e7bdb5fd53c32092468662fefd8ea6c1dc5d/evaluation_utils/__init__.py -------------------------------------------------------------------------------- /evaluation_utils/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | 23 | https://github.com/tensorflow/nmt 24 | """ 25 | 26 | import collections 27 | import math 28 | 29 | 30 | def _get_ngrams(segment, max_order): 31 | """Extracts all n-grams upto a given maximum order from an input segment. 32 | 33 | Args: 34 | segment: text segment from which n-grams will be extracted. 35 | max_order: maximum length in tokens of the n-grams returned by this 36 | methods. 37 | 38 | Returns: 39 | The Counter containing all n-grams upto max_order in segment 40 | with a count of how many times each n-gram occurred. 41 | """ 42 | ngram_counts = collections.Counter() 43 | for order in range(1, max_order + 1): 44 | for i in range(0, len(segment) - order + 1): 45 | ngram = tuple(segment[i:i+order]) 46 | ngram_counts[ngram] += 1 47 | return ngram_counts 48 | 49 | 50 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 51 | smooth=False): 52 | """Computes BLEU score of translated segments against one or more references. 53 | 54 | Args: 55 | reference_corpus: list of lists of references for each translation. Each 56 | reference should be tokenized into a list of tokens. 57 | translation_corpus: list of translations to score. Each translation 58 | should be tokenized into a list of tokens. 59 | max_order: Maximum n-gram order to use when computing BLEU score. 60 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 61 | 62 | Returns: 63 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 64 | precisions and brevity penalty. 65 | """ 66 | matches_by_order = [0] * max_order 67 | possible_matches_by_order = [0] * max_order 68 | reference_length = 0 69 | translation_length = 0 70 | for (references, translation) in zip(reference_corpus, 71 | translation_corpus): 72 | reference_length += min(len(r) for r in references) 73 | translation_length += len(translation) 74 | 75 | merged_ref_ngram_counts = collections.Counter() 76 | for reference in references: 77 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 78 | translation_ngram_counts = _get_ngrams(translation, max_order) 79 | overlap = translation_ngram_counts & merged_ref_ngram_counts 80 | for ngram in overlap: 81 | matches_by_order[len(ngram)-1] += overlap[ngram] 82 | for order in range(1, max_order+1): 83 | possible_matches = len(translation) - order + 1 84 | if possible_matches > 0: 85 | possible_matches_by_order[order-1] += possible_matches 86 | 87 | precisions = [0] * max_order 88 | for i in range(0, max_order): 89 | if smooth: 90 | precisions[i] = ((matches_by_order[i] + 1.) / 91 | (possible_matches_by_order[i] + 1.)) 92 | else: 93 | if possible_matches_by_order[i] > 0: 94 | precisions[i] = (float(matches_by_order[i]) / 95 | possible_matches_by_order[i]) 96 | else: 97 | precisions[i] = 0.0 98 | 99 | if min(precisions) > 0: 100 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 101 | geo_mean = math.exp(p_log_sum) 102 | else: 103 | geo_mean = 0 104 | 105 | ratio = float(translation_length) / reference_length 106 | 107 | if ratio > 1.0: 108 | bp = 1. 109 | else: 110 | bp = math.exp(1 - 1. / ratio) 111 | 112 | bleu = geo_mean * bp 113 | 114 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 115 | -------------------------------------------------------------------------------- /evaluation_utils/evaluators.py: -------------------------------------------------------------------------------- 1 | from evaluation_utils import sentence_simplification 2 | SUPPORTED_TASKS = ["WikiLarge", "WikiSmall", "Newsela"] 3 | 4 | 5 | def evaluate(mode, 6 | gen_file, 7 | ref_file=None, 8 | execute_dir=None, 9 | source_file=None, 10 | evaluation_task=None, 11 | deanonymize_file=True): 12 | """ 13 | Evaluate the model on validation set 14 | 15 | Args: 16 | gen_file: model outputs 17 | ref_file: reference file 18 | execute_dir: directory to `ducrush` perl evaluation folder 19 | or directory to `JOSHUA` program directory 20 | source_file: directory to WikiLarge evaluation source 21 | evaluation_task: task to run evaluation 22 | """ 23 | if mode not in ["val", "test"]: 24 | raise ValueError("Unsupported mode ", mode) 25 | 26 | if evaluation_task not in SUPPORTED_TASKS: 27 | raise ValueError("%s is not supported" % evaluation_task) 28 | 29 | scores = sentence_simplification.evaluate( 30 | mode=mode, 31 | gen_file=gen_file, 32 | ref_file=ref_file, 33 | execute_dir=execute_dir, 34 | source_file=source_file, 35 | evaluation_task=evaluation_task, 36 | deanonymize_file=deanonymize_file) 37 | 38 | 39 | return scores 40 | -------------------------------------------------------------------------------- /evaluation_utils/nmt_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility for evaluating various tasks, e.g., translation & summarization. 17 | https://github.com/tensorflow/nmt""" 18 | 19 | import codecs 20 | import io 21 | import os 22 | import re 23 | import subprocess 24 | 25 | import tensorflow as tf 26 | 27 | from evaluation_utils import bleu 28 | from evaluation_utils import rouge 29 | 30 | 31 | __all__ = ["evaluate"] 32 | 33 | 34 | def evaluate(ref_file, trans_file, metric, subword_option=None): 35 | """Pick a metric and evaluate depending on task.""" 36 | # BLEU scores for translation task 37 | if metric.lower() == "bleu": 38 | evaluation_score = _bleu(ref_file, trans_file, 39 | subword_option=subword_option) 40 | # ROUGE scores for summarization tasks 41 | elif metric.lower() == "rouge": 42 | evaluation_score = _rouge(ref_file, trans_file, 43 | subword_option=subword_option) 44 | elif metric.lower() == "accuracy": 45 | evaluation_score = _accuracy(ref_file, trans_file) 46 | else: 47 | raise ValueError("Unknown metric %s" % metric) 48 | 49 | return evaluation_score 50 | 51 | 52 | def _clean(sentence, subword_option): 53 | """Clean and handle BPE or SPM outputs.""" 54 | sentence = sentence.strip() 55 | 56 | # BPE 57 | if subword_option == "bpe": 58 | sentence = re.sub("@@ ", "", sentence) 59 | 60 | # SPM 61 | elif subword_option == "spm": 62 | sentence = u"".join(sentence.split()).replace(u"\u2581", u" ").lstrip() 63 | 64 | return sentence 65 | 66 | 67 | # Follow //transconsole/localization/machine_translation/metrics/bleu_calc.py 68 | def _bleu(ref_file, trans_file, subword_option=None): 69 | """Compute BLEU scores and handling BPE.""" 70 | max_order = 4 71 | smooth = False 72 | 73 | ref_files = [ref_file] 74 | reference_text = [] 75 | for reference_filename in ref_files: 76 | with codecs.getreader("utf-8")( 77 | tf.gfile.GFile(reference_filename, "rb")) as fh: 78 | reference_text.append(fh.readlines()) 79 | 80 | per_segment_references = [] 81 | for references in zip(*reference_text): 82 | reference_list = [] 83 | for reference in references: 84 | reference = _clean(reference, subword_option) 85 | reference_list.append(reference.split(" ")) 86 | per_segment_references.append(reference_list) 87 | 88 | translations = [] 89 | with codecs.getreader("utf-8")(tf.gfile.GFile(trans_file, "rb")) as fh: 90 | for line in fh: 91 | line = _clean(line, subword_option=None) 92 | translations.append(line.split(" ")) 93 | 94 | # bleu_score, precisions, bp, ratio, translation_length, reference_length 95 | bleu_score, _, _, _, _, _ = bleu.compute_bleu( 96 | per_segment_references, translations, max_order, smooth) 97 | return 100 * bleu_score 98 | 99 | 100 | def _rouge(ref_file, summarization_file, subword_option=None): 101 | """Compute ROUGE scores and handling BPE.""" 102 | 103 | references = [] 104 | with io.open(ref_file, 'r', encoding='utf8') as fh: 105 | for line in fh: 106 | references.append(_clean(line, subword_option)) 107 | 108 | hypotheses = [] 109 | with io.open(summarization_file, 'r', encoding='utf8') as fh: 110 | for line in fh: 111 | hypotheses.append(_clean(line, subword_option=None)) 112 | 113 | rouge_score_map = rouge.rouge(hypotheses, references) 114 | return 100 * rouge_score_map["rouge_l/f_score"] 115 | 116 | 117 | def _accuracy(label_file, pred_file): 118 | """Compute accuracy, each line contains a label.""" 119 | 120 | with codecs.getreader("utf-8")(tf.gfile.GFile(label_file, "rb")) as label_fh: 121 | with codecs.getreader("utf-8")(tf.gfile.GFile(pred_file, "rb")) as pred_fh: 122 | count = 0.0 123 | match = 0.0 124 | for label in label_fh: 125 | label = label.strip() 126 | pred = pred_fh.readline().strip() 127 | if label == pred: 128 | match += 1 129 | count += 1 130 | return 100 * match / count 131 | 132 | 133 | def _moses_bleu(multi_bleu_script, tgt_test, trans_file, subword_option=None): 134 | """Compute BLEU scores using Moses multi-bleu.perl script.""" 135 | 136 | # TODO(thangluong): perform rewrite using python 137 | # BPE 138 | if subword_option == "bpe": 139 | debpe_tgt_test = tgt_test + ".debpe" 140 | if not os.path.exists(debpe_tgt_test): 141 | # TODO(thangluong): not use shell=True, can be a security hazard 142 | subprocess.call("cp %s %s" % 143 | (tgt_test, debpe_tgt_test), shell=True) 144 | subprocess.call("sed s/@@ //g %s" % (debpe_tgt_test), 145 | shell=True) 146 | tgt_test = debpe_tgt_test 147 | elif subword_option == "spm": 148 | despm_tgt_test = tgt_test + ".despm" 149 | if not os.path.exists(debpe_tgt_test): 150 | subprocess.call("cp %s %s" % (tgt_test, despm_tgt_test)) 151 | subprocess.call("sed s/ //g %s" % (despm_tgt_test)) 152 | subprocess.call(u"sed s/^\u2581/g %s" % (despm_tgt_test)) 153 | subprocess.call(u"sed s/\u2581/ /g %s" % (despm_tgt_test)) 154 | tgt_test = despm_tgt_test 155 | cmd = "%s %s < %s" % (multi_bleu_script, tgt_test, trans_file) 156 | 157 | # subprocess 158 | # TODO(thangluong): not use shell=True, can be a security hazard 159 | bleu_output = subprocess.check_output(cmd, shell=True) 160 | 161 | # extract BLEU score 162 | m = re.search("BLEU = (.+?),", bleu_output) 163 | bleu_score = float(m.group(1)) 164 | 165 | return bleu_score 166 | -------------------------------------------------------------------------------- /evaluation_utils/rouge.py: -------------------------------------------------------------------------------- 1 | """ROUGE metric implementation. 2 | 3 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 4 | This is a modified and slightly extended verison of 5 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 6 | 7 | https://github.com/tensorflow/nmt 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | 15 | import itertools 16 | import numpy as np 17 | 18 | #pylint: disable=C0103 19 | 20 | 21 | def _get_ngrams(n, text): 22 | """Calcualtes n-grams. 23 | 24 | Args: 25 | n: which n-grams to calculate 26 | text: An array of tokens 27 | 28 | Returns: 29 | A set of n-grams 30 | """ 31 | ngram_set = set() 32 | text_length = len(text) 33 | max_index_ngram_start = text_length - n 34 | for i in range(max_index_ngram_start + 1): 35 | ngram_set.add(tuple(text[i:i + n])) 36 | return ngram_set 37 | 38 | 39 | def _split_into_words(sentences): 40 | """Splits multiple sentences into words and flattens the result""" 41 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 42 | 43 | 44 | def _get_word_ngrams(n, sentences): 45 | """Calculates word n-grams for multiple sentences. 46 | """ 47 | assert len(sentences) > 0 48 | assert n > 0 49 | 50 | words = _split_into_words(sentences) 51 | return _get_ngrams(n, words) 52 | 53 | 54 | def _len_lcs(x, y): 55 | """ 56 | Returns the length of the Longest Common Subsequence between sequences x 57 | and y. 58 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 59 | 60 | Args: 61 | x: sequence of words 62 | y: sequence of words 63 | 64 | Returns 65 | integer: Length of LCS between x and y 66 | """ 67 | table = _lcs(x, y) 68 | n, m = len(x), len(y) 69 | return table[n, m] 70 | 71 | 72 | def _lcs(x, y): 73 | """ 74 | Computes the length of the longest common subsequence (lcs) between two 75 | strings. The implementation below uses a DP programming algorithm and runs 76 | in O(nm) time where n = len(x) and m = len(y). 77 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 78 | 79 | Args: 80 | x: collection of words 81 | y: collection of words 82 | 83 | Returns: 84 | Table of dictionary of coord and len lcs 85 | """ 86 | n, m = len(x), len(y) 87 | table = dict() 88 | for i in range(n + 1): 89 | for j in range(m + 1): 90 | if i == 0 or j == 0: 91 | table[i, j] = 0 92 | elif x[i - 1] == y[j - 1]: 93 | table[i, j] = table[i - 1, j - 1] + 1 94 | else: 95 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 96 | return table 97 | 98 | 99 | def _recon_lcs(x, y): 100 | """ 101 | Returns the Longest Subsequence between x and y. 102 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 103 | 104 | Args: 105 | x: sequence of words 106 | y: sequence of words 107 | 108 | Returns: 109 | sequence: LCS of x and y 110 | """ 111 | i, j = len(x), len(y) 112 | table = _lcs(x, y) 113 | 114 | def _recon(i, j): 115 | """private recon calculation""" 116 | if i == 0 or j == 0: 117 | return [] 118 | elif x[i - 1] == y[j - 1]: 119 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 120 | elif table[i - 1, j] > table[i, j - 1]: 121 | return _recon(i - 1, j) 122 | else: 123 | return _recon(i, j - 1) 124 | 125 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 126 | return recon_tuple 127 | 128 | 129 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 130 | """ 131 | Computes ROUGE-N of two text collections of sentences. 132 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 133 | papers/rouge-working-note-v1.3.1.pdf 134 | 135 | Args: 136 | evaluated_sentences: The sentences that have been picked by the summarizer 137 | reference_sentences: The sentences from the referene set 138 | n: Size of ngram. Defaults to 2. 139 | 140 | Returns: 141 | A tuple (f1, precision, recall) for ROUGE-N 142 | 143 | Raises: 144 | ValueError: raises exception if a param has len <= 0 145 | """ 146 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 147 | raise ValueError("Collections must contain at least 1 sentence.") 148 | 149 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 150 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 151 | reference_count = len(reference_ngrams) 152 | evaluated_count = len(evaluated_ngrams) 153 | 154 | # Gets the overlapping ngrams between evaluated and reference 155 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 156 | overlapping_count = len(overlapping_ngrams) 157 | 158 | # Handle edge case. This isn't mathematically correct, but it's good enough 159 | if evaluated_count == 0: 160 | precision = 0.0 161 | else: 162 | precision = overlapping_count / evaluated_count 163 | 164 | if reference_count == 0: 165 | recall = 0.0 166 | else: 167 | recall = overlapping_count / reference_count 168 | 169 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 170 | 171 | # return overlapping_count / reference_count 172 | return f1_score, precision, recall 173 | 174 | 175 | def _f_p_r_lcs(llcs, m, n): 176 | """ 177 | Computes the LCS-based F-measure score 178 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 179 | rouge-working-note-v1.3.1.pdf 180 | 181 | Args: 182 | llcs: Length of LCS 183 | m: number of words in reference summary 184 | n: number of words in candidate summary 185 | 186 | Returns: 187 | Float. LCS-based F-measure score 188 | """ 189 | r_lcs = llcs / m 190 | p_lcs = llcs / n 191 | beta = p_lcs / (r_lcs + 1e-12) 192 | num = (1 + (beta**2)) * r_lcs * p_lcs 193 | denom = r_lcs + ((beta**2) * p_lcs) 194 | f_lcs = num / (denom + 1e-12) 195 | return f_lcs, p_lcs, r_lcs 196 | 197 | 198 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 199 | """ 200 | Computes ROUGE-L (sentence level) of two text collections of sentences. 201 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 202 | rouge-working-note-v1.3.1.pdf 203 | 204 | Calculated according to: 205 | R_lcs = LCS(X,Y)/m 206 | P_lcs = LCS(X,Y)/n 207 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 208 | 209 | where: 210 | X = reference summary 211 | Y = Candidate summary 212 | m = length of reference summary 213 | n = length of candidate summary 214 | 215 | Args: 216 | evaluated_sentences: The sentences that have been picked by the summarizer 217 | reference_sentences: The sentences from the referene set 218 | 219 | Returns: 220 | A float: F_lcs 221 | 222 | Raises: 223 | ValueError: raises exception if a param has len <= 0 224 | """ 225 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 226 | raise ValueError("Collections must contain at least 1 sentence.") 227 | reference_words = _split_into_words(reference_sentences) 228 | evaluated_words = _split_into_words(evaluated_sentences) 229 | m = len(reference_words) 230 | n = len(evaluated_words) 231 | lcs = _len_lcs(evaluated_words, reference_words) 232 | return _f_p_r_lcs(lcs, m, n) 233 | 234 | 235 | def _union_lcs(evaluated_sentences, reference_sentence): 236 | """ 237 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 238 | subsequence between reference sentence ri and candidate summary C. For example 239 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 240 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 241 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 242 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 243 | LCS_u(r_i, C) = 4/5. 244 | 245 | Args: 246 | evaluated_sentences: The sentences that have been picked by the summarizer 247 | reference_sentence: One of the sentences in the reference summaries 248 | 249 | Returns: 250 | float: LCS_u(r_i, C) 251 | 252 | ValueError: 253 | Raises exception if a param has len <= 0 254 | """ 255 | if len(evaluated_sentences) <= 0: 256 | raise ValueError("Collections must contain at least 1 sentence.") 257 | 258 | lcs_union = set() 259 | reference_words = _split_into_words([reference_sentence]) 260 | combined_lcs_length = 0 261 | for eval_s in evaluated_sentences: 262 | evaluated_words = _split_into_words([eval_s]) 263 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 264 | combined_lcs_length += len(lcs) 265 | lcs_union = lcs_union.union(lcs) 266 | 267 | union_lcs_count = len(lcs_union) 268 | union_lcs_value = union_lcs_count / combined_lcs_length 269 | return union_lcs_value 270 | 271 | 272 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 273 | """ 274 | Computes ROUGE-L (summary level) of two text collections of sentences. 275 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 276 | rouge-working-note-v1.3.1.pdf 277 | 278 | Calculated according to: 279 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 280 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 281 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 282 | 283 | where: 284 | SUM(i,u) = SUM from i through u 285 | u = number of sentences in reference summary 286 | C = Candidate summary made up of v sentences 287 | m = number of words in reference summary 288 | n = number of words in candidate summary 289 | 290 | Args: 291 | evaluated_sentences: The sentences that have been picked by the summarizer 292 | reference_sentence: One of the sentences in the reference summaries 293 | 294 | Returns: 295 | A float: F_lcs 296 | 297 | Raises: 298 | ValueError: raises exception if a param has len <= 0 299 | """ 300 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 301 | raise ValueError("Collections must contain at least 1 sentence.") 302 | 303 | # total number of words in reference sentences 304 | m = len(_split_into_words(reference_sentences)) 305 | 306 | # total number of words in evaluated sentences 307 | n = len(_split_into_words(evaluated_sentences)) 308 | 309 | union_lcs_sum_across_all_references = 0 310 | for ref_s in reference_sentences: 311 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 312 | ref_s) 313 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 314 | 315 | 316 | def rouge(hypotheses, references): 317 | """Calculates average rouge scores for a list of hypotheses and 318 | references""" 319 | 320 | # Filter out hyps that are of 0 length 321 | # hyps_and_refs = zip(hypotheses, references) 322 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 323 | # hypotheses, references = zip(*hyps_and_refs) 324 | 325 | # Calculate ROUGE-1 F1, precision, recall scores 326 | rouge_1 = [ 327 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 328 | ] 329 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 330 | 331 | # Calculate ROUGE-2 F1, precision, recall scores 332 | rouge_2 = [ 333 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 334 | ] 335 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 336 | 337 | # Calculate ROUGE-L F1, precision, recall scores 338 | rouge_l = [ 339 | rouge_l_sentence_level([hyp], [ref]) 340 | for hyp, ref in zip(hypotheses, references) 341 | ] 342 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 343 | 344 | return { 345 | "rouge_1/f_score": rouge_1_f, 346 | "rouge_1/r_score": rouge_1_r, 347 | "rouge_1/p_score": rouge_1_p, 348 | "rouge_2/f_score": rouge_2_f, 349 | "rouge_2/r_score": rouge_2_r, 350 | "rouge_2/p_score": rouge_2_p, 351 | "rouge_l/f_score": rouge_l_f, 352 | "rouge_l/r_score": rouge_l_r, 353 | "rouge_l/p_score": rouge_l_p, 354 | } 355 | -------------------------------------------------------------------------------- /evaluation_utils/sentence_simplification.py: -------------------------------------------------------------------------------- 1 | """Utility functions simplification evaluations""" 2 | from __future__ import print_function 3 | 4 | import os 5 | import sys 6 | import torchfile 7 | import subprocess 8 | from readability import Readability 9 | 10 | """ 11 | try: 12 | reload(sys) 13 | sys.setdefaultencoding('utf8') 14 | except NameError: 15 | # raise EnvironmentError("This file only supports python2") 16 | pass 17 | """ 18 | reload(sys) 19 | sys.setdefaultencoding('utf8') 20 | 21 | 22 | 23 | 24 | def _replace_ner(sentence, ner_dict): 25 | """Replace the Named Entities in a sentence 26 | 27 | Args: 28 | sentence: str, sentences to be processed 29 | ner_dict: dictionary of {NER_tag: word} or an empty list 30 | 31 | Returns: 32 | processed sentence 33 | """ 34 | if isinstance(ner_dict, (list, tuple)): 35 | # the map is empty, no NER in the sentence 36 | return sentence 37 | 38 | def replace_fn(token): 39 | # for compatability between python2 and 3 40 | # upper because the NER are upper-based 41 | if token.upper().encode() in ner_dict.keys(): 42 | # lower case replaced words 43 | return ner_dict[token.upper().encode()].decode().lower() 44 | else: 45 | return token 46 | 47 | return " ".join(map(replace_fn, sentence.split())) 48 | 49 | 50 | def _deanonymize_file(file, ner_map_file, mode): 51 | if not os.path.exists(file): 52 | raise ValueError("file %s does not exist" % file) 53 | if not os.path.exists(file): 54 | raise ValueError("NER_Map %s does not exist" % ner_map_file) 55 | 56 | if mode not in ["train", "valid", "test"]: 57 | raise ValueError( 58 | "mode must be in `valid` for `test`, saw ", mode) 59 | 60 | # read in unprocessed file 61 | with open(file) as f: 62 | raw_outputs = f.readlines() 63 | raw_outputs = [d.strip() for d in raw_outputs] 64 | 65 | # read in NER_Map 66 | ner_maps = torchfile.load(ner_map_file) 67 | # for compatability between python2 and 3 68 | ner_map = ner_maps[mode.encode(encoding="utf-8")] 69 | 70 | # process sentences 71 | deanonymized_outputs = [] 72 | if not len(raw_outputs) == len(ner_map): 73 | raise ValueError("raw_outputs and ner_map shape mismatch") 74 | for raw_output, ner_dict in zip(raw_outputs, ner_map): 75 | deanonymized_output = _replace_ner(raw_output, ner_dict) 76 | deanonymized_outputs.append(deanonymized_output) 77 | 78 | deanonymized_file = file + "_deanonymized" 79 | with open(deanonymized_file, "w") as f: 80 | f.write("\n".join(deanonymized_outputs)) 81 | 82 | return deanonymized_file 83 | 84 | 85 | def run_BLEU(JOSHUA_dir, output_dir, reference_dir, num_references=8): 86 | joshua_output = subprocess.check_output(""" 87 | export JAVA_HOME=/usr/lib/jvm/java-1.8.0 88 | export JOSHUA=%s 89 | export LC_ALL=en_US.UTF-8 90 | export LANG=en_US.UTF-8 91 | $JOSHUA/bin/bleu %s %s %d 92 | """ % (JOSHUA_dir, output_dir, reference_dir, num_references), shell=True) 93 | 94 | score = float(joshua_output.split("BLEU = ")[1]) 95 | return 100 * score 96 | 97 | 98 | def run_FKGL(output_dir): 99 | with open(output_dir) as f: 100 | output = f.readlines() 101 | output = [d.lower().strip() for d in output] 102 | 103 | output_final = " ".join(output) 104 | rd = Readability(output_final) 105 | score = rd.FleschKincaidGradeLevel() 106 | return score 107 | 108 | 109 | def run_SARI(JOSHUA_dir, output_dir, reference_dir, source_dir, 110 | num_references=8): 111 | if num_references == 8: 112 | executable = None # Instruction: PLEASE CHANGE THE DIRECTORIES HERE 113 | elif num_references == 1: 114 | executable = None # Instruction: PLEASE CHANGE THE DIRECTORIES HERE 115 | else: 116 | raise ValueError("num_references must be 8 or 1") 117 | 118 | joshua_output = subprocess.check_output(""" 119 | export JAVA_HOME=/usr/lib/jvm/java-1.8.0 120 | export JOSHUA=%s 121 | export LC_ALL=en_US.UTF-8 122 | export LANG=en_US.UTF-8 123 | %s %s %s %s 124 | """ % (JOSHUA_dir, executable, output_dir, 125 | reference_dir, source_dir), shell=True) 126 | 127 | score = float(joshua_output.split("STAR = ")[1]) 128 | return 100 * score 129 | 130 | 131 | def evaluate(mode, 132 | gen_file, 133 | ref_file=None, 134 | execute_dir=None, 135 | source_file=None, 136 | evaluation_task=None, 137 | deanonymize_file=True): 138 | 139 | # Instruction: PLEASE CHANGE THE DIRECTORIES HERE 140 | JOSHUA_dir = None 141 | WikiLarge_NER_MAP_FILE = None 142 | WikiSmall_NER_MAP_FILE = None 143 | Newsela_NER_MAP_FILE = None 144 | 145 | if evaluation_task in ["WikiLarge"]: 146 | ner_map_file = WikiLarge_NER_MAP_FILE 147 | num_test_references = 8 148 | elif evaluation_task in ["WikiSmall"]: 149 | ner_map_file = WikiSmall_NER_MAP_FILE 150 | num_test_references = 1 151 | elif evaluation_task in ["Newsela"]: 152 | ner_map_file = Newsela_NER_MAP_FILE 153 | num_test_references = 1 154 | 155 | if not execute_dir: 156 | execute_dir = JOSHUA_dir 157 | if mode == "val": 158 | num_references = 1 159 | else: 160 | num_references = num_test_references 161 | if deanonymize_file: 162 | gen_file = _deanonymize_file( 163 | file=gen_file, mode="test", 164 | ner_map_file=ner_map_file) 165 | 166 | bleu = run_BLEU( 167 | JOSHUA_dir=execute_dir, 168 | output_dir=gen_file, 169 | reference_dir=ref_file, 170 | num_references=num_references) 171 | fkgl = run_FKGL( 172 | output_dir=gen_file) 173 | sari = run_SARI( 174 | JOSHUA_dir=execute_dir, 175 | output_dir=gen_file, 176 | reference_dir=ref_file, 177 | source_dir=source_file, 178 | num_references=num_references) 179 | 180 | scores = {"BLEU": bleu, 181 | "FKGL": fkgl, 182 | "SARI": sari} 183 | 184 | return scores 185 | -------------------------------------------------------------------------------- /multitask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HanGuo97/MultitaskSimplification/2632e7bdb5fd53c32092468662fefd8ea6c1dc5d/multitask/__init__.py -------------------------------------------------------------------------------- /multitask/bandits.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import pickle 7 | import warnings 8 | import numpy as np 9 | from namedlist import namedlist 10 | 11 | Q_Entry = namedlist("Q_Entry", ("Value", "Count")) 12 | 13 | 14 | def softmax(X, theta=1.0, axis=None): 15 | """Compute the softmax of each element along an axis of X. 16 | https://nolanbconaway.github.io/blog/2017/softmax-numpy 17 | """ 18 | # make X at least 2d 19 | y = np.atleast_2d(X) 20 | # find axis 21 | if axis is None: 22 | axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1) 23 | # multiply y against the theta parameter, 24 | y = y * float(theta) 25 | # subtract the max for numerical stability 26 | y = y - np.expand_dims(np.max(y, axis=axis), axis) 27 | # exponentiate y 28 | y = np.exp(y) 29 | # take the sum along the specified axis 30 | ax_sum = np.expand_dims(np.sum(y, axis=axis), axis) 31 | # finally: divide elementwise 32 | p = y / ax_sum 33 | # flatten if X was 1D 34 | if len(X.shape) == 1: 35 | p = p.flatten() 36 | 37 | return p 38 | 39 | 40 | def gradient_bandit(old_Q, reward, alpha): 41 | new_Q = old_Q + alpha * (reward - old_Q) 42 | return new_Q 43 | 44 | 45 | def convert_to_one_hot(action_id, action_space): 46 | return np.eye(action_space)[action_id] 47 | 48 | 49 | def boltzmann_exploration(Q_values, temperature=1.0): 50 | # for numerical stability, add 1e-7 51 | Q_probs = softmax(Q_values, theta=1 / (temperature + 1e-7)) 52 | action = np.random.choice(len(Q_probs), p=Q_probs) 53 | return action, Q_probs 54 | 55 | 56 | class MultiArmedBanditSelector(object): 57 | def __init__(self, 58 | num_actions, 59 | initial_weight, 60 | update_rate_fn, 61 | reward_shaping_fn, 62 | initial_temperature=1.0, 63 | temperature_anneal_rate=None): 64 | """ 65 | Args: 66 | update_rate_fn: fn(step) --> Real 67 | a function that takes `step` as input, and produce 68 | real value, the gradent bandit update rate. 69 | Common functions include: 70 | 1. (constant update) lambda step: CONSTANT 71 | 2. (average of entire history): lambda step: 1 / (step + 1) 72 | 73 | reward_shaping_fn: fn(reward, histories) --> Real 74 | a function that takes current and histories of rewards 75 | as inputs and produce real value, the reward to be fed into 76 | the bandits algorithm 77 | Common functions include: 78 | 1. lambda reward, hist: reward / CONSTANT 79 | 2. lambda reward, hist: [reward - mean(hist)] / std(hist) 80 | """ 81 | if not callable(update_rate_fn): 82 | raise TypeError("`update_rate_fn` must be callable") 83 | if not callable(reward_shaping_fn): 84 | raise TypeError("`reward_shaping_fn` must be callable") 85 | 86 | self._Q_entries = [ 87 | # intial Count = 1 because of `initial_weight` 88 | Q_Entry(Value=initial_weight, Count=1) 89 | for _ in range(num_actions)] 90 | self._num_actions = num_actions 91 | self._update_rate_fn = update_rate_fn 92 | self._reward_shaping_fn = reward_shaping_fn 93 | 94 | self._temperature = initial_temperature 95 | self._temperature_anneal_rate = temperature_anneal_rate 96 | 97 | self._sample_histories = [] 98 | self._update_histories = [] 99 | 100 | 101 | def sample(self, step=0): 102 | temperature_coef = ( # tau x rate^step 103 | np.power(self._temperature_anneal_rate, step) 104 | if self._temperature_anneal_rate is not None else 1.) 105 | 106 | chosen_action, Q_probs = boltzmann_exploration( 107 | Q_values=np.asarray(self.arm_weights), 108 | temperature=self._temperature * temperature_coef) 109 | 110 | self._sample_histories.append([Q_probs, chosen_action]) 111 | 112 | return chosen_action, Q_probs 113 | 114 | def update(self, reward, chosen_arm): 115 | # uses sampling, set weights = 1 116 | if not isinstance(chosen_arm, int): 117 | raise ValueError("chosen_arm must be integers") 118 | if not chosen_arm < self._num_actions: 119 | raise ValueError("chosen_arm out of range") 120 | 121 | step_size = self._update_rate_fn( 122 | self._Q_entries[chosen_arm].Count) 123 | shaped_reward = self._reward_shaping_fn( 124 | reward, self.reward_histories) 125 | 126 | new_Q = gradient_bandit(reward=shaped_reward, alpha=step_size, 127 | old_Q=self._Q_entries[chosen_arm].Value) 128 | 129 | self._Q_entries[chosen_arm].Count += 1 130 | self._Q_entries[chosen_arm].Value = new_Q 131 | self._update_histories.append([reward, chosen_arm, shaped_reward]) 132 | 133 | 134 | @property 135 | def arm_weights(self): 136 | return [Q.Value for Q in self._Q_entries] 137 | 138 | @property 139 | def step_counts(self): 140 | return np.sum([Q.Count for Q in self._Q_entries]) 141 | 142 | @property 143 | def reward_histories(self): 144 | # at the start, the update_histories is empty 145 | # to avoid nan, we will force set this to 0 146 | if len(self._update_histories) == 0: 147 | return [0.0] 148 | 149 | return [hist[0] for hist in self._update_histories] 150 | 151 | def save(self, file_dir): 152 | with open(file_dir + "._Q_entries", "wb") as f: 153 | pickle.dump(self._Q_entries, f, pickle.HIGHEST_PROTOCOL) 154 | 155 | with open(file_dir + "._sample_histories", "wb") as f: 156 | pickle.dump(self._sample_histories, f, pickle.HIGHEST_PROTOCOL) 157 | 158 | with open(file_dir + "._update_histories", "wb") as f: 159 | pickle.dump(self._update_histories, f, pickle.HIGHEST_PROTOCOL) 160 | 161 | print("INFO: Successfully Saved MABSelector to ", file_dir) 162 | 163 | def load(self, file_dir): 164 | warnings.warn("num_actions are *NOT* checked") 165 | for suffix in ["._Q_entries", 166 | "._sample_histories", 167 | "._update_histories"]: 168 | if not os.path.exists(file_dir + suffix): 169 | raise ValueError("%s File not exist ", suffix) 170 | 171 | with open(file_dir + "._Q_entries", "rb") as f: 172 | Q_values = pickle.load(f) 173 | 174 | with open(file_dir + "._sample_histories", "rb") as f: 175 | sample_histories = pickle.load(f) 176 | 177 | with open(file_dir + "._update_histories", "rb") as f: 178 | update_histories = pickle.load(f) 179 | 180 | self._Q_entries = Q_values 181 | self._sample_histories = sample_histories 182 | self._update_histories = update_histories 183 | 184 | print("INFO: Successfully Loaded %s from %s" % 185 | (self.__class__.__name__, file_dir)) 186 | -------------------------------------------------------------------------------- /multitask/multitask_autoMR_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from multitask import bandits 7 | from multitask.multitask_base_model import MultitaskBaseModel 8 | 9 | 10 | class MultitaskAutoMRModel(MultitaskBaseModel): 11 | """ 12 | Multitask model with automatic task selection 13 | 14 | Build a TaskSelector object that keeps track of 15 | previous val-loss when running on task S, and 16 | choose which task to run by sampling from: 17 | 18 | S_t ~ P(S_t | history of validation loss) 19 | 20 | where P is modeled as a boltzmann distribution 21 | 22 | P(S | history) = softmax(history) 23 | 24 | and S is kept constant until new validation loss 25 | is available, that is, every 10 or so steps 26 | 27 | Q Score should be negative loss thus lower is better 28 | high initial Q for being "optimistic under uncertainty" 29 | 30 | """ 31 | def _build_models(self, 32 | names, 33 | selector_Q_initial, 34 | alpha=0.3, 35 | temperature_anneal_rate=None, 36 | *args, **kargs): 37 | 38 | self._task_selector_actions = names 39 | self._TaskSelector = bandits.MultiArmedBanditSelector( 40 | num_actions=len(names), 41 | initial_weight=selector_Q_initial, 42 | update_rate_fn=lambda step: alpha, # constant update 43 | reward_shaping_fn=lambda reward, hist: reward, # no shaping 44 | temperature_anneal_rate=temperature_anneal_rate) 45 | print("Initial TaskSelector Q_score: %.1f, " 46 | "and temperature anneal rate %.5f" 47 | % (selector_Q_initial, 48 | self._TaskSelector._temperature_anneal_rate or 1.0)) 49 | 50 | # initial task will be main task 51 | self._current_task_index = 0 52 | 53 | # normal building models 54 | return super(MultitaskAutoMRModel, self)._build_models( 55 | names=names, *args, **kargs) 56 | 57 | def update_TaskSelector_Q_values(self, Q_score): 58 | 59 | self._TaskSelector.update( 60 | reward=Q_score, 61 | chosen_arm=self._current_task_index) 62 | 63 | # sample a new task to run 64 | self._current_task_index, _ = ( 65 | self._TaskSelector.sample(step=self.global_step)) 66 | 67 | # print info 68 | print("\n\n\n") 69 | print("New Q_score: %.3f" % Q_score) 70 | print("ChosenTask: %d" % self._current_task_index) 71 | for idx, val in enumerate(self._TaskSelector.arm_weights): 72 | print("%s/Expected_Q_Value: %.3f" 73 | % (self._task_selector_actions[idx], val)) 74 | print("\n\n\n") 75 | 76 | def _task_selector(self, step): 77 | # override parent method 78 | # step argument is kept for compatability 79 | return self._current_task_index 80 | 81 | def save_selector(self): 82 | # additionally save the selector 83 | selector_dir = os.path.join(self._logdir, "mab_selector.pkl") 84 | self._TaskSelector.save(selector_dir) 85 | 86 | def load_selector(self): 87 | # additionally restore the selector 88 | selector_dir = os.path.join(self._logdir, "mab_selector.pkl") 89 | # if not exist, skip this 90 | if os.path.exists(selector_dir): 91 | self._TaskSelector.load(selector_dir) 92 | 93 | def save_session(self): 94 | self.save_selector() 95 | return super(MultitaskAutoMRModel, self).save_session() 96 | 97 | 98 | def initialize_or_restore_session(self, *args, **kargs): 99 | self.load_selector() 100 | return super(MultitaskAutoMRModel, self).initialize_or_restore_session(*args, **kargs) 101 | -------------------------------------------------------------------------------- /multitask/multitask_base_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from utils import misc_utils 10 | from multitask import multitask_utils 11 | from multitask import sharing_dicts_utils 12 | from pointer_model.batcher import Batcher 13 | 14 | 15 | MIXING_RATIOS_BASE = 10 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | 18 | 19 | class MultitaskBaseModel(object): 20 | """Multitask Model""" 21 | def __init__(self, 22 | names, 23 | all_hparams, 24 | mixing_ratios, 25 | model_creators, 26 | logdir, 27 | soft_sharing_coef=None, 28 | data_generators=None, 29 | val_data_generator=None, 30 | *args, **kargs): 31 | # check lengths match 32 | if not len(names) == len(all_hparams): 33 | raise ValueError("names and all_hparams size mismatch") 34 | if not len(names) == len(model_creators): 35 | raise ValueError("names and model_creators size mismatch") 36 | if data_generators and not ( 37 | len(names) == len(data_generators) and 38 | isinstance(data_generators, MultitaskBatcher)): 39 | raise ValueError("names and data_generators shape mismatch or " 40 | "data_generators is not MultitaskBatcher") 41 | 42 | # check mixing ratios and MTL 43 | if len(names) == 1 and mixing_ratios is not None: 44 | raise ValueError("if running single model, set mixing_ratios None") 45 | if mixing_ratios is not None: 46 | if len(names) != len(mixing_ratios) + 1: 47 | raise ValueError("names and mixing_ratios + 1 size mismatch") 48 | checked_mixing_ratio = [ 49 | _assert_mixing_ratio_compatability(mr) for mr in mixing_ratios] 50 | print("With Base %d Scaled mixing batch ratios are " % 51 | MIXING_RATIOS_BASE, checked_mixing_ratio) 52 | 53 | if not soft_sharing_coef or soft_sharing_coef < 1e-6: 54 | raise ValueError("soft_sharing_coef too small") 55 | 56 | # misc check 57 | if not all([callable(mc) for mc in model_creators]): 58 | raise TypeError("Expect model_creator to be callable") 59 | misc_utils.assert_all_same(all_hparams, attr="batch_size") 60 | 61 | 62 | 63 | if len(names) == 1: 64 | sharing_dicts = [ 65 | sharing_dicts_utils.sharing_dict_soft] 66 | # won't be used anyway 67 | soft_sharing_params = [ 68 | sharing_dicts_utils.Layered_Shared_Params] 69 | else: 70 | sharing_dicts = [ 71 | sharing_dicts_utils.sharing_dict_soft, 72 | sharing_dicts_utils.sharing_dict_soft] 73 | soft_sharing_params = [ 74 | sharing_dicts_utils.Layered_Shared_Params, 75 | sharing_dicts_utils.E1D2_Shared_Params] 76 | 77 | # make sure sharing dictionaries are the same 78 | misc_utils.assert_all_same(sharing_dicts) 79 | if len(names) != 1: 80 | # sharing dicts and soft-sharing params for main models 81 | # in decode models or baseine, only one sharing_dict 82 | sharing_dicts = [sharing_dicts[0]] + sharing_dicts 83 | # main model's soft-sharing params should be 84 | # the union of two soft-sharing params 85 | # and only one in decode models 86 | soft_sharing_params = [misc_utils.union_lists( 87 | soft_sharing_params)] + soft_sharing_params 88 | 89 | # create MTL scopes 90 | MTL_scope = multitask_utils.MTLScope(names, sharing_dicts) 91 | 92 | # build models 93 | graph = tf.Graph() 94 | with graph.as_default(): 95 | # global step shared across all models 96 | global_step = tf.Variable( 97 | 0, name='global_step', trainable=False) 98 | models, steps = self._build_models( 99 | names=names, 100 | MTL_scope=MTL_scope, 101 | all_hparams=all_hparams, 102 | global_step=global_step, 103 | model_creators=model_creators, 104 | soft_sharing_coef=soft_sharing_coef, 105 | soft_sharing_params=soft_sharing_params, 106 | *args, **kargs) 107 | 108 | saver = tf.train.Saver(max_to_keep=20) 109 | 110 | save_path = None 111 | summary_dir = None 112 | summary_writer = None 113 | if logdir is not None: 114 | # e.g. model-113000.meta 115 | save_path = os.path.join(logdir, "model") 116 | summary_dir = os.path.join(logdir, "summaries") 117 | summary_writer = tf.summary.FileWriter(summary_dir) 118 | 119 | if not len(names) == len(models): 120 | raise ValueError("built `models` have mismatch shape, names") 121 | 122 | 123 | self._sess = None 124 | self._graph = graph 125 | self._steps = steps 126 | self._names = names 127 | self._models = models 128 | self._MTL_scope = MTL_scope 129 | self._all_hparams = all_hparams 130 | self._global_step = global_step 131 | self._data_generators = data_generators 132 | self._val_data_generator = val_data_generator 133 | 134 | self._mixing_ratios = mixing_ratios 135 | self._sharing_dicts = sharing_dicts 136 | self._soft_sharing_coef = soft_sharing_coef 137 | self._soft_sharing_params = soft_sharing_params 138 | 139 | self._saver = saver 140 | self._logdir = logdir 141 | self._save_path = save_path 142 | self._summary_dir = summary_dir 143 | self._summary_writer = summary_writer 144 | 145 | 146 | def _build_models(self, 147 | names, 148 | MTL_scope, 149 | all_hparams, 150 | global_step, 151 | model_creators, 152 | soft_sharing_coef, 153 | soft_sharing_params, 154 | vocab, 155 | # kept for compatability 156 | *args, **kargs): 157 | models = [] 158 | steps = {"GlobalStep": 0} 159 | for name, hparams, model_creator, soft_sharing_param in \ 160 | zip(names, all_hparams, model_creators, soft_sharing_params): 161 | 162 | print("Creating %s \t%s Model" % (model_creator.__name__, name)) 163 | # this returns a object with scopes as attributes 164 | scope = MTL_scope.get_scopes_object(name) 165 | model = model_creator( 166 | hparams, vocab, 167 | global_step=global_step, 168 | name=name, scope=scope, 169 | soft_sharing_coef=soft_sharing_coef, 170 | soft_sharing_params=soft_sharing_param) 171 | 172 | with tf.variable_scope(name): 173 | model.build_graph() 174 | models.append(model) 175 | steps[name] = 0 176 | 177 | # reuse variables 178 | # actually, not necessary 179 | MTL_scope.reuse_all_shared_variables() 180 | 181 | return models, steps 182 | 183 | 184 | 185 | def initialize_or_restore_session(self, ckpt_file=None): 186 | """Initialize or restore session 187 | 188 | Args: 189 | ckpt_file: directory to specific checkpoints 190 | """ 191 | # restore from lastest_checkpoint or specific file 192 | with self._graph.as_default(): 193 | self._sess = tf.Session( 194 | graph=self._graph, config=misc_utils.get_config()) 195 | self._sess.run(tf.global_variables_initializer()) 196 | 197 | if self._logdir or ckpt_file: 198 | # restore from lastest_checkpoint or specific file if provided 199 | misc_utils.load_ckpt(saver=self._saver, 200 | sess=self._sess, 201 | ckpt_dir=self._logdir, 202 | ckpt_file=ckpt_file) 203 | return 204 | 205 | 206 | def _run_train_step(self, batch, model_idx): 207 | # when running non-major task 208 | # the regularized model is the major task 209 | if model_idx != 0: 210 | reg_model_idx = 0 211 | 212 | # when running the auxiliary model 213 | # the soft-shared parameters should be those 214 | # that used between this pair model main-aux parameters 215 | # e.g. for SNLI vs. CNNDM , use SNLI's soft-params 216 | filtering_fn = lambda name: ( 217 | name in self._soft_sharing_params[model_idx]) 218 | 219 | else: # when model_idx == 0 220 | # for 3-way models 221 | # when running second or third model 222 | # reg_model is the first model 223 | # when running the first model 224 | # the reg_model is either second or third model 225 | reg_model_idx = 2 if self._steps[self._names[0]] % 2 else 1 226 | 227 | # when running the main model 228 | # the soft-shared parameters should be those 229 | # that used between this pair model main-aux parameters 230 | # e.g. for CNNDN vs. SNLI, use SNLI's soft-params 231 | filtering_fn = lambda name: ( 232 | name in self._soft_sharing_params[reg_model_idx]) 233 | 234 | 235 | return self._models[model_idx].run_train_step( 236 | sess=self._sess, batch=batch, 237 | reg_model_name=self._names[reg_model_idx], 238 | reg_filtering_fn=filtering_fn, 239 | all_scopes=self._MTL_scope.all_scopes) 240 | 241 | def run_train_step(self): 242 | model_idx = self._task_selector(self.global_step) 243 | model_name = self._names[model_idx] 244 | 245 | # get data batch 246 | data_batch = self._data_generators.next_batch(model_idx) 247 | # run one step 248 | train_step_info = self._run_train_step(data_batch, model_idx) 249 | # increment train step 250 | self._steps[model_name] += 1 251 | self._steps["GlobalStep"] += 1 252 | # print info and write summary 253 | # and return the loss for debug usages 254 | return self._log_train_step_info(train_step_info) 255 | 256 | def _log_train_step_info(self, train_step_info): 257 | # log statistics 258 | loss = train_step_info["loss"] 259 | summaries = train_step_info["summaries"] 260 | train_step = self._steps["GlobalStep"] 261 | self._summary_writer.add_summary(summaries, train_step) 262 | 263 | if train_step % 100 == 0: 264 | self._summary_writer.flush() 265 | 266 | if not np.isfinite(loss): 267 | self.save_session() 268 | raise Exception("Loss is not finite. Stopping.") 269 | 270 | # print statistics 271 | step_msg = "loss: %f step %d " % (loss, train_step) 272 | 273 | for key, val in self._steps.items(): 274 | step_msg += "%s %d " % (key, val) 275 | 276 | tf.logging.info(step_msg) 277 | return loss 278 | 279 | def run_eval_step(self, model_idx=0): 280 | # usually only use the main model 281 | model_name = self._names[model_idx] 282 | # get data batch 283 | val_data_batch = self._val_data_generator.next_batch(model_idx) 284 | # run one step 285 | val_step_info = self._models[model_idx].run_eval_step( 286 | sess=self._sess, batch=val_data_batch) 287 | # NLL loss not included 288 | val_nll_loss = val_step_info["nll_loss"] 289 | 290 | return val_nll_loss 291 | 292 | def _task_selector(self, step): 293 | if self._mixing_ratios is None: 294 | return 0 295 | 296 | return _task_selector_for_three( 297 | self._mixing_ratios[0], self._mixing_ratios[1], step) 298 | 299 | def save_session(self): 300 | self._saver.save(self._sess, 301 | save_path=self._save_path, 302 | global_step=self.global_step) 303 | 304 | @property 305 | def global_step(self): 306 | return self._steps["GlobalStep"] 307 | 308 | @property 309 | def sess(self): 310 | return self._sess 311 | 312 | @property 313 | def graph(self): 314 | return self._graph 315 | 316 | @property 317 | def logdir(self): 318 | return self._logdir 319 | 320 | def run_encoder(self, sess, batch, model_idx=0): 321 | return self._models[model_idx].run_encoder(sess, batch) 322 | 323 | def decode_onestep(self, 324 | sess, 325 | batch, 326 | latest_tokens, 327 | enc_states, 328 | dec_init_states, 329 | prev_coverage, 330 | model_idx=0): 331 | return self._models[model_idx].decode_onestep(sess, batch, 332 | latest_tokens, enc_states, dec_init_states, prev_coverage) 333 | 334 | 335 | class MultitaskBatcher(object): 336 | """Decorator for Batcher for multiple models""" 337 | def __init__(self, data_paths, vocabs, hps, single_pass): 338 | if not len(data_paths) == len(vocabs): 339 | raise ValueError("data_paths and vocabs size mismatch") 340 | 341 | batchers = [] 342 | for data_path, vocab in zip(data_paths, vocabs): 343 | batcher = Batcher(data_path, vocab, hps, single_pass) 344 | batchers.append(batcher) 345 | 346 | self._vocabs = vocabs 347 | self._batchers = batchers 348 | self._data_paths = data_paths 349 | 350 | def next_batch(self, batcher_idx=0): 351 | return self._batchers[batcher_idx].next_batch() 352 | 353 | def __len__(self): 354 | return len(self._batchers) 355 | 356 | 357 | # ================================================ 358 | # some utility functions 359 | # ================================================ 360 | def _task_selector_for_three(mixing_ratio_1, mixing_ratio_2, step): 361 | if mixing_ratio_1 <= 0.01: 362 | raise ValueError("mr_1 too small") 363 | 364 | if mixing_ratio_2 <= 0.01: 365 | raise ValueError("mr_2 too small") 366 | 367 | left_over = step % MIXING_RATIOS_BASE 368 | task_one_boundary = (MIXING_RATIOS_BASE - 369 | int(MIXING_RATIOS_BASE * (mixing_ratio_1 + mixing_ratio_2))) 370 | 371 | task_two_boundary = (MIXING_RATIOS_BASE - 372 | int(MIXING_RATIOS_BASE * mixing_ratio_2)) 373 | 374 | if 0 <= left_over < task_one_boundary: 375 | return 0 376 | elif task_one_boundary <= left_over < task_two_boundary: 377 | return 1 378 | else: 379 | return 2 380 | 381 | 382 | def _assert_mixing_ratio_compatability(mixing_ratio): 383 | if not isinstance(mixing_ratio, (int, float)): 384 | raise TypeError("%s should be int or float, found " 385 | % mixing_ratio, type(mixing_ratio)) 386 | result = mixing_ratio * MIXING_RATIOS_BASE 387 | if not int(result) == result: 388 | raise ValueError("%s x %s = %s are not integers" % 389 | (mixing_ratio, MIXING_RATIOS_BASE, result)) 390 | 391 | if result <= 0.01: 392 | raise ValueError("MixingRatio too small") 393 | 394 | return int(result) 395 | -------------------------------------------------------------------------------- /multitask/multitask_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | SHARED_SCOPE_PREFIX = "Shared" 7 | SCOPE_LIST = ['WordEmb', 8 | 'EncoderFW', 9 | 'Projection', 10 | 'Attention', 11 | 'EncoderBW', 12 | 'Decoder', 13 | 'Pointer'] 14 | 15 | 16 | 17 | class MTLOutOfRangeError(tf.errors.OutOfRangeError): 18 | """Wraps tf.errors.OutOfRangeError with model_idx""" 19 | def __init__(self, node_def, op, message, model_idx): 20 | super(MTLOutOfRangeError, self).__init__( 21 | node_def=node_def, op=op, message=message) 22 | self.model_idx = model_idx 23 | 24 | 25 | class MTLScope(object): 26 | def __init__(self, 27 | model_names=None, 28 | sharing_dicts=None): 29 | if not isinstance(model_names, (list, tuple)): 30 | raise TypeError("model_names should be list of dicts") 31 | if not isinstance(sharing_dicts, (list, tuple)): 32 | raise TypeError("sharing_dicts should be list of dicts") 33 | if not len(model_names) == len(sharing_dicts): 34 | raise ValueError("model_names and sharing_dicts shape mismatch") 35 | 36 | # make sure variables that are shared used the same scope instance 37 | # by tracking previously created scopes 38 | 39 | # all_scopes keep track of all variable scopes, and store scopes 40 | # in the format e.g. { 41 | # ModelA_Enc_layer1: variable_scope, 42 | # ModelA_Enc_layer1: variable_scope 43 | # ModelB_Enc_layer2: ... } 44 | all_scopes = {} 45 | 46 | # model_scopes keep track of variables scopes for each model 47 | # in the format e.g. { 48 | # ModelA:{Encoder: [variable_scope, variable_scope ...] }, 49 | # ModelB: {...}} 50 | model_scopes = {} 51 | 52 | for var_name in SCOPE_LIST: 53 | for model_name, sharing_dict in zip(model_names, sharing_dicts): 54 | # set ModelName 55 | model_scopes.setdefault(model_name, 56 | {"Model": create_scope(model_name)}) 57 | 58 | # raise exception if the sharing dict is in wrong format 59 | if var_name not in sharing_dict.keys(): 60 | raise ValueError( 61 | "model %s variable %s not in sharing_dict" % 62 | (model_name, var_name)) 63 | 64 | # ensure will_share is a tuple of True or False 65 | if not is_sequence(sharing_dict[var_name]): 66 | this_is_sequence = False 67 | will_share_these_var = [sharing_dict[var_name]] 68 | else: 69 | this_is_sequence = True 70 | will_share_these_var = sharing_dict[var_name] 71 | 72 | for idx, will_shar_this_var in enumerate(will_share_these_var): 73 | if will_shar_this_var: 74 | # e.g. "Shared_Encoder_1" and "Shared_Encoder" 75 | (scope_name_without_idx, 76 | scope_name_maybe_with_idx) = ( 77 | get_scope_names_tuple(SHARED_SCOPE_PREFIX, 78 | var_name, idx, this_is_sequence)) 79 | 80 | # when sharing variables, get previou created ones 81 | # if scopes with the same same was created, otherwise 82 | # create a new one 83 | if scope_name_maybe_with_idx in all_scopes.keys(): 84 | # reuse previous shared scope 85 | scope = all_scopes[scope_name_maybe_with_idx] 86 | else: # create a new one 87 | scope = create_scope(scope_name_maybe_with_idx) 88 | all_scopes[scope_name_maybe_with_idx] = scope 89 | 90 | else: 91 | # e.g. "ModelA_Encoder_1" and "ModelA_Encoder" 92 | (scope_name_without_idx, 93 | scope_name_maybe_with_idx) = ( 94 | get_scope_names_tuple( 95 | model_name, var_name, idx, this_is_sequence)) 96 | 97 | # not sharing variables 98 | scope = create_scope(scope_name_maybe_with_idx) 99 | all_scopes[scope_name_maybe_with_idx] = scope 100 | 101 | # model_scopes are indexed by: 102 | # 1. model_name (dictionary) 103 | # 2. var_name (dictionary) 104 | # 3. index (list) 105 | if this_is_sequence: 106 | model_scopes[model_name].setdefault(var_name, []) 107 | model_scopes[model_name][var_name].append(scope) 108 | if model_scopes[model_name][var_name].index(scope) != idx: 109 | raise ValueError("Index messed up ", 110 | model_scopes[model_name][var_name]) 111 | else: 112 | model_scopes[model_name][var_name] = scope 113 | 114 | 115 | self.all_scopes = all_scopes 116 | self.model_scopes = model_scopes 117 | 118 | def print_scope_names(self): 119 | for model_name, scopes in self.model_scopes.items(): 120 | msg = "%s" % model_name 121 | for var, scope in scopes.items(): 122 | if is_sequence(scope): 123 | info = ["\t%s %s" % (s.name, s.reuse) for s in scope] 124 | info = "".join(info) 125 | else: 126 | info = "\t%s %s" % (scope.name, scope.reuse) 127 | 128 | msg += "\n%s:\t %s" % (var, info) 129 | print(msg + "\n") 130 | 131 | def reuse_all_shared_variables(self): 132 | for scope in self.all_scopes.values(): 133 | if SHARED_SCOPE_PREFIX in scope.name: 134 | scope.reuse_variables() 135 | 136 | def get_scopes_list(self): 137 | return list(self.model_scopes.values()) 138 | 139 | def get_scopes(self): 140 | raise Exception("Please switch to get_scopes_list") 141 | 142 | def get_scopes_object(self, model_name): 143 | if model_name not in self.model_scopes.keys(): 144 | raise ValueError( 145 | "model_name %s not in model_names" % model_name) 146 | 147 | return type("", (object,), self.model_scopes[model_name])() 148 | 149 | 150 | def create_scope(name): 151 | with tf.variable_scope(name) as scope: 152 | pass 153 | return scope 154 | 155 | 156 | def is_sequence(X): 157 | if isinstance(X, (list, tuple)): 158 | return True 159 | return False 160 | 161 | 162 | def get_scope_names_tuple(name_prefix, var_name, idx, this_is_sequence): 163 | scope_name_without_idx = "_".join([name_prefix, var_name]) 164 | scope_name_with_idx = "_".join([name_prefix, var_name, str(idx)]) 165 | 166 | if not this_is_sequence: 167 | return [scope_name_without_idx, scope_name_without_idx] 168 | else: 169 | return [scope_name_without_idx, scope_name_with_idx] -------------------------------------------------------------------------------- /multitask/sharing_dicts_utils.py: -------------------------------------------------------------------------------- 1 | sharing_dict_soft = { 2 | 'Attention': False, 3 | 'Decoder': [False, False], 4 | 'EncoderBW': [False, False], 5 | 'EncoderFW': [False, False], 6 | 'Pointer': False, 7 | 'Projection': False, 8 | 'WordEmb': False} 9 | 10 | 11 | Attention_Params = [ 12 | 'Newsela_Attention/memory_kernel/kernel:0', 13 | 'Newsela_Attention/input_kernel/kernel:0', 14 | 'Newsela_Attention/query_kernel/kernel:0', 15 | # 'Newsela_Attention/coverage_kernel/kernel:0', 16 | 'Newsela_Attention/attention_v:0', 17 | 'Newsela_Attention/output_kernel/kernel:0', 18 | 19 | 'WikiSmall_Attention/memory_kernel/kernel:0', 20 | 'WikiSmall_Attention/input_kernel/kernel:0', 21 | 'WikiSmall_Attention/query_kernel/kernel:0', 22 | # 'WikiSmall_Attention/coverage_kernel/kernel:0', 23 | 'WikiSmall_Attention/attention_v:0', 24 | 'WikiSmall_Attention/output_kernel/kernel:0', 25 | 26 | 'WikiLarge_Attention/memory_kernel/kernel:0', 27 | 'WikiLarge_Attention/input_kernel/kernel:0', 28 | 'WikiLarge_Attention/query_kernel/kernel:0', 29 | # 'WikiLarge_Attention/coverage_kernel/kernel:0', 30 | 'WikiLarge_Attention/attention_v:0', 31 | 'WikiLarge_Attention/output_kernel/kernel:0', 32 | 33 | 'SNLI_Attention/memory_kernel/kernel:0', 34 | 'SNLI_Attention/input_kernel/kernel:0', 35 | 'SNLI_Attention/query_kernel/kernel:0', 36 | # 'SNLI_Attention/coverage_kernel/kernel:0', 37 | 'SNLI_Attention/attention_v:0', 38 | 'SNLI_Attention/output_kernel/kernel:0', 39 | 40 | 'PP_Attention/memory_kernel/kernel:0', 41 | 'PP_Attention/input_kernel/kernel:0', 42 | 'PP_Attention/query_kernel/kernel:0', 43 | # 'PP_Attention/coverage_kernel/kernel:0', 44 | 'PP_Attention/attention_v:0', 45 | 'PP_Attention/output_kernel/kernel:0'] 46 | 47 | 48 | Encoder_LowerLayer_Params = [ 49 | 'Newsela_EncoderBW_0/lstm_cell/kernel:0', 50 | 'Newsela_EncoderFW_0/lstm_cell/kernel:0', 51 | 'WikiSmall_EncoderBW_0/lstm_cell/kernel:0', 52 | 'WikiSmall_EncoderFW_0/lstm_cell/kernel:0', 53 | 'WikiLarge_EncoderBW_0/lstm_cell/kernel:0', 54 | 'WikiLarge_EncoderFW_0/lstm_cell/kernel:0', 55 | 'PP_EncoderBW_0/lstm_cell/kernel:0', 56 | 'PP_EncoderFW_0/lstm_cell/kernel:0', 57 | 'SNLI_EncoderBW_0/lstm_cell/kernel:0', 58 | 'SNLI_EncoderFW_0/lstm_cell/kernel:0'] 59 | 60 | 61 | Encoder_HigherLayer_Params = [ 62 | 'Newsela_EncoderBW_1/lstm_cell/kernel:0', 63 | 'Newsela_EncoderFW_1/lstm_cell/kernel:0', 64 | 'WikiSmall_EncoderBW_1/lstm_cell/kernel:0', 65 | 'WikiSmall_EncoderFW_1/lstm_cell/kernel:0', 66 | 'WikiLarge_EncoderBW_1/lstm_cell/kernel:0', 67 | 'WikiLarge_EncoderFW_1/lstm_cell/kernel:0', 68 | 'PP_EncoderBW_1/lstm_cell/kernel:0', 69 | 'PP_EncoderFW_1/lstm_cell/kernel:0', 70 | 'SNLI_EncoderBW_1/lstm_cell/kernel:0', 71 | 'SNLI_EncoderFW_1/lstm_cell/kernel:0'] 72 | 73 | Decoder_HigherLayer_Params = [ 74 | 'Newsela_Decoder_0/lstm_cell/kernel:0', 75 | 'WikiSmall_Decoder_0/lstm_cell/kernel:0', 76 | 'WikiLarge_Decoder_0/lstm_cell/kernel:0', 77 | 'PP_Decoder_0/lstm_cell/kernel:0', 78 | 'SNLI_Decoder_0/lstm_cell/kernel:0'] 79 | 80 | Decoder_LowerLayer_Params = [ 81 | 'Newsela_Decoder_1/lstm_cell/kernel:0', 82 | 'WikiSmall_Decoder_1/lstm_cell/kernel:0', 83 | 'WikiLarge_Decoder_1/lstm_cell/kernel:0', 84 | 'PP_Decoder_1/lstm_cell/kernel:0', 85 | 'SNLI_Decoder_1/lstm_cell/kernel:0'] 86 | 87 | 88 | Layered_Shared_Params = ( 89 | Attention_Params + 90 | Encoder_HigherLayer_Params + 91 | Decoder_HigherLayer_Params) 92 | 93 | E1D2_Shared_Params = ( 94 | Encoder_LowerLayer_Params + 95 | Decoder_LowerLayer_Params) 96 | -------------------------------------------------------------------------------- /multitask/soft_sharing_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def _varname_to_plname(varname): 9 | # e.g. Shared_AttentionScope/Wa:0 10 | # should become Shared_AttentionScope_Wa_pl 11 | plname = varname.replace("/", "_") 12 | plname = plname.replace(":0", "") 13 | plname = "_".join([plname, "pl"]) 14 | return plname 15 | 16 | 17 | def get_regularzation_loss(filtering_fn, scope, coef=0.0001): 18 | """calculate regularization loss as soft-sharing constraint""" 19 | if not callable(filtering_fn): 20 | raise TypeError( 21 | "Expected `filtering_fn` to be callable, found ", 22 | type(filtering_fn)) 23 | 24 | with tf.variable_scope(scope.Model.name): 25 | shared_vars = [ 26 | v for v in tf.trainable_variables() if filtering_fn(v.name)] 27 | 28 | shared_var_pls = [ 29 | tf.placeholder(tf.float32, 30 | shape=shared_var.get_shape(), 31 | name=_varname_to_plname(shared_var.name)) 32 | for shared_var in shared_vars] 33 | 34 | 35 | if not len(shared_vars) == len(shared_var_pls): 36 | raise ValueError( 37 | "shared_vars and shared_var_pls have different lengths") 38 | 39 | # total regularization loss 40 | reg_loss = 0 41 | # mapping from placeholder name to var name 42 | reg_pl_names_dict = {} 43 | for var_pl, var in zip(shared_var_pls, shared_vars): 44 | reg_loss += tf.nn.l2_loss(var_pl - var) 45 | reg_pl_names_dict[var_pl.name] = var.name 46 | 47 | reg_loss = tf.multiply(coef, reg_loss) 48 | 49 | return reg_loss, reg_pl_names_dict 50 | 51 | 52 | def calc_regularization_loss(filtering_fn, 53 | reg_pl_names_dict, 54 | reg_model_name, 55 | feed_dict, 56 | sess, 57 | all_scopes=None): 58 | """Calculate regularization loss 59 | 60 | Args: 61 | filtering_fn: 62 | callable(reg_param_name) --> boolean 63 | whether to add regularization loss on this param 64 | if False, then reg_placeholder will be filled 65 | with same param effectively making regulaization loss 0 66 | reg_pl_names_dict: 67 | dictionary mapping placeholder_name to param_name 68 | reg_model_name: 69 | name of the model to be regularized w.r.t 70 | feed_dict: 71 | feed_dict to be used in sess.run 72 | all_scopes: 73 | all parameter scopes, used to check whether the 74 | new parameter name is valid 75 | 76 | """ 77 | if not callable(filtering_fn): 78 | raise TypeError("`filtering_fn` should be callable, found ", 79 | type(filtering_fn).__name__) 80 | 81 | for reg_pl_name, reg_param_name in reg_pl_names_dict.items(): 82 | # decide whether a parameter is to be softly shared 83 | # for those not softly shared at this timestep, 84 | # but still have a placeholder for parameters to be shared 85 | # we just let the parameters to be regularized bt itself, 86 | # or setting Loss = || param_i - param_i || 87 | # which effectively means Loss = 0 88 | # there are obviously better ways to approach this 89 | # e.g. using conditional graph, but this is a bit tricky 90 | # to implement and not much speed gain anyway 91 | if filtering_fn(reg_param_name): 92 | changed_reg_param_name = "_".join( 93 | [reg_model_name] + reg_param_name.split("_")[1:]) 94 | 95 | else: 96 | # this will make reg_loss == 0 97 | changed_reg_param_name = reg_param_name 98 | 99 | # just to make sure the new name is within scopes 100 | if (all_scopes and changed_reg_param_name.split("/")[0] 101 | not in all_scopes): 102 | raise ValueError("%s not in all scopes" 103 | % changed_reg_param_name.split("/")[0]) 104 | 105 | # add regularization terms into feed_dict 106 | feed_dict[reg_pl_name] = sess.run(changed_reg_param_name) 107 | 108 | return feed_dict 109 | -------------------------------------------------------------------------------- /pointer_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HanGuo97/MultitaskSimplification/2632e7bdb5fd53c32092468662fefd8ea6c1dc5d/pointer_model/__init__.py -------------------------------------------------------------------------------- /pointer_model/attention_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file defines the decoder 18 | https://github.com/abisee/pointer-generator""" 19 | 20 | import tensorflow as tf 21 | from tensorflow.python.ops import variable_scope 22 | from tensorflow.python.ops import array_ops 23 | from tensorflow.python.ops import nn_ops 24 | from tensorflow.python.ops import math_ops 25 | 26 | # Note: this function is based on tf.contrib.legacy_seq2seq_attention_decoder, which is now outdated. 27 | # In the future, it would make more sense to write variants on the 28 | # attention mechanism using the new seq2seq library for tensorflow 1.0: 29 | # https://www.tensorflow.org/api_guides/python/contrib.seq2seq#Attention 30 | 31 | 32 | def attention_decoder(scope, decoder_inputs, initial_state, encoder_states, enc_padding_mask, cell, initial_state_attention=False, pointer_gen=True, use_coverage=False, prev_coverage=None): 33 | """ 34 | Args: 35 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 36 | initial_state: 2D Tensor [batch_size x cell.state_size]. 37 | encoder_states: 3D Tensor [batch_size x attn_length x attn_size]. 38 | enc_padding_mask: 2D Tensor [batch_size x attn_length] containing 1s and 0s; indicates which of the encoder locations are padding (0) or a real token (1). 39 | cell: rnn_cell.RNNCell defining the cell function and size. 40 | initial_state_attention: 41 | Note that this attention decoder passes each decoder input through a linear layer with the previous step's context vector to get a modified version of the input. If initial_state_attention is False, on the first decoder step the "previous context vector" is just a zero vector. If initial_state_attention is True, we use initial_state to (re)calculate the previous step's context vector. We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step). 42 | pointer_gen: boolean. If True, calculate the generation probability p_gen for each decoder step. 43 | use_coverage: boolean. If True, use coverage mechanism. 44 | prev_coverage: 45 | If not None, a tensor with shape (batch_size, attn_length). The previous step's coverage vector. This is only not None in decode mode when using coverage. 46 | 47 | Returns: 48 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 49 | shape [batch_size x cell.output_size]. The output vectors. 50 | state: The final state of the decoder. A tensor shape [batch_size x cell.state_size]. 51 | attn_dists: A list containing tensors of shape (batch_size,attn_length). 52 | The attention distributions for each decoder step. 53 | p_gens: List of scalars. The values of p_gen for each decoder step. Empty list if pointer_gen=False. 54 | coverage: Coverage vector on the last step computed. None if use_coverage=False. 55 | """ 56 | with variable_scope.variable_scope(scope.Attention): 57 | # if this line fails, it's because the batch size isn't defined 58 | batch_size = encoder_states.get_shape()[0].value 59 | # if this line fails, it's because the attention length isn't defined 60 | attn_size = encoder_states.get_shape()[2].value 61 | 62 | # Reshape encoder_states (need to insert a dim) 63 | # now is shape (batch_size, attn_len, 1, attn_size) 64 | encoder_states = tf.expand_dims(encoder_states, axis=2) 65 | 66 | # To calculate attention, we calculate 67 | # v^T tanh(W_h h_i + W_s s_t + b_attn) 68 | # where h_i is an encoder state, and s_t a decoder state. 69 | # attn_vec_size is the length of the vectors v, b_attn, (W_h h_i) and (W_s s_t). 70 | # We set it to be equal to the size of the encoder states. 71 | attention_vec_size = attn_size 72 | 73 | # Get the weight matrix W_h and apply it to each encoder state to get 74 | # (W_h h_i), the encoder features 75 | W_h = variable_scope.get_variable( 76 | "memory_kernel", [1, 1, attn_size, attention_vec_size]) 77 | # shape (batch_size,attn_length,1,attention_vec_size) 78 | encoder_features = nn_ops.conv2d( 79 | encoder_states, W_h, [1, 1, 1, 1], "SAME") 80 | 81 | # Get the weight vectors v and w_c (w_c is for coverage) 82 | v = variable_scope.get_variable("attention_v", [attention_vec_size]) 83 | if use_coverage: 84 | with variable_scope.variable_scope(scope.Pointer): 85 | w_c = variable_scope.get_variable( 86 | "coverage_kernel", [1, 1, 1, attention_vec_size]) 87 | 88 | if prev_coverage is not None: # for beam search mode with coverage 89 | # reshape from (batch_size, attn_length) to (batch_size, attn_len, 90 | # 1, 1) 91 | prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage, 2), 3) 92 | 93 | def attention(decoder_state, coverage=None): 94 | """Calculate the context vector and attention distribution from the decoder state. 95 | 96 | Args: 97 | decoder_state: state of the decoder 98 | coverage: Optional. Previous timestep's coverage vector, shape (batch_size, attn_len, 1, 1). 99 | 100 | Returns: 101 | context_vector: weighted sum of encoder_states 102 | attn_dist: attention distribution 103 | coverage: new coverage vector. shape (batch_size, attn_len, 1, 1) 104 | """ 105 | 106 | # Pass the decoder state through a linear layer (this is W_s 107 | # s_t + b_attn in the paper) 108 | # shape (batch_size, attention_vec_size) 109 | decoder_features = linear( 110 | decoder_state, attention_vec_size, True, 111 | kernel_name="query_kernel", bias_name="query_bias") 112 | # reshape to (batch_size, 1, 1, attention_vec_size) 113 | decoder_features = tf.expand_dims( 114 | tf.expand_dims(decoder_features, 1), 1) 115 | 116 | def masked_attention(e): 117 | """Take softmax of e then apply enc_padding_mask and re-normalize""" 118 | attn_dist = nn_ops.softmax( 119 | e) # take softmax. shape (batch_size, attn_length) 120 | attn_dist *= enc_padding_mask # apply mask 121 | masked_sums = tf.reduce_sum( 122 | attn_dist, axis=1) # shape (batch_size) 123 | # re-normalize 124 | return attn_dist / tf.reshape(masked_sums, [-1, 1]) 125 | 126 | if use_coverage and coverage is not None: # non-first step of coverage 127 | # Multiply coverage vector by w_c to get coverage_features. 128 | # c has shape (batch_size, attn_length, 1, 129 | # attention_vec_size) 130 | coverage_features = nn_ops.conv2d( 131 | coverage, w_c, [1, 1, 1, 1], "SAME") 132 | 133 | # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + 134 | # b_attn) 135 | # shape (batch_size,attn_length) 136 | e = math_ops.reduce_sum( 137 | v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3]) 138 | 139 | # Calculate attention distribution 140 | attn_dist = masked_attention(e) 141 | 142 | # Update coverage vector 143 | coverage += array_ops.reshape(attn_dist, 144 | [batch_size, -1, 1, 1]) 145 | else: 146 | # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn) 147 | e = math_ops.reduce_sum( 148 | v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e 149 | 150 | # Calculate attention distribution 151 | attn_dist = masked_attention(e) 152 | 153 | if use_coverage: # first step of training 154 | coverage = tf.expand_dims(tf.expand_dims( 155 | attn_dist, 2), 2) # initialize coverage 156 | 157 | # Calculate the context vector from attn_dist and 158 | # encoder_states 159 | # shape (batch_size, attn_size). 160 | # print("attn_dist", attn_dist) 161 | # print("attn_dist reshaped", array_ops.reshape(attn_dist, [batch_size, -1, 1, 1])) 162 | # print("encoder_states", encoder_states) 163 | # print("attn_dist reshaped", math_ops.reduce_sum(array_ops.reshape( 164 | # attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2])) 165 | 166 | context_vector = math_ops.reduce_sum(array_ops.reshape( 167 | attn_dist, [batch_size, -1, 1, 1]) * encoder_states, [1, 2]) 168 | context_vector = array_ops.reshape( 169 | context_vector, [-1, attn_size]) 170 | 171 | return context_vector, attn_dist, coverage 172 | 173 | outputs = [] 174 | attn_dists = [] 175 | p_gens = [] 176 | states = initial_state 177 | coverage = prev_coverage # initialize coverage to None or whatever was passed in 178 | context_vector = array_ops.zeros([batch_size, attn_size]) 179 | # Ensure the second shape of attention vectors is set. 180 | context_vector.set_shape([None, attn_size]) 181 | if initial_state_attention: # true in decode mode 182 | # Re-calculate the context vector from the previous step so that we 183 | # can pass it through a linear layer with this step's input to get 184 | # a modified version of the input 185 | # in decode mode, this is what updates the coverage vector 186 | context_vector, _, coverage = attention(initial_state[-1].h, coverage) 187 | 188 | for i, inp in enumerate(decoder_inputs): 189 | # tf.logging.info( 190 | # "Adding attention_decoder timestep %i of %i", i, len(decoder_inputs)) 191 | if i > 0: 192 | [scope.Decoder[i].reuse_variables() 193 | for i in range(len(scope.Decoder))] 194 | scope.Attention.reuse_variables() 195 | scope.Pointer.reuse_variables() 196 | 197 | # Merge input and previous attentions into one vector x of the same 198 | # size as inp 199 | input_size = inp.get_shape().with_rank(2)[1] 200 | if input_size.value is None: 201 | raise ValueError( 202 | "Could not infer input size from input: %s" % inp.name) 203 | 204 | with variable_scope.variable_scope(scope.Attention, 205 | reuse=i > 0 or scope.Attention.reuse): 206 | x = linear([inp] + [context_vector], input_size, True, 207 | kernel_name="input_kernel", bias_name="input_bias") 208 | 209 | # Run the decoder RNN cell. cell_output = decoder state 210 | with variable_scope.variable_scope("decoder_scope"): 211 | cell_output, states = cell(x, states) 212 | state = states[-1] 213 | 214 | # Run the attention mechanism. 215 | with tf.variable_scope(scope.Attention, 216 | reuse=initial_state_attention or i > 0 or scope.Attention.reuse): 217 | if i == 0 and initial_state_attention: # always true in decode mode 218 | # you need this because you've already run the initial 219 | # attention(...) call 220 | 221 | context_vector, attn_dist, _ = attention( 222 | cell_output, coverage) # don't allow coverage to update 223 | else: 224 | context_vector, attn_dist, coverage = attention( 225 | cell_output, coverage) 226 | 227 | attn_dists.append(attn_dist) 228 | 229 | # Calculate p_gen 230 | if pointer_gen: 231 | with variable_scope.variable_scope(scope.Pointer, 232 | reuse=(i != 0) or scope.Pointer.reuse): 233 | p_gen = linear( 234 | [context_vector, state.c, state.h, x], 1, True, 235 | kernel_name="pgen_kernel", bias_name="pgen_bias") # a scalar 236 | p_gen = tf.sigmoid(p_gen) 237 | p_gens.append(p_gen) 238 | 239 | # Concatenate the cell_output (= decoder state) and the context vector, and pass them through a linear layer 240 | # This is V[s_t, h*_t] + b in the paper 241 | with variable_scope.variable_scope(scope.Attention, 242 | reuse=i > 0 or scope.Attention.reuse): 243 | output = linear([cell_output] + [context_vector], 244 | cell.output_size, True, 245 | kernel_name="output_kernel", bias_name="output_bias") 246 | outputs.append(output) 247 | 248 | # If using coverage, reshape it 249 | if coverage is not None: 250 | coverage = array_ops.reshape(coverage, [batch_size, -1]) 251 | 252 | return outputs, states, attn_dists, p_gens, coverage 253 | 254 | 255 | def linear(args, output_size, bias, bias_start=0.0, 256 | kernel_name="Matrix", bias_name="Bias"): 257 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 258 | 259 | Args: 260 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 261 | output_size: int, second dimension of W[i]. 262 | bias: boolean, whether to add a bias term or not. 263 | bias_start: starting value to initialize the bias; 0 by default. 264 | scope: VariableScope for the created subgraph; defaults to "Linear". 265 | 266 | Returns: 267 | A 2D Tensor with shape [batch x output_size] equal to 268 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 269 | 270 | Raises: 271 | ValueError: if some of the arguments has unspecified or wrong shape. 272 | """ 273 | if args is None or (isinstance(args, (list, tuple)) and not args): 274 | raise ValueError("`args` must be specified") 275 | if not isinstance(args, (list, tuple)): 276 | args = [args] 277 | 278 | # Calculate the total size of arguments on dimension 1. 279 | total_arg_size = 0 280 | shapes = [a.get_shape().as_list() for a in args] 281 | for shape in shapes: 282 | if len(shape) != 2: 283 | raise ValueError( 284 | "Linear is expecting 2D arguments: %s" % str(shapes)) 285 | if not shape[1]: 286 | raise ValueError( 287 | "Linear expects shape[1] of arguments: %s" % str(shapes)) 288 | else: 289 | total_arg_size += shape[1] 290 | 291 | # Now the computation. 292 | matrix = tf.get_variable(kernel_name, [total_arg_size, output_size]) 293 | if len(args) == 1: 294 | res = tf.matmul(args[0], matrix) 295 | else: 296 | res = tf.matmul(tf.concat(axis=1, values=args), matrix) 297 | if not bias: 298 | return res 299 | bias_term = tf.get_variable( 300 | bias_name, [output_size], initializer=tf.constant_initializer(bias_start)) 301 | return res + bias_term 302 | -------------------------------------------------------------------------------- /pointer_model/attention_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import collections 9 | import functools 10 | import math 11 | 12 | import numpy as np 13 | 14 | from tensorflow.python.framework import dtypes 15 | from tensorflow.python.framework import ops 16 | from tensorflow.python.framework import tensor_shape 17 | from tensorflow.python.layers import base as layers_base 18 | from tensorflow.python.layers import core as layers_core 19 | from tensorflow.python.ops import array_ops 20 | from tensorflow.python.ops import check_ops 21 | from tensorflow.python.ops import clip_ops 22 | from tensorflow.python.ops import functional_ops 23 | from tensorflow.python.ops import init_ops 24 | from tensorflow.python.ops import math_ops 25 | from tensorflow.python.ops import nn_ops 26 | from tensorflow.python.ops import random_ops 27 | from tensorflow.python.ops import rnn_cell_impl 28 | from tensorflow.python.ops import tensor_array_ops 29 | from tensorflow.python.ops import variable_scope 30 | from tensorflow.python.util import nest 31 | 32 | 33 | def _bahdanau_score(processed_query, keys, normalize): 34 | """Implements Bahdanau-style (additive) scoring function. 35 | This attention has two forms. The first is Bhandanau attention, 36 | as described in: 37 | Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 38 | "Neural Machine Translation by Jointly Learning to Align and Translate." 39 | ICLR 2015. https://arxiv.org/abs/1409.0473 40 | The second is the normalized form. This form is inspired by the 41 | weight normalization article: 42 | Tim Salimans, Diederik P. Kingma. 43 | "Weight Normalization: A Simple Reparameterization to Accelerate 44 | Training of Deep Neural Networks." 45 | https://arxiv.org/abs/1602.07868 46 | To enable the second form, set `normalize=True`. 47 | Args: 48 | processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. 49 | keys: Processed memory, shape `[batch_size, max_time, num_units]`. 50 | normalize: Whether to normalize the score function. 51 | Returns: 52 | A `[batch_size, max_time]` tensor of unnormalized score values. 53 | """ 54 | dtype = processed_query.dtype 55 | # Get the number of hidden units from the trailing dimension of keys 56 | num_units = keys.shape[2].value or array_ops.shape(keys)[2] 57 | # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. 58 | processed_query = array_ops.expand_dims(processed_query, 1) 59 | v = variable_scope.get_variable( 60 | "attention_v", [num_units], dtype=dtype) 61 | if normalize: 62 | # Scalar used in weight normalization 63 | g = variable_scope.get_variable( 64 | "attention_g", dtype=dtype, 65 | initializer=math.sqrt((1. / num_units))) 66 | # Bias added prior to the nonlinearity 67 | b = variable_scope.get_variable( 68 | "attention_b", [num_units], dtype=dtype, 69 | initializer=init_ops.zeros_initializer()) 70 | # normed_v = g * v / ||v|| 71 | normed_v = g * v * math_ops.rsqrt( 72 | math_ops.reduce_sum(math_ops.square(v))) 73 | return math_ops.reduce_sum( 74 | normed_v * math_ops.tanh(keys + processed_query + b), [2]) 75 | else: 76 | return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) 77 | -------------------------------------------------------------------------------- /pointer_model/batcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to process data into batches 18 | https://github.com/abisee/pointer-generator""" 19 | from builtins import range 20 | from random import shuffle 21 | from threading import Thread 22 | from pointer_model import data 23 | import time 24 | import numpy as np 25 | import tensorflow as tf 26 | from six.moves import queue as Queue 27 | 28 | 29 | class Example(object): 30 | """Class representing a train/val/test example for text summarization.""" 31 | 32 | def __init__(self, article, abstract_sentences, vocab, hps): 33 | """Initializes the Example, performing tokenization and truncation to produce the encoder, decoder and target sequences, which are stored in self. 34 | 35 | Args: 36 | article: source text; a string. each token is separated by a single space. 37 | abstract_sentences: list of strings, one per abstract sentence. In each sentence, each token is separated by a single space. 38 | vocab: Vocabulary object 39 | hps: hyperparameters 40 | """ 41 | self.hps = hps 42 | 43 | # Get ids of special tokens 44 | start_decoding = vocab.word2id(data.START_DECODING) 45 | stop_decoding = vocab.word2id(data.STOP_DECODING) 46 | 47 | # Process the article 48 | article_words = article.split() 49 | if len(article_words) > hps.max_enc_steps: 50 | article_words = article_words[:hps.max_enc_steps] 51 | # store the length after truncation but before padding 52 | self.enc_len = len(article_words) 53 | # list of word ids; OOVs are represented by the id for UNK token 54 | self.enc_input = [vocab.word2id(w) for w in article_words] 55 | 56 | # Process the abstract 57 | abstract = ' '.join(abstract_sentences) # string 58 | abstract_words = abstract.split() # list of strings 59 | # list of word ids; OOVs are represented by the id for UNK token 60 | abs_ids = [vocab.word2id(w) for w in abstract_words] 61 | 62 | # Get the decoder input sequence and target sequence 63 | self.dec_input, self.target = self.get_dec_inp_targ_seqs( 64 | abs_ids, hps.max_dec_steps, start_decoding, stop_decoding) 65 | self.dec_len = len(self.dec_input) 66 | 67 | # If using pointer-generator mode, we need to store some extra info 68 | if hps.pointer_gen: 69 | # Store a version of the enc_input where in-article OOVs are 70 | # represented by their temporary OOV id; also store the in-article 71 | # OOVs words themselves 72 | self.enc_input_extend_vocab, self.article_oovs = data.article2ids( 73 | article_words, vocab) 74 | 75 | # Get a verison of the reference summary where in-article OOVs are 76 | # represented by their temporary article OOV id 77 | abs_ids_extend_vocab = data.abstract2ids( 78 | abstract_words, vocab, self.article_oovs) 79 | 80 | # Overwrite decoder target sequence so it uses the temp article OOV 81 | # ids 82 | _, self.target = self.get_dec_inp_targ_seqs( 83 | abs_ids_extend_vocab, hps.max_dec_steps, start_decoding, stop_decoding) 84 | 85 | # Store the original strings 86 | self.original_article = article 87 | self.original_abstract = abstract 88 | self.original_abstract_sents = abstract_sentences 89 | 90 | def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id): 91 | """Given the reference summary as a sequence of tokens, return the input sequence for the decoder, and the target sequence which we will use to calculate loss. The sequence will be truncated if it is longer than max_len. The input sequence must start with the start_id and the target sequence must end with the stop_id (but not if it's been truncated). 92 | 93 | Args: 94 | sequence: List of ids (integers) 95 | max_len: integer 96 | start_id: integer 97 | stop_id: integer 98 | 99 | Returns: 100 | inp: sequence length <=max_len starting with start_id 101 | target: sequence same length as input, ending with stop_id only if there was no truncation 102 | """ 103 | inp = [start_id] + sequence[:] 104 | target = sequence[:] 105 | if len(inp) > max_len: # truncate 106 | inp = inp[:max_len] 107 | target = target[:max_len] # no end_token 108 | else: # no truncation 109 | target.append(stop_id) # end token 110 | assert len(inp) == len(target) 111 | return inp, target 112 | 113 | def pad_decoder_inp_targ(self, max_len, pad_id): 114 | """Pad decoder input and target sequences with pad_id up to max_len.""" 115 | while len(self.dec_input) < max_len: 116 | self.dec_input.append(pad_id) 117 | while len(self.target) < max_len: 118 | self.target.append(pad_id) 119 | 120 | def pad_encoder_input(self, max_len, pad_id): 121 | """Pad the encoder input sequence with pad_id up to max_len.""" 122 | while len(self.enc_input) < max_len: 123 | self.enc_input.append(pad_id) 124 | if self.hps.pointer_gen: 125 | while len(self.enc_input_extend_vocab) < max_len: 126 | self.enc_input_extend_vocab.append(pad_id) 127 | 128 | 129 | class Batch(object): 130 | """Class representing a minibatch of train/val/test examples for text summarization.""" 131 | 132 | def __init__(self, example_list, hps, vocab): 133 | """Turns the example_list into a Batch object. 134 | 135 | Args: 136 | example_list: List of Example objects 137 | hps: hyperparameters 138 | vocab: Vocabulary object 139 | """ 140 | self.pad_id = vocab.word2id( 141 | data.PAD_TOKEN) # id of the PAD token used to pad sequences 142 | # initialize the input to the encoder 143 | self.init_encoder_seq(example_list, hps) 144 | # initialize the input and targets for the decoder 145 | self.init_decoder_seq(example_list, hps) 146 | self.store_orig_strings(example_list) # store the original strings 147 | 148 | def init_encoder_seq(self, example_list, hps): 149 | """Initializes the following: 150 | self.enc_batch: 151 | numpy array of shape (batch_size, <=max_enc_steps) containing integer ids (all OOVs represented by UNK id), padded to length of longest sequence in the batch 152 | self.enc_lens: 153 | numpy array of shape (batch_size) containing integers. The (truncated) length of each encoder input sequence (pre-padding). 154 | self.enc_padding_mask: 155 | numpy array of shape (batch_size, <=max_enc_steps), containing 1s and 0s. 1s correspond to real tokens in enc_batch and target_batch; 0s correspond to padding. 156 | 157 | If hps.pointer_gen, additionally initializes the following: 158 | self.max_art_oovs: 159 | maximum number of in-article OOVs in the batch 160 | self.art_oovs: 161 | list of list of in-article OOVs (strings), for each example in the batch 162 | self.enc_batch_extend_vocab: 163 | Same as self.enc_batch, but in-article OOVs are represented by their temporary article OOV number. 164 | """ 165 | # Determine the maximum length of the encoder input sequence in this 166 | # batch 167 | max_enc_seq_len = max([ex.enc_len for ex in example_list]) 168 | 169 | # Pad the encoder input sequences up to the length of the longest 170 | # sequence 171 | for ex in example_list: 172 | ex.pad_encoder_input(max_enc_seq_len, self.pad_id) 173 | 174 | # Initialize the numpy arrays 175 | # Note: our enc_batch can have different length (second dimension) for 176 | # each batch because we use dynamic_rnn for the encoder. 177 | self.enc_batch = np.zeros( 178 | (hps.batch_size, max_enc_seq_len), dtype=np.int32) 179 | self.enc_lens = np.zeros((hps.batch_size), dtype=np.int32) 180 | self.enc_padding_mask = np.zeros( 181 | (hps.batch_size, max_enc_seq_len), dtype=np.float32) 182 | 183 | # Fill in the numpy arrays 184 | for i, ex in enumerate(example_list): 185 | self.enc_batch[i, :] = ex.enc_input[:] 186 | self.enc_lens[i] = ex.enc_len 187 | for j in range(ex.enc_len): 188 | self.enc_padding_mask[i][j] = 1 189 | 190 | # For pointer-generator mode, need to store some extra info 191 | if hps.pointer_gen: 192 | # Determine the max number of in-article OOVs in this batch 193 | self.max_art_oovs = max([len(ex.article_oovs) 194 | for ex in example_list]) 195 | # Store the in-article OOVs themselves 196 | self.art_oovs = [ex.article_oovs for ex in example_list] 197 | # Store the version of the enc_batch that uses the article OOV ids 198 | self.enc_batch_extend_vocab = np.zeros( 199 | (hps.batch_size, max_enc_seq_len), dtype=np.int32) 200 | for i, ex in enumerate(example_list): 201 | self.enc_batch_extend_vocab[ 202 | i, :] = ex.enc_input_extend_vocab[:] 203 | 204 | def init_decoder_seq(self, example_list, hps): 205 | """Initializes the following: 206 | self.dec_batch: 207 | numpy array of shape (batch_size, max_dec_steps), containing integer ids as input for the decoder, padded to max_dec_steps length. 208 | self.target_batch: 209 | numpy array of shape (batch_size, max_dec_steps), containing integer ids for the target sequence, padded to max_dec_steps length. 210 | self.dec_padding_mask: 211 | numpy array of shape (batch_size, max_dec_steps), containing 1s and 0s. 1s correspond to real tokens in dec_batch and target_batch; 0s correspond to padding. 212 | """ 213 | # Pad the inputs and targets 214 | for ex in example_list: 215 | ex.pad_decoder_inp_targ(hps.max_dec_steps, self.pad_id) 216 | 217 | # Initialize the numpy arrays. 218 | # Note: our decoder inputs and targets must be the same length for each 219 | # batch (second dimension = max_dec_steps) because we do not use a 220 | # dynamic_rnn for decoding. However I believe this is possible, or will 221 | # soon be possible, with Tensorflow 1.0, in which case it may be best 222 | # to upgrade to that. 223 | self.dec_batch = np.zeros( 224 | (hps.batch_size, hps.max_dec_steps), dtype=np.int32) 225 | self.target_batch = np.zeros( 226 | (hps.batch_size, hps.max_dec_steps), dtype=np.int32) 227 | self.dec_padding_mask = np.zeros( 228 | (hps.batch_size, hps.max_dec_steps), dtype=np.float32) 229 | 230 | # Fill in the numpy arrays 231 | for i, ex in enumerate(example_list): 232 | self.dec_batch[i, :] = ex.dec_input[:] 233 | self.target_batch[i, :] = ex.target[:] 234 | for j in range(ex.dec_len): 235 | self.dec_padding_mask[i][j] = 1 236 | 237 | def store_orig_strings(self, example_list): 238 | """Store the original article and abstract strings in the Batch object""" 239 | self.original_articles = [ 240 | ex.original_article for ex in example_list] # list of lists 241 | self.original_abstracts = [ 242 | ex.original_abstract for ex in example_list] # list of lists 243 | self.original_abstracts_sents = [ 244 | ex.original_abstract_sents for ex in example_list] # list of list of lists 245 | 246 | 247 | class Batcher(object): 248 | """A class to generate minibatches of data. Buckets examples together based on length of the encoder sequence.""" 249 | 250 | BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold 251 | 252 | def __init__(self, data_path, vocab, hps, single_pass): 253 | """Initialize the batcher. Start threads that process the data into batches. 254 | 255 | Args: 256 | data_path: tf.Example filepattern. 257 | vocab: Vocabulary object 258 | hps: hyperparameters 259 | single_pass: If True, run through the dataset exactly once (useful for when you want to run evaluation on the dev or test set). Otherwise generate random batches indefinitely (useful for training). 260 | """ 261 | self._data_path = data_path 262 | self._vocab = vocab 263 | self._hps = hps 264 | self._single_pass = single_pass 265 | 266 | # Initialize a queue of Batches waiting to be used, and a queue of 267 | # Examples waiting to be batched 268 | self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX) 269 | self._example_queue = Queue.Queue( 270 | self.BATCH_QUEUE_MAX * self._hps.batch_size) 271 | 272 | # Different settings depending on whether we're in single_pass mode or 273 | # not 274 | if single_pass: 275 | # just one thread, so we read through the dataset just once 276 | self._num_example_q_threads = 1 277 | self._num_batch_q_threads = 1 # just one thread to batch examples 278 | # only load one batch's worth of examples before bucketing; this 279 | # essentially means no bucketing 280 | self._bucketing_cache_size = 1 281 | # this will tell us when we're finished reading the dataset 282 | self._finished_reading = False 283 | else: 284 | self._num_example_q_threads = 1 # num threads to fill example queue 285 | self._num_batch_q_threads = 1 # num threads to fill batch queue 286 | # how many batches-worth of examples to load into cache before 287 | # bucketing 288 | self._bucketing_cache_size = 100 289 | 290 | # Start the threads that load the queues 291 | self._example_q_threads = [] 292 | for _ in range(self._num_example_q_threads): 293 | self._example_q_threads.append( 294 | Thread(target=self.fill_example_queue)) 295 | self._example_q_threads[-1].daemon = True 296 | self._example_q_threads[-1].start() 297 | self._batch_q_threads = [] 298 | for _ in range(self._num_batch_q_threads): 299 | self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) 300 | self._batch_q_threads[-1].daemon = True 301 | self._batch_q_threads[-1].start() 302 | 303 | # Start a thread that watches the other threads and restarts them if 304 | # they're dead 305 | if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever 306 | self._watch_thread = Thread(target=self.watch_threads) 307 | self._watch_thread.daemon = True 308 | self._watch_thread.start() 309 | 310 | def next_batch(self): 311 | """Return a Batch from the batch queue. 312 | 313 | If mode='decode' then each batch contains a single example repeated beam_size-many times; this is necessary for beam search. 314 | 315 | Returns: 316 | batch: a Batch object, or None if we're in single_pass mode and we've exhausted the dataset. 317 | """ 318 | # If the batch queue is empty, print a warning 319 | if self._batch_queue.qsize() == 0: 320 | tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', 321 | self._batch_queue.qsize(), self._example_queue.qsize()) 322 | if self._single_pass and self._finished_reading: 323 | tf.logging.info( 324 | "Finished reading dataset in single_pass mode.") 325 | return None 326 | 327 | batch = self._batch_queue.get() # get the next Batch 328 | return batch 329 | 330 | def fill_example_queue(self): 331 | """Reads data from file and processes into Examples which are then placed into the example queue.""" 332 | 333 | input_gen = self.text_generator( 334 | data.example_generator(self._data_path, self._single_pass)) 335 | 336 | while True: 337 | try: 338 | # read the next example from file. article and abstract are 339 | # both strings. 340 | try: 341 | (article, abstract) = input_gen.next() 342 | except AttributeError: 343 | (article, abstract) = next(input_gen) 344 | except StopIteration: # if there are no more examples: 345 | tf.logging.info( 346 | "The example generator for this example queue filling thread has exhausted data.") 347 | if self._single_pass: 348 | tf.logging.info( 349 | "single_pass mode is on, so we've finished reading dataset. This thread is stopping.") 350 | self._finished_reading = True 351 | break 352 | else: 353 | raise Exception( 354 | "single_pass mode is off but the example generator is out of data; error.") 355 | 356 | # Use the and tags in abstract to get a list of sentences. 357 | abstract_sentences = [sent.strip() 358 | for sent in data.abstract2sents(abstract)] 359 | # Process into an Example. 360 | example = Example(article, abstract_sentences, 361 | self._vocab, self._hps) 362 | # place the Example in the example queue. 363 | self._example_queue.put(example) 364 | 365 | def fill_batch_queue(self): 366 | """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue. 367 | 368 | In decode mode, makes batches that each contain a single example repeated. 369 | """ 370 | while True: 371 | if self._hps.mode != 'decode': 372 | # Get bucketing_cache_size-many batches of Examples into a 373 | # list, then sort 374 | inputs = [] 375 | for _ in range(self._hps.batch_size * self._bucketing_cache_size): 376 | inputs.append(self._example_queue.get()) 377 | # sort by length of encoder sequence 378 | inputs = sorted(inputs, key=lambda inp: inp.enc_len) 379 | 380 | # Group the sorted Examples into batches, optionally shuffle 381 | # the batches, and place in the batch queue. 382 | batches = [] 383 | for i in range(0, len(inputs), self._hps.batch_size): 384 | batches.append(inputs[i:i + self._hps.batch_size]) 385 | if not self._single_pass: 386 | shuffle(batches) 387 | for b in batches: # each b is a list of Example objects 388 | self._batch_queue.put(Batch(b, self._hps, self._vocab)) 389 | 390 | else: # beam search decode mode 391 | ex = self._example_queue.get() 392 | b = [ex for _ in range(self._hps.batch_size)] 393 | self._batch_queue.put(Batch(b, self._hps, self._vocab)) 394 | 395 | def watch_threads(self): 396 | """Watch example queue and batch queue threads and restart if dead.""" 397 | while True: 398 | time.sleep(60) 399 | for idx, t in enumerate(self._example_q_threads): 400 | if not t.is_alive(): # if the thread is dead 401 | tf.logging.error( 402 | 'Found example queue thread dead. Restarting.') 403 | new_t = Thread(target=self.fill_example_queue) 404 | self._example_q_threads[idx] = new_t 405 | new_t.daemon = True 406 | new_t.start() 407 | for idx, t in enumerate(self._batch_q_threads): 408 | if not t.is_alive(): # if the thread is dead 409 | tf.logging.error( 410 | 'Found batch queue thread dead. Restarting.') 411 | new_t = Thread(target=self.fill_batch_queue) 412 | self._batch_q_threads[idx] = new_t 413 | new_t.daemon = True 414 | new_t.start() 415 | 416 | def text_generator(self, example_generator): 417 | """Generates article and abstract text from tf.Example. 418 | 419 | Args: 420 | example_generator: a generator of tf.Examples from file. See data.example_generator""" 421 | while True: 422 | try: 423 | e = example_generator.next() # e is a tf.Example 424 | except AttributeError: 425 | e = next(example_generator) # e is a tf.Example 426 | try: 427 | # the article text was saved under the key 'article' in the 428 | # data files 429 | article_text = e.features.feature[ 430 | 'article'].bytes_list.value[0] 431 | # the abstract text was saved under the key 'abstract' in the 432 | # data files 433 | abstract_text = e.features.feature[ 434 | 'abstract'].bytes_list.value[0] 435 | except ValueError: 436 | tf.logging.error( 437 | 'Failed to get article or abstract from example') 438 | continue 439 | if len(article_text) == 0: # See https://github.com/abisee/pointer-generator/issues/1 440 | tf.logging.warning( 441 | 'Found an example with empty article text. Skipping it.') 442 | else: 443 | yield (article_text, abstract_text) 444 | -------------------------------------------------------------------------------- /pointer_model/beam_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to run beam search decoding 18 | https://github.com/abisee/pointer-generator""" 19 | 20 | import tensorflow as tf 21 | import numpy as np 22 | import data 23 | 24 | 25 | class Hypothesis(object): 26 | """Class to represent a hypothesis during beam search. Holds all the information needed for the hypothesis.""" 27 | 28 | def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage): 29 | """Hypothesis constructor. 30 | 31 | Args: 32 | tokens: List of integers. The ids of the tokens that form the summary so far. 33 | log_probs: List, same length as tokens, of floats, giving the log probabilities of the tokens so far. 34 | state: Current state of the decoder, a LSTMStateTuple. 35 | attn_dists: List, same length as tokens, of numpy arrays with shape (attn_length). These are the attention distributions so far. 36 | p_gens: List, same length as tokens, of floats, or None if not using pointer-generator model. The values of the generation probability so far. 37 | coverage: Numpy array of shape (attn_length), or None if not using coverage. The current coverage vector. 38 | """ 39 | self.tokens = tokens 40 | self.log_probs = log_probs 41 | self.state = state 42 | self.attn_dists = attn_dists 43 | self.p_gens = p_gens 44 | self.coverage = coverage 45 | 46 | def extend(self, token, log_prob, state, attn_dist, p_gen, coverage): 47 | """Return a NEW hypothesis, extended with the information from the latest step of beam search. 48 | 49 | Args: 50 | token: Integer. Latest token produced by beam search. 51 | log_prob: Float. Log prob of the latest token. 52 | state: Current decoder state, a LSTMStateTuple. 53 | attn_dist: Attention distribution from latest step. Numpy array shape (attn_length). 54 | p_gen: Generation probability on latest step. Float. 55 | coverage: Latest coverage vector. Numpy array shape (attn_length), or None if not using coverage. 56 | Returns: 57 | New Hypothesis for next step. 58 | """ 59 | return Hypothesis(tokens=self.tokens + [token], 60 | log_probs=self.log_probs + [log_prob], 61 | state=state, 62 | attn_dists=self.attn_dists + [attn_dist], 63 | p_gens=self.p_gens + [p_gen], 64 | coverage=coverage) 65 | 66 | @property 67 | def latest_token(self): 68 | return self.tokens[-1] 69 | 70 | @property 71 | def log_prob(self): 72 | # the log probability of the hypothesis so far is the sum of the log 73 | # probabilities of the tokens so far 74 | return sum(self.log_probs) 75 | 76 | @property 77 | def avg_log_prob(self): 78 | # normalize log probability by number of tokens (otherwise longer 79 | # sequences always have lower probability) 80 | return self.log_prob / len(self.tokens) 81 | 82 | 83 | def run_beam_search(sess, model, vocab, batch, FLAGS): 84 | """Performs beam search decoding on the given example. 85 | 86 | Args: 87 | sess: a tf.Session 88 | model: a seq2seq model 89 | vocab: Vocabulary object 90 | batch: Batch object that is the same example repeated across the batch 91 | 92 | Returns: 93 | best_hyp: Hypothesis object; the best hypothesis found by beam search. 94 | """ 95 | # Run the encoder to get the encoder hidden states and decoder initial 96 | # state 97 | enc_states, dec_in_state = model.run_encoder(sess, batch) 98 | # dec_in_state is a LSTMStateTuple 99 | # enc_states has shape [batch_size, <=max_enc_steps, 2*hidden_dim]. 100 | 101 | # Initialize beam_size-many hyptheses 102 | hyps = [Hypothesis(tokens=[vocab.word2id(data.START_DECODING)], 103 | log_probs=[0.0], 104 | state=dec_in_state, 105 | attn_dists=[], 106 | p_gens=[], 107 | # zero vector of length attention_length 108 | coverage=np.zeros([batch.enc_batch.shape[1]]) 109 | ) for _ in xrange(FLAGS.beam_size)] 110 | # this will contain finished hypotheses (those that have emitted the 111 | # [STOP] token) 112 | results = [] 113 | 114 | steps = 0 115 | while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size: 116 | # latest token produced by each hypothesis 117 | latest_tokens = [h.latest_token for h in hyps] 118 | # change any in-article temporary OOV ids to [UNK] id, so that we can 119 | # lookup word embeddings 120 | latest_tokens = [t if t in xrange(vocab.size()) else vocab.word2id( 121 | data.UNKNOWN_TOKEN) for t in latest_tokens] 122 | # list of current decoder states of the hypotheses 123 | states = [h.state for h in hyps] 124 | # list of coverage vectors (or None) 125 | prev_coverage = [h.coverage for h in hyps] 126 | 127 | # Run one step of the decoder to get the new info 128 | (topk_ids, topk_log_probs, new_states, attn_dists, p_gens, new_coverage) = model.decode_onestep(sess=sess, 129 | batch=batch, 130 | latest_tokens=latest_tokens, 131 | enc_states=enc_states, 132 | dec_init_states=states, 133 | prev_coverage=prev_coverage) 134 | 135 | # Extend each hypothesis and collect them all in all_hyps 136 | all_hyps = [] 137 | # On the first step, we only had one original hypothesis (the initial 138 | # hypothesis). On subsequent steps, all original hypotheses are 139 | # distinct. 140 | num_orig_hyps = 1 if steps == 0 else len(hyps) 141 | for i in xrange(num_orig_hyps): 142 | h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[ 143 | i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info 144 | # for each of the top 2*beam_size hyps: 145 | for j in xrange(FLAGS.beam_size * 2): 146 | # Extend the ith hypothesis with the jth option 147 | new_hyp = h.extend(token=topk_ids[i, j], 148 | log_prob=topk_log_probs[i, j], 149 | state=new_state, 150 | attn_dist=attn_dist, 151 | p_gen=p_gen, 152 | coverage=new_coverage_i) 153 | all_hyps.append(new_hyp) 154 | 155 | # Filter and collect any hypotheses that have produced the end token. 156 | hyps = [] # will contain hypotheses for the next step 157 | for h in sort_hyps(all_hyps): # in order of most likely h 158 | # if stop token is reached... 159 | if h.latest_token == vocab.word2id(data.STOP_DECODING): 160 | # If this hypothesis is sufficiently long, put in results. 161 | # Otherwise discard. 162 | if steps >= FLAGS.min_dec_steps: 163 | results.append(h) 164 | else: # hasn't reached stop token, so continue to extend this hypothesis 165 | hyps.append(h) 166 | if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: 167 | # Once we've collected beam_size-many hypotheses for the next 168 | # step, or beam_size-many complete hypotheses, stop. 169 | break 170 | 171 | steps += 1 172 | 173 | # At this point, either we've got beam_size results, or we've reached 174 | # maximum decoder steps 175 | 176 | if len(results) == 0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 177 | results = hyps 178 | 179 | # Sort hypotheses by average log probability 180 | hyps_sorted = sort_hyps(results) 181 | 182 | # Return the hypothesis with highest average log prob 183 | return hyps_sorted[0] 184 | 185 | 186 | def sort_hyps(hyps): 187 | """Return a list of Hypothesis objects, sorted by descending average log probability""" 188 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True) 189 | -------------------------------------------------------------------------------- /pointer_model/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it 18 | https://github.com/abisee/pointer-generator""" 19 | from __future__ import print_function 20 | 21 | import glob 22 | import random 23 | import struct 24 | import csv 25 | from tensorflow.core.example import example_pb2 26 | 27 | # and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. 28 | SENTENCE_START = '' 29 | SENTENCE_END = '' 30 | 31 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 32 | UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 33 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 34 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 35 | 36 | # Note: none of , , [PAD], [UNK], [START], [STOP] should appear in the vocab file. 37 | 38 | 39 | class Vocab(object): 40 | """Vocabulary class for mapping between words and ids (integers)""" 41 | 42 | def __init__(self, vocab_file, max_size): 43 | """Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file. 44 | 45 | Args: 46 | vocab_file: path to the vocab file, which is assumed to contain " " on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though. 47 | max_size: integer. The maximum size of the resulting Vocabulary.""" 48 | self._word_to_id = {} 49 | self._id_to_word = {} 50 | self._count = 0 # keeps track of total number of words in the Vocab 51 | 52 | # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3. 53 | for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 54 | self._word_to_id[w] = self._count 55 | self._id_to_word[self._count] = w 56 | self._count += 1 57 | 58 | # Read the vocab file and add words up to max_size 59 | with open(vocab_file, 'r') as vocab_f: 60 | for line in vocab_f: 61 | pieces = line.split() 62 | if len(pieces) != 2: 63 | print( 'Warning: incorrectly formatted line in vocabulary file: %s\n' % line) 64 | continue 65 | w = pieces[0] 66 | if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]: 67 | raise Exception(', , [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w) 68 | if w in self._word_to_id: 69 | raise Exception('Duplicated word in vocabulary file: %s' % w) 70 | self._word_to_id[w] = self._count 71 | self._id_to_word[self._count] = w 72 | self._count += 1 73 | if max_size != 0 and self._count >= max_size: 74 | print( "max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)) 75 | break 76 | 77 | print( "Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])) 78 | 79 | def word2id(self, word): 80 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" 81 | if word not in self._word_to_id: 82 | return self._word_to_id[UNKNOWN_TOKEN] 83 | return self._word_to_id[word] 84 | 85 | def id2word(self, word_id): 86 | """Returns the word (string) corresponding to an id (integer).""" 87 | if word_id not in self._id_to_word: 88 | raise ValueError('Id not found in vocab: %d' % word_id) 89 | return self._id_to_word[word_id] 90 | 91 | def size(self): 92 | """Returns the total size of the vocabulary""" 93 | return self._count 94 | 95 | def write_metadata(self, fpath): 96 | """Writes metadata file for Tensorboard word embedding visualizer as described here: 97 | https://www.tensorflow.org/get_started/embedding_viz 98 | 99 | Args: 100 | fpath: place to write the metadata file 101 | """ 102 | print( "Writing word embedding metadata file to %s..." % (fpath)) 103 | with open(fpath, "w") as f: 104 | fieldnames = ['word'] 105 | writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames) 106 | for i in xrange(self.size()): 107 | writer.writerow({"word": self._id_to_word[i]}) 108 | 109 | 110 | def example_generator(data_path, single_pass): 111 | """Generates tf.Examples from data files. 112 | 113 | Binary data format: . represents the byte size 114 | of . is serialized tf.Example proto. The tf.Example contains 115 | the tokenized article text and summary. 116 | 117 | Args: 118 | data_path: 119 | Path to tf.Example data files. Can include wildcards, e.g. if you have several training data chunk files train_001.bin, train_002.bin, etc, then pass data_path=train_* to access them all. 120 | single_pass: 121 | Boolean. If True, go through the dataset exactly once, generating examples in the order they appear, then return. Otherwise, generate random examples indefinitely. 122 | 123 | Yields: 124 | Deserialized tf.Example. 125 | """ 126 | while True: 127 | filelist = glob.glob(data_path) # get the list of datafiles 128 | assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty 129 | if single_pass: 130 | filelist = sorted(filelist) 131 | else: 132 | random.shuffle(filelist) 133 | for f in filelist: 134 | reader = open(f, 'rb') 135 | while True: 136 | len_bytes = reader.read(8) 137 | if not len_bytes: break # finished reading this file 138 | str_len = struct.unpack('q', len_bytes)[0] 139 | example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] 140 | yield example_pb2.Example.FromString(example_str) 141 | if single_pass: 142 | print( "example_generator completed reading all datafiles. No more data.") 143 | break 144 | 145 | 146 | def article2ids(article_words, vocab): 147 | """Map the article words to their ids. Also return a list of OOVs in the article. 148 | 149 | Args: 150 | article_words: list of words (strings) 151 | vocab: Vocabulary object 152 | 153 | Returns: 154 | ids: 155 | A list of word ids (integers); OOVs are represented by their temporary article OOV number. If the vocabulary size is 50k and the article has 3 OOVs, then these temporary OOV numbers will be 50000, 50001, 50002. 156 | oovs: 157 | A list of the OOV words in the article (strings), in the order corresponding to their temporary article OOV numbers.""" 158 | ids = [] 159 | oovs = [] 160 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 161 | for w in article_words: 162 | i = vocab.word2id(w) 163 | if i == unk_id: # If w is OOV 164 | if w not in oovs: # Add to list of OOVs 165 | oovs.append(w) 166 | oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV... 167 | ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second... 168 | else: 169 | ids.append(i) 170 | return ids, oovs 171 | 172 | 173 | def abstract2ids(abstract_words, vocab, article_oovs): 174 | """Map the abstract words to their ids. In-article OOVs are mapped to their temporary OOV numbers. 175 | 176 | Args: 177 | abstract_words: list of words (strings) 178 | vocab: Vocabulary object 179 | article_oovs: list of in-article OOV words (strings), in the order corresponding to their temporary article OOV numbers 180 | 181 | Returns: 182 | ids: List of ids (integers). In-article OOV words are mapped to their temporary OOV numbers. Out-of-article OOV words are mapped to the UNK token id.""" 183 | ids = [] 184 | unk_id = vocab.word2id(UNKNOWN_TOKEN) 185 | for w in abstract_words: 186 | i = vocab.word2id(w) 187 | if i == unk_id: # If w is an OOV word 188 | if w in article_oovs: # If w is an in-article OOV 189 | vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number 190 | ids.append(vocab_idx) 191 | else: # If w is an out-of-article OOV 192 | ids.append(unk_id) # Map to the UNK token id 193 | else: 194 | ids.append(i) 195 | return ids 196 | 197 | 198 | def outputids2words(id_list, vocab, article_oovs): 199 | """Maps output ids to words, including mapping in-article OOVs from their temporary ids to the original OOV string (applicable in pointer-generator mode). 200 | 201 | Args: 202 | id_list: list of ids (integers) 203 | vocab: Vocabulary object 204 | article_oovs: list of OOV words (strings) in the order corresponding to their temporary article OOV ids (that have been assigned in pointer-generator mode), or None (in baseline mode) 205 | 206 | Returns: 207 | words: list of words (strings) 208 | """ 209 | words = [] 210 | for i in id_list: 211 | try: 212 | w = vocab.id2word(i) # might be [UNK] 213 | except ValueError as e: # w is OOV 214 | assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode" 215 | article_oov_idx = i - vocab.size() 216 | try: 217 | w = article_oovs[article_oov_idx] 218 | except ValueError as e: # i doesn't correspond to an article oov 219 | raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs))) 220 | words.append(w) 221 | return words 222 | 223 | 224 | def abstract2sents(abstract): 225 | """Splits abstract text from datafile into list of sentences. 226 | 227 | Args: 228 | abstract: string containing and tags for starts and ends of sentences 229 | 230 | Returns: 231 | sents: List of sentence strings (no tags)""" 232 | cur = 0 233 | sents = [] 234 | while True: 235 | try: 236 | start_p = abstract.index(SENTENCE_START, cur) 237 | end_p = abstract.index(SENTENCE_END, start_p + 1) 238 | cur = end_p + len(SENTENCE_END) 239 | sents.append(abstract[start_p+len(SENTENCE_START):end_p]) 240 | except ValueError as e: # no more sentences 241 | return sents 242 | 243 | 244 | def show_art_oovs(article, vocab): 245 | """Returns the article string, highlighting the OOVs by placing __underscores__ around them""" 246 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 247 | words = article.split(' ') 248 | words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words] 249 | out_str = ' '.join(words) 250 | return out_str 251 | 252 | 253 | def show_abs_oovs(abstract, vocab, article_oovs): 254 | """Returns the abstract string, highlighting the article OOVs with __underscores__. 255 | 256 | If a list of article_oovs is provided, non-article OOVs are differentiated like !!__this__!!. 257 | 258 | Args: 259 | abstract: string 260 | vocab: Vocabulary object 261 | article_oovs: list of words (strings), or None (in baseline mode) 262 | """ 263 | unk_token = vocab.word2id(UNKNOWN_TOKEN) 264 | words = abstract.split(' ') 265 | new_words = [] 266 | for w in words: 267 | if vocab.word2id(w) == unk_token: # w is oov 268 | if article_oovs is None: # baseline mode 269 | new_words.append("__%s__" % w) 270 | else: # pointer-generator mode 271 | if w in article_oovs: 272 | new_words.append("__%s__" % w) 273 | else: 274 | new_words.append("!!__%s__!!" % w) 275 | else: # w is in-vocab word 276 | new_words.append(w) 277 | out_str = ' '.join(new_words) 278 | return out_str 279 | -------------------------------------------------------------------------------- /pointer_model/decode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # Modifications Copyright 2017 Abigail See 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """https://github.com/abisee/pointer-generator""" 18 | import os 19 | import tensorflow as tf 20 | import data 21 | import pyrouge 22 | import logging 23 | import warnings 24 | from pointer_model import beam_search 25 | 26 | from utils import misc_utils 27 | 28 | SECS_UNTIL_NEW_CKPT = 60 # max number of seconds before loading new checkpoint 29 | 30 | 31 | class BeamSearchDecoder(object): 32 | """Beam search decoder.""" 33 | 34 | def __init__(self, model, batcher, vocab, ckpt_dir, decode_dir, FLAGS): 35 | """Initialize decoder. 36 | 37 | Args: 38 | model: a Seq2SeqAttentionModel object. 39 | batcher: a Batcher object. 40 | vocab: Vocabulary object 41 | ckpt_dir: directory to checkpoints 42 | decode_dir: directory to decoding outputs 43 | """ 44 | self._model = model 45 | self._batcher = batcher 46 | self._vocab = vocab 47 | self._ckpt_dir = ckpt_dir 48 | self._decode_dir = decode_dir 49 | self._FLAGS = FLAGS 50 | 51 | 52 | def reset_batcher(self, batcher): 53 | self._batcher = batcher 54 | 55 | def build_graph(self, sess): 56 | self._sess = sess 57 | self._saver = tf.train.Saver() 58 | 59 | def decode(self, ckpt_file=None): 60 | FLAGS = self._FLAGS 61 | 62 | # load latest checkpoint 63 | misc_utils.load_ckpt(self._saver, self._sess, self._ckpt_dir, ckpt_file) 64 | 65 | counter = 0 66 | f = open(self._decode_dir, "w") 67 | while True: 68 | batch = self._batcher.next_batch() # 1 example repeated across batch 69 | if batch is None: # finished decoding dataset in single_pass mode 70 | tf.logging.info( 71 | "Decoder has finished reading dataset for single_pass.") 72 | break 73 | 74 | original_article = batch.original_articles[0] # string 75 | original_abstract = batch.original_abstracts[0] # string 76 | original_abstract_sents = batch.original_abstracts_sents[ 77 | 0] # list of strings 78 | 79 | article_withunks = data.show_art_oovs( 80 | original_article, self._vocab) # string 81 | abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[ 82 | 0] if FLAGS.pointer_gen else None)) # string 83 | 84 | # Run beam search to get best Hypothesis 85 | best_hyp = beam_search.run_beam_search( 86 | self._sess, self._model, self._vocab, batch, FLAGS) 87 | 88 | # Extract the output ids from the hypothesis and convert back to 89 | # words 90 | output_ids = [int(t) for t in best_hyp.tokens[1:]] 91 | decoded_words = data.outputids2words( 92 | output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) 93 | 94 | # Remove the [STOP] token from decoded_words, if necessary 95 | try: 96 | # index of the (first) [STOP] symbol 97 | fst_stop_idx = decoded_words.index(data.STOP_DECODING) 98 | decoded_words = decoded_words[:fst_stop_idx] 99 | except ValueError: 100 | decoded_words = decoded_words 101 | 102 | # write ref summary and decoded summary to file, to eval with 103 | # pyrouge later 104 | # self.write_for_rouge(original_abstract_sents, decoded_words, counter) 105 | processed = self.depreciated_processing(decoded_words) 106 | f.write(processed + "\n") 107 | 108 | counter += 1 109 | if counter % 100 == 0: 110 | print("%d sentences decoded" % counter) 111 | 112 | f.close() 113 | 114 | 115 | def depreciated_processing(self, decoded_words): 116 | """Depreciated function, will be removed later""" 117 | # First, divide decoded output into sentences 118 | decoded_sents = [] 119 | while len(decoded_words) > 0: 120 | try: 121 | # fst_period_idx = decoded_words.index(".") 122 | fst_period_idx = len(decoded_words) 123 | except ValueError: # there is text remaining that doesn't end in "." 124 | fst_period_idx = len(decoded_words) 125 | # sentence up to and including the period 126 | sent = decoded_words[:fst_period_idx + 1] 127 | decoded_words = decoded_words[ 128 | fst_period_idx + 1:] # everything else 129 | decoded_sents.append(' '.join(sent)) 130 | 131 | # pyrouge calls a perl script that puts the data into HTML files. 132 | # Therefore we need to make our output HTML safe. 133 | decoded_sents = [make_html_safe(w) for w in decoded_sents] 134 | if not len(decoded_sents) == 1: 135 | warnings.warn("Found multiple decoded_sents in " 136 | "`depreciated_processing`: ", decoded_sents) 137 | #raise ValueError( 138 | # "Found multiple decoded_sents in `depreciated_processing`: ", 139 | # decoded_sents) 140 | return decoded_sents[0] 141 | 142 | 143 | def make_html_safe(s): 144 | """Replace any angled brackets in string s to avoid interfering with HTML attention visualizer.""" 145 | s.replace("<", "<") 146 | s.replace(">", ">") 147 | return s 148 | 149 | 150 | def rouge_eval(ref_dir, dec_dir): 151 | """Evaluate the files in ref_dir and dec_dir with pyrouge, returning results_dict""" 152 | r = pyrouge.Rouge155() 153 | r.model_filename_pattern = '#ID#_reference.txt' 154 | r.system_filename_pattern = '(\d+)_decoded.txt' 155 | r.model_dir = ref_dir 156 | r.system_dir = dec_dir 157 | logging.getLogger('global').setLevel( 158 | logging.WARNING) # silence pyrouge logging 159 | rouge_results = r.convert_and_evaluate() 160 | return r.output_to_dict(rouge_results) 161 | 162 | 163 | def rouge_log(results_dict, dir_to_write): 164 | """Log ROUGE results to screen and write to file. 165 | 166 | Args: 167 | results_dict: the dictionary returned by pyrouge 168 | dir_to_write: the directory where we will write the results to""" 169 | log_str = "" 170 | for x in ["1", "2", "l"]: 171 | log_str += "\nROUGE-%s:\n" % x 172 | for y in ["f_score", "recall", "precision"]: 173 | key = "rouge_%s_%s" % (x, y) 174 | key_cb = key + "_cb" 175 | key_ce = key + "_ce" 176 | val = results_dict[key] 177 | val_cb = results_dict[key_cb] 178 | val_ce = results_dict[key_ce] 179 | log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % ( 180 | key, val, val_cb, val_ce) 181 | tf.logging.info(log_str) # log to screen 182 | results_file = os.path.join(dir_to_write, "ROUGE_results.txt") 183 | tf.logging.info("Writing final ROUGE results to %s...", results_file) 184 | with open(results_file, "w") as f: 185 | f.write(log_str) 186 | 187 | 188 | def correct_unk(sentences): 189 | new_sentences = [] 190 | for sent in sentences: 191 | new_sent = sent.replace("", "UNK").replace("unk", "UNK") 192 | new_sentences.append(new_sent) 193 | return new_sentences 194 | 195 | 196 | -------------------------------------------------------------------------------- /pointer_model/pg_decoder_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # import tensorflow.contrib.eager as tfe 3 | # tfe.enable_eager_execution() 4 | from tensorflow.python.ops import array_ops 5 | from tensorflow.python.ops import math_ops 6 | from tensorflow.python.ops import rnn_cell_impl 7 | from tensorflow.python.ops import variable_scope as vs 8 | 9 | from TFLibrary.SPG import pg_decoder 10 | from TFLibrary.utils import test_utils 11 | from TFLibrary.utils import attention_utils 12 | _zero_state_tensors = rnn_cell_impl._zero_state_tensors 13 | 14 | print("Using Tensorflow ", tf.__version__) 15 | 16 | 17 | # Set Up ######################################## 18 | class Scope(object): 19 | Decoder = test_utils.create_scope("decoder") 20 | Attention = test_utils.create_scope("attention") 21 | Pointer = test_utils.create_scope("pointer") 22 | Projection = test_utils.create_scope("projection") 23 | 24 | 25 | class HPS(object): 26 | batch_size = 32 27 | enc_seq_length = 20 28 | dec_seq_length = 10 29 | embedding_size = 256 30 | num_units = 256 31 | vocab_size = 100 32 | OOVs = 100 33 | 34 | 35 | class FakeVocab(object): 36 | def size(self): 37 | return HPS.vocab_size 38 | 39 | 40 | def test(): 41 | _test_seq2seq_attention_pointer() 42 | 43 | 44 | def _test_seq2seq_attention_pointer(): 45 | # Set Up 46 | # ================================================== 47 | cell = tf.nn.rnn_cell.MultiRNNCell([ 48 | tf.nn.rnn_cell.BasicLSTMCell(HPS.num_units) for _ in range(2)]) 49 | encoder_states = test_utils.random_tensor( 50 | [HPS.batch_size, HPS.enc_seq_length, HPS.num_units]) 51 | decoder_inputs = test_utils.random_tensor( 52 | [HPS.batch_size, HPS.dec_seq_length, HPS.embedding_size]) 53 | initial_state = cell.zero_state(HPS.batch_size, tf.float32) 54 | enc_padding_mask = test_utils.random_integers( 55 | low=0, high=2, dtype=tf.float32, 56 | shape=[HPS.batch_size, HPS.enc_seq_length]) 57 | start_tokens = test_utils.random_integers( 58 | low=0, high=HPS.vocab_size + HPS.OOVs, shape=[HPS.batch_size]) 59 | embeddings = test_utils.random_tensor( 60 | [HPS.vocab_size, HPS.embedding_size]) 61 | enc_batch_extended_vocab = test_utils.random_integers( 62 | low=0, high=(HPS.vocab_size + HPS.OOVs), 63 | shape=[HPS.batch_size, HPS.enc_seq_length]) 64 | 65 | 66 | # Test Seq2Seq with Attention 67 | # ======================================================= 68 | 69 | # Actual Outputs 70 | # ------------------------------------------------------- 71 | (final_dists, final_cell_state, attn_dists, p_gens, 72 | coverage, sampled_tokens, decoder_outputs_ta, 73 | debug_variables, final_loop_state) = ( 74 | pg_decoder.policy_gradient_pointer_attention_decoder( 75 | cell=cell, 76 | scope=Scope, 77 | memory=encoder_states, 78 | decoder_inputs=decoder_inputs, 79 | initial_state=initial_state, 80 | enc_padding_mask=enc_padding_mask, 81 | # token 82 | UNK_token=0, 83 | start_tokens=start_tokens, 84 | embeddings=embeddings, 85 | vocab_size=HPS.vocab_size, 86 | num_source_OOVs=HPS.OOVs, 87 | enc_batch_extended_vocab=enc_batch_extended_vocab, 88 | # flags 89 | reinforce=False)) 90 | 91 | # cell_input is cell inputs 92 | (sampled_tokens_history, 93 | outputs_history, alignments_history, p_gens_history, 94 | coverages_history, logits_history, vocab_dists_history, 95 | final_dists_history, coverage, cell_input) = final_loop_state 96 | 97 | # Expected Outputs 98 | # ------------------------------------------------------- 99 | (all_cell_output, 100 | all_next_cell_state, 101 | all_cell_inputs, 102 | all_attention, 103 | all_context, 104 | all_alignments) = _attention_rnn_cell( 105 | cell=cell, 106 | scope=Scope, 107 | num_units=HPS.num_units, 108 | batch_size=HPS.batch_size, 109 | inputs=decoder_inputs, 110 | memory=encoder_states, 111 | mask=enc_padding_mask, 112 | query_layer=debug_variables["query_kernel"], 113 | memory_layer=debug_variables["memory_kernel"], 114 | input_layer=debug_variables["input_kernel"], 115 | attention_layer=debug_variables["output_kernel"], 116 | initial_cell_state=initial_state) 117 | 118 | 119 | # Check differences 120 | # ------------------------------------------------------- 121 | sess = tf.Session() 122 | sess.__enter__() 123 | tf.global_variables_initializer().run() 124 | 125 | cell_outputs_diff = ( 126 | tf.stack(all_attention) - decoder_outputs_ta.stack()) 127 | last_cell_state_diff = ( 128 | tf.stack(all_next_cell_state[-1]) - tf.stack(final_cell_state)) 129 | alignments_diff = ( 130 | tf.transpose(all_alignments, perm=[1, 0, 2]) - tf.stack(attn_dists)) 131 | 132 | test_utils.tensor_is_zero(sess, cell_outputs_diff, "CellOutputsDiff") 133 | test_utils.tensor_is_zero(sess, last_cell_state_diff, "LastCellStateDiff") 134 | test_utils.tensor_is_zero(sess, alignments_diff, "AlignmentsDiff") 135 | 136 | 137 | # Test p_gens 138 | # ======================================================= 139 | 140 | def _stack_and_transpose(X): 141 | X = array_ops.stack(X) 142 | X = array_ops.transpose(X, perm=[1, 0, 2]) 143 | return X 144 | 145 | def _calculate_p_gens(contexts, cell_states, cell_inputs, pgen_layer): 146 | p_gens = array_ops.concat([ 147 | _stack_and_transpose(contexts), 148 | _stack_and_transpose([s[-1].c for s in cell_states]), 149 | _stack_and_transpose([s[-1].h for s in cell_states]), 150 | _stack_and_transpose(cell_inputs)], axis=-1) 151 | p_gens = pgen_layer(p_gens) 152 | return p_gens 153 | 154 | # Expected Outputs 155 | # ------------------------------------------------------- 156 | pgens = _calculate_p_gens( 157 | all_context, 158 | all_next_cell_state, 159 | all_cell_inputs, 160 | debug_variables["pgen_kernel"]) 161 | 162 | pgens_diff = pgens - p_gens 163 | test_utils.tensor_is_zero(sess, pgens_diff, "PgensDIff") 164 | 165 | 166 | 167 | # Test Vocab Distribution 168 | # ======================================================= 169 | 170 | # Expected Outputs 171 | # ------------------------------------------------------- 172 | all_attn_dists = all_alignments 173 | all_logits = map(lambda X: 174 | debug_variables["logits_kernel"](X), all_attention) 175 | all_vocab_dists = map(lambda X: tf.nn.softmax(X), all_logits) 176 | all_vocab_dists = list(all_vocab_dists) 177 | diff_vocab_dists = tf.stack(all_vocab_dists) - vocab_dists_history.stack() 178 | test_utils.tensor_is_zero(sess, diff_vocab_dists, "DiffVocabDists") 179 | 180 | # Test Final Distribution 181 | # ======================================================= 182 | 183 | # Expected Outputs 184 | # ------------------------------------------------------- 185 | fake_self = test_utils.DictClass({ 186 | "p_gens": tf.unstack(pgens, axis=1), 187 | "_vocab": FakeVocab(), 188 | "_max_art_oovs": HPS.OOVs, 189 | "_hps": test_utils.DictClass({"batch_size": HPS.batch_size}), 190 | "_enc_batch_extend_vocab": enc_batch_extended_vocab}) 191 | # print("self.p_gens ", fake_self.p_gens) 192 | # print("self._vocab.size() ", fake_self._vocab.size()) 193 | # print("self._max_art_oovs ", fake_self._max_art_oovs) 194 | # print("self._hps.batch_size ", fake_self._hps.batch_size) 195 | # print("self._enc_batch_extend_vocab ", fake_self._enc_batch_extend_vocab) 196 | 197 | expected_final_dist = _calc_final_dist( 198 | fake_self, all_vocab_dists, all_attn_dists) 199 | final_dist_diff = _stack_and_transpose(expected_final_dist) - final_dists 200 | test_utils.tensor_is_zero(sess, final_dist_diff, "DiffFinalDists") 201 | 202 | 203 | # Test pg_decoder._calc_final_dist 204 | # ------------------------------------------------------- 205 | actual_final_dist2 = [ 206 | pg_decoder._calc_final_dist( 207 | vocab_dist=vd, 208 | attn_dist=ad, 209 | p_gen=pg, 210 | batch_size=fake_self._hps.batch_size, 211 | vocab_size=fake_self._vocab.size(), 212 | num_source_OOVs=fake_self._max_art_oovs, 213 | enc_batch_extended_vocab=fake_self._enc_batch_extend_vocab) 214 | 215 | for vd, ad, pg in zip(all_vocab_dists, 216 | all_attn_dists, 217 | fake_self.p_gens)] 218 | 219 | final_dist_diff2 = (_stack_and_transpose(expected_final_dist) - 220 | _stack_and_transpose(actual_final_dist2)) 221 | test_utils.tensor_is_zero(sess, final_dist_diff2, "DiffFinalDists") 222 | 223 | 224 | 225 | def _attention_cell(cell, 226 | scope, 227 | input_layer, 228 | attention_layer, 229 | attention_mechanism, 230 | inputs, cell_state, attention): 231 | """More flexible attention wrapper""" 232 | with vs.variable_scope(scope, reuse=True): 233 | # compute cell inputs 234 | cell_inputs = input_layer( 235 | array_ops.concat([inputs, attention], -1)) 236 | 237 | with vs.variable_scope("cell", reuse=True): 238 | # run cell 239 | cell_output, next_cell_state = cell(cell_inputs, cell_state) 240 | 241 | 242 | with vs.variable_scope(scope, reuse=True): 243 | # Computes attention and alignments 244 | alignments, _ = attention_mechanism(cell_output, state=None) 245 | expanded_alignments = array_ops.expand_dims(alignments, 1) 246 | context = math_ops.matmul(expanded_alignments, 247 | attention_mechanism.values) 248 | context = array_ops.squeeze(context, [1]) 249 | 250 | # compute attention output 251 | if attention_layer is not None: 252 | attention = attention_layer( 253 | array_ops.concat([cell_output, context], 1)) 254 | else: 255 | attention = context 256 | 257 | 258 | return (cell_output, # direct outputs from cell 259 | next_cell_state, # next cell state 260 | cell_inputs, # attention concatenated cell inputs 261 | attention, # linear projection of [cell_outputs; context] 262 | context, # convext vector 263 | alignments) # attention distribution 264 | 265 | 266 | def _attention_rnn_cell(cell, 267 | scope, 268 | num_units, 269 | batch_size, 270 | inputs, 271 | memory, 272 | mask, 273 | query_layer, 274 | memory_layer, 275 | input_layer, 276 | attention_layer, 277 | initial_cell_state): 278 | attention_mechanism = attention_utils.BahdanauAttentionTester( 279 | num_units=num_units, 280 | memory=memory, 281 | mask=mask, 282 | query_layer=query_layer, 283 | memory_layer=memory_layer, 284 | scope=scope.Attention) 285 | 286 | next_cell_state = initial_cell_state 287 | context = _zero_state_tensors( 288 | num_units, batch_size, tf.float32) 289 | sequence_length = inputs.get_shape()[1].value 290 | 291 | all_cell_output = [] 292 | all_next_cell_state = [] 293 | all_cell_inputs = [] 294 | all_attention = [] 295 | all_context = [] 296 | all_alignments = [] 297 | 298 | for time in range(sequence_length): 299 | (cell_output, 300 | next_cell_state, 301 | cell_inputs, 302 | attention, 303 | context, 304 | alignments) = _attention_cell(cell=cell, 305 | scope=scope.Attention, 306 | input_layer=input_layer, 307 | attention_layer=attention_layer, 308 | attention_mechanism=attention_mechanism, 309 | inputs=inputs[:, time, :], 310 | cell_state=next_cell_state, 311 | attention=context) 312 | all_cell_output.append(cell_output) 313 | all_next_cell_state.append(next_cell_state) 314 | all_cell_inputs.append(cell_inputs) 315 | all_attention.append(attention) 316 | all_context.append(context) 317 | all_alignments.append(alignments) 318 | 319 | return (all_cell_output, 320 | all_next_cell_state, 321 | all_cell_inputs, 322 | all_attention, 323 | all_context, 324 | all_alignments) 325 | 326 | 327 | def _calc_final_dist(self, vocab_dists, attn_dists): 328 | with tf.variable_scope("Projection"): 329 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen) 330 | vocab_dists = [ 331 | p_gen * dist for (p_gen, dist) in zip(self.p_gens, vocab_dists)] 332 | attn_dists = [(1 - p_gen) * dist for (p_gen, dist) 333 | in zip(self.p_gens, attn_dists)] 334 | 335 | # Concatenate some zeros to each vocabulary dist, to hold the 336 | # probabilities for in-article OOV words 337 | # the maximum (over the batch) size of the extended vocabulary 338 | extended_vsize = self._vocab.size() + self._max_art_oovs 339 | extra_zeros = tf.zeros((self._hps.batch_size, self._max_art_oovs)) 340 | # list length max_dec_steps of shape (batch_size, extended_vsize) 341 | vocab_dists_extended = [ 342 | tf.concat(axis=1, values=[dist, extra_zeros]) for dist in vocab_dists] 343 | 344 | # Project the values in the attention distributions onto the appropriate entries in the final distributions 345 | # This means that if a_i = 0.1 and the ith encoder word is w, and w has index 500 in the vocabulary, then we add 0.1 onto the 500th entry of the final distribution 346 | # This is done for each decoder timestep. 347 | # This is fiddly; we use tf.scatter_nd to do the projection 348 | # shape (batch_size) 349 | batch_nums = tf.range(0, limit=self._hps.batch_size) 350 | batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1) 351 | attn_len = tf.shape(self._enc_batch_extend_vocab)[ 352 | 1] # number of states we attend over 353 | # shape (batch_size, attn_len) 354 | batch_nums = tf.tile(batch_nums, [1, attn_len]) 355 | # shape (batch_size, enc_t, 2) 356 | indices = tf.stack( 357 | (batch_nums, self._enc_batch_extend_vocab), axis=2) 358 | shape = [self._hps.batch_size, extended_vsize] 359 | # list length max_dec_steps (batch_size, extended_vsize) 360 | attn_dists_projected = [tf.scatter_nd( 361 | indices, copy_dist, shape) for copy_dist in attn_dists] 362 | 363 | # Add the vocab distributions and the copy distributions together to get the final distributions 364 | # final_dists is a list length max_dec_steps; each entry is a tensor shape (batch_size, extended_vsize) giving the final distribution for that decoder timestep 365 | # Note that for decoder timesteps and examples corresponding to a 366 | # [PAD] token, this is junk - ignore. 367 | final_dists = [vocab_dist + copy_dist for (vocab_dist, copy_dist) in zip( 368 | vocab_dists_extended, attn_dists_projected)] 369 | 370 | return final_dists 371 | 372 | 373 | if __name__ == "__main__": 374 | test() 375 | -------------------------------------------------------------------------------- /pointer_model/policy_gradient_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.python.ops import math_ops 3 | from tensorflow.python.ops import array_ops 4 | from tensorflow.python.framework import dtypes 5 | from pointer_model import data 6 | from evaluation_utils import bleu 7 | 8 | 9 | def cross_entropy_loss(p_model, targets, loss_mask, batch_size): 10 | # Calculate the loss per step 11 | # This is fiddly; we use tf.gather_nd to pick out the 12 | # probabilities of the gold target words 13 | # will be list length max_dec_steps containing shape 14 | # (batch_size) 15 | 16 | # this works when dec_seq_len is undetermined 17 | dec_seq_len = array_ops.shape(p_model)[1] 18 | batch_nums = math_ops.range(0, limit=batch_size) 19 | batch_nums = array_ops.expand_dims(batch_nums, 1) 20 | batch_nums = array_ops.tile(batch_nums, [1, dec_seq_len]) 21 | 22 | # time indices 23 | time_nums = math_ops.range(array_ops.shape(p_model)[1]) 24 | time_nums = array_ops.expand_dims(time_nums, 0) 25 | time_nums = array_ops.tile(time_nums, [batch_size, 1]) 26 | 27 | indices = array_ops.stack((batch_nums, time_nums, targets), axis=-1) 28 | gold_probs = array_ops.gather_nd(p_model, indices) 29 | raw_loss = - math_ops.log(gold_probs + 1e-6) 30 | 31 | # words per sentence 32 | tokens_per_seq = math_ops.reduce_sum(loss_mask, axis=-1) 33 | # masked loss 34 | masked_loss = raw_loss * loss_mask 35 | # average loss per sequence 36 | sequence_loss = math_ops.reduce_sum(masked_loss, axis=-1) / tokens_per_seq 37 | # avergae loss per batch 38 | loss = math_ops.reduce_mean(sequence_loss) 39 | 40 | return loss 41 | 42 | 43 | def negative_log_likelihood(actions_prob, 44 | target_actions, 45 | episode_masks, 46 | action_space, 47 | dtype=dtypes.float32, 48 | policy_multipliers=None): 49 | # exactly equal to `cross_entropy_loss` 50 | # but simpler 51 | if policy_multipliers is None: 52 | # broadcasting will do the rest of the jobs 53 | policy_multipliers = 1 54 | 55 | # calculate - p(x) log q(x) 56 | actions_log_prob = math_ops.log(actions_prob + 1e-6) 57 | target_actions_onehot = array_ops.one_hot( 58 | indices=target_actions, 59 | depth=action_space, dtype=dtype) 60 | nll = - target_actions_onehot * actions_log_prob 61 | 62 | # masked NLL, or nll * mask * policy_multipliers 63 | masked_policy_multipliers = array_ops.expand_dims( 64 | policy_multipliers * episode_masks, axis=2) 65 | scaled_masked_nll = nll * masked_policy_multipliers 66 | 67 | # sequence and batch NLL 68 | actions_per_episode = math_ops.reduce_sum(episode_masks, axis=-1) 69 | sequence_nll = math_ops.reduce_sum(scaled_masked_nll, axis=[1, 2]) 70 | sequence_nll = sequence_nll / actions_per_episode 71 | batch_nll = math_ops.reduce_mean(sequence_nll) 72 | return batch_nll 73 | 74 | 75 | def calc_bleu_rewards(sess, 76 | feed_dict, 77 | vocabulary, 78 | batch_OOVs, 79 | target_actions_pl, 80 | sampled_actions_pl, 81 | policy_multipliers_pl): 82 | fetches = sess.run( 83 | fetches={"target_actions": target_actions_pl, 84 | "sampled_actions": sampled_actions_pl}, 85 | feed_dict=feed_dict) 86 | 87 | rewards = [] 88 | for target, sampled in zip(fetches["target_actions"].tolist(), 89 | fetches["sampled_actions"].tolist()): 90 | target_actions = data.outputids2words( 91 | target, vocabulary, batch_OOVs) 92 | sampled_actions = data.outputids2words( 93 | sampled, vocabulary, batch_OOVs) 94 | 95 | reward, _, _, _, _, _ = bleu.compute_bleu( 96 | reference_corpus=[[target_actions]], 97 | translation_corpus=[sampled_actions], 98 | max_order=4, 99 | smooth=True) 100 | 101 | rewards.append(100 * reward) 102 | 103 | batch_size = fetches["sampled_actions"].shape[0] 104 | sequence_lengths = fetches["sampled_actions"].shape[1] 105 | 106 | if not len(rewards) == batch_size: 107 | raise ValueError("rewards lengths %d " 108 | "!= batch_size %d" % (len(rewards), batch_size)) 109 | 110 | # add time dimensions 111 | rewards = np.expand_dims(rewards, axis=1) 112 | # tile the time dimension 113 | rewards = np.tile(rewards, reps=sequence_lengths) 114 | feed_dict[policy_multipliers_pl] = rewards 115 | print("BatchRewards: ", np.mean(rewards)) 116 | 117 | return feed_dict 118 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | import os 6 | import argparse 7 | import tensorflow as tf 8 | from datetime import datetime 9 | from collections import namedtuple 10 | 11 | from pointer_model.data import Vocab 12 | from pointer_model.decode import BeamSearchDecoder 13 | from pointer_model.model import SummarizationModel 14 | 15 | from utils import misc_utils 16 | from evaluation_utils.evaluators import evaluate 17 | from multitask.multitask_base_model import MultitaskBatcher 18 | from multitask.multitask_base_model import MultitaskBaseModel 19 | from multitask.multitask_autoMR_model import MultitaskAutoMRModel 20 | 21 | 22 | tf.logging.set_verbosity(tf.logging.INFO) 23 | NAMES = "Newsela,SNLI,PP" 24 | StepsPerVal = 10 25 | StepsPerCheckpoint = 1500 26 | AutoMRNumValBatches = 2 27 | AutoMRStepsPerUpdate = 10 28 | ValNLL_Normalizing_Constant = 2 29 | MultitaskBatcherArgs = namedtuple("MultitaskBatcherArgs", 30 | ("data_paths", "vocabs", "hps", "single_pass")) 31 | HParamsList = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 32 | 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 33 | 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 34 | 'coverage', 'cov_loss_wt', 'pointer_gen', 35 | # additionals 36 | "num_encoder_layers", "num_decoder_layers", "dropout_rate"] 37 | 38 | 39 | def add_arguments(): 40 | parser = argparse.ArgumentParser() 41 | # Hparams (default are good) 42 | parser.add_argument("--steps_per_eval", 43 | type=int, default=1500, 44 | help="number of steps for evaluation") 45 | parser.add_argument("--hidden_dim", 46 | type=int, default=256, 47 | help="dimension of RNN hidden states") 48 | parser.add_argument("--emb_dim", 49 | type=int, default=128, 50 | help="dimension of word embeddings") 51 | parser.add_argument("--batch_size", 52 | type=int, default=256, 53 | help="minibatch size") 54 | parser.add_argument("--max_enc_steps", 55 | type=int, default=None, 56 | help="max timesteps of encoder") 57 | parser.add_argument("--max_dec_steps", 58 | type=int, default=None, 59 | help="max timesteps of decoder") 60 | parser.add_argument("--min_dec_steps", 61 | type=int, default=1, 62 | help="Minimum sequence length of generated summary. " 63 | "Applies only for beam search decoding mode") 64 | parser.add_argument("--vocab_size", 65 | type=int, default=50000, 66 | help="Size of vocabulary") 67 | parser.add_argument("--rand_unif_init_mag", 68 | type=float, default=0.02, 69 | help="magnitude for lstm cells random uniform inititalization") 70 | parser.add_argument("--trunc_norm_init_std", 71 | type=float, default=1e-4, 72 | help="std of trunc norm init, used for initializing everything else") 73 | parser.add_argument("--max_grad_norm", 74 | type=float, default=2.0, 75 | help="for gradient clipping") 76 | parser.add_argument("--pointer_gen", 77 | type=bool, default=True, 78 | help="Use pointer-generator model") 79 | parser.add_argument("--coverage", 80 | type=bool, default=False, 81 | help="Use coverage mechanism.") 82 | parser.add_argument("--convert_to_coverage_model", 83 | type=bool, default=False, 84 | help="Convert a non-coverage model to a coverage model.") 85 | parser.add_argument("--cov_loss_wt", 86 | type=float, default=1.0, 87 | help="Weight of coverage loss (lambda in the paper)" 88 | "If zero, then no incentive to minimize coverage loss.") 89 | 90 | # Hyparams need to change 91 | parser.add_argument("--num_encoder_layers", 92 | type=int, default=2, 93 | help="number of layers") 94 | parser.add_argument("--num_decoder_layers", 95 | type=int, default=2, 96 | help="number of layers") 97 | parser.add_argument("--dropout_rate", 98 | type=float, default=None, 99 | help="dropout_rate = 1 - keep_prob") 100 | parser.add_argument("--lr", 101 | type=float, default=0.001, 102 | help="learning rate") 103 | parser.add_argument("--beam_size", 104 | type=int, default=None, 105 | help="beam size for beam search decoding.") 106 | parser.add_argument("--max_hours", 107 | type=int, default=None, 108 | help="number of hours before killing the model.") 109 | 110 | # model directories 111 | parser.add_argument("--mode", 112 | type=str, default=None, 113 | help="train or decode") 114 | 115 | parser.add_argument("--log_root", 116 | type=str, default=None, 117 | help="Root directory for all logging.") 118 | 119 | parser.add_argument("--exp_name", 120 | type=str, default=None, 121 | help="Name for experiment. Logs will be saved " 122 | "in a directory with this name, under log_root.") 123 | 124 | parser.add_argument("--vocab_path", 125 | type=str, default=None, 126 | help="path to vocabulary") 127 | 128 | parser.add_argument("--train_data_dirs", 129 | type=str, default=None, 130 | help="Comma-separated: " 131 | "path expression to tf.Example datafiles. ") 132 | 133 | parser.add_argument("--val_data_dir", 134 | type=str, default=None, 135 | help="path expression to tf.Example datafiles. ") 136 | 137 | # evaluation 138 | parser.add_argument("--eval_source_dir", 139 | type=str, default=None, 140 | help="Directory to the evaluation source") 141 | 142 | parser.add_argument("--eval_target_dir", 143 | type=str, default=None, 144 | help="Directory to the evaluation target") 145 | 146 | parser.add_argument("--eval_folder_dir", 147 | type=str, default=None, 148 | help="directory to the evaluation folder") 149 | 150 | # load models 151 | parser.add_argument("--load_ckpt_file", 152 | type=str, default=None, 153 | help="restore from specific checkpints") 154 | 155 | # decoding 156 | parser.add_argument("--decode_data_dir", 157 | type=str, default=None, 158 | help="directory to the file for decoding") 159 | 160 | parser.add_argument("--decode_ckpt_file", 161 | type=str, default=None, 162 | help="checkpoint files for decoding only") 163 | 164 | parser.add_argument("--decode_output_file", 165 | type=str, default=None, 166 | help="outputs of decoding") 167 | 168 | parser.add_argument("--names", 169 | type=str, default=NAMES) 170 | parser.add_argument("--mixing_ratios", 171 | type=str, default=None) 172 | parser.add_argument("--soft_sharing_coef", 173 | type=float, default=None) 174 | parser.add_argument("--autoMR", 175 | action="store_true", default=False) 176 | parser.add_argument("--reward_scaling_factor", 177 | type=float, default=ValNLL_Normalizing_Constant, 178 | help="reward scaling") 179 | parser.add_argument("--selector_alpha", 180 | type=float, default=0.3) 181 | 182 | 183 | 184 | FLAGS, unparsed = parser.parse_known_args() 185 | 186 | # convert comma-separated strings into lists 187 | FLAGS.names = FLAGS.names.split(",") 188 | FLAGS.mixing_ratios = ( 189 | [float(x) for x in FLAGS.mixing_ratios.split(",")] 190 | if FLAGS.mixing_ratios is not None else None) 191 | 192 | FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) 193 | if not os.path.exists(FLAGS.log_root): 194 | os.makedirs(FLAGS.log_root) 195 | 196 | if FLAGS.train_data_dirs is None: 197 | if FLAGS.mode != "decode": 198 | raise ValueError("train_data_dirs cannot be None") 199 | # else keep it None, since it doesnt matter 200 | else: 201 | # check compatability 202 | FLAGS.train_data_dirs = FLAGS.train_data_dirs.split(",") 203 | if not len(FLAGS.names) == len(FLAGS.train_data_dirs): 204 | raise ValueError("names and train_data_dirs not match") 205 | 206 | if (FLAGS.mixing_ratios is not None and 207 | len(FLAGS.names) != len(FLAGS.mixing_ratios) + 1): 208 | raise ValueError("names and mixing_ratios + 1 not match") 209 | 210 | if not FLAGS.soft_sharing_coef or FLAGS.soft_sharing_coef < 1e-6: 211 | raise ValueError("not really supported") 212 | 213 | 214 | if FLAGS.dropout_rate is not None: 215 | raise ValueError("Not supporting dropout") 216 | 217 | # Make a namedtuple hps 218 | hps_dict = {} 219 | for key, val in vars(FLAGS).items(): 220 | if key in HParamsList: 221 | hps_dict[key] = val 222 | hps = namedtuple("HParams", hps_dict.keys())(** hps_dict) 223 | return FLAGS, hps 224 | 225 | 226 | def _model_factory(name): 227 | model = SummarizationModel 228 | print("Task %s is using %s" % (name, model.__name__)) 229 | return model 230 | 231 | 232 | def setup_training(FLAGS, hps): 233 | """Does setup before starting training (run_training)""" 234 | 235 | # Setting up the Multitask Wrapper 236 | # ---------------------------------------- 237 | if FLAGS.autoMR: 238 | # for decode, we can still use this one 239 | # since both are essentially the same 240 | # except no auto-MR feature 241 | MultitaskModel = MultitaskAutoMRModel 242 | else: 243 | MultitaskModel = MultitaskBaseModel 244 | 245 | # Setting up the models and directories 246 | # ---------------------------------------- 247 | num_models = len(FLAGS.names) 248 | # train_dir is a folder, decode_dir is a file 249 | train_dir = os.path.join(FLAGS.log_root, "train") 250 | decode_dir = os.path.join(FLAGS.log_root, "decode") 251 | model_creators = [_model_factory(name) for name in FLAGS.names] 252 | if not os.path.exists(train_dir): 253 | os.makedirs(train_dir) 254 | 255 | # Setting up the batchers and data readers 256 | # ---------------------------------------- 257 | print("Loading Training Data from %s " % FLAGS.train_data_dirs) 258 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) 259 | train_batchers = MultitaskBatcher( 260 | data_paths=FLAGS.train_data_dirs, 261 | vocabs=[vocab for _ in range(num_models)], 262 | hps=hps, single_pass=False) 263 | # not using decode_model_hps which have batch-size = beam-size 264 | val_batchers = MultitaskBatcher( 265 | data_paths=[FLAGS.val_data_dir], 266 | vocabs=[vocab], hps=hps, single_pass=False) 267 | 268 | # Setting up the task selectors 269 | # ---------------------------------------- 270 | Q_initial = -1 271 | if FLAGS.reward_scaling_factor > 0.0: 272 | Q_initial = Q_initial / FLAGS.reward_scaling_factor 273 | tf.logging.info("Normalization %.2f" % FLAGS.reward_scaling_factor) 274 | 275 | # Build 276 | # ---------------------------------------- 277 | print("Mixing ratios are %s " % FLAGS.mixing_ratios) 278 | train_models = MultitaskModel( 279 | names=FLAGS.names, 280 | all_hparams=[hps for _ in range(num_models)], 281 | mixing_ratios=FLAGS.mixing_ratios, 282 | model_creators=model_creators, 283 | logdir=train_dir, 284 | soft_sharing_coef=FLAGS.soft_sharing_coef, 285 | data_generators=train_batchers, 286 | val_data_generator=val_batchers, 287 | vocab=vocab, 288 | selector_Q_initial=Q_initial, 289 | alpha=FLAGS.selector_alpha, 290 | temperature_anneal_rate=None) 291 | 292 | # Note this use a different decoder_batcher 293 | 294 | # The model is configured with max_dec_steps=1 because we only ever run 295 | # one step of the decoder at a time (to do beam search). Note that the 296 | # batcher is initialized with max_dec_steps equal to e.g. 100 because 297 | # the batches need to contain the full summaries 298 | 299 | # If in decode mode, set batch_size = beam_size 300 | # Reason: in decode mode, we decode one example at a time. 301 | # On each step, we have beam_size-many hypotheses in the beam, so we need 302 | # to make a batch of these hypotheses. 303 | decode_model_hps = hps 304 | decode_model_hps = hps._replace( 305 | mode="decode")._replace(batch_size=FLAGS.beam_size) 306 | 307 | # we need to constantly re-initialize this generator 308 | # so save arguments as a namedtuple 309 | print("Loading Validation Data from %s " % FLAGS.val_data_dir) 310 | decode_batcher_args = MultitaskBatcherArgs( 311 | data_paths=[FLAGS.val_data_dir], 312 | vocabs=[vocab], 313 | hps=decode_model_hps, 314 | single_pass=True) 315 | 316 | decode_batchers = ( 317 | MultitaskBatcher(** decode_batcher_args._asdict())) 318 | 319 | # only for one model 320 | decode_models = MultitaskBaseModel( 321 | names=[FLAGS.names[0]], 322 | all_hparams=[decode_model_hps._replace(max_dec_steps=1)], 323 | mixing_ratios=None, 324 | model_creators=[model_creators[0]], 325 | logdir=train_dir, 326 | soft_sharing_coef=FLAGS.soft_sharing_coef, 327 | vocab=vocab) 328 | 329 | with decode_models.graph.as_default(): 330 | decoder = BeamSearchDecoder(model=decode_models, 331 | batcher=decode_batchers, 332 | vocab=vocab, 333 | ckpt_dir=train_dir, 334 | decode_dir=decode_dir, 335 | FLAGS=FLAGS) 336 | decode_sess = tf.Session(graph=decode_models.graph, 337 | config=misc_utils.get_config()) 338 | decoder.build_graph(decode_sess) 339 | 340 | try: 341 | # this is an infinite loop until interrupted 342 | run_training(FLAGS=FLAGS, 343 | models=train_models, 344 | decoder=decoder, 345 | decode_batcher_args=decode_batcher_args) 346 | 347 | except KeyboardInterrupt: 348 | tf.logging.info("Stopped...") 349 | 350 | 351 | def run_training(FLAGS, models, decoder, decode_batcher_args): 352 | tf.logging.info("Initializing ...") 353 | models.initialize_or_restore_session(ckpt_file=FLAGS.load_ckpt_file) 354 | 355 | start_time = datetime.now() 356 | tf.logging.info("Starting run_training at %s, will run " 357 | "for %s hours", start_time, FLAGS.max_hours) 358 | 359 | while True: 360 | with misc_utils.calculate_time("seconds for training step"): 361 | models.run_train_step() 362 | 363 | elapsed_hours = (datetime.now() - start_time).seconds // 3600 364 | if FLAGS.max_hours and elapsed_hours >= FLAGS.max_hours: 365 | models.save_session() 366 | break 367 | 368 | # update the val-loss as Q-values 369 | # define Q as negative val-loss 370 | if FLAGS.autoMR and models.global_step % AutoMRStepsPerUpdate == 0: 371 | total_val_loss = 0 372 | for _ in range(AutoMRNumValBatches): 373 | val_loss = models.run_eval_step() 374 | total_val_loss += val_loss 375 | 376 | # Q = negaative average val-loss 377 | scores = -float(total_val_loss) / float(AutoMRNumValBatches) 378 | # reward scaling 379 | if FLAGS.reward_scaling_factor > 0.0: 380 | scores = scores / float(FLAGS.reward_scaling_factor) 381 | # update the Q values 382 | models.update_TaskSelector_Q_values(scores) 383 | 384 | if models.global_step % StepsPerCheckpoint == 0: 385 | models.save_session() 386 | 387 | if (models.global_step != 0 and 388 | models.global_step % FLAGS.steps_per_eval == 0): 389 | # save checkpoints 390 | models.save_session() 391 | # run decode for calculating scores 392 | decoder.decode() 393 | # reset batcher from exhausted state 394 | decode_batchers = ( 395 | MultitaskBatcher(** decode_batcher_args._asdict())) 396 | decoder.reset_batcher(decode_batchers) 397 | # evaluate generated outputs and log results 398 | 399 | scores = evaluate( 400 | mode="val", 401 | gen_file=decoder._decode_dir, 402 | ref_file=FLAGS.eval_target_dir, 403 | execute_dir=FLAGS.eval_folder_dir, 404 | source_file=FLAGS.eval_source_dir, 405 | evaluation_task=FLAGS.names[0]) 406 | 407 | print(scores) 408 | 409 | 410 | def setup_and_run_decoding(FLAGS, hps): 411 | # raise ValueError("Pay attention to dropout is set or not") 412 | if os.path.exists(FLAGS.decode_output_file): 413 | raise ValueError("`decode_output_file` exists") 414 | 415 | decode_model_hps = hps 416 | decode_model_hps = hps._replace( 417 | mode="decode")._replace(batch_size=FLAGS.beam_size) 418 | train_dir = os.path.join(FLAGS.log_root, "train") 419 | model_creators = [_model_factory(name) for name in FLAGS.names] 420 | 421 | print("Loading Decoding Data from %s " % FLAGS.decode_data_dir) 422 | vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) 423 | decode_batchers = MultitaskBatcher( 424 | data_paths=[FLAGS.decode_data_dir], 425 | vocabs=[vocab], 426 | hps=decode_model_hps, 427 | single_pass=True) 428 | 429 | # only for one model 430 | decode_models = MultitaskBaseModel( 431 | names=[FLAGS.names[0]], 432 | all_hparams=[decode_model_hps._replace(max_dec_steps=1)], 433 | mixing_ratios=None, 434 | model_creators=[model_creators[0]], 435 | logdir=train_dir, 436 | soft_sharing_coef=FLAGS.soft_sharing_coef, 437 | # additional args 438 | vocab=vocab) 439 | 440 | with decode_models.graph.as_default(): 441 | decoder = BeamSearchDecoder(model=decode_models, 442 | batcher=decode_batchers, 443 | vocab=vocab, 444 | ckpt_dir=train_dir, 445 | decode_dir=FLAGS.decode_output_file, 446 | FLAGS=FLAGS) 447 | decode_sess = tf.Session(graph=decode_models.graph, 448 | config=misc_utils.get_config()) 449 | decoder.build_graph(decode_sess) 450 | 451 | # run decode for calculating scores 452 | decoder.decode(ckpt_file=FLAGS.decode_ckpt_file) 453 | 454 | scores = evaluate( 455 | mode="test", 456 | gen_file=decoder._decode_dir, 457 | ref_file=FLAGS.eval_target_dir, 458 | execute_dir=FLAGS.eval_folder_dir, 459 | source_file=FLAGS.eval_source_dir, 460 | evaluation_task=FLAGS.names[0]) 461 | 462 | print(scores) 463 | 464 | 465 | def main(unused_argv): 466 | tf.set_random_seed(111) 467 | FLAGS, hps = add_arguments() 468 | 469 | if hps.mode == 'train': 470 | print("creating training model...") 471 | setup_training(FLAGS, hps) 472 | 473 | elif hps.mode == 'decode': 474 | print("creating decoding model") 475 | setup_and_run_decoding(FLAGS, hps) 476 | 477 | else: 478 | raise ValueError(hps.mode) 479 | 480 | 481 | if __name__ == '__main__': 482 | tf.app.run() 483 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HanGuo97/MultitaskSimplification/2632e7bdb5fd53c32092468662fefd8ea6c1dc5d/utils/__init__.py -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | from time import time 6 | from copy import deepcopy 7 | from contextlib import contextmanager 8 | 9 | import tensorflow as tf 10 | from utils import modified_rnn_cell_wrappers as cell_wrappers 11 | 12 | FLAGS = tf.app.flags.FLAGS 13 | 14 | 15 | @contextmanager 16 | def calculate_time(tag): 17 | start_time = time() 18 | yield 19 | print("%s: " % tag, time() - start_time) 20 | 21 | 22 | def union_lists(lists): 23 | if not isinstance(lists, (list, tuple)): 24 | raise TypeError("lists Must be a list of list") 25 | 26 | new_list = deepcopy(lists[0]) 27 | for single_list in lists: 28 | if not isinstance(single_list, (list, tuple)): 29 | raise TypeError("single_list Must be a list") 30 | for list_item in single_list: 31 | if list_item not in new_list: 32 | new_list.append(list_item) 33 | return new_list 34 | 35 | 36 | def assert_all_same(items, attr=None): 37 | if not isinstance(items, (list, tuple)): 38 | raise TypeError("items should be list or tuple") 39 | 40 | if attr is not None: 41 | if not all(getattr(x, attr) == getattr(items[0], attr) for x in items): 42 | raise ValueError("items of %s not consistent between items" % attr) 43 | else: 44 | if not all(x == items[0] for x in items): 45 | raise ValueError("items not consistent between items") 46 | 47 | 48 | def get_config(): 49 | """Returns config for tf.session""" 50 | config = tf.ConfigProto(allow_soft_placement=True) 51 | config.gpu_options.allow_growth = True 52 | return config 53 | 54 | 55 | def load_ckpt(saver, sess, ckpt_dir=None, ckpt_file=None): 56 | if not ckpt_dir and not ckpt_file: 57 | return 58 | 59 | if not ckpt_file: 60 | ckpt_file = tf.train.latest_checkpoint(ckpt_dir) 61 | if ckpt_file is None: 62 | return 63 | 64 | saver.restore(sess, ckpt_file) 65 | tf.logging.info("Loaded checkpoint %s" % ckpt_file) 66 | return ckpt_file 67 | 68 | 69 | def concate_multi_rnn_cell_states(states, concat_fn, expand_fn): 70 | if not isinstance(states, (list, tuple)): 71 | raise TypeError( 72 | "states should be a list of beam_size, but saw ", type(states)) 73 | if not isinstance(states[0], (list, tuple)): 74 | raise TypeError( 75 | "each states[beam_id] should be a list " 76 | "of multi_rnn states, but saw ", type(states[0])) 77 | if not isinstance(states[0][0], tf.nn.rnn_cell.LSTMStateTuple): 78 | raise TypeError("each states[beam_id][layer_id] should be " 79 | "LSTMStateTuple, but saw ", type(states[0][0])) 80 | num_beams = len(states) 81 | num_layers = len(states[0]) 82 | concat_states = [] 83 | append_states = [ 84 | tf.nn.rnn_cell.LSTMStateTuple(c=[], h=[]) 85 | for _ in range(num_layers)] 86 | 87 | # append all cell and hidden states to lists 88 | for beam_id in range(num_beams): 89 | for layer_id in range(num_layers): 90 | append_states[layer_id].c.append(expand_fn(states[beam_id][layer_id].c)) 91 | append_states[layer_id].h.append(expand_fn(states[beam_id][layer_id].h)) 92 | 93 | # concatenate the list 94 | for layer_id in range(num_layers): 95 | concat_c = concat_fn(append_states[layer_id].c) 96 | concat_h = concat_fn(append_states[layer_id].h) 97 | concat_states.append(tf.nn.rnn_cell.LSTMStateTuple(c=concat_c, h=concat_h)) 98 | 99 | 100 | return cell_wrappers.to_MultiRNNLSTMStateTuple(concat_states) 101 | 102 | 103 | def split_multi_rnn_cell_states(states): 104 | if not isinstance(states, (list, tuple)): 105 | raise TypeError( 106 | "states should be a list of beam_size, but saw ", type(states)) 107 | if not isinstance(states[0], (list, tuple)): 108 | raise TypeError("each states[layer_id] should be " 109 | "LSTMStateTuple, but saw ", type(states[0])) 110 | num_layers = len(states) 111 | num_beams = states[0].c.shape[0] 112 | new_states = [[0 for _ in range(num_layers)] for _ in range(num_beams)] 113 | for layer_id in range(num_layers): 114 | for beam_id in range(num_beams): 115 | c = states[layer_id].c[beam_id, :] 116 | h = states[layer_id].h[beam_id, :] 117 | new_states[beam_id][layer_id] = tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h) 118 | return new_states 119 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | 5 | from tensorflow.python.ops import * 6 | 7 | 8 | def is_seqence(x): 9 | if isinstance(x, (tuple, list)): 10 | return True 11 | return False 12 | 13 | 14 | def have_same_length(x, y): 15 | if not is_seqence(x): 16 | raise TypeError("x is not sequence") 17 | if not is_seqence(y): 18 | raise TypeError("y is not sequence") 19 | 20 | if not len(x) == len(y): 21 | return False 22 | return True 23 | 24 | 25 | def get_decoder_initial_state(encoder_states, 26 | num_layers, 27 | encoder_type="bi-directional", 28 | decoder_zero_state_fn=None, 29 | method="direct-pass", 30 | scope=None, 31 | num_units, 32 | dtype=tf.float32, 33 | initializer=None): 34 | 35 | raise NotImplementedError("not-used") 36 | 37 | if not is_seqence(encoder_states) or \ 38 | len(encoder_states) != num_layers: 39 | raise ValueError("encoder_states must be sequences " 40 | "of encoder states, where len(states) " 41 | "equals number of layers, found " 42 | "encoder_states to be %s, and " 43 | "len(states) %d != %d" % (type(encoder_states), 44 | len(encoder_states), 45 | num_layers)) 46 | 47 | if encoder_type not in ["bi-directional", "uni-directional"]: 48 | raise ValueError("%s not recognized" % encoder_type) 49 | if method not in ["direct-pass", "linear-projection"]: 50 | raise ValueError("%s not recognized" % method) 51 | 52 | 53 | decoder_initial_states = [] 54 | if method == "linear-projection": 55 | # linear-projection: 56 | 57 | # requirements: 58 | # 1. bidirectional encoders 59 | # 2. num_encoder_layers = num_decoder_layers 60 | 61 | # linearly project forward and backward cell states 62 | # into one single state, applied per layer 63 | 64 | # each encoder_state should be list of [fw_state, bw_state] 65 | if not is_seqence(encoder_states[0]) and len(encoder_states[0]) == 2: 66 | raise ValueError("linear-projection is for bidirectional encoder") 67 | 68 | for layer_id in range(num_layers): 69 | with tf.variable_scope( 70 | scope or "encoder_scope", reuse=layer_id != 0): 71 | # Define weights and biases to reduce 72 | # the cell and reduce the state 73 | w_reduce_c = tf.get_variable( 74 | 'w_reduce_c', [num_units * 2, num_units], 75 | dtype=dtype, initializer=initializer) 76 | w_reduce_h = tf.get_variable( 77 | 'w_reduce_h', [num_units * 2, num_units], 78 | dtype=dtype, initializer=initializer) 79 | bias_reduce_c = tf.get_variable( 80 | 'bias_reduce_c', [num_units], 81 | dtype=dtype, initializer=initializer) 82 | bias_reduce_h = tf.get_variable( 83 | 'bias_reduce_h', [num_units], 84 | dtype=dtype, initializer=initializer) 85 | 86 | # Apply linear layer 87 | # Concatenation of fw and bw cell 88 | cell_states = [encoder_states[0][layer_id], 89 | encoder_states[1][layer_id]] 90 | old_c = tf.concat_v2(axis=1, 91 | values=[st.c for st in cell_states]) 92 | # Concatenation of fw and bw state 93 | old_h = tf.concat_v2(axis=1, 94 | values=[st.h for st in cell_states]) 95 | # Get new cell from old cell 96 | new_c = tf.nn.relu( 97 | tf.matmul(old_c, w_reduce_c) + bias_reduce_c) 98 | # Get new state from old state 99 | new_h = tf.nn.relu( 100 | tf.matmul(old_h, w_reduce_h) + bias_reduce_h) 101 | 102 | # Return new cell and state 103 | decoder_initial_states.append( 104 | tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h)) 105 | 106 | -------------------------------------------------------------------------------- /utils/modified_rnn_cell_wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import namedtuple 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.ops import variable_scope as vs 8 | from tensorflow.python.util import nest 9 | from tensorflow.python.ops import array_ops 10 | from tensorflow.python.ops import rnn_cell 11 | 12 | # Hashable List 13 | NUM_LAYERS = 2 14 | MultiRNNLSTMStateTuple = namedtuple("MultiRNNLSTMStateTuple", 15 | ("Layer0", "Layer1")) 16 | 17 | 18 | def to_MultiRNNLSTMStateTuple(states): 19 | if isinstance(states, MultiRNNLSTMStateTuple): 20 | return states 21 | 22 | if not isinstance(states, (list, tuple)): 23 | raise TypeError( 24 | "Expected states to be list, found ", type(states)) 25 | 26 | if not len(states) == NUM_LAYERS: 27 | raise ValueError( 28 | "Only %d layers are supported now, found %d" 29 | % (NUM_LAYERS, len(states))) 30 | 31 | return MultiRNNLSTMStateTuple(* states) 32 | 33 | 34 | class SingleRNNCell(rnn_cell.RNNCell): 35 | """Cell Wrapper with Scope""" 36 | def __init__(self, cell, cell_scope): 37 | # use if in TF1.0 or afterwards 38 | super(SingleRNNCell, self).__init__() 39 | self._cell = cell 40 | self._cell_scope = cell_scope 41 | 42 | @property 43 | def state_size(self): 44 | return self._cell.state_size 45 | 46 | @property 47 | def output_size(self): 48 | return self._cell.output_size 49 | 50 | def zero_state(self, batch_size, dtype): 51 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 52 | return self._cell.zero_state(batch_size, dtype) 53 | 54 | def call(self, inputs, state): 55 | """Run this multi-layer cell on inputs, starting from state.""" 56 | # customized scope 57 | with vs.variable_scope(self._cell_scope or "Cell"): 58 | outputs, new_state = self._cell(inputs, state) 59 | 60 | return outputs, new_state 61 | 62 | 63 | class MultiRNNCell(rnn_cell.RNNCell): 64 | """RNN cell composed sequentially of multiple simple cells.""" 65 | 66 | def __init__(self, cells, state_is_tuple=True, cell_scopes=None): 67 | """Create a RNN cell composed sequentially of a number of RNNCells. 68 | Args: 69 | cells: list of RNNCells that will be composed in this order. 70 | state_is_tuple: If True, accepted and returned states are n-tuples, where 71 | `n = len(cells)`. If False, the states are all 72 | concatenated along the column axis. This latter behavior will soon be 73 | deprecated. 74 | Raises: 75 | ValueError: if cells is empty (not allowed), or at least one of the cells 76 | returns a state tuple but the flag `state_is_tuple` is `False`. 77 | """ 78 | # use if in TF1.0 or afterwards 79 | super(MultiRNNCell, self).__init__() 80 | if not cells: 81 | raise ValueError( 82 | "Must specify at least one cell for MultiRNNCell.") 83 | if not nest.is_sequence(cells): 84 | raise TypeError( 85 | "cells must be a list or tuple, but saw: %s." % cells) 86 | if (not isinstance(cell_scopes, (tuple, list)) and 87 | len(cell_scopes) == len(cells)): 88 | raise ValueError( 89 | "scopes should be a list with same shape as cells") 90 | if not len(cells) == NUM_LAYERS: 91 | raise ValueError("Only two layer Cells are supported") 92 | 93 | self._cells = cells 94 | self._state_is_tuple = state_is_tuple 95 | self._cell_scopes = cell_scopes 96 | if not state_is_tuple: 97 | if any(nest.is_sequence(c.state_size) for c in self._cells): 98 | raise ValueError( 99 | "Some cells return tuples of states, but the flag " 100 | "state_is_tuple is not set. State sizes are: %s" 101 | % str([c.state_size for c in self._cells])) 102 | 103 | @property 104 | def state_size(self): 105 | if self._state_is_tuple: 106 | state_size = tuple(cell.state_size for cell in self._cells) 107 | return MultiRNNLSTMStateTuple(* state_size) 108 | else: 109 | return sum([cell.state_size for cell in self._cells]) 110 | 111 | @property 112 | def output_size(self): 113 | return self._cells[-1].output_size 114 | 115 | def zero_state(self, batch_size, dtype): 116 | # overwrite 0.12 styles with 1.4 style 117 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 118 | if self._state_is_tuple: 119 | zero_state = tuple(cell.zero_state(batch_size, dtype) 120 | for cell in self._cells) 121 | # wrap list of states with hashable tuple 122 | return MultiRNNLSTMStateTuple(* zero_state) 123 | else: 124 | # We know here that state_size of each cell is not a tuple and 125 | # presumably does not contain TensorArrays or anything else 126 | # fancy 127 | return super(MultiRNNCell, self).zero_state(batch_size, dtype) 128 | 129 | def call(self, inputs, state): 130 | """Run this multi-layer cell on inputs, starting from state.""" 131 | cur_state_pos = 0 132 | cur_inp = inputs 133 | new_states = [] 134 | for i, cell in enumerate(self._cells): 135 | # customized scope 136 | with vs.variable_scope(self._cell_scopes[i] or "Cell_%d" % i): 137 | if self._state_is_tuple: 138 | if not nest.is_sequence(state): 139 | raise ValueError( 140 | "Expected state to be a tuple of length %d, " 141 | "but received: %s" % (len(self.state_size), state)) 142 | if not isinstance(state, MultiRNNLSTMStateTuple): 143 | raise TypeError( 144 | "Expected state to be MultiRNNLSTMStateTuple, " 145 | "found ", type(state)) 146 | cur_state = state[i] 147 | else: 148 | cur_state = array_ops.slice(state, [0, cur_state_pos], 149 | [-1, cell.state_size]) 150 | cur_state_pos += cell.state_size 151 | cur_inp, new_state = cell(cur_inp, cur_state) 152 | new_states.append(new_state) 153 | 154 | if self._state_is_tuple: 155 | # wrap list of states with hashable tuple 156 | new_states = tuple(new_states) 157 | new_states = MultiRNNLSTMStateTuple(* new_states) 158 | else: 159 | # use concat_v2 for consistency with TF1.0 160 | new_states = array_ops.concat(new_states, 1) 161 | 162 | return cur_inp, new_states 163 | -------------------------------------------------------------------------------- /utils/rnn_cell_utils.py: -------------------------------------------------------------------------------- 1 | """https://github.com/tensorflow/nmt""" 2 | from __future__ import print_function 3 | from __future__ import division 4 | from __future__ import absolute_import 5 | 6 | import tensorflow as tf 7 | 8 | 9 | def _single_cell(unit_type, 10 | num_units, 11 | mode="train", 12 | dropout=None, 13 | residual_connection=False, 14 | *args, **kargs): 15 | """Create an instance of a single RNN cell.""" 16 | 17 | # Cell Type 18 | if unit_type == "lstm": 19 | single_cell = tf.nn.rnn_cell.BasicLSTMCell( 20 | num_units, *args, **kargs) 21 | 22 | elif unit_type == "gru": 23 | single_cell = tf.nn.rnn_cell.GRUCell( 24 | num_units, *args, **kargs) 25 | 26 | elif unit_type == "layer_norm_lstm": 27 | single_cell = tf.nn.rnn_cell.LayerNormBasicLSTMCell( 28 | num_units, layer_norm=True, *args, **kargs) 29 | 30 | elif unit_type == "classical_lstm": 31 | single_cell = tf.nn.rnn_cell.LSTMCell( 32 | num_units, *args, **kargs) 33 | 34 | else: 35 | raise ValueError("Unknown unit type %s !" % unit_type) 36 | 37 | # dropout (= 1 - keep_prob) is set to 0 during eval and infer 38 | if dropout is not None: 39 | dropout = dropout if mode == "train" else 0.0 40 | single_cell = tf.nn.rnn_cell.DropoutWrapper( 41 | cell=single_cell, input_keep_prob=(1.0 - dropout)) 42 | 43 | print("Using Dropout of dropout_keep rate %.2f" % (1.0 - dropout)) 44 | 45 | # Residual 46 | if residual_connection: 47 | single_cell = tf.nn.rnn_cell.ResidualWrapper(single_cell) 48 | 49 | 50 | return single_cell 51 | 52 | 53 | def _cell_list(unit_type, 54 | num_units, 55 | num_layers, 56 | mode="train", 57 | dropout=None, 58 | num_residual_layers=0, 59 | single_cell_fn=None, 60 | *args, **kargs): 61 | """Create a list of RNN cells.""" 62 | if not single_cell_fn: 63 | single_cell_fn = _single_cell 64 | 65 | cell_list = [] 66 | for i in range(num_layers): 67 | single_cell = single_cell_fn( 68 | unit_type=unit_type, 69 | num_units=num_units, 70 | mode=mode, 71 | dropout=dropout, 72 | residual_connection=(i >= num_layers - num_residual_layers), 73 | *args, **kargs) 74 | cell_list.append(single_cell) 75 | 76 | return cell_list 77 | 78 | 79 | def create_rnn_cell(unit_type, 80 | num_units, 81 | num_layers, 82 | mode, 83 | dropout=None, 84 | num_residual_layers=0, 85 | single_cell_fn=None, 86 | cell_wrapper=None, 87 | cell_scope=None, 88 | *args, **kargs): 89 | """Create multi-layer RNN cell. 90 | 91 | Args: 92 | unit_type: string representing the unit type, i.e. "lstm". 93 | num_units: the depth of each unit. 94 | num_layers: number of cells. 95 | num_residual_layers: Number of residual layers from top to bottom. For 96 | example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN 97 | cells in the returned list will be wrapped with `ResidualWrapper`. 98 | dropout: floating point value between 0.0 and 1.0: 99 | the probability of dropout. this is ignored if `mode != TRAIN`. 100 | mode: either tf.contrib.learn.TRAIN/EVAL/INFER 101 | single_cell_fn: allow for adding customized cell. 102 | When not specified, we default to model_helper._single_cell 103 | Returns: 104 | An `RNNCell` instance. 105 | """ 106 | cell_list = _cell_list(unit_type=unit_type, 107 | num_units=num_units, 108 | num_layers=num_layers, 109 | mode=mode, 110 | dropout=dropout, 111 | num_residual_layers=num_residual_layers, 112 | single_cell_fn=single_cell_fn, 113 | *args, **kargs) 114 | 115 | if cell_wrapper and not callable(cell_wrapper): 116 | raise TypeError("Expect `cell_wrapper` to be callable, " 117 | "found ", type(cell_wrapper)) 118 | 119 | if len(cell_list) == 1: # Single layer. 120 | if not cell_wrapper: 121 | return cell_list[0] 122 | return cell_wrapper(cell=cell_list[0], cell_scope=cell_scope) 123 | 124 | else: # Multi layers 125 | if not cell_wrapper: 126 | return tf.nn.rnn_cell.MultiRNNCell(cell_list) 127 | return cell_wrapper(cells=cell_list, cell_scopes=cell_scope) 128 | --------------------------------------------------------------------------------