├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── convert_to_parent_reference.py ├── data_utils.py ├── decoders.py ├── decorators.py ├── encoders.py ├── eval_parent.py ├── eval_rouge_meteor_rep.py ├── generate_beam_search.py ├── generate_greedy_sampl.py ├── model_utils.py ├── models.py ├── scripts ├── convert2parent_dev.sh ├── eval_dev.sh ├── generate_beam_search.sh └── train_large_copy_cyc.sh ├── train.py ├── train_helper.py └── transformer_xlm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mingda Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WikiTableT 2 | 3 | Code, data, and pretrained models for the paper "[Generating Wikipedia Article Sections from Diverse Data Sources](https://arxiv.org/abs/2012.14919)" 4 | 5 | **Note: we refer to the section data as hyperlink data in both the processed json files and the codebase.** 6 | 7 | ## Resources 8 | 9 | - [WikiTableT dataset](https://drive.google.com/file/d/1HRpnKLI6vZusB8NoR0cgUYeoD8aX5q2r/view?usp=sharing) 10 | - [multi-bleu and METEOR score](https://drive.google.com/drive/folders/1FJjvMldeZrJnQd-iVXJ3KGFBLEvsndNY?usp=sharing) 11 | - [Trained models (base+copy+cyc (trained on 500k instances) and large+copy+cyc (trained on the full dataset))](https://drive.google.com/drive/folders/1L8kzbWVwufnJXtMAmoB1slPez7mqsVzE?usp=sharing) 12 | - [BPE code and vocab](https://drive.google.com/file/d/1PN_0lHLBCbBDHnJC3CTsdkC1poVBSG7M/view?usp=sharing) (We used https://github.com/rsennrich/subword-nmt) 13 | - [Data for computing the PARENT scores](https://drive.google.com/file/d/1VjyqChwuzAhUcP1Me8Ay_UemZwL8qNNS/view?usp=sharing) 14 | 15 | ## Dependencies 16 | 17 | - Python 3.7 18 | - PyTorch 1.5.1 19 | - NLTK 20 | - [py-rouge](https://github.com/Diego999/py-rouge) 21 | - [entmax](https://github.com/deep-spin/entmax) 22 | 23 | ## Usage 24 | 25 | Tp train a new model, you may use a command similar to ``scripts/train_large_copy_cyc.sh``. 26 | 27 | To perform beam search generation using a trained model, you may use a command similar to ``scripts/generate_beam_search.sh``. The process should generate 4 files including references. 2 of them are tokenized using NLTK for the convenience of latter evaluation steps. 28 | 29 | If you want to generate your own version of reference data when computing the PARENT scores, use a command similar to ``scripts/convert2parent_dev.sh``. 30 | 31 | Once you have the generated file, you may evaluate it against the reference using the command ``scripts/eval_dev.sh REF_FILE_PATH GEN_FILE_PATH``. Please make sure that you are using the tokenized files. 32 | 33 | ## Acknowledgement 34 | 35 | Part of the code in this repository is adapted from the following repositories: 36 | 37 | - https://github.com/huggingface/transformers 38 | - https://github.com/facebookresearch/XLM 39 | - https://github.com/google-research/language/tree/master/language/table_text_eval 40 | - https://github.com/OpenNMT/OpenNMT-py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | UNK_IDX = 0 4 | UNK_WORD = "UUUNKKK" 5 | BOS_IDX = 1 6 | EOS_IDX = 2 7 | BOC_IDX = 3 8 | BOV_IDX = 4 9 | BOQK_IDX = 5 10 | BOQV_IDX = 6 11 | SUBSEC_IDX = 7 12 | 13 | EOC_IDX = 8 14 | MASK_IDX = 9 15 | 16 | MAX_VALUE_LEN = 20 17 | MAX_GEN_LEN = 200 18 | 19 | METEOR_JAR = 'evaluation/meteor-1.5.jar' 20 | METEOR_DATA = 'evaluation/data/paraphrase-en.gz' 21 | MULTI_BLEU_PERL = 'evaluation/multi-bleu.perl' 22 | RESOURCE_LINK = 'https://drive.google.com/drive/folders/1FJjvMldeZrJnQd-iVXJ3KGFBLEvsndNY?usp=sharing' 23 | 24 | 25 | def str2bool(v): 26 | return v.lower() in ('yes', 'true', 't', '1', 'y') 27 | 28 | 29 | def get_base_parser(): 30 | parser = argparse.ArgumentParser( 31 | description='WikiTableT using PyTorch') 32 | parser.register('type', 'bool', str2bool) 33 | 34 | basic_group = parser.add_argument_group('basics') 35 | # Basics 36 | basic_group.add_argument('--debug', type="bool", default=False, 37 | help='activation of debug mode (default: False)') 38 | basic_group.add_argument('--auto_disconnect', type="bool", default=True, 39 | help='for slurm (default: True)') 40 | basic_group.add_argument('--save_prefix', type=str, default="experiments", 41 | help='saving path prefix') 42 | basic_group.add_argument('--gen_prefix', type=str, default="gen", 43 | help='generation saving file prefix') 44 | basic_group.add_argument('--gen_dir', type=str, default="gen", 45 | help='generation saving path directory') 46 | 47 | data_group = parser.add_argument_group('data') 48 | # Data file 49 | data_group.add_argument('--train_path', type=str, default=None, 50 | help='data file') 51 | data_group.add_argument('--vocab_file', type=str, default=None, 52 | help='vocabulary file') 53 | data_group.add_argument('--bpe_codes', type=str, default=None, 54 | help='bpe code file') 55 | data_group.add_argument('--bpe_vocab', type=str, default=None, 56 | help='bpe vocabulary file') 57 | data_group.add_argument('--dev_path', type=str, default=None, 58 | help='data file') 59 | data_group.add_argument('--test_path', type=str, default=None, 60 | help='data file') 61 | data_group.add_argument('--wikidata_path', type=str, default=None, 62 | help='data file') 63 | data_group.add_argument('--infobox_path', type=str, default=None, 64 | help='data file') 65 | data_group.add_argument('--max_num_value', type=int, default=None, 66 | help='max number of values per cell') 67 | 68 | config_group = parser.add_argument_group('model_configs') 69 | config_group.add_argument('-lr', '--learning_rate', 70 | dest='lr', 71 | type=float, 72 | default=1e-3, 73 | help='learning rate') 74 | config_group.add_argument('-dp', '--dropout', 75 | dest='dp', 76 | type=float, 77 | default=0.0, 78 | help='dropout rate') 79 | config_group.add_argument('-lratio', '--logloss_ratio', 80 | dest='lratio', 81 | type=float, 82 | default=1.0, 83 | help='ratio of log loss') 84 | config_group.add_argument('--eps', 85 | type=float, 86 | default=1e-5, 87 | help='safty for avoiding numerical issues') 88 | config_group.add_argument('-edim', '--embed_dim', 89 | dest='edim', 90 | type=int, default=300, 91 | help='size of embedding') 92 | config_group.add_argument('-gclip', '--grad_clip', 93 | dest='gclip', 94 | type=float, default=1.0, 95 | help='gradient clipping threshold') 96 | 97 | # recurrent neural network detail 98 | config_group.add_argument('-ensize', '--encoder_size', 99 | dest='ensize', 100 | type=int, default=512, 101 | help='encoder hidden size') 102 | config_group.add_argument('-desize', '--decoder_size', 103 | dest='desize', 104 | type=int, default=512, 105 | help='decoder hidden size') 106 | config_group.add_argument('-elayer', '--encoder_num_layer', 107 | dest='elayer', 108 | type=int, default=3, 109 | help='number of encoder layer') 110 | config_group.add_argument('-dlayer', '--decoder_num_layer', 111 | dest='dlayer', 112 | type=int, default=3, 113 | help='number of decoder layer') 114 | config_group.add_argument('-asize', '--attn_size', 115 | dest='asize', 116 | type=int, default=100, 117 | help='size of attention') 118 | config_group.add_argument('-bwdelayer', '--bwd_encoder_num_layer', 119 | dest='bwdelayer', 120 | type=int, default=2, 121 | help='number of encoder layer for backward models') 122 | config_group.add_argument('-bwdnhead', '--bwd_num_head', 123 | dest='bwdnhead', 124 | type=int, default=4, 125 | help='number of attention heads for backward models') 126 | config_group.add_argument('-bwddlayer', '--bwd_decoder_num_layer', 127 | dest='bwddlayer', 128 | type=int, default=2, 129 | help='number of decoder layer for backward models') 130 | 131 | # transformer 132 | config_group.add_argument('-act_fn', '--activation_function', 133 | dest='act_fn', 134 | type=str, default="gelu", 135 | help='types of activation function used in transformer model') 136 | config_group.add_argument('-nhead', '--num_head', 137 | dest='nhead', 138 | type=int, default=4, 139 | help='number of attention heads') 140 | 141 | # optimization 142 | config_group.add_argument('--l2', type=float, default=0., 143 | help='l2 regularization') 144 | config_group.add_argument('-wstep', '--warmup_steps', 145 | dest='wstep', type=int, default=0, 146 | help='learning rate warmup steps') 147 | config_group.add_argument('-lm', '--label_smoothing', 148 | dest='lm', type=float, default=0.0, 149 | help='label smoothing') 150 | config_group.add_argument('-bwd_lm', '--backward_label_smoothing', 151 | dest='bwd_lm', type=float, default=0.0, 152 | help='label smoothing') 153 | config_group.add_argument('-gcs', '--gradient_accumulation_steps', 154 | dest='gcs', type=int, default=1, 155 | help='gradient accumulation steps') 156 | config_group.add_argument('-tloss', '--true_cyclic_loss', 157 | dest='tloss', type=float, default=1.0, 158 | help='cyclic loss based on reference input') 159 | config_group.add_argument('-floss', '--fake_cyclic_loss', 160 | dest='floss', type=float, default=1.0, 161 | help='cyclic loss based on model input') 162 | 163 | setup_group = parser.add_argument_group('train_setup') 164 | # train detail 165 | setup_group.add_argument('--model_file', type=str, default=None, 166 | help='model save path') 167 | setup_group.add_argument('--save_dir', type=str, default=None, 168 | help='model save path') 169 | basic_group.add_argument('--encoder_type', 170 | type=str, default="transformer", 171 | help='types of encoder') 172 | basic_group.add_argument('--decoder_type', 173 | type=str, default="ctransformer", 174 | help='types of decoder') 175 | setup_group.add_argument('--n_epoch', type=int, default=5, 176 | help='number of epochs') 177 | setup_group.add_argument('--max_gen_len', type=int, default=200, 178 | help='maximum length for generation') 179 | setup_group.add_argument('--min_gen_len', type=int, default=0, 180 | help='minimum length for generation') 181 | setup_group.add_argument('--max_encoder_len', type=int, default=512, 182 | help='maximum input length for encoder') 183 | setup_group.add_argument('--max_decoder_len', type=int, default=512, 184 | help='maximum input length for encoder') 185 | setup_group.add_argument('--max_train_txt_len', type=int, default=500, 186 | help='maximum length for target text during training') 187 | setup_group.add_argument('--top_p', type=float, default=None, 188 | help='generation sampling') 189 | setup_group.add_argument('--top_k', type=int, default=None, 190 | help='generation sampling') 191 | setup_group.add_argument('--batch_size', type=int, default=20, 192 | help='batch size') 193 | setup_group.add_argument('--eval_batch_size', type=int, default=50, 194 | help='batch size') 195 | setup_group.add_argument('--opt', type=str, default='adam', 196 | choices=['sadam', 'adam', 'sgd', 'rmsprop'], 197 | help='types of optimizer: adam (default), \ 198 | sgd, rmsprop') 199 | setup_group.add_argument('--filter_ngram', type=int, default=0, 200 | help='filter ngram during beam search') 201 | setup_group.add_argument('--trigram_blocking', type="bool", default=False, 202 | help='filter ngram during beam search') 203 | setup_group.add_argument('--beam_size', type=int, default=10, 204 | help='size for beam search') 205 | setup_group.add_argument('--return_wikidata', type="bool", default=False, 206 | help='whether to mask predict wikidata') 207 | setup_group.add_argument('--return_hyperlink', type="bool", default=True, 208 | help='whether to mask predict hyperlink') 209 | setup_group.add_argument('--return_titles', type="bool", default=True, 210 | help='whether to mask predict titles') 211 | setup_group.add_argument('--input_wikidata', type="bool", default=True, 212 | help='whether to input wikidata') 213 | setup_group.add_argument('--input_hyperlink', type="bool", default=True, 214 | help='whether to input hyperlink') 215 | setup_group.add_argument('--use_copy', type="bool", default=False, 216 | help='whether to use copy mechanism') 217 | setup_group.add_argument('--use_fyl', type="bool", default=False, 218 | help='whether to use FY loss') 219 | setup_group.add_argument('-force_copy', '--force_copy', 220 | dest='force_copy', type="bool", default=False, 221 | help='whether to force copy from source in the copy mechanism') 222 | setup_group.add_argument('-share_decoder_embedding', '--share_decoder_embedding', 223 | dest='share_decoder_embedding', type="bool", default=False, 224 | help='whether to share embeddings in decoders') 225 | 226 | misc_group = parser.add_argument_group('misc') 227 | # misc 228 | misc_group.add_argument('--print_every', type=int, default=500, 229 | help='print training details after \ 230 | this number of iterations') 231 | misc_group.add_argument('--eval_every', type=int, default=5000, 232 | help='evaluate model after \ 233 | this number of iterations') 234 | misc_group.add_argument('--save_every', type=int, default=2000, 235 | help='save model after \ 236 | this number of iterations') 237 | return parser 238 | -------------------------------------------------------------------------------- /convert_to_parent_reference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import nltk 3 | import sys 4 | 5 | from glob import glob 6 | from tqdm import tqdm 7 | 8 | input_data_path = sys.argv[1] 9 | infobox_path = sys.argv[2] 10 | wikidata_path = sys.argv[3] 11 | output_path = sys.argv[4] 12 | 13 | dup = 0 14 | trunc_max_n_data = 0 15 | max_n_keys = 0 16 | max_n_data = 0 17 | max_n_qual = 0 18 | max_datatype_count = 0 19 | wikidata = {} 20 | if wikidata_path is not None: 21 | with open(wikidata_path) as fp: 22 | for nline, dataline in tqdm(enumerate(fp)): 23 | if dataline.strip(): 24 | datajson = json.loads(dataline.strip()) 25 | datalist = "" 26 | datapos = [] 27 | datatype = [] 28 | datamask = [] 29 | datatype_count = 0 30 | max_n_keys = max(max_n_keys, len(datajson["wikidata_details"])) 31 | for key in datajson["wikidata_details"]: 32 | all_data_for_key = datajson["wikidata_details"][key] 33 | 34 | for l in all_data_for_key: 35 | 36 | temp_value = " ".join( 37 | nltk.word_tokenize(key.replace("@@ ", ""))) 38 | prop_value = temp_value 39 | 40 | temp_value = " ".join( 41 | nltk.word_tokenize(l["data"].replace("@@ ", ""))) 42 | prop_value += "|||" + temp_value 43 | 44 | if "qualifiers" in l: 45 | max_n_qual = max(max_n_qual, len(l["qualifiers"])) 46 | for qual_key in l["qualifiers"]: 47 | temp_value = " ".join( 48 | nltk.word_tokenize( 49 | qual_key.replace("@@ ", ""))) 50 | prop_value += " " + temp_value 51 | temp_value = " ".join( 52 | nltk.word_tokenize( 53 | l["qualifiers"][qual_key].replace("@@ ", ""))) 54 | prop_value += " " + temp_value 55 | 56 | if prop_value.strip(): 57 | datalist += "\t" + " ".join(prop_value.split()) 58 | 59 | wikidata[datajson["wikidata_name"]] = datalist.strip() 60 | print("loaded #wikidata entries: {}".format(len(wikidata))) 61 | 62 | infobox = {} 63 | with open(infobox_path) as fp: 64 | for nline, dataline in enumerate(fp): 65 | if nline and nline % 50000 == 0: 66 | print("loading infobox #line: {}".format(nline)) 67 | if dataline.strip(): 68 | datajson = json.loads(dataline.strip()) 69 | 70 | datalist = "" 71 | max_n_keys = max(max_n_keys, len(datajson["infobox"])) 72 | for key in datajson["infobox"]: 73 | all_data_for_key = datajson["infobox"][key] 74 | max_n_data = max(max_n_data, len(all_data_for_key)) 75 | 76 | prop_value = " ".join(nltk.word_tokenize(key.replace("@@ ", ""))) 77 | 78 | temp_value = " ".join(nltk.word_tokenize(all_data_for_key.replace("@@ ", ""))) 79 | prop_value += "|||" + temp_value 80 | datalist += "\t" + prop_value 81 | 82 | if datalist: 83 | infobox[datajson["title"]] = datalist.strip() 84 | 85 | 86 | all_train_data = [] 87 | n_train_has_wikidata = 0 88 | max_datatype_count = 0 89 | max_datalist_len = 0 90 | max_st_len = 0 91 | max_dt_len = 0 92 | n_skip = 0 93 | n_break = 0 94 | fp_out = open(output_path, "w") 95 | for train_file in glob(input_data_path): 96 | with open(train_file) as fp: 97 | for nline, dataline in enumerate(fp): 98 | if nline and nline % 100000 == 0: 99 | print("loading input file #line: {}".format(nline)) 100 | if dataline.strip(): 101 | datajson = json.loads(dataline.strip()) 102 | 103 | wikidata_datalist = "" 104 | infobox_datalist = "" 105 | if datajson["doc_title"] in wikidata: 106 | wikidata_datalist = wikidata[datajson["doc_title"]] 107 | n_train_has_wikidata += 1 108 | if datajson["doc_title"] in infobox: 109 | infobox_datalist = infobox[datajson["doc_title"]] 110 | datatype_count += 1 111 | 112 | templist = "document title" 113 | doc_datalist = templist 114 | 115 | templist = " ".join(nltk.word_tokenize(datajson["doc_title"])) 116 | doc_datalist += "|||" + templist 117 | 118 | datalist = "" 119 | masklist = [] 120 | datapos = [] 121 | datatype = [] 122 | 123 | for d in datajson["data"]: 124 | templist = " ".join(nltk.word_tokenize(d[0].replace("@@ ", ""))) 125 | curr_datalist = "\t" + templist 126 | 127 | templist = " ".join(nltk.word_tokenize(d[1].replace("@@ ", ""))) 128 | curr_datalist += "|||" + templist 129 | 130 | datalist += curr_datalist 131 | 132 | templist = "section title" 133 | section_datalist = templist 134 | 135 | templist = " ".join(nltk.word_tokenize(datajson["sec_title"][0].replace("@@ ", ""))) 136 | section_datalist += "|||" + templist 137 | for s in datajson["sec_title"][1:]: 138 | 139 | templist = " ".join(nltk.word_tokenize(s.replace("@@ ", ""))) 140 | section_datalist += "\t" + "section title" + "|||" + templist 141 | 142 | write_line = "" 143 | if doc_datalist.strip(): 144 | write_line += doc_datalist.strip() 145 | if section_datalist.strip(): 146 | write_line += "\t" + section_datalist.strip() 147 | if wikidata_datalist.strip(): 148 | write_line += "\t" + wikidata_datalist.strip() 149 | if datalist.strip(): 150 | write_line += "\t" + datalist.strip() 151 | if infobox_datalist.strip(): 152 | write_line += "\t" + infobox_datalist.strip() 153 | fp_out.write(write_line + "\n") 154 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import json 4 | import statistics 5 | 6 | import numpy as np 7 | 8 | from glob import glob 9 | from subword_nmt.apply_bpe import BPE, read_vocabulary 10 | from decorators import auto_init_args 11 | from config import UNK_IDX, UNK_WORD, BOS_IDX, EOS_IDX, \ 12 | BOC_IDX, BOV_IDX, BOQK_IDX, BOQV_IDX, SUBSEC_IDX, \ 13 | EOC_IDX, MAX_VALUE_LEN, MASK_IDX 14 | 15 | 16 | class DataHolder: 17 | @auto_init_args 18 | def __init__(self, train_data, dev_data, test_data, vocab, bpe): 19 | self.inv_vocab = {i: w for w, i in vocab.items()} 20 | 21 | 22 | class DataProcessor: 23 | @auto_init_args 24 | def __init__(self, train_path, dev_path, test_path, 25 | wikidata_path, infobox_path, 26 | bpe_vocab, bpe_codes, experiment): 27 | self.expe = experiment 28 | bpe_vocab_dict = read_vocabulary(open(bpe_vocab), 50) 29 | self.bpe = BPE(open(bpe_codes), -1, "@@", bpe_vocab_dict, None) 30 | 31 | def process(self): 32 | 33 | vocab = {UNK_WORD: UNK_IDX, "": BOS_IDX, "": EOS_IDX, 34 | "": BOC_IDX, "": BOV_IDX, "": BOQK_IDX, 35 | "": BOQV_IDX, "": SUBSEC_IDX, "": EOC_IDX, 36 | "": MASK_IDX} 37 | 38 | with open(self.bpe_vocab) as fp: 39 | for line in fp: 40 | w, f = line.strip().split() 41 | if int(f) > 50: 42 | vocab[w] = len(vocab) 43 | 44 | self.expe.log.info("vocab size: {}".format(len(vocab))) 45 | 46 | train_data, dev_data, test_data = self._load_sent_and_data( 47 | self.train_path, self.dev_path, self.test_path, 48 | self.wikidata_path, self.infobox_path, vocab) 49 | 50 | def cal_stats(data): 51 | data_unk_count = 0 52 | data_total_count = 0 53 | text_unk_count = 0 54 | text_total_count = 0 55 | leng_text = [] 56 | leng_hyp_data = [] 57 | leng_wiki_data = [] 58 | leng_sec_data = [] 59 | leng_doc_data = [] 60 | for d in data: 61 | assert len(d["hyperlink_data"]) == 5, \ 62 | "d['hyperlink_data'] = {} != 5"\ 63 | .format(len(d["hyperlink_data"])) 64 | leng_text.append(len(d["text"])) 65 | leng_hyp_data.append(len(d["hyperlink_data"][0])) 66 | leng_wiki_data.append(len(d["wikidata"][0])) 67 | leng_sec_data.append(len(d["sec_title"][0])) 68 | leng_doc_data.append(len(d["doc_title"][0])) 69 | for word in d["hyperlink_data"][0]: 70 | if word == UNK_IDX: 71 | data_unk_count += 1 72 | data_total_count += 1 73 | for word in d["wikidata"][0]: 74 | if word == UNK_IDX: 75 | data_unk_count += 1 76 | data_total_count += 1 77 | for word in d["text"]: 78 | if word == UNK_IDX: 79 | text_unk_count += 1 80 | text_total_count += 1 81 | return (data_unk_count, data_total_count, 82 | data_unk_count / data_total_count * 100 83 | if data_total_count else 0), \ 84 | (text_unk_count, text_total_count, 85 | text_unk_count / text_total_count * 100 86 | if text_total_count else 0), \ 87 | (len(leng_hyp_data), max(leng_hyp_data), 88 | min(leng_hyp_data), sum(leng_hyp_data) / len(leng_hyp_data) 89 | if len(leng_hyp_data) else 0, 90 | statistics.median(leng_hyp_data) 91 | ), \ 92 | (len(leng_wiki_data), max(leng_wiki_data), 93 | min(leng_wiki_data), 94 | sum(leng_wiki_data) / len(leng_wiki_data) 95 | if len(leng_wiki_data) else 0, 96 | statistics.median(leng_wiki_data)), \ 97 | (len(leng_sec_data), 98 | max(leng_sec_data), min(leng_sec_data), 99 | sum(leng_sec_data) / len(leng_sec_data) 100 | if len(leng_sec_data) else 0, 101 | statistics.median(leng_sec_data)), \ 102 | (len(leng_doc_data), max(leng_doc_data), 103 | min(leng_doc_data), sum(leng_doc_data) / len(leng_doc_data) 104 | if len(leng_doc_data) else 0, 105 | statistics.median(leng_doc_data)), \ 106 | (len(leng_text), max(leng_text), 107 | min(leng_text), sum(leng_text) / len(leng_text) 108 | if len(leng_text) else 0, 109 | statistics.median(leng_text) 110 | ) 111 | 112 | data_unk_stats, text_unk_stats, \ 113 | data_hyp_len_stats, text_wiki_len_stats, \ 114 | text_sec_len_stats, leng_doc_data, \ 115 | text_len_stats = cal_stats(train_data) 116 | self.expe.log.info("#train hyp data: {}, max len: {}, " 117 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 118 | .format(*data_hyp_len_stats)) 119 | 120 | self.expe.log.info("#train wiki data: {}, max len: {}, " 121 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 122 | .format(*text_wiki_len_stats)) 123 | 124 | self.expe.log.info("#train sec data: {}, max len: {}, " 125 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 126 | .format(*text_sec_len_stats)) 127 | 128 | self.expe.log.info("#train doc data: {}, max len: {}, " 129 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 130 | .format(*leng_doc_data)) 131 | 132 | self.expe.log.info("#train text: {}, max len: {}, " 133 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 134 | .format(*text_len_stats)) 135 | 136 | self.expe.log.info( 137 | "#train data unk: {}, {}, {:.4f}%".format(*data_unk_stats)) 138 | self.expe.log.info( 139 | "#train text unk: {}, {}, {:.4f}%".format(*text_unk_stats)) 140 | 141 | self.expe.log.info("*" * 50) 142 | 143 | data_unk_stats, text_unk_stats, \ 144 | data_hyp_len_stats, text_wiki_len_stats, \ 145 | text_sec_len_stats, leng_doc_data, \ 146 | text_len_stats = cal_stats(dev_data) 147 | 148 | self.expe.log.info("#dev hyp data: {}, max len: {}, " 149 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 150 | .format(*data_hyp_len_stats)) 151 | 152 | self.expe.log.info("#dev wiki data: {}, max len: {}, " 153 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 154 | .format(*text_wiki_len_stats)) 155 | 156 | self.expe.log.info("#dev sec data: {}, max len: {}, " 157 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 158 | .format(*text_sec_len_stats)) 159 | 160 | self.expe.log.info("#dev doc data: {}, max len: {}, " 161 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 162 | .format(*leng_doc_data)) 163 | 164 | self.expe.log.info("#dev text: {}, max len: {}, " 165 | "min len: {}, avg len: {:.2f}, median len: {:.2f}" 166 | .format(*text_len_stats)) 167 | 168 | self.expe.log.info( 169 | "#dev data unk: {}, {}, {:.4f}%".format(*data_unk_stats)) 170 | self.expe.log.info( 171 | "#dev text unk: {}, {}, {:.4f}%".format(*text_unk_stats)) 172 | 173 | self.expe.log.info("*" * 50) 174 | data = DataHolder( 175 | train_data=np.array(train_data), 176 | dev_data=np.array(dev_data), 177 | test_data=None, 178 | vocab=vocab, 179 | bpe=self.bpe) 180 | 181 | return data 182 | 183 | def _load_sent_and_data(self, train_path, dev_path, 184 | test_path, wikidata_path, infobox_path, vocab): 185 | infobox = {} 186 | trunc_max_n_data = 0 187 | max_n_keys = 0 188 | max_n_data = 0 189 | max_datatype_count = 0 190 | n_skip = 0 191 | 192 | if infobox_path is not None: 193 | with open(infobox_path) as fp: 194 | for nline, dataline in enumerate(fp): 195 | if nline and nline % 50000 == 0: 196 | self.expe.log.info( 197 | "loading infobox #line: {}".format(nline)) 198 | if dataline.strip(): 199 | datajson = json.loads(dataline.strip()) 200 | 201 | datalist = [] 202 | datapos = [] 203 | datatype = [] 204 | datamask = [] 205 | datasrc = [] 206 | datatype_count = 0 207 | 208 | max_n_keys = max(max_n_keys, len(datajson["infobox"])) 209 | for key in datajson["infobox"]: 210 | all_data_for_key = datajson["infobox"][key] 211 | max_n_data = max(max_n_data, len(all_data_for_key)) 212 | trunc_max_n_data = max(trunc_max_n_data, 213 | len(all_data_for_key)) 214 | 215 | prop_value = [vocab[""]] 216 | prop_mask = [vocab[""]] 217 | prop_src = [""] 218 | 219 | temp_value = [vocab.get(w, 0) for w in 220 | key.split()[:MAX_VALUE_LEN]] 221 | prop_value += temp_value 222 | prop_mask += [vocab[""]] * len(temp_value) 223 | prop_src += key.split()[:MAX_VALUE_LEN] 224 | 225 | prop_value += [vocab[""]] 226 | prop_mask += [vocab[""]] 227 | prop_src += [""] 228 | 229 | temp_value = [vocab.get(w, 0) for w in 230 | all_data_for_key.split() 231 | [:MAX_VALUE_LEN]] 232 | prop_value += temp_value 233 | prop_mask += [vocab[""]] * len(temp_value) 234 | prop_src += \ 235 | all_data_for_key.split()[:MAX_VALUE_LEN] 236 | 237 | if len(datalist) + len(prop_value) > 350 or \ 238 | datatype_count > 300: 239 | continue 240 | datamask += prop_mask 241 | datalist += prop_value 242 | datapos += list(range(len(prop_value))) 243 | datatype += [datatype_count] * len(prop_value) 244 | datasrc += prop_src 245 | datatype_count += 1 246 | max_datatype_count = \ 247 | max(max_datatype_count, datatype_count) 248 | assert len(datalist) == len(datamask), \ 249 | "{} != {}".format(len(datalist), len(datamask)) 250 | assert len(datalist) == len(datasrc), \ 251 | "{} != {}".format(len(datalist), len(datasrc)) 252 | if datalist: 253 | infobox[datajson["title"]] = \ 254 | (datalist, datapos, datatype, 255 | datamask, datasrc) 256 | else: 257 | n_skip += 1 258 | self.expe.log.info("loaded #infobox entries: {}, " 259 | "truncated max #data: {}, " 260 | "non-truncated max #data: {}, " 261 | "max #keys: {}, " 262 | "max datatype count: {}, " 263 | "#skip: {}" 264 | .format(len(infobox), trunc_max_n_data, 265 | max_n_data, max_n_keys, 266 | max_datatype_count - 1, 267 | n_skip)) 268 | 269 | dup = 0 270 | trunc_max_n_data = 0 271 | max_n_keys = 0 272 | max_n_data = 0 273 | max_n_qual = 0 274 | max_datatype_count = 0 275 | wikidata = {} 276 | n_skip = 0 277 | if wikidata_path is not None: 278 | with open(wikidata_path) as fp: 279 | for nline, dataline in enumerate(fp): 280 | if nline and nline % 100000 == 0: 281 | self.expe.log.info( 282 | "loading wikidata #line: {}".format(nline)) 283 | if dataline.strip(): 284 | datajson = json.loads(dataline.strip()) 285 | datalist = [] 286 | datapos = [] 287 | datatype = [] 288 | datamask = [] 289 | datasrc = [] 290 | datatype_count = 0 291 | 292 | max_n_keys = max(max_n_keys, 293 | len(datajson["wikidata_details"])) 294 | for key in datajson["wikidata_details"]: 295 | all_data_for_key = \ 296 | datajson["wikidata_details"][key] 297 | max_n_data = max(max_n_data, len(all_data_for_key)) 298 | if self.expe.config.max_num_value is not None: 299 | all_data_for_key = \ 300 | all_data_for_key[:self.expe.config.max_num_value] 301 | trunc_max_n_data = max(trunc_max_n_data, 302 | len(all_data_for_key)) 303 | for l in all_data_for_key: 304 | prop_value = [vocab[""]] 305 | prop_mask = [vocab[""]] 306 | prop_src = [""] 307 | 308 | temp_value = [vocab.get(w, 0) for w in 309 | key.split()[:MAX_VALUE_LEN]] 310 | prop_value += temp_value 311 | prop_mask += [vocab[""]] * \ 312 | len(temp_value) 313 | prop_src += key.split()[:MAX_VALUE_LEN] 314 | 315 | prop_value += [vocab[""]] 316 | prop_mask += [vocab[""]] 317 | prop_src += [""] 318 | 319 | temp_value = \ 320 | [vocab.get(w, 0) for w in 321 | l["data"].split()[:MAX_VALUE_LEN]] 322 | prop_value += temp_value 323 | prop_mask += [vocab[""]] * \ 324 | len(temp_value) 325 | prop_src += l["data"].split()[:MAX_VALUE_LEN] 326 | 327 | if "qualifiers" in l: 328 | max_n_qual = max(max_n_qual, 329 | len(l["qualifiers"])) 330 | for qual_key in l["qualifiers"]: 331 | prop_value += [vocab[""]] 332 | prop_mask += [vocab[""]] 333 | prop_src += [""] 334 | 335 | temp_value = \ 336 | [vocab.get(w, 0) for w in 337 | qual_key.split()[:MAX_VALUE_LEN]] 338 | prop_value += temp_value 339 | prop_mask += [vocab[""]] * \ 340 | len(temp_value) 341 | prop_src += \ 342 | qual_key.split()[:MAX_VALUE_LEN] 343 | 344 | prop_value += [vocab[""]] 345 | prop_mask += [vocab[""]] 346 | prop_src += [vocab[""]] 347 | 348 | temp_value = \ 349 | [vocab.get(w, 0) for w in 350 | l["qualifiers"][qual_key] 351 | .split()[:MAX_VALUE_LEN]] 352 | prop_value += temp_value 353 | prop_mask += [vocab[""]] * \ 354 | len(temp_value) 355 | prop_src += l["qualifiers"][qual_key]\ 356 | .split()[:MAX_VALUE_LEN] 357 | 358 | prop_value += [vocab[""]] 359 | prop_mask += [vocab[""]] 360 | prop_src += [""] 361 | if len(datalist) + len(prop_value) > 350 \ 362 | or datatype_count > 300: 363 | continue 364 | datamask += prop_mask 365 | datalist += prop_value 366 | datapos += list(range(len(prop_value))) 367 | datatype += [datatype_count] * len(prop_value) 368 | datasrc += prop_src 369 | datatype_count += 1 370 | max_datatype_count = \ 371 | max(max_datatype_count, datatype_count) 372 | assert len(datalist) == len(datamask), \ 373 | "{} != {}".format(len(datalist), len(datamask)) 374 | assert len(datalist) == len(datasrc), \ 375 | "{} != {}".format(len(datalist), len(datasrc)) 376 | 377 | if datalist: 378 | wikidata[datajson["wikidata_name"]] = \ 379 | (datalist, datapos, datatype, 380 | datamask, datasrc) 381 | else: 382 | n_skip += 1 383 | self.expe.log.info("loaded #wikidata entries: {}, " 384 | "found #duplicate entries: {}, " 385 | "truncated max #data: {}, " 386 | "non-truncated max #data: {}, " 387 | "max #qual: {}, max #keys: {}, " 388 | "max datatype count: {}, " 389 | "#skip: {}" 390 | .format(len(wikidata), dup, trunc_max_n_data, 391 | max_n_data, max_n_qual, 392 | max_n_keys, max_datatype_count - 1, 393 | n_skip)) 394 | 395 | all_train_data = [] 396 | n_train_has_wikidata = 0 397 | n_train_has_infobox = 0 398 | n_train_has_infobox_wikidata = 0 399 | 400 | max_datatype_count = 0 401 | max_datalist_len = 0 402 | max_st_len = 0 403 | max_dt_len = 0 404 | n_skip = 0 405 | n_break = 0 406 | for train_file in glob(train_path): 407 | with open(train_file) as fp: 408 | for nline, dataline in enumerate(fp): 409 | if nline and nline % 100000 == 0: 410 | self.expe.log.info( 411 | "loading train file #line: {}".format(nline)) 412 | if dataline.strip(): 413 | datajson = json.loads(dataline.strip()) 414 | 415 | src_vocab = {UNK_WORD: 0} 416 | wikidata_datalist = [] 417 | wikidata_datapos = [] 418 | wikidata_datatype = [] 419 | wikidata_datamask = [] 420 | wikidata_datasrc = [] 421 | 422 | infobox_datalist = [] 423 | infobox_datapos = [] 424 | infobox_datatype = [] 425 | infobox_datamask = [] 426 | infobox_datasrc = [] 427 | 428 | datatype_count = -1 429 | 430 | if datajson["doc_title"] in wikidata: 431 | # if wikidata_path is not None: 432 | wikidata_datalist, wikidata_datapos, \ 433 | wikidata_datatype, wikidata_datamask, \ 434 | wikidata_datasrc = \ 435 | wikidata[datajson["doc_title"]] 436 | datatype_count = wikidata_datatype[-1] 437 | 438 | n_train_has_wikidata += 1 439 | 440 | if datajson["doc_title"] in infobox: 441 | infobox_datalist, infobox_datapos, \ 442 | infobox_datatype, infobox_datamask, \ 443 | infobox_datasrc = \ 444 | infobox[datajson["doc_title"]] 445 | 446 | if datatype_count != -1: 447 | n_train_has_infobox_wikidata += 1 448 | 449 | datatype_count += 1 450 | infobox_datatype = \ 451 | [idx + datatype_count 452 | for idx in infobox_datatype] 453 | datatype_count = infobox_datatype[-1] 454 | n_train_has_infobox += 1 455 | 456 | wikidata_datasrc_ids = [] 457 | for w in wikidata_datasrc: 458 | if w not in src_vocab: 459 | src_vocab[w] = len(src_vocab) 460 | wikidata_datasrc_ids.append(src_vocab[w]) 461 | 462 | infobox_datasrc_ids = [] 463 | for w in infobox_datasrc: 464 | if w not in src_vocab: 465 | src_vocab[w] = len(src_vocab) 466 | infobox_datasrc_ids.append(src_vocab[w]) 467 | 468 | datatype_count += 1 469 | doc_datalist = [vocab[""]] 470 | doc_masklist = [vocab[""]] 471 | if "" not in src_vocab: 472 | src_vocab[""] = len(src_vocab) 473 | doc_datasrc = [src_vocab[""]] 474 | 475 | templist = [] 476 | for w in self.bpe.process_line("document title")\ 477 | .split()[:MAX_VALUE_LEN]: 478 | if w not in src_vocab: 479 | src_vocab[w] = len(src_vocab) 480 | templist.append(vocab.get(w, 0)) 481 | doc_datasrc.append(src_vocab[w]) 482 | 483 | doc_datalist += templist 484 | doc_masklist += [vocab[""]] * len(templist) 485 | 486 | doc_datalist += [vocab[""]] 487 | doc_masklist += [vocab[""]] 488 | if "" not in src_vocab: 489 | src_vocab[""] = len(src_vocab) 490 | doc_datasrc += [src_vocab[""]] 491 | 492 | templist = [] 493 | for w in datajson["doc_title_bpe"]\ 494 | .split()[:MAX_VALUE_LEN]: 495 | if w not in src_vocab: 496 | src_vocab[w] = len(src_vocab) 497 | templist.append(vocab.get(w, 0)) 498 | doc_datasrc.append(src_vocab[w]) 499 | doc_datalist += templist 500 | doc_masklist += [vocab[""]] * len(templist) 501 | 502 | doc_datalist += [vocab[""]] 503 | doc_masklist += [vocab[""]] 504 | if "" not in src_vocab: 505 | src_vocab[""] = len(src_vocab) 506 | doc_datasrc += [src_vocab[""]] 507 | doc_datapos = list(range(len(doc_datalist))) 508 | doc_datatype = [datatype_count] * len(doc_datalist) 509 | 510 | datalist = [] 511 | masklist = [] 512 | datapos = [] 513 | datatype = [] 514 | datasrc = [] 515 | 516 | datatype_count += 1 517 | for d in datajson["data"]: 518 | curr_datalist = [vocab[""]] 519 | curr_masklist = [vocab[""]] 520 | curr_datasrc = [src_vocab[""]] 521 | 522 | templist = [] 523 | for w in d[0].split()[:MAX_VALUE_LEN]: 524 | if w not in src_vocab: 525 | src_vocab[w] = len(src_vocab) 526 | templist.append(vocab.get(w, 0)) 527 | curr_datasrc.append(src_vocab[w]) 528 | curr_datalist += templist 529 | curr_masklist += [vocab[""]] * len(templist) 530 | 531 | curr_datalist += [vocab[""]] 532 | curr_masklist += [vocab[""]] 533 | curr_datasrc += [src_vocab[""]] 534 | 535 | templist = [] 536 | for w in d[1].split()[:MAX_VALUE_LEN]: 537 | if w not in src_vocab: 538 | src_vocab[w] = len(src_vocab) 539 | templist.append(vocab.get(w, 0)) 540 | curr_datasrc.append(src_vocab[w]) 541 | curr_datalist += templist 542 | curr_masklist += [vocab[""]] * len(templist) 543 | 544 | curr_datalist += [vocab[""]] 545 | curr_masklist += [vocab[""]] 546 | if "" not in src_vocab: 547 | src_vocab[""] = len(src_vocab) 548 | curr_datasrc += [src_vocab[""]] 549 | 550 | if len(doc_datalist) + len(curr_datalist) + \ 551 | len(datalist) + len(wikidata_datalist) + \ 552 | len(infobox_datalist) > 500: 553 | continue 554 | if datatype_count + 1 >= 499: 555 | continue 556 | curr_datapos = list(range(len(curr_datalist))) 557 | curr_datatype = \ 558 | [datatype_count] * len(curr_datalist) 559 | 560 | datatype_count += 1 561 | 562 | masklist += curr_masklist 563 | datalist += curr_datalist 564 | datapos += curr_datapos 565 | datatype += curr_datatype 566 | datasrc += curr_datasrc 567 | 568 | section_datalist, section_datapos, \ 569 | section_datatype, section_masklist, \ 570 | section_datasrc = [], [], [], [], [] 571 | if datajson["sec_title"]: 572 | section_masklist = [vocab[""]] 573 | section_datalist = [vocab[""]] 574 | section_datasrc = [src_vocab[""]] 575 | 576 | templist = [] 577 | for w in self.bpe.process_line("section title")\ 578 | .split()[:MAX_VALUE_LEN]: 579 | if w not in src_vocab: 580 | src_vocab[w] = len(src_vocab) 581 | templist.append(vocab.get(w, 0)) 582 | section_datasrc.append(src_vocab[w]) 583 | section_datalist += templist 584 | section_masklist += [vocab[""]] * \ 585 | len(templist) 586 | 587 | section_masklist += [vocab[""]] 588 | section_datalist += [vocab[""]] 589 | section_datasrc += [src_vocab[""]] 590 | 591 | templist = [] 592 | for w in datajson["sec_title"][0]\ 593 | .split()[:MAX_VALUE_LEN]: 594 | if w not in src_vocab: 595 | src_vocab[w] = len(src_vocab) 596 | templist.append(vocab.get(w, 0)) 597 | section_datasrc.append(src_vocab[w]) 598 | 599 | if "" not in src_vocab: 600 | src_vocab[""] = len(src_vocab) 601 | section_datalist += templist 602 | section_masklist += [vocab[""]] * \ 603 | len(templist) 604 | for s in datajson["sec_title"][1:]: 605 | 606 | section_datalist += [vocab[""]] 607 | section_masklist += [vocab[""]] 608 | section_datasrc += [src_vocab[""]] 609 | 610 | templist = [] 611 | for w in s.split()[:MAX_VALUE_LEN]: 612 | if w not in src_vocab: 613 | src_vocab[w] = len(src_vocab) 614 | templist.append(vocab.get(w, 0)) 615 | section_datasrc.append(src_vocab[w]) 616 | section_datalist += templist 617 | section_masklist += [vocab[""]] * \ 618 | len(templist) 619 | 620 | section_datalist += [vocab[""]] 621 | section_masklist += [vocab[""]] 622 | section_datasrc += [src_vocab[""]] 623 | section_datapos = \ 624 | list(range(len(section_datalist))) 625 | section_datatype = \ 626 | [datatype_count] * len(section_datalist) 627 | 628 | if len(section_datalist) + len(doc_datalist) + \ 629 | len(datalist) + len(wikidata_datalist) + \ 630 | len(infobox_datalist) > 1000 \ 631 | or datatype_count >= 500: 632 | n_skip += 1 633 | continue 634 | max_datatype_count = \ 635 | max(max_datatype_count, datatype_count) 636 | max_datalist_len = \ 637 | max(max_datalist_len, 638 | len(datalist) + len(wikidata_datalist) + len(doc_datalist) 639 | ) 640 | max_st_len = max(max_st_len, len(section_datalist)) 641 | max_dt_len = max(max_dt_len, len(doc_datalist)) 642 | assert len(doc_datalist) == \ 643 | len(doc_masklist), "{} != {}".format( 644 | len(doc_datalist), len(doc_masklist)) 645 | assert len(doc_datalist) == len(doc_datasrc), \ 646 | "{} != {}".format( 647 | len(doc_datalist), len(doc_datasrc)) 648 | assert len(section_datalist) == len(section_masklist),\ 649 | "{} != {}".format( 650 | len(section_datalist), len(section_masklist)) 651 | assert len(section_datalist) == len(section_datasrc), \ 652 | "{} != {}".format( 653 | len(section_datalist), len(section_datasrc)) 654 | assert len(datalist) == len(masklist), \ 655 | "{} != {}".format(len(datalist), len(masklist)) 656 | assert len(datalist) == len(datasrc), \ 657 | "{} != {}".format(len(datalist), len(datasrc)) 658 | all_train_data.append( 659 | {"idx": len(all_train_data), 660 | "doc_title": (doc_datalist, doc_datapos, 661 | doc_datatype, doc_masklist, 662 | doc_datasrc), 663 | "sec_title": (section_datalist, 664 | section_datapos, section_datatype, 665 | section_masklist, section_datasrc), 666 | "wikidata": (wikidata_datalist, wikidata_datapos, 667 | wikidata_datatype, wikidata_datamask, 668 | wikidata_datasrc_ids), 669 | "hyperlink_data": (datalist, datapos, 670 | datatype, masklist, datasrc), 671 | "infobox": (infobox_datalist, infobox_datapos, 672 | infobox_datatype, infobox_datamask, 673 | infobox_datasrc_ids), 674 | "src_vocab": src_vocab, 675 | "inv_src_vocab": {i: str(w) for w, i 676 | in src_vocab.items()}, 677 | "text_src_vocab": [src_vocab.get(w, 0) for w 678 | in datajson["text"].split()[:self.expe.config.max_train_txt_len]], 679 | "text": [vocab.get(w, 0) for w 680 | in datajson["text"].split()[:self.expe.config.max_train_txt_len]] 681 | } 682 | ) 683 | 684 | self.expe.log.info( 685 | "loaded #train: {}, #has wikidata: {} ({:.2f}%), " 686 | "#has infobox: {} ({:.2f}%), #has infobox&wikidata: {} ({:.2f}%)" 687 | .format(len(all_train_data), n_train_has_wikidata, 688 | n_train_has_wikidata / len(all_train_data) * 100 689 | if len(all_train_data) else 0, 690 | n_train_has_infobox, 691 | n_train_has_infobox / len(all_train_data) * 100 692 | if len(all_train_data) else 0, 693 | n_train_has_infobox_wikidata, 694 | n_train_has_infobox_wikidata / len(all_train_data) * 100 695 | if len(all_train_data) else 0) 696 | ) 697 | self.expe.log.info( 698 | "#skip: {}, max datatype count: {}, " 699 | "max datalist len: {}, " 700 | "max sec title len: {}, " 701 | "max doc title len: {}" 702 | .format(n_skip, max_datatype_count, 703 | max_datalist_len, max_st_len, max_dt_len) 704 | ) 705 | 706 | all_dev_data = [] 707 | n_dev_has_wikidata = 0 708 | n_dev_has_infobox = 0 709 | n_dev_has_infobox_wikidata = 0 710 | 711 | max_datatype_count = 0 712 | max_datalist_len = 0 713 | max_st_len = 0 714 | max_dt_len = 0 715 | n_skip = 0 716 | for train_file in glob(dev_path): 717 | with open(train_file) as fp: #pylint: disable=C0103 718 | for nline, dataline in enumerate(fp): 719 | if nline and nline % 100000 == 0: 720 | self.expe.log.info("loading train file #line: {}".format(nline)) 721 | if dataline.strip(): 722 | datajson = json.loads(dataline.strip()) 723 | 724 | src_vocab = {UNK_WORD: 0} 725 | wikidata_datalist = [] 726 | wikidata_datapos = [] 727 | wikidata_datatype = [] 728 | wikidata_datamask = [] 729 | wikidata_datasrc = [] 730 | 731 | infobox_datalist = [] 732 | infobox_datapos = [] 733 | infobox_datatype = [] 734 | infobox_datamask = [] 735 | infobox_datasrc = [] 736 | 737 | datatype_count = -1 738 | 739 | if datajson["doc_title"] in wikidata: 740 | # if wikidata_path is not None: 741 | wikidata_datalist, wikidata_datapos, \ 742 | wikidata_datatype, wikidata_datamask, \ 743 | wikidata_datasrc = \ 744 | wikidata[datajson["doc_title"]] 745 | datatype_count = wikidata_datatype[-1] 746 | 747 | n_train_has_wikidata += 1 748 | 749 | if datajson["doc_title"] in infobox: 750 | infobox_datalist, infobox_datapos, \ 751 | infobox_datatype, infobox_datamask, \ 752 | infobox_datasrc = \ 753 | infobox[datajson["doc_title"]] 754 | 755 | if datatype_count != -1: 756 | n_train_has_infobox_wikidata += 1 757 | 758 | datatype_count += 1 759 | infobox_datatype = \ 760 | [idx + datatype_count for idx 761 | in infobox_datatype] 762 | datatype_count = infobox_datatype[-1] 763 | n_train_has_infobox += 1 764 | 765 | wikidata_datasrc_ids = [] 766 | for w in wikidata_datasrc: 767 | if w not in src_vocab: 768 | src_vocab[w] = len(src_vocab) 769 | wikidata_datasrc_ids.append(src_vocab[w]) 770 | 771 | infobox_datasrc_ids = [] 772 | for w in infobox_datasrc: 773 | if w not in src_vocab: 774 | src_vocab[w] = len(src_vocab) 775 | infobox_datasrc_ids.append(src_vocab[w]) 776 | 777 | datatype_count += 1 778 | doc_datalist = [vocab[""]] 779 | doc_masklist = [vocab[""]] 780 | if "" not in src_vocab: 781 | src_vocab[""] = len(src_vocab) 782 | doc_datasrc = [src_vocab[""]] 783 | 784 | templist = [] 785 | for w in self.bpe.process_line("document title")\ 786 | .split()[:MAX_VALUE_LEN]: 787 | if w not in src_vocab: 788 | src_vocab[w] = len(src_vocab) 789 | templist.append(vocab.get(w, 0)) 790 | doc_datasrc.append(src_vocab[w]) 791 | 792 | doc_datalist += templist 793 | doc_masklist += [vocab[""]] * len(templist) 794 | 795 | doc_datalist += [vocab[""]] 796 | doc_masklist += [vocab[""]] 797 | if "" not in src_vocab: 798 | src_vocab[""] = len(src_vocab) 799 | doc_datasrc += [src_vocab[""]] 800 | 801 | templist = [] 802 | for w in datajson["doc_title_bpe"]\ 803 | .split()[:MAX_VALUE_LEN]: 804 | if w not in src_vocab: 805 | src_vocab[w] = len(src_vocab) 806 | templist.append(vocab.get(w, 0)) 807 | doc_datasrc.append(src_vocab[w]) 808 | doc_datalist += templist 809 | doc_masklist += [vocab[""]] * len(templist) 810 | 811 | doc_datalist += [vocab[""]] 812 | doc_masklist += [vocab[""]] 813 | if "" not in src_vocab: 814 | src_vocab[""] = len(src_vocab) 815 | doc_datasrc += [src_vocab[""]] 816 | doc_datapos = list(range(len(doc_datalist))) 817 | doc_datatype = [datatype_count] * len(doc_datalist) 818 | 819 | datalist = [] 820 | masklist = [] 821 | datapos = [] 822 | datatype = [] 823 | datasrc = [] 824 | 825 | datatype_count += 1 826 | for d in datajson["data"]: 827 | curr_datalist = [vocab[""]] 828 | curr_masklist = [vocab[""]] 829 | curr_datasrc = [src_vocab[""]] 830 | 831 | templist = [] 832 | for w in d[0].split()[:MAX_VALUE_LEN]: 833 | if w not in src_vocab: 834 | src_vocab[w] = len(src_vocab) 835 | templist.append(vocab.get(w, 0)) 836 | curr_datasrc.append(src_vocab[w]) 837 | curr_datalist += templist 838 | curr_masklist += [vocab[""]] * len(templist) 839 | 840 | curr_datalist += [vocab[""]] 841 | curr_masklist += [vocab[""]] 842 | curr_datasrc += [src_vocab[""]] 843 | 844 | templist = [] 845 | for w in d[1].split()[:MAX_VALUE_LEN]: 846 | if w not in src_vocab: 847 | src_vocab[w] = len(src_vocab) 848 | templist.append(vocab.get(w, 0)) 849 | curr_datasrc.append(src_vocab[w]) 850 | curr_datalist += templist 851 | curr_masklist += [vocab[""]] * len(templist) 852 | 853 | curr_datalist += [vocab[""]] 854 | curr_masklist += [vocab[""]] 855 | if "" not in src_vocab: 856 | src_vocab[""] = len(src_vocab) 857 | curr_datasrc += [src_vocab[""]] 858 | 859 | if len(doc_datalist) + len(curr_datalist) + \ 860 | len(datalist) + len(wikidata_datalist) + \ 861 | len(infobox_datalist) > 500: 862 | continue 863 | if datatype_count + 1 >= 499: 864 | continue 865 | curr_datapos = list(range(len(curr_datalist))) 866 | curr_datatype = \ 867 | [datatype_count] * len(curr_datalist) 868 | 869 | datatype_count += 1 870 | 871 | masklist += curr_masklist 872 | datalist += curr_datalist 873 | datapos += curr_datapos 874 | datatype += curr_datatype 875 | datasrc += curr_datasrc 876 | 877 | section_datalist, section_datapos, section_datatype, \ 878 | section_masklist, section_datasrc = \ 879 | [], [], [], [], [] 880 | if datajson["sec_title"]: 881 | section_masklist = [vocab[""]] 882 | section_datalist = [vocab[""]] 883 | section_datasrc = [src_vocab[""]] 884 | 885 | templist = [] 886 | for w in self.bpe.process_line("section title")\ 887 | .split()[:MAX_VALUE_LEN]: 888 | if w not in src_vocab: 889 | src_vocab[w] = len(src_vocab) 890 | templist.append(vocab.get(w, 0)) 891 | section_datasrc.append(src_vocab[w]) 892 | section_datalist += templist 893 | section_masklist += \ 894 | [vocab[""]] * len(templist) 895 | 896 | section_masklist += [vocab[""]] 897 | section_datalist += [vocab[""]] 898 | section_datasrc += [src_vocab[""]] 899 | 900 | templist = [] 901 | for w in datajson["sec_title"][0]\ 902 | .split()[:MAX_VALUE_LEN]: 903 | if w not in src_vocab: 904 | src_vocab[w] = len(src_vocab) 905 | templist.append(vocab.get(w, 0)) 906 | section_datasrc.append(src_vocab[w]) 907 | 908 | if "" not in src_vocab: 909 | src_vocab[""] = len(src_vocab) 910 | section_datalist += templist 911 | section_masklist += \ 912 | [vocab[""]] * len(templist) 913 | for s in datajson["sec_title"][1:]: 914 | 915 | section_datalist += [vocab[""]] 916 | section_masklist += [vocab[""]] 917 | section_datasrc += [src_vocab[""]] 918 | 919 | templist = [] 920 | for w in s.split()[:MAX_VALUE_LEN]: 921 | if w not in src_vocab: 922 | src_vocab[w] = len(src_vocab) 923 | templist.append(vocab.get(w, 0)) 924 | section_datasrc.append(src_vocab[w]) 925 | section_datalist += templist 926 | section_masklist += \ 927 | [vocab[""]] * len(templist) 928 | 929 | section_datalist += [vocab[""]] 930 | section_masklist += [vocab[""]] 931 | section_datasrc += [src_vocab[""]] 932 | section_datapos = \ 933 | list(range(len(section_datalist))) 934 | section_datatype = \ 935 | [datatype_count] * len(section_datalist) 936 | 937 | max_datatype_count = \ 938 | max(max_datatype_count, datatype_count) 939 | max_datalist_len = \ 940 | max(max_datalist_len, 941 | len(datalist) + len(wikidata_datalist) + len(doc_datalist) 942 | ) 943 | max_st_len = max(max_st_len, len(section_datalist)) 944 | max_dt_len = max(max_dt_len, len(doc_datalist)) 945 | assert len(doc_datalist) == len(doc_masklist), \ 946 | "{} != {}".format( 947 | len(doc_datalist), len(doc_masklist)) 948 | assert len(doc_datalist) == len(doc_datasrc), \ 949 | "{} != {}".format( 950 | len(doc_datalist), len(doc_datasrc)) 951 | assert len(section_datalist) == len(section_masklist),\ 952 | "{} != {}".format( 953 | len(section_datalist), len(section_masklist)) 954 | assert len(section_datalist) == len(section_datasrc), \ 955 | "{} != {}".format( 956 | len(section_datalist), len(section_datasrc)) 957 | assert len(datalist) == len(masklist), \ 958 | "{} != {}".format(len(datalist), len(masklist)) 959 | assert len(datalist) == len(datasrc), \ 960 | "{} != {}".format(len(datalist), len(datasrc)) 961 | all_dev_data.append( 962 | {"idx": len(all_dev_data), 963 | "doc_title": (doc_datalist, doc_datapos, 964 | doc_datatype, doc_masklist, 965 | doc_datasrc), 966 | "sec_title": (section_datalist, section_datapos, 967 | section_datatype, section_masklist, 968 | section_datasrc), 969 | "wikidata": (wikidata_datalist, wikidata_datapos, 970 | wikidata_datatype, wikidata_datamask, 971 | wikidata_datasrc_ids), 972 | "hyperlink_data": (datalist, datapos, datatype, 973 | masklist, datasrc), 974 | "infobox": (infobox_datalist, infobox_datapos, 975 | infobox_datatype, infobox_datamask, 976 | infobox_datasrc_ids), 977 | "src_vocab": src_vocab, 978 | "inv_src_vocab": {i: str(w) for w, i 979 | in src_vocab.items()}, 980 | "text_src_vocab": [src_vocab.get(w, 0) for w 981 | in datajson["text"].split()[:self.expe.config.max_train_txt_len]], 982 | "text": [vocab.get(w, 0) for w 983 | in datajson["text"].split()[:self.expe.config.max_train_txt_len]], 984 | "tok_text": " ".join( 985 | datajson.get("tokenized_text", "")), 986 | "untok_text": datajson["text"].replace("@@ ", "") 987 | } 988 | ) 989 | 990 | self.expe.log.info( 991 | "loaded #dev: {}, #has wikidata: {} ({:.2f}%), " 992 | "#has infobox: {} ({:.2f}%), #has infobox&wikidata: {} ({:.2f}%)" 993 | .format(len(all_dev_data), n_dev_has_wikidata, 994 | n_dev_has_wikidata / len(all_dev_data) * 100 995 | if len(all_dev_data) else 0, 996 | n_dev_has_infobox, 997 | n_dev_has_infobox / len(all_dev_data) * 100 998 | if len(all_dev_data) else 0, 999 | n_dev_has_infobox_wikidata, 1000 | n_dev_has_infobox_wikidata / len(all_dev_data) * 100 1001 | if len(all_dev_data) else 0) 1002 | ) 1003 | self.expe.log.info( 1004 | "#skip: {}, max datatype count: {}, " 1005 | "max datalist len: {}, " 1006 | "max sec title len: {}, " 1007 | "max doc title len: {}" 1008 | .format(n_skip, max_datatype_count, 1009 | max_datalist_len, max_st_len, max_dt_len) 1010 | ) 1011 | 1012 | all_test_data = [] 1013 | self.expe.log.info("loaded #test: {}".format(len(all_test_data))) 1014 | del wikidata 1015 | del infobox 1016 | return all_train_data, all_dev_data, all_test_data 1017 | 1018 | 1019 | class Minibatcher: 1020 | @auto_init_args 1021 | def __init__(self, data, save_dir, log, verbose, vocab_size, 1022 | filename, is_eval, batch_size, vocab, 1023 | return_wikidata, return_hyperlink, 1024 | input_wikidata, input_hyperlink, *args, **kwargs): 1025 | self._reset() 1026 | self.load(filename) 1027 | 1028 | def __len__(self): 1029 | return len(self.idx_pool) - self.init_pointer 1030 | 1031 | def save(self, filename="minibatcher.ckpt"): 1032 | path = os.path.join(self.save_dir, filename) 1033 | pickle.dump([self.pointer, self.idx_pool], open(path, "wb")) 1034 | if self.verbose: 1035 | self.log.info("minibatcher saved to: {}".format(path)) 1036 | 1037 | def load(self, filename="minibatcher.ckpt"): 1038 | if self.save_dir is not None: 1039 | path = os.path.join(self.save_dir, filename) 1040 | else: 1041 | path = None 1042 | if self.save_dir is not None and os.path.exists(path): 1043 | self.init_pointer, self.idx_pool = pickle.load(open(path, "rb")) 1044 | self.pointer = self.init_pointer 1045 | if self.verbose: 1046 | self.log.info("loaded minibatcher from {}, init pointer: {}" 1047 | .format(path, self.init_pointer)) 1048 | else: 1049 | if self.verbose: 1050 | self.log.info("no minibatcher found at {}".format(path)) 1051 | 1052 | def _reset(self): 1053 | self.pointer = 0 1054 | self.init_pointer = 0 1055 | idx_list = np.arange(len(self.data)) 1056 | if not self.is_eval: 1057 | np.random.shuffle(idx_list) 1058 | self.idx_pool = [idx_list[i: i + self.batch_size] 1059 | for i in range(0, len(self.data), self.batch_size)] 1060 | 1061 | def _pad(self, data): 1062 | max_text_len = max([len(d["text"]) for d in data]) 1063 | max_src_vocab_size = max([len(d["src_vocab"]) for d in data]) 1064 | if self.input_wikidata and self.input_hyperlink: 1065 | max_input_data_len = \ 1066 | max([len(d["sec_title"][0]) + 1067 | len(d["doc_title"][0]) + 1068 | len(d["hyperlink_data"][0]) + 1069 | len(d["wikidata"][0]) + 1070 | len(d["infobox"][0]) 1071 | for d in data]) 1072 | elif self.input_wikidata: 1073 | max_input_data_len = \ 1074 | max([len(d["sec_title"][0]) + 1075 | len(d["doc_title"][0]) + 1076 | len(d["wikidata"][0]) + 1077 | len(d["infobox"][0]) 1078 | for d in data]) 1079 | elif self.input_hyperlink: 1080 | max_input_data_len = \ 1081 | max([len(d["sec_title"][0]) + 1082 | len(d["doc_title"][0]) + 1083 | len(d["hyperlink_data"][0]) 1084 | for d in data]) 1085 | 1086 | if self.return_wikidata and self.return_hyperlink: 1087 | max_output_data_len = \ 1088 | max([len(d["sec_title"][0]) + 1089 | len(d["doc_title"][0]) + 1090 | len(d["hyperlink_data"][0]) + 1091 | len(d["wikidata"][0]) + 1092 | len(d["infobox"][0]) 1093 | for d in data]) 1094 | elif self.return_wikidata: 1095 | max_output_data_len = \ 1096 | max([len(d["sec_title"][0]) + 1097 | len(d["doc_title"][0]) + 1098 | len(d["wikidata"][0]) + 1099 | len(d["infobox"][0]) 1100 | for d in data]) 1101 | elif self.return_hyperlink: 1102 | max_output_data_len = \ 1103 | max([len(d["sec_title"][0]) + 1104 | len(d["doc_title"][0]) + 1105 | len(d["hyperlink_data"][0]) 1106 | for d in data]) 1107 | 1108 | input_data = \ 1109 | np.zeros((len(data), max_input_data_len)).astype("float32") 1110 | input_data_mask = \ 1111 | np.zeros((len(data), max_input_data_len)).astype("float32") 1112 | input_data_pos = \ 1113 | np.zeros((len(data), max_input_data_len)).astype("float32") 1114 | input_data_type = \ 1115 | np.zeros((len(data), max_input_data_len)).astype("float32") 1116 | input_if_hyp = \ 1117 | np.zeros((len(data), max_input_data_len)).astype("float32") 1118 | input_data_src_vocab = \ 1119 | np.zeros((len(data), max_input_data_len)).astype("float32") 1120 | input_data_src_tgt_vocab_map = \ 1121 | np.full((len(data), max_src_vocab_size), -1).astype("float32") 1122 | 1123 | tgt_inp_data = \ 1124 | np.zeros((len(data), max_output_data_len)).astype("float32") 1125 | tgt_inp_data_mask = \ 1126 | np.zeros((len(data), max_output_data_len)).astype("float32") 1127 | tgt_inp_data_pos = \ 1128 | np.zeros((len(data), max_output_data_len)).astype("float32") 1129 | tgt_inp_data_type = \ 1130 | np.zeros((len(data), max_output_data_len)).astype("float32") 1131 | tgt_inp_data_if_hyp = \ 1132 | np.zeros((len(data), max_output_data_len)).astype("float32") 1133 | 1134 | tgt_out_data = \ 1135 | np.zeros((len(data), max_output_data_len)).astype("float32") 1136 | tgt_out_data_mask = \ 1137 | np.zeros((len(data), max_output_data_len)).astype("float32") 1138 | 1139 | tgt_input = \ 1140 | np.zeros((len(data), max_text_len + 1)).astype("float32") 1141 | 1142 | tgt_label = \ 1143 | np.zeros((len(data), max_text_len + 1)).astype("float32") 1144 | tgt_label_src_vocab = \ 1145 | np.zeros((len(data), max_text_len + 1)).astype("float32") 1146 | tgt_mask = \ 1147 | np.zeros((len(data), max_text_len + 1)).astype("float32") 1148 | 1149 | def get_hyp_only(d, idx): 1150 | return d["hyperlink_data"][idx] + \ 1151 | d["doc_title"][idx] + \ 1152 | d["sec_title"][idx] 1153 | 1154 | def get_wikidata_only(d, idx): 1155 | return d["doc_title"][idx] + \ 1156 | d["sec_title"][idx] + \ 1157 | d["wikidata"][idx] + \ 1158 | d["infobox"][idx] 1159 | 1160 | def get_all(d, idx): 1161 | return d["hyperlink_data"][idx] + \ 1162 | d["doc_title"][idx] + \ 1163 | d["sec_title"][idx] + \ 1164 | d["wikidata"][idx] + \ 1165 | d["infobox"][idx] 1166 | 1167 | if self.input_wikidata and self.input_hyperlink: 1168 | get_inp_d = get_all 1169 | elif self.input_wikidata: 1170 | get_inp_d = get_wikidata_only 1171 | elif self.input_hyperlink: 1172 | get_inp_d = get_hyp_only 1173 | else: 1174 | raise ValueError( 1175 | "input_wikidata and input_hyperlink cannot both be False!") 1176 | 1177 | if self.return_wikidata and self.return_hyperlink: 1178 | get_ret_d = get_all 1179 | elif self.return_wikidata: 1180 | get_ret_d = get_wikidata_only 1181 | elif self.return_hyperlink: 1182 | get_ret_d = get_hyp_only 1183 | else: 1184 | raise ValueError( 1185 | "return_wikidata and return_hyperlink cannot both be False!") 1186 | 1187 | for i, d in enumerate(data): 1188 | input_data[i, :len(get_inp_d(d, 0))] = \ 1189 | np.asarray(list(get_inp_d(d, 0))).astype("float32") 1190 | input_data_mask[i, :len(get_inp_d(d, 0))] = 1. 1191 | input_data_pos[i, :len(get_inp_d(d, 1))] = \ 1192 | np.asarray(list(get_inp_d(d, 1))).astype("float32") 1193 | input_data_type[i, :len(get_inp_d(d, 2))] = \ 1194 | np.asarray(list(get_inp_d(d, 2))).astype("float32") 1195 | input_data_src_vocab[i, :len(get_inp_d(d, 4))] = \ 1196 | np.asarray(list(get_inp_d(d, 4))).astype("float32") 1197 | if self.input_hyperlink: 1198 | input_if_hyp[i, :len(d["hyperlink_data"][2])] = 1.0 1199 | 1200 | for word, ids in d["src_vocab"].items(): 1201 | input_data_src_tgt_vocab_map[i, ids] = self.vocab.get(word, 0) 1202 | 1203 | assert sum(input_data_src_tgt_vocab_map[i][len(d["src_vocab"]):]) == \ 1204 | -1 * len(input_data_src_tgt_vocab_map[i][len(d["src_vocab"]):]),\ 1205 | "{} != {}".format( 1206 | sum(input_data_src_tgt_vocab_map[i][len(d["src_vocab"]):]), 1207 | -1 * len(input_data_src_tgt_vocab_map[i][len(d["src_vocab"]):])) 1208 | assert sum((input_data_src_tgt_vocab_map[i][:len(d["src_vocab"])] != -1)) == \ 1209 | len(input_data_src_tgt_vocab_map[i][:len(d["src_vocab"])]), \ 1210 | "{} != {}".format( 1211 | sum((input_data_src_tgt_vocab_map[i][:len(d["src_vocab"])] != -1)), 1212 | len(input_data_src_tgt_vocab_map[i][:len(d["src_vocab"])])) 1213 | 1214 | tgt_mask_data = np.asarray(list(get_ret_d(d, 3))).astype("float32") 1215 | tgt_orig_data = np.asarray(list(get_ret_d(d, 0))).astype("float32") 1216 | mask_idx = tgt_mask_data == MASK_IDX 1217 | tgt_recon_data = np.where(mask_idx, tgt_orig_data, UNK_IDX) 1218 | tgt_out_data[i, :len(get_ret_d(d, 3))] = tgt_recon_data 1219 | tgt_out_data_mask[i, :len(get_ret_d(d, 3))] = mask_idx 1220 | 1221 | tgt_inp_data[i, :len(get_ret_d(d, 3))] = tgt_mask_data 1222 | tgt_inp_data_mask[i, :len(get_ret_d(d, 0))] = 1. 1223 | tgt_inp_data_pos[i, :len(get_ret_d(d, 1))] = \ 1224 | np.asarray(list(get_ret_d(d, 1))).astype("float32") 1225 | tgt_inp_data_type[i, :len(get_ret_d(d, 2))] = \ 1226 | np.asarray(list(get_ret_d(d, 2))).astype("float32") 1227 | 1228 | if self.return_hyperlink: 1229 | tgt_inp_data_if_hyp[i, :len(d["hyperlink_data"][2])] = 1.0 1230 | 1231 | tgt_input[i, :len(d["text"]) + 1] = \ 1232 | np.asarray([BOS_IDX] + list(d["text"])).astype("float32") 1233 | 1234 | tgt_label[i, :len(d["text"]) + 1] = \ 1235 | np.asarray(list(d["text"]) + [EOS_IDX]).astype("float32") 1236 | tgt_label_src_vocab[i, :len(d["text_src_vocab"]) + 1] = \ 1237 | np.asarray(list(d["text_src_vocab"]) + [0]).astype("float32") 1238 | tgt_mask[i, :len(d["text"]) + 1] = 1. 1239 | 1240 | return [input_data, input_data_mask, input_data_pos, input_data_type, 1241 | input_if_hyp, input_data_src_vocab, 1242 | input_data_src_tgt_vocab_map, 1243 | tgt_inp_data, tgt_inp_data_mask, tgt_inp_data_pos, 1244 | tgt_inp_data_type, tgt_inp_data_if_hyp, 1245 | tgt_out_data, tgt_out_data_mask, 1246 | tgt_input, tgt_label, tgt_mask, tgt_label_src_vocab, 1247 | [d["idx"] for d in data]] 1248 | 1249 | def __iter__(self): 1250 | return self 1251 | 1252 | def __next__(self): 1253 | if self.pointer == len(self.idx_pool): 1254 | self._reset() 1255 | raise StopIteration() 1256 | 1257 | idx = self.idx_pool[self.pointer] 1258 | data = self.data[idx] 1259 | self.pointer += 1 1260 | return self._pad(data) 1261 | -------------------------------------------------------------------------------- /decoders.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from config import UNK_IDX, UNK_WORD, BOS_IDX, EOS_IDX, \ 4 | BOC_IDX, BOV_IDX, BOQK_IDX, BOQV_IDX, SUBSEC_IDX, EOC_IDX 5 | from transformer_xlm import CacheTransformer 6 | 7 | 8 | class decoder_base(nn.Module): 9 | def __init__(self, vocab_size, embed_dim, dropout): 10 | super(decoder_base, self).__init__() 11 | self.dropout = nn.Dropout(dropout) 12 | self.embed = nn.Embedding(vocab_size, embed_dim) 13 | 14 | 15 | class ctransformer(CacheTransformer): 16 | def __init__(self, vocab_size, encoder_size, 17 | decoder_size, embed_dim, dropout, nlayer, 18 | nhead, act_fn, max_len, share_embedding, 19 | use_copy, use_entmax, *args, **kwargs): 20 | super(ctransformer, self).__init__( 21 | n_words=vocab_size, 22 | bos_index=BOS_IDX, 23 | eos_index=EOS_IDX, 24 | pad_index=UNK_IDX, 25 | emb_dim=embed_dim, 26 | n_heads=nhead, 27 | n_layers=nlayer, 28 | max_len=max_len, 29 | share_embedding=share_embedding, 30 | use_entmax=use_entmax, 31 | dropout=dropout, 32 | attention_dropout=dropout, 33 | use_copy=use_copy) 34 | 35 | def forward(self, encoder_outputs, encoder_mask, 36 | tgts, src_map, *args, **kwargs): 37 | logits = self.fwd(x=tgts, src_enc=encoder_outputs, 38 | src_mask=encoder_mask, src_map=src_map) 39 | return logits 40 | 41 | def generate(self, encoder_outputs, encoder_mask, max_len, 42 | min_len, top_k, top_p, src_map, src_tgt_vocab_map, 43 | *args, **kwargs): 44 | return self._generate( 45 | src_enc=encoder_outputs, src_mask=encoder_mask, 46 | max_len=max_len, min_len=min_len, 47 | top_p=top_p, src_map=src_map, src_tgt_vocab_map=src_tgt_vocab_map) 48 | 49 | def generate_beam(self, encoder_output, encoder_mask, 50 | beam_size, length_penalty=0.0, 51 | early_stopping=False, trigram_blocking=False, 52 | bos_index=1, eos_index=2, pad_index=1, 53 | min_len=0, max_len=100, return_all=False, 54 | src_map=None, src_tgt_vocab_map=None): 55 | return self._generate_beam( 56 | src_enc=encoder_output, src_mask=encoder_mask, 57 | beam_size=beam_size, length_penalty=length_penalty, 58 | early_stopping=early_stopping, min_len=min_len, 59 | max_len=max_len, trigram_blocking=trigram_blocking, 60 | return_all=return_all, src_map=src_map, 61 | src_tgt_vocab_map=src_tgt_vocab_map) 62 | 63 | 64 | class ie_mask_transformer(decoder_base): 65 | def __init__(self, vocab_size, encoder_size, 66 | decoder_size, embed_dim, dropout, nlayer, 67 | nhead, act_fn, *args, **kwargs): 68 | super(ie_mask_transformer, self).__init__( 69 | vocab_size, embed_dim, dropout) 70 | 71 | # Define layers 72 | assert encoder_size == decoder_size, \ 73 | "encoder size: {} != decoder size: {}"\ 74 | .format(encoder_size, decoder_size) 75 | self.pos_embed = nn.Embedding(200, embed_dim) 76 | self.type_embed = nn.Embedding(500, embed_dim) 77 | self.if_hyp_embed = nn.Embedding(2, embed_dim) 78 | self.layers = nn.ModuleList( 79 | [nn.TransformerDecoderLayer( 80 | d_model=decoder_size, 81 | nhead=nhead, 82 | dim_feedforward=decoder_size * 4, 83 | activation=act_fn) for n in range(nlayer)]) 84 | self.hid2vocab = nn.Linear(decoder_size, vocab_size) 85 | 86 | def forward(self, encoder_outputs, encoder_mask, inp, inp_mask, inp_pos, 87 | inp_type, inp_if_hyp, *args, **kwargs): 88 | 89 | tgt_vecs = self.embed(inp.long()) + \ 90 | self.pos_embed(inp_pos.long()) + \ 91 | self.type_embed(inp_type.long()) + \ 92 | self.if_hyp_embed(inp_if_hyp.long()) 93 | output = self.dropout(tgt_vecs.transpose(0, 1)) 94 | encoder_outputs = encoder_outputs.transpose(0, 1) 95 | encoder_mask = (1 - encoder_mask).bool() 96 | inp_mask = (1 - inp_mask).bool() 97 | for layer in self.layers: 98 | output = layer( 99 | tgt=output, 100 | memory=encoder_outputs, 101 | tgt_key_padding_mask=inp_mask, 102 | memory_key_padding_mask=encoder_mask) 103 | 104 | logits = self.hid2vocab(output) 105 | return logits.transpose(0, 1) 106 | -------------------------------------------------------------------------------- /decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pickle 3 | import os 4 | 5 | 6 | def auto_init_args(init): 7 | def new_init(self, *args, **kwargs): 8 | arg_dict = inspect.signature(init).parameters 9 | arg_names = list(arg_dict.keys())[1:] # skip self 10 | proc_names = set() 11 | for name, arg in zip(arg_names, args): 12 | setattr(self, name, arg) 13 | proc_names.add(name) 14 | for name, arg in kwargs.items(): 15 | setattr(self, name, arg) 16 | proc_names.add(name) 17 | remain_names = set(arg_names) - proc_names 18 | if len(remain_names): 19 | for name in remain_names: 20 | setattr(self, name, arg_dict[name].default) 21 | init(self, *args, **kwargs) 22 | 23 | return new_init 24 | 25 | 26 | def auto_init_pytorch(init): 27 | def new_init(self, *args, **kwargs): 28 | init(self, *args, **kwargs) 29 | self.apply(self.init_weights) 30 | self.opt = self.init_optimizer( 31 | self.expe.config.opt, 32 | self.expe.config.lr, 33 | self.expe.config.l2) 34 | 35 | if not self.expe.config.resume: 36 | self.to(self.device) 37 | self.expe.log.info( 38 | "transferred model to {}".format(self.device)) 39 | 40 | self.expe.log.info("#all parameters: {}, #trainable parameters: {}" 41 | .format(self.count_all_parameters(), 42 | self.count_trainable_parameters())) 43 | return new_init 44 | 45 | 46 | class lazy_execute: 47 | @auto_init_args 48 | def __init__(self, func_name): 49 | pass 50 | 51 | def __call__(self, fn): 52 | func_name = self.func_name 53 | 54 | def new_fn(self, *args, **kwargs): 55 | file_name = kwargs.pop('file_name') 56 | if os.path.isfile(file_name): 57 | return getattr(self, func_name)(file_name) 58 | else: 59 | data = fn(self, *args, **kwargs) 60 | 61 | self.expe.log.info("saving to {}" 62 | .format(file_name)) 63 | with open(file_name, "wb+") as fp: 64 | pickle.dump(data, fp, protocol=-1) 65 | return data 66 | return new_fn 67 | -------------------------------------------------------------------------------- /encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from entmax import entmax_bisect 5 | 6 | 7 | class encoder_base(nn.Module): 8 | def __init__(self, vocab_size, type_vocab_size, 9 | embed_dim, dropout, max_pos_len, 10 | *args, **kwargs): 11 | super(encoder_base, self).__init__() 12 | self.dropout = nn.Dropout(dropout) 13 | self.embed = nn.Embedding(vocab_size, embed_dim) 14 | self.type_embed = nn.Embedding(type_vocab_size, embed_dim) 15 | self.pos_embed = nn.Embedding(max_pos_len, embed_dim) 16 | self.if_hyp_embed = nn.Embedding(2, embed_dim) 17 | 18 | 19 | class transformer(encoder_base): 20 | """encoder module for the input linearized table""" 21 | def __init__(self, vocab_size, type_vocab_size, embed_dim, dropout, 22 | nlayer, nhead, hidden_size, act_fn, 23 | max_pos_len, *args, **kwargs): 24 | super(transformer, self).__init__( 25 | vocab_size, type_vocab_size, embed_dim, dropout, max_pos_len) 26 | self.layers = nn.ModuleList([nn.TransformerEncoderLayer( 27 | d_model=hidden_size, 28 | nhead=nhead, 29 | dim_feedforward=hidden_size * 4, 30 | dropout=dropout, 31 | activation=act_fn) for n in range(nlayer)]) 32 | 33 | def forward(self, data, data_mask, data_pos, data_type, if_hyp): 34 | data_vec = self.embed(data.long()) 35 | data_pos_vec = self.pos_embed(data_pos.long()) 36 | data_type_vec = self.type_embed(data_type.long()) 37 | data_if_hyp_vec = self.if_hyp_embed(if_hyp.long()) 38 | output = self.dropout( 39 | data_vec + data_pos_vec + data_type_vec + data_if_hyp_vec)\ 40 | .transpose(0, 1) 41 | data_mask = (1 - data_mask).bool() 42 | for layer in self.layers: 43 | output = layer(src=output, src_key_padding_mask=data_mask) 44 | return output.transpose(0, 1) 45 | 46 | 47 | class ie_transformer(nn.Module): 48 | """encoder module for the backward model in cyclic training""" 49 | def __init__(self, vocab_size, embed_dim, dropout, 50 | nlayer, nhead, hidden_size, act_fn, max_len, *args, **kwargs): 51 | super(ie_transformer, self).__init__() 52 | self.dropout = nn.Dropout(dropout) 53 | self.embed = nn.Embedding(vocab_size, embed_dim) 54 | self.pos_embed = nn.Embedding(max_len, embed_dim) 55 | self.layers = nn.ModuleList([nn.TransformerEncoderLayer( 56 | d_model=hidden_size, 57 | nhead=nhead, 58 | dim_feedforward=hidden_size * 4, 59 | dropout=dropout, 60 | activation=act_fn) for n in range(nlayer)]) 61 | 62 | def forward(self, data, data_mask): 63 | pos_ids = torch.arange(data.size(1), device=data.device)\ 64 | .unsqueeze(0).expand(data.size(0), -1) 65 | data_vec = self.embed(data.long()) 66 | data_pos_vec = self.pos_embed(pos_ids.long()) 67 | output = self.dropout(data_vec + data_pos_vec).transpose(0, 1) 68 | data_mask = (1 - data_mask).bool() 69 | for layer in self.layers: 70 | output = layer(src=output, src_key_padding_mask=data_mask) 71 | 72 | return output.transpose(0, 1) 73 | 74 | def softmax_forward(self, logits, data_mask): 75 | """logits are softmaxed""" 76 | pos_ids = torch.arange(data_mask.size(1), device=data_mask.device)\ 77 | .unsqueeze(0).expand(data_mask.size(0), -1) 78 | 79 | data_vec = torch.matmul(logits, self.embed.weight) 80 | 81 | data_pos_vec = self.pos_embed(pos_ids.long()) 82 | output = self.dropout(data_vec + data_pos_vec).transpose(0, 1) 83 | data_mask = (1 - data_mask).bool() 84 | for layer in self.layers: 85 | output = layer(src=output, src_key_padding_mask=data_mask) 86 | return output.transpose(0, 1) 87 | 88 | def mix_forward(self, logits, data_mask): 89 | """logits have not been softmaxed""" 90 | pos_ids = torch.arange(data_mask.size(1), device=data_mask.device)\ 91 | .unsqueeze(0).expand(data_mask.size(0), -1) 92 | 93 | data_vec = torch.matmul(F.softmax(logits, -1), self.embed.weight) 94 | 95 | data_pos_vec = self.pos_embed(pos_ids.long()) 96 | output = self.dropout(data_vec + data_pos_vec).transpose(0, 1) 97 | data_mask = (1 - data_mask).bool() 98 | for layer in self.layers: 99 | output = layer(src=output, src_key_padding_mask=data_mask) 100 | return output.transpose(0, 1) 101 | 102 | def entmax_forward(self, logits, data_mask): 103 | pos_ids = torch.arange(data_mask.size(1), device=data_mask.device)\ 104 | .unsqueeze(0).expand(data_mask.size(0), -1) 105 | 106 | bs, sq_len, dim = logits.shape 107 | logits = entmax_bisect(logits.reshape(bs * sq_len, dim), 1.2) 108 | data_vec = torch.matmul(logits.reshape(bs, sq_len, dim), 109 | self.embed.weight) 110 | 111 | data_pos_vec = self.pos_embed(pos_ids.long()) 112 | output = self.dropout(data_vec + data_pos_vec).transpose(0, 1) 113 | data_mask = (1 - data_mask).bool() 114 | for layer in self.layers: 115 | output = layer(src=output, src_key_padding_mask=data_mask) 116 | return output.transpose(0, 1) 117 | -------------------------------------------------------------------------------- /eval_parent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Script to compute metric. 16 | 17 | The and should contain references and 18 | generations, respectively, one per line. The should contain the 19 | ground truth tables corresponding to these in each line. Multiple references 20 | should be separated by s on the same line. 21 | 22 | There are two formats supported for the tables: 23 | 1. For tables similar to those in WikiBio, with pairs of attributes and values: 24 | attribute_1|||value_1attribute_2|||value_2... 25 | 2. For tables similar to WebNLG with triples of (head, relation, tail): 26 | head_1|||relation_1|||tail_1head_2|||relation_2|||tail_2... 27 | 28 | The default implementations for computing the entailment probability and the 29 | table recall provided in this script can handle both the cases above. 30 | https://github.com/google-research/language/tree/master/language/table_text_eval 31 | """ 32 | 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import collections 38 | import io 39 | import json 40 | import logging 41 | import math 42 | import argparse 43 | 44 | logging.basicConfig( 45 | level=logging.DEBUG, 46 | format='%(asctime)s %(levelname)s: %(message)s', 47 | datefmt='%m-%d %H:%M') 48 | 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument('--references', type=str, default=None) 52 | parser.add_argument('--generations', type=str, default=None) 53 | parser.add_argument('--tables', type=str, default=None) 54 | parser.add_argument('--cooccurrence_counts', type=str, default=None) 55 | parser.add_argument('--entailment_fn', type=str, default="overlap") 56 | parser.add_argument('--lambda_weight', type=float, default=None) 57 | parser.add_argument('--smoothing', type=float, default=0.00001) 58 | 59 | FLAGS = parser.parse_args() 60 | 61 | def _text_reader(text_file, multiple=False): 62 | """Yields lines from the text file. 63 | 64 | Performs lowercasing and white-space tokenization on each line before 65 | returning. 66 | 67 | Args: 68 | text_file: String filename. 69 | multiple: Whether multiple references / generations are expected in a line. 70 | """ 71 | with io.open(text_file) as f: 72 | for line in f: 73 | if multiple: 74 | yield [item.lower().split() for item in line.strip().split("\t")] 75 | else: 76 | yield line.strip().lower().split() 77 | 78 | 79 | def _table_reader(table_file): 80 | """Yields tables from the table file. 81 | 82 | Tables are parsed into a list of tuples with tokenized entries. 83 | 84 | Args: 85 | table_file: String filename. 86 | """ 87 | with io.open(table_file) as f: 88 | for line in f: 89 | entries = line.lower().split("\t") 90 | table = [ 91 | [member.split() for member in entry.split("|||")] for entry in entries if entry.strip() 92 | ] 93 | yield table 94 | 95 | 96 | def cooccur_probability_fn(counts): 97 | """Returns function for computing entailment probability. 98 | 99 | Args: 100 | counts: Dict mapping unigrams / bigrams (joined using "|||") to their 101 | counts. 102 | 103 | Returns: 104 | Function handle to compute entailment probability. 105 | """ 106 | 107 | def _cooccur_probability(ngram, table): 108 | """Returns probability of ngram being entailed by the table. 109 | 110 | Uses the co-occurrence counts given along with the lexical 111 | entailment model described in: 112 | 113 | Glickman, Oren, Ido Dagan, and Moshe Koppel. 114 | "A lexical alignment model for probabilistic textual entailment." 115 | Machine Learning Challenges. 116 | Springer, Berlin, Heidelberg, 2006. 287-298. 117 | 118 | E.g.: 119 | >>> _cooccur_probability(["michael", "dahlquist"], 120 | [(["name"], ["michael", "dahlquist"])]) 121 | >>> 1.0 122 | 123 | Args: 124 | ngram: List of tokens. 125 | table: List of either (attribute, value) pairs or (head, relation, tail) 126 | triples. Each member of the pair / triple is assumed to already be 127 | tokenized into a list of strings. 128 | 129 | Returns: 130 | prob: Float probability of ngram being entailed by the table. 131 | """ 132 | table_toks = set() 133 | for item in table: 134 | if len(item) == 2: 135 | # attribute, value 136 | table_toks.add("_".join(item[0])) 137 | table_toks.update(item[1]) 138 | else: 139 | # head, relation, tail 140 | table_toks.update(item[0] + ["_".join(item[1])] + item[2]) 141 | probability = 1. 142 | for xtok in ngram: 143 | if xtok in table_toks: 144 | continue 145 | max_p = 0. 146 | for btok in table_toks: 147 | if btok not in counts: 148 | continue 149 | p = float(counts.get(btok + "|||" + xtok, 0.)) / counts[btok] 150 | if p > max_p: 151 | max_p = p 152 | probability *= max_p 153 | return math.pow(probability, 1. / len(ngram)) 154 | 155 | return _cooccur_probability 156 | 157 | 158 | def overlap_probability(ngram, table, smoothing=0.0, stopwords=None): 159 | """Returns the probability that the given n-gram overlaps with the table. 160 | 161 | A simple implementation which checks how many tokens in the n-gram are also 162 | among the values in the table. For tables with (attribute, value) pairs on the 163 | `value` field is condidered. For tables with (head, relation, tail) triples a 164 | concatenation of `head` and `tail` are considered. 165 | 166 | E.g.: 167 | >>> overlap_probability(["michael", "dahlquist"], 168 | [(["name"], ["michael", "dahlquist"])]) 169 | >>> 1.0 170 | 171 | Args: 172 | ngram: List of tokens. 173 | table: List of either (attribute, value) pairs or (head, relation, tail) 174 | triples. Each member of the pair / triple is assumed to already be 175 | tokenized into a list of strings. 176 | smoothing: (Optional) Float parameter for laplace smoothing. 177 | stopwords: (Optional) List of stopwords to ignore (assign P = 1). 178 | 179 | Returns: 180 | prob: Float probability of ngram being entailed by the table. 181 | """ 182 | if len(table[0]) == 2: 183 | table_values = set([tok for _, value in table for tok in value]) 184 | else: 185 | table_values = set([tok for head, _, tail in table for tok in head + tail]) 186 | overlap = 0 187 | for token in ngram: 188 | if stopwords is not None and token in stopwords: 189 | overlap += 1 190 | continue 191 | if token in table_values: 192 | overlap += 1 193 | return float(overlap + smoothing) / float(len(ngram) + smoothing) 194 | 195 | 196 | def _mention_probability(table_entry, sentence, smoothing=0.00001): 197 | """Returns the probability that the table entry is mentioned in the sentence. 198 | 199 | A simple implementation which checks the longest common subsequence between 200 | the table entry and the sentence. For tables with (attribute, value) pairs 201 | only the `value` is considered. For tables with (head, relation, tail) triples 202 | a concatenation of the `head` and `tail` is considered. 203 | 204 | E.g.: 205 | >>> _mention_probability((["name"], ["michael", "dahlquist"]), 206 | ["michael", "dahlquist", "was", "a", "drummer"]) 207 | >>> 1.0 208 | 209 | Args: 210 | table_entry: Tuple of either (attribute, value) or (head, relation, tail). 211 | Each member of the tuple is assumed to already be tokenized into a list of 212 | strings. 213 | sentence: List of tokens. 214 | smoothing: Float parameter for laplace smoothing. 215 | 216 | Returns: 217 | prob: Float probability of entry being in sentence. 218 | """ 219 | if len(table_entry) == 2: 220 | value = table_entry[1] 221 | else: 222 | value = table_entry[0] + table_entry[2] 223 | overlap = _len_lcs(value, sentence) 224 | return float(overlap + smoothing) / float(len(value) + smoothing) 225 | 226 | 227 | def _len_lcs(x, y): 228 | """Returns the length of the Longest Common Subsequence between two seqs. 229 | 230 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 231 | 232 | Args: 233 | x: sequence of words 234 | y: sequence of words 235 | 236 | Returns 237 | integer: Length of LCS between x and y 238 | """ 239 | table = _lcs(x, y) 240 | n, m = len(x), len(y) 241 | return table[n, m] 242 | 243 | 244 | def _lcs(x, y): 245 | """Computes the length of the LCS between two seqs. 246 | 247 | The implementation below uses a DP programming algorithm and runs 248 | in O(nm) time where n = len(x) and m = len(y). 249 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 250 | 251 | Args: 252 | x: collection of words 253 | y: collection of words 254 | 255 | Returns: 256 | Table of dictionary of coord and len lcs 257 | """ 258 | n, m = len(x), len(y) 259 | table = dict() 260 | for i in range(n + 1): 261 | for j in range(m + 1): 262 | if i == 0 or j == 0: 263 | table[i, j] = 0 264 | elif x[i - 1] == y[j - 1]: 265 | table[i, j] = table[i - 1, j - 1] + 1 266 | else: 267 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 268 | return table 269 | 270 | 271 | def _ngrams(sequence, order): 272 | """Yields all ngrams of given order in sequence.""" 273 | assert order >= 1 274 | for n in range(order, len(sequence) + 1): 275 | yield tuple(sequence[n - order: n]) 276 | 277 | 278 | def _ngram_counts(sequence, order): 279 | """Returns count of all ngrams of given order in sequence.""" 280 | if len(sequence) < order: 281 | return collections.Counter() 282 | return collections.Counter(_ngrams(sequence, order)) 283 | 284 | 285 | def parent(predictions, 286 | references, 287 | tables, 288 | lambda_weight=0.5, 289 | smoothing=0.00001, 290 | max_order=4, 291 | entailment_fn=overlap_probability, 292 | mention_fn=_mention_probability): 293 | """Metric for comparing predictions to references given tables. 294 | 295 | Args: 296 | predictions: An iterator over tokenized predictions. 297 | Each prediction is a list. 298 | references: An iterator over lists of tokenized references. 299 | Each prediction can have multiple references. 300 | tables: An iterator over the tables. Each table is a list of tuples, where a 301 | tuple can either be (attribute, value) pair or (head, relation, tail) 302 | triple. The members of the tuples are assumed to be themselves tokenized 303 | lists of strings. E.g. 304 | `[(["name"], ["michael", "dahlquist"]), 305 | (["birth", "date"], ["december", "22", "1965"])]` 306 | is one table in the (attribute, value) format with two entries. 307 | lambda_weight: Float weight in [0, 1] to multiply table recall. 308 | smoothing: Float value for replace zero values of precision and recall. 309 | max_order: Maximum order of the ngrams to use. 310 | entailment_fn: A python function for computing the probability that an 311 | ngram is entailed by the table. Its signature should match that of 312 | `overlap_probability` above. 313 | mention_fn: A python function for computing the probability that a 314 | table entry is mentioned in the text. Its signature should 315 | match that of `_mention_probability` above. 316 | 317 | Returns: 318 | precision: Average precision of all predictions. 319 | recall: Average recall of all predictions. 320 | f1: Average F-scores of all predictions. 321 | all_f_scores: List of all F-scores for each item. 322 | """ 323 | precisions, recalls, all_f_scores = [], [], [] 324 | reference_recalls, table_recalls = [], [] 325 | all_lambdas = [] 326 | for prediction, list_of_references, table in zip( 327 | predictions, references, tables): 328 | c_prec, c_rec, c_f = [], [], [] 329 | ref_rec, table_rec = [], [] 330 | for reference in list_of_references: 331 | # Weighted ngram precisions and recalls for each order. 332 | ngram_prec, ngram_rec = [], [] 333 | for order in range(1, max_order + 1): 334 | # Collect n-grams and their entailment probabilities. 335 | pred_ngram_counts = _ngram_counts(prediction, order) 336 | pred_ngram_weights = {ngram: entailment_fn(ngram, table) 337 | for ngram in pred_ngram_counts} 338 | ref_ngram_counts = _ngram_counts(reference, order) 339 | ref_ngram_weights = {ngram: entailment_fn(ngram, table) 340 | for ngram in ref_ngram_counts} 341 | 342 | # Precision. 343 | numerator, denominator = 0., 0. 344 | for ngram, count in pred_ngram_counts.items(): 345 | denominator += count 346 | prob_ngram_in_ref = min( 347 | 1., float(ref_ngram_counts.get(ngram, 0) / count)) 348 | numerator += count * ( 349 | prob_ngram_in_ref + 350 | (1. - prob_ngram_in_ref) * pred_ngram_weights[ngram]) 351 | if denominator == 0.: 352 | # Set precision to 0. 353 | ngram_prec.append(0.0) 354 | else: 355 | ngram_prec.append(numerator / denominator) 356 | 357 | # Recall. 358 | numerator, denominator = 0., 0. 359 | for ngram, count in ref_ngram_counts.items(): 360 | prob_ngram_in_pred = min( 361 | 1., float(pred_ngram_counts.get(ngram, 0) / count)) 362 | denominator += count * ref_ngram_weights[ngram] 363 | numerator += count * ref_ngram_weights[ngram] * prob_ngram_in_pred 364 | if denominator == 0.: 365 | # Set recall to 1. 366 | ngram_rec.append(1.0) 367 | else: 368 | ngram_rec.append(numerator / denominator) 369 | 370 | # Compute recall against table fields. 371 | table_mention_probs = [mention_fn(entry, prediction) 372 | for entry in table] 373 | table_rec.append(sum(table_mention_probs) / len(table)) 374 | 375 | # Smoothing. 376 | for order in range(1, max_order): 377 | if ngram_prec[order] == 0.: 378 | ngram_prec[order] = smoothing 379 | if ngram_rec[order] == 0.: 380 | ngram_rec[order] = smoothing 381 | 382 | # Compute geometric averages of precision and recall for all orders. 383 | w = 1. / max_order 384 | if any(prec == 0. for prec in ngram_prec): 385 | c_prec.append(0.) 386 | else: 387 | sp = (w * math.log(p_i) for p_i in ngram_prec) 388 | c_prec.append(math.exp(math.fsum(sp))) 389 | if any(rec == 0. for rec in ngram_rec): 390 | ref_rec.append(smoothing) 391 | else: 392 | sr = [w * math.log(r_i) for r_i in ngram_rec] 393 | ref_rec.append(math.exp(math.fsum(sr))) 394 | 395 | # Combine reference and table recalls. 396 | if table_rec[-1] == 0.: 397 | table_rec[-1] = smoothing 398 | if ref_rec[-1] == 0. or table_rec[-1] == 0.: 399 | c_rec.append(0.) 400 | else: 401 | if lambda_weight is None: 402 | lw = sum([mention_fn(entry, reference) for entry in table 403 | ]) / len(table) 404 | lw = 1. - lw 405 | else: 406 | lw = lambda_weight 407 | all_lambdas.append(lw) 408 | c_rec.append( 409 | math.exp((1. - lw) * math.log(ref_rec[-1]) + 410 | (lw) * math.log(table_rec[-1]))) 411 | 412 | # F-score. 413 | c_f.append((2. * c_prec[-1] * c_rec[-1]) / 414 | (c_prec[-1] + c_rec[-1] + 1e-8)) 415 | 416 | # Get index of best F-score. 417 | max_i = max(enumerate(c_f), key=lambda x: x[1])[0] 418 | precisions.append(c_prec[max_i]) 419 | recalls.append(c_rec[max_i]) 420 | all_f_scores.append(c_f[max_i]) 421 | reference_recalls.append(ref_rec[max_i]) 422 | table_recalls.append(table_rec[max_i]) 423 | 424 | avg_precision = sum(precisions) / len(precisions) 425 | avg_recall = sum(recalls) / len(recalls) 426 | avg_f_score = sum(all_f_scores) / len(all_f_scores) 427 | 428 | return avg_precision, avg_recall, avg_f_score, all_f_scores 429 | 430 | 431 | def main(): 432 | reference_it = _text_reader(FLAGS.references, multiple=True) 433 | generation_it = _text_reader(FLAGS.generations) 434 | table_it = _table_reader(FLAGS.tables) 435 | 436 | if FLAGS.entailment_fn == "cooccurrence": 437 | assert FLAGS.cooccurrence_counts is not None 438 | logging.info("Reading %s...", FLAGS.cooccurrence_counts) 439 | with open(FLAGS.cooccurrence_counts) as f: 440 | cooccur_counts = json.load(f) 441 | entail_method = cooccur_probability_fn(cooccur_counts) 442 | else: 443 | entail_method = overlap_probability 444 | 445 | precision, recall, f_score, all_f = parent( 446 | generation_it, 447 | reference_it, 448 | table_it, 449 | lambda_weight=FLAGS.lambda_weight, 450 | smoothing=FLAGS.smoothing, 451 | entailment_fn=entail_method) 452 | 453 | logging.info("Evaluated %d examples.", len(all_f)) 454 | logging.info("Precision = %.4f Recall = %.4f F-score = %.4f", 455 | precision, recall, f_score) 456 | 457 | if __name__ == "__main__": 458 | main() 459 | -------------------------------------------------------------------------------- /eval_rouge_meteor_rep.py: -------------------------------------------------------------------------------- 1 | from train_helper import run_multi_bleu, Meteor 2 | from collections import Counter 3 | from tqdm import tqdm 4 | 5 | import statistics 6 | import rouge 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--references', type=str, default=None) 12 | parser.add_argument('--generations', type=str, default=None) 13 | 14 | args = parser.parse_args() 15 | 16 | gen_path = args.generations 17 | ref_path = args.references 18 | 19 | rouge_eval = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], 20 | max_n=2, 21 | limit_length=True, 22 | length_limit=5000, 23 | length_limit_type='words', 24 | apply_avg=False, 25 | apply_best=False, 26 | alpha=0.5, # Default F1_score 27 | weight_factor=1.2, 28 | stemming=True) 29 | 30 | max_rep_n = 4 31 | min_rep_n = 2 32 | intra_rep = {"mean": 0.0, "median": 0.0, "min": 0.0, 33 | "max": 0.0, "std": 0.0, "count_ratio": 0.0} 34 | intra_rep_per_gram = {} 35 | for i in range(1, max_rep_n + 1): 36 | intra_rep_per_gram[i] = {"mean": 0.0, "median": 0.0, 37 | "min": 0.0, "max": 0.0, 38 | "std": 0.0, "count_ratio": 0.0} 39 | 40 | n_dev = 0 41 | stats = {"rouge1": 0, "rouge2": 0, "rougel": 0, "meteor": 0} 42 | meteor = Meteor() 43 | with open(gen_path) as gen_fp, open(ref_path) as ref_fp: 44 | for nline, (gen_line, ref_line) in tqdm(enumerate(zip(gen_fp, ref_fp))): 45 | if gen_line.strip(): 46 | curr_intra_rep = Counter() 47 | curr_intra_rep_per_gram = \ 48 | {i: Counter() for i in range(1, max_rep_n + 1)} 49 | 50 | n_dev += 1 51 | gen_tok = gen_line.strip().split() 52 | for i in range(len(gen_tok)): 53 | for l in range(1, min(max_rep_n + 1, len(gen_tok[i:]) + 1)): 54 | if l >= min_rep_n: 55 | curr_intra_rep[" ".join(gen_tok[i: i + l])] += 1 56 | curr_intra_rep_per_gram[len(gen_tok[i: i + l])][" ".join(gen_tok[i: i + l])] += 1 57 | 58 | curr_intra_rep_count = [] 59 | curr_total_ngram = len(curr_intra_rep) 60 | curr_total_rep_ngram_count = 0 61 | curr_total_rep_ngram_type = 0 62 | for p, c in curr_intra_rep.items(): 63 | if c >= 3: 64 | curr_intra_rep_count.append(c) 65 | curr_total_rep_ngram_count += c 66 | curr_total_rep_ngram_type += 1 67 | 68 | curr_intra_rep_total_count_per_gram = {} 69 | curr_intra_rep_total_type_per_gram = {} 70 | curr_intra_rep_count_per_gram = {} 71 | curr_total_ngram_per_gram = {} 72 | curr_total_rep_ngram_count_per_gram = {} 73 | curr_total_rep_ngram_type_per_gram = {} 74 | 75 | for i in range(1, max_rep_n + 1): 76 | curr_intra_rep_total_count_per_gram[i] = sum(curr_intra_rep_per_gram[i].values()) 77 | curr_intra_rep_total_type_per_gram[i] = len(curr_intra_rep_per_gram[i]) 78 | 79 | curr_intra_rep_count_per_gram[i] = [] 80 | curr_total_ngram_per_gram[i] = len(curr_intra_rep_per_gram[i]) 81 | curr_total_rep_ngram_count_per_gram[i] = 0 82 | curr_total_rep_ngram_type_per_gram[i] = 0 83 | for p, c in curr_intra_rep_per_gram[i].items(): 84 | if c >= 3: 85 | curr_intra_rep_count_per_gram[i].append(c) 86 | curr_total_rep_ngram_count_per_gram[i] += c 87 | curr_total_rep_ngram_type_per_gram[i] += 1 88 | 89 | intra_rep_per_gram[i]["max"] += max(curr_intra_rep_count_per_gram[i]) if len(curr_intra_rep_count_per_gram[i]) else 0 90 | intra_rep_per_gram[i]["min"] += min(curr_intra_rep_count_per_gram[i]) if len(curr_intra_rep_count_per_gram[i]) else 0 91 | intra_rep_per_gram[i]["median"] += statistics.median(curr_intra_rep_count_per_gram[i]) if len(curr_intra_rep_count_per_gram[i]) else 0 92 | intra_rep_per_gram[i]["mean"] += statistics.mean(curr_intra_rep_count_per_gram[i]) if len(curr_intra_rep_count_per_gram[i]) else 0 93 | intra_rep_per_gram[i]["std"] += statistics.stdev(curr_intra_rep_count_per_gram[i]) if len(curr_intra_rep_count_per_gram[i]) > 1 else 0 94 | intra_rep_per_gram[i]["count_ratio"] += curr_total_rep_ngram_count_per_gram[i] / curr_intra_rep_total_count_per_gram[i] if curr_intra_rep_total_count_per_gram[i] else 0 95 | 96 | curr_intra_rep_total_count = sum(curr_intra_rep.values()) 97 | curr_intra_rep_total_type = len(curr_intra_rep) 98 | 99 | intra_rep["max"] += max(curr_intra_rep_count) if len(curr_intra_rep_count) else 0 100 | intra_rep["min"] += min(curr_intra_rep_count) if len(curr_intra_rep_count) else 0 101 | intra_rep["median"] += statistics.median(curr_intra_rep_count) if len(curr_intra_rep_count) else 0 102 | intra_rep["mean"] += statistics.mean(curr_intra_rep_count) if len(curr_intra_rep_count) else 0 103 | intra_rep["std"] += statistics.stdev(curr_intra_rep_count) if len(curr_intra_rep_count) > 1 else 0 104 | intra_rep["count_ratio"] += curr_total_rep_ngram_count / curr_intra_rep_total_count if curr_intra_rep_total_count else 0 105 | 106 | ms = meteor._score(gen_line.strip(), [ref_line.strip()]) 107 | rouge_scores = rouge_eval.get_scores([gen_line.strip()], [ref_line.strip()]) 108 | stats['rouge1'] += rouge_scores['rouge-1'][0]['f'][0] 109 | stats['rouge2'] += rouge_scores['rouge-2'][0]['f'][0] 110 | stats['rougel'] += rouge_scores['rouge-l'][0]['f'][0] 111 | stats['meteor'] += ms 112 | 113 | dev_rouge1 = stats['rouge1'] / n_dev * 100 114 | dev_rouge2 = stats['rouge2'] / n_dev * 100 115 | dev_rougel = stats['rougel'] / n_dev * 100 116 | dev_meteor = stats['meteor'] / n_dev * 100 117 | 118 | intra_rep["max"] = intra_rep["max"] / n_dev 119 | intra_rep["min"] = intra_rep["min"] / n_dev 120 | intra_rep["median"] = intra_rep["median"] / n_dev 121 | intra_rep["mean"] = intra_rep["mean"] / n_dev 122 | intra_rep["std"] = intra_rep["std"] / n_dev 123 | intra_rep["count_ratio"] = intra_rep["count_ratio"] / n_dev * 100 124 | 125 | 126 | dev_bleu_score = run_multi_bleu(gen_path, ref_path) 127 | 128 | print("REP -", ", ".join(["{} : {:.2f}".format(k, v) for k, v 129 | in sorted(intra_rep.items())])) 130 | 131 | print("NOTE: for REP, We report count ratio.") 132 | 133 | print("=== REP ngram breakdown ===") 134 | for i in range(1, max_rep_n + 1): 135 | print("*** {}-gram ***".format(i)) 136 | print(", ".join(["{} : {:.2f}".format(k, v / n_dev if k != "count_ratio" else v / n_dev * 100) 137 | for k, v in sorted(intra_rep_per_gram[i].items())])) 138 | 139 | print("BLEU: {:.2f}".format(dev_bleu_score)) 140 | print("REOUG-1: {:.2f}".format(dev_rouge1)) 141 | print("ROUGE-2: {:.2f}".format(dev_rouge2)) 142 | print("ROUGE-L: {:.2f}".format(dev_rougel)) 143 | print("METEOR: {:.2f}".format(dev_meteor)) 144 | -------------------------------------------------------------------------------- /generate_beam_search.py: -------------------------------------------------------------------------------- 1 | import train_helper 2 | import data_utils 3 | import config 4 | import models 5 | import torch 6 | import sys 7 | import nltk 8 | import os 9 | 10 | from train_helper import run_multi_bleu 11 | from config import EOS_IDX 12 | from tqdm import tqdm 13 | 14 | BEST_DEV_BLEU = TEST_BLEU = 0 15 | 16 | 17 | def run(e): 18 | global BEST_DEV_BLEU, TEST_BLEU 19 | 20 | checkpoint = torch.load(e.config.model_file, 21 | map_location=lambda storage, loc: storage) 22 | e.log.info("loaded from: {}".format(e.config.model_file)) 23 | 24 | class dummy_exp: 25 | pass 26 | model_exp = dummy_exp() 27 | model_exp.log = e.log 28 | checkpoint["config"].debug = False 29 | checkpoint["config"].resume = True 30 | 31 | model_exp.config = checkpoint["config"] 32 | model_exp.experiment_dir = e.config.gen_dir \ 33 | if e.config.gen_dir else e.experiment_dir 34 | model_exp.config.top_p = e.config.top_p 35 | model_exp.config.max_gen_len = e.config.max_gen_len 36 | model_exp.config.min_gen_len = e.config.min_gen_len 37 | for name in dir(e.config): 38 | if name.startswith("__"): 39 | continue 40 | if name not in dir(model_exp.config): 41 | value = getattr(e.config, name) 42 | e.log.info("update {} to {}".format(name, value)) 43 | setattr(model_exp.config, name, value) 44 | 45 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 46 | data_processor = data_utils.DataProcessor( 47 | train_path=e.config.train_path, 48 | dev_path=e.config.dev_path, 49 | test_path=e.config.test_path, 50 | wikidata_path=e.config.wikidata_path, 51 | infobox_path=e.config.infobox_path, 52 | bpe_vocab=e.config.bpe_vocab, 53 | bpe_codes=e.config.bpe_codes, 54 | experiment=model_exp) 55 | data = data_processor.process() 56 | 57 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 58 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 59 | 60 | model = models.BasicCyclicAttnSplitMask( 61 | vocab_size=len(data.vocab), 62 | type_vocab_size=500, 63 | embed_dim=model_exp.config.edim, 64 | iter_per_epoch=100, 65 | use_entmax=False, 66 | experiment=model_exp) 67 | 68 | model.load(checkpointed_state_dict=checkpoint["state_dict"]) 69 | 70 | e.log.info(model) 71 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 72 | 73 | dev_eval = train_helper.SplitMaskEvaluator( 74 | model=model, 75 | data=data.dev_data, 76 | inv_vocab=data.inv_vocab, 77 | vocab=data.vocab, 78 | return_wikidata=e.config.return_wikidata, 79 | return_hyperlink=e.config.return_hyperlink, 80 | input_wikidata=e.config.input_wikidata, 81 | input_hyperlink=e.config.input_hyperlink, 82 | eval_batch_size=e.config.eval_batch_size, 83 | experiment=model_exp) 84 | 85 | model.eval() 86 | gen_fn = e.config.gen_prefix 87 | output_path = e.config.gen_dir \ 88 | if e.config.gen_dir else e.experiment_dir 89 | 90 | if not os.path.isdir(output_path): 91 | print("make dirs", output_path) 92 | os.makedirs(output_path) 93 | print("e.config.max_gen_len", e.config.max_gen_len, 94 | "e.config.min_gen_len", e.config.min_gen_len) 95 | all_gen = {} 96 | for nbatch, (input_data, input_data_mask, input_data_pos, 97 | input_data_type, input_if_hyp, 98 | input_data_src_vocab, input_data_src_tgt_vocab_map, 99 | tgt_inp_data, tgt_inp_data_mask, tgt_inp_data_pos, 100 | tgt_inp_data_type, tgt_inp_data_if_hyp, 101 | tgt_out_data, tgt_out_data_mask, 102 | tgt_input, tgt_label, tgt_mask, 103 | tgt_src_vocab, batch_idx) in tqdm( 104 | enumerate(dev_eval.data_iterator), 105 | total=len(dev_eval.data_iterator)): 106 | if nbatch and nbatch % (len(dev_eval.data_iterator) // 10 + 1) == 0: 107 | e.log.info("evaluating progress: {}/{} = {:.2f} %" 108 | .format(nbatch, 109 | len(dev_eval.data_iterator), 110 | nbatch / (len(dev_eval.data_iterator) + 1) * 100) 111 | ) 112 | with torch.no_grad(): 113 | data, data_mask, data_pos, \ 114 | data_type, data_if_hyp, data_src_vocab,\ 115 | data_src_tgt_vocab_map = \ 116 | model.to_tensors(input_data, input_data_mask, 117 | input_data_pos, input_data_type, 118 | input_if_hyp, 119 | input_data_src_vocab, 120 | input_data_src_tgt_vocab_map) 121 | data_vec = model.encode( 122 | data, data_mask, data_pos, data_type, data_if_hyp) 123 | 124 | batch_gen, batch_nll, batch_len = model.decode.generate_beam( 125 | encoder_output=data_vec, 126 | encoder_mask=data_mask, 127 | beam_size=e.config.beam_size, 128 | trigram_blocking=e.config.trigram_blocking, 129 | min_len=e.config.min_gen_len, 130 | max_len=e.config.max_gen_len, 131 | src_map=data_src_vocab, 132 | src_tgt_vocab_map=data_src_tgt_vocab_map) 133 | 134 | for batch_gen_, batch_null_, batch_len_, idx in \ 135 | zip(batch_gen, batch_nll, batch_len, batch_idx): 136 | curr_gen = [] 137 | for i in batch_gen_[1:batch_len_]: 138 | if i == EOS_IDX: 139 | break 140 | if i >= len(dev_eval.inv_vocab): 141 | curr_gen.append( 142 | dev_eval.data_iterator.data[idx]["inv_src_vocab"][int(i) - len(dev_eval.inv_vocab)] 143 | ) 144 | else: 145 | curr_gen.append(dev_eval.inv_vocab[int(i)]) 146 | all_gen[idx] = " ".join(curr_gen)\ 147 | .replace("@@ ", "").replace("@@", "") 148 | 149 | assert len(all_gen) == len(dev_eval.data_iterator.data), \ 150 | "{} != {}".format(len(all_gen), len(dev_eval.data_iterator.data)) 151 | file_name = os.path.join(output_path, gen_fn + ".txt") 152 | ref_file_name = os.path.join(output_path, gen_fn + "_ref.txt") 153 | 154 | file_name_untok = os.path.join(output_path, gen_fn + "_untok.txt") 155 | ref_file_name_untok = os.path.join(output_path, gen_fn + "_untok_ref.txt") 156 | 157 | gen_len_list = [] 158 | with open(file_name, "w") as fp, open(ref_file_name, "w") as fp2, \ 159 | open(file_name_untok, "w") as fpu, \ 160 | open(ref_file_name_untok, "w") as fpu2: 161 | for hyp_idx, ref in zip(sorted(all_gen), 162 | sorted(dev_eval.data_iterator.data, 163 | key=lambda x: x["idx"])): 164 | assert hyp_idx == ref["idx"], \ 165 | "hyp_idx={} != ref[\"idx\"]={}".format(hyp_idx, ref["idx"]) 166 | hyp = all_gen[hyp_idx] 167 | fp2.write(ref["tok_text"] + "\n") 168 | fpu2.write(ref["untok_text"] + "\n") 169 | if hyp: 170 | tok_hyp = nltk.word_tokenize(hyp) 171 | gen_len_list.append(len(tok_hyp)) 172 | tok_hyp = " ".join(tok_hyp) 173 | fp.write(tok_hyp + "\n") 174 | fpu.write(hyp + "\n") 175 | else: 176 | gen_len_list.append(0) 177 | fp.write("\n") 178 | fpu.write("\n") 179 | bleu_score = run_multi_bleu(file_name, ref_file_name) 180 | e.log.info("generated sentences saved to: {}".format(file_name)) 181 | 182 | e.log.info( 183 | "#Data: {}, bleu: {:.3f}, loss: {:.3f}, gloss: {:.3f}, " 184 | "floss: {:.3f}, tloss: {:.3f}, " 185 | "avg gen len: {:.2f}" 186 | .format(len(all_gen), bleu_score, dev_eval.eval_stats["loss"], 187 | dev_eval.eval_stats["gen_loss"], 188 | dev_eval.eval_stats["fake_cyclic_loss"], 189 | dev_eval.eval_stats["true_cyclic_loss"], 190 | sum(gen_len_list) / len(gen_len_list) 191 | ) 192 | ) 193 | dev_eval.eval_stats.reset() 194 | 195 | 196 | if __name__ == '__main__': 197 | 198 | PARSED_CONFIG = config.get_base_parser().parse_args() 199 | 200 | def exit_handler(*args): 201 | print(PARSED_CONFIG) 202 | print("best dev bleu: {:.4f}, test bleu: {:.4f}" 203 | .format(BEST_DEV_BLEU, TEST_BLEU)) 204 | sys.exit() 205 | 206 | train_helper.register_exit_handler(exit_handler) 207 | 208 | with train_helper.Experiment(PARSED_CONFIG, 209 | PARSED_CONFIG.save_prefix, 210 | forced_debug=True) as exp: 211 | 212 | exp.log.info("*" * 25 + " ARGS " + "*" * 25) 213 | exp.log.info(PARSED_CONFIG) 214 | exp.log.info("*" * 25 + " ARGS " + "*" * 25) 215 | 216 | run(exp) 217 | -------------------------------------------------------------------------------- /generate_greedy_sampl.py: -------------------------------------------------------------------------------- 1 | import train_helper 2 | import data_utils 3 | import config 4 | import models 5 | import torch 6 | import sys 7 | import os 8 | 9 | BEST_DEV_BLEU = TEST_BLEU = 0 10 | 11 | 12 | def run(e): 13 | global BEST_DEV_BLEU, TEST_BLEU 14 | 15 | checkpoint = torch.load(e.config.model_file, 16 | map_location=lambda storage, loc: storage) 17 | e.log.info("loaded from: {}".format(e.config.model_file)) 18 | 19 | class dummy_exp: 20 | pass 21 | model_exp = dummy_exp() 22 | model_exp.log = e.log 23 | checkpoint["config"].debug = False 24 | checkpoint["config"].resume = True 25 | # rev_model_exp = copy.deepcopy(experiment) 26 | model_exp.config = checkpoint["config"] 27 | model_exp.experiment_dir = \ 28 | e.config.gen_dir if e.config.gen_dir else e.experiment_dir 29 | model_exp.config.top_p = e.config.top_p 30 | model_exp.config.max_gen_len = e.config.max_gen_len 31 | model_exp.config.min_gen_len = e.config.min_gen_len 32 | for name in dir(e.config): 33 | if name.startswith("__"): 34 | continue 35 | if name not in dir(model_exp.config): 36 | value = getattr(e.config, name) 37 | e.log.info("update {} to {}".format(name, value)) 38 | setattr(model_exp.config, name, value) 39 | 40 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 41 | data_processor = data_utils.DataProcessor( 42 | train_path=e.config.train_path, 43 | dev_path=e.config.dev_path, 44 | test_path=e.config.test_path, 45 | wikidata_path=e.config.wikidata_path, 46 | infobox_path=e.config.infobox_path, 47 | bpe_vocab=e.config.bpe_vocab, 48 | bpe_codes=e.config.bpe_codes, 49 | experiment=model_exp) 50 | data = data_processor.process() 51 | 52 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 53 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 54 | 55 | model = models.BasicCyclicAttnSplitMask( 56 | vocab_size=len(data.vocab), 57 | type_vocab_size=500, 58 | embed_dim=model_exp.config.edim, 59 | iter_per_epoch=100, 60 | use_entmax=False, 61 | experiment=model_exp) 62 | 63 | model.load(checkpointed_state_dict=checkpoint["state_dict"]) 64 | 65 | e.log.info(model) 66 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 67 | 68 | dev_eval = train_helper.Evaluator( 69 | model=model, 70 | data=data.dev_data, 71 | inv_vocab=data.inv_vocab, 72 | vocab=data.vocab, 73 | eval_batch_size=e.config.eval_batch_size, 74 | return_wikidata=model_exp.config.return_wikidata, 75 | return_hyperlink=model_exp.config.return_hyperlink, 76 | input_wikidata=e.config.input_wikidata, 77 | input_hyperlink=e.config.input_hyperlink, 78 | experiment=model_exp) 79 | 80 | if not os.path.isdir(model_exp.experiment_dir): 81 | print("make dirs", model_exp.experiment_dir) 82 | os.makedirs(model_exp.experiment_dir) 83 | dev_eval.evaluate(e.config.gen_prefix) 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | PARSED_CONFIG = config.get_base_parser().parse_args() 89 | 90 | def exit_handler(*args): 91 | print(PARSED_CONFIG) 92 | print("best dev bleu: {:.4f}, test bleu: {:.4f}" 93 | .format(BEST_DEV_BLEU, TEST_BLEU)) 94 | sys.exit() 95 | 96 | train_helper.register_exit_handler(exit_handler) 97 | 98 | with train_helper.Experiment(PARSED_CONFIG, 99 | PARSED_CONFIG.save_prefix, 100 | forced_debug=True) as exp: 101 | 102 | exp.log.info("*" * 25 + " ARGS " + "*" * 25) 103 | exp.log.info(PARSED_CONFIG) 104 | exp.log.info("*" * 25 + " ARGS " + "*" * 25) 105 | 106 | run(exp) 107 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.optim.lr_scheduler import LambdaLR 6 | from entmax import entmax_bisect 7 | 8 | 9 | def init_embedding(input_embedding): 10 | """ 11 | Initialize embedding 12 | """ 13 | bias = np.sqrt(3.0 / input_embedding.size(1)) 14 | nn.init.uniform_(input_embedding, -bias, bias) 15 | 16 | 17 | def init_linear(input_linear): 18 | """ 19 | Initialize linear transformation 20 | """ 21 | bias = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1))) 22 | nn.init.uniform_(input_linear.weight, -bias, bias) 23 | if input_linear.bias is not None: 24 | input_linear.bias.data.zero_() 25 | 26 | 27 | class LabelSmoothingLoss(nn.Module): 28 | def __init__(self, classes, smoothing=0.1, dim=-1): 29 | super(LabelSmoothingLoss, self).__init__() 30 | self.confidence = 1.0 - smoothing 31 | self.smoothing = smoothing 32 | self.cls = classes 33 | self.dim = dim 34 | 35 | def forward(self, pred, target): 36 | pred = pred.log_softmax(dim=self.dim) 37 | with torch.no_grad(): 38 | true_dist = F.one_hot(target.long(), self.cls) 39 | true_dist = true_dist * self.confidence 40 | true_dist = true_dist + self.smoothing / (self.cls - 1) 41 | return torch.sum(-true_dist * pred, dim=self.dim) 42 | 43 | 44 | def get_linear_schedule_with_warmup( 45 | optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 46 | """ Create a schedule with a learning rate that decreases linearly after 47 | linearly increasing during a warmup period. 48 | """ 49 | def lr_lambda(current_step): 50 | if current_step < num_warmup_steps: 51 | return float(current_step) / float(max(1, num_warmup_steps)) 52 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 53 | 54 | return LambdaLR(optimizer, lr_lambda, last_epoch) 55 | 56 | 57 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, need_softmax=True, filter_value=-float('Inf')): 58 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 59 | Args: 60 | logits: logits distribution shape (vocabulary size) 61 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 62 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 63 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 64 | """ 65 | # assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 66 | top_k = min(top_k, logits.size(-1)) # Safety check 67 | if top_k > 0: 68 | # Remove all tokens with a probability less than the last token of the top-k 69 | indices_to_remove = \ 70 | logits < torch.topk(logits, top_k, dim=-1)[0]\ 71 | .min(-1)[0].unsqueeze(-1) 72 | logits[indices_to_remove] = filter_value 73 | 74 | if top_p > 0.0 and top_p < 1.0: 75 | sorted_logits, sorted_indices = \ 76 | torch.sort(logits, dim=-1, descending=True) 77 | if need_softmax: 78 | softmax_sorted_logits = F.softmax(sorted_logits, dim=-1) 79 | else: 80 | softmax_sorted_logits = sorted_logits 81 | cumulative_probs = torch.cumsum(softmax_sorted_logits, dim=-1) 82 | 83 | # Remove tokens with cumulative probability above the threshold 84 | sorted_indices_to_remove = cumulative_probs > top_p 85 | sorted_indices_to_remove[:, 0] = False 86 | # Shift the indices to the right to keep also the first token above the threshold 87 | # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 88 | # sorted_indices_to_remove[..., 0] = 0 89 | 90 | # indices_to_remove = sorted_indices[sorted_indices_to_remove] 91 | # unsorted 92 | sorted_logits[sorted_indices_to_remove] = filter_value 93 | logits = sorted_logits.gather(-1, torch.sort(sorted_indices, descending=False)[1]) 94 | return logits 95 | 96 | 97 | class CopyGenerator(nn.Module): 98 | """An implementation of pointer-generator networks 99 | :cite:`DBLP:journals/corr/SeeLM17`. 100 | These networks consider copying words 101 | directly from the source sequence. 102 | The copy generator is an extended version of the standard 103 | generator that computes three values. 104 | * :math:`p_{softmax}` the standard softmax over `tgt_dict` 105 | * :math:`p(z)` the probability of copying a word from 106 | the source 107 | * :math:`p_{copy}` the probility of copying a particular word. 108 | taken from the attention distribution directly. 109 | The model returns a distribution over the extend dictionary, 110 | computed as 111 | :math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` 112 | .. mermaid:: 113 | graph BT 114 | A[input] 115 | S[src_map] 116 | B[softmax] 117 | BB[switch] 118 | C[attn] 119 | D[copy] 120 | O[output] 121 | A --> B 122 | A --> BB 123 | S --> D 124 | C --> D 125 | D --> O 126 | B --> O 127 | BB --> O 128 | Args: 129 | input_size (int): size of input representation 130 | output_size (int): size of output vocabulary 131 | pad_idx (int) 132 | based on an implementation from 133 | https://github.com/OpenNMT/OpenNMT-py 134 | """ 135 | 136 | def __init__(self, input_size, use_entmax, pad_idx=0): 137 | super(CopyGenerator, self).__init__() 138 | # self.linear = nn.Linear(input_size, output_size) 139 | self.linear_copy = nn.Linear(input_size, 1) 140 | self.pad_idx = pad_idx 141 | self.use_entmax = use_entmax 142 | 143 | def _bottle(self, _v): 144 | return _v.view(-1, _v.size(2)) 145 | 146 | def _unbottle(self, _v, batch_size): 147 | return _v.view(-1, batch_size, _v.size(1)) 148 | 149 | def forward(self, hidden, orig_prob, attn, src_map): 150 | """ 151 | Compute a distribution over the target dictionary 152 | extended by the dynamic dictionary implied by copying 153 | source words. 154 | Args: 155 | hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` 156 | attn (FloatTensor): attn for each ``(batch x tlen, input_size)`` 157 | src_map (FloatTensor): 158 | A sparse indicator matrix mapping each source word to 159 | its index in the "extended" vocab containing. 160 | ``(src_len, batch, extra_words)`` 161 | """ 162 | 163 | # CHECKS 164 | # batch_by_tlen, _ = hidden.size() 165 | # batch_by_tlen_, slen = attn.size() 166 | onehot_src_map = \ 167 | F.one_hot(src_map.long(), torch.max(src_map).long() + 1) 168 | batch, slen, cvocab = onehot_src_map.size() 169 | 170 | if self.use_entmax: 171 | prob = entmax_bisect(orig_prob, 1.2) 172 | else: 173 | prob = torch.softmax(orig_prob, 1) 174 | 175 | # Probability of copying p(z=1) batch. 176 | p_copy = torch.sigmoid(self.linear_copy(hidden)) 177 | # Probability of not copying: p_{word}(w) * (1 - p(z)) 178 | out_prob = torch.mul(prob, 1 - p_copy) 179 | mul_attn = torch.mul(attn, p_copy) 180 | copy_prob = torch.bmm( 181 | mul_attn.view(batch, -1, slen), # batch size x tgt len x src len 182 | onehot_src_map.float()) # batch size x src len x cvocab 183 | copy_prob = copy_prob.contiguous().view(-1, cvocab) 184 | return out_prob, copy_prob 185 | 186 | 187 | class CopyGeneratorLoss(nn.Module): 188 | """Copy generator criterion. 189 | 190 | based on an implementation from 191 | https://github.com/OpenNMT/OpenNMT-py 192 | """ 193 | def __init__(self, vocab_size, force_copy, unk_index=0, 194 | ignore_index=0, eps=1e-10): 195 | super(CopyGeneratorLoss, self).__init__() 196 | self.force_copy = force_copy 197 | self.eps = eps 198 | self.vocab_size = vocab_size 199 | self.ignore_index = ignore_index 200 | self.unk_index = unk_index 201 | 202 | def _bottle(self, _v): 203 | return _v.view(-1, _v.size(2)) 204 | 205 | def _unbottle(self, _v, batch_size): 206 | return _v.view(-1, batch_size, _v.size(1)) 207 | 208 | def forward(self, scores, align, target, src_tgt_map, label_smoothing): 209 | """ 210 | Args: 211 | scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size 212 | whose sum along dim 1 is less than or equal to 1, i.e. cols 213 | softmaxed. 214 | src_tgt_map: batch size x extended vocab size 215 | ([b, src vocab idx] = tgt vocab idx) 216 | align (LongTensor): ``(batch_size x tgt_len)`` 217 | target (LongTensor): ``(batch_size x tgt_len)`` 218 | """ 219 | bs, sqlen = align.shape 220 | flat_align = align.reshape(-1) 221 | flat_target = target.reshape(-1) 222 | 223 | if label_smoothing: 224 | out_prob, copy_prob = scores 225 | 226 | scores, copy_mask = collapse_copy_scores( 227 | torch.cat([out_prob, copy_prob], 1), 228 | src_tgt_map, self.vocab_size) 229 | 230 | label_mask = copy_mask 231 | 232 | confidence = 1 - label_smoothing 233 | smoothing = label_smoothing / label_mask.sum(1, keepdim=True) 234 | 235 | tgt_labels = torch.zeros_like(scores) 236 | copy_labels = torch.zeros_like(scores) 237 | 238 | tgt_labels.scatter_(1, flat_target.unsqueeze(1).long(), 1) 239 | 240 | copy_ix = flat_align.unsqueeze(1) + self.vocab_size 241 | copy_labels.scatter_(1, copy_ix.long(), 1) 242 | non_copy = flat_align == self.unk_index 243 | if not self.force_copy: 244 | non_copy = non_copy | (flat_target != self.unk_index) 245 | 246 | final_labels = torch.where( 247 | non_copy.unsqueeze(1), tgt_labels, copy_labels 248 | ) 249 | 250 | final_labels = final_labels * (confidence - smoothing) + smoothing 251 | final_labels = final_labels * label_mask 252 | 253 | # final_labels = final_labels * label_mask 254 | loss = torch.sum(- (scores + self.eps).log() * final_labels, dim=1) 255 | else: 256 | scores = torch.cat(scores, 1) 257 | # probabilities assigned by the model to the gold targets 258 | vocab_probs = scores.gather( 259 | 1, flat_target.unsqueeze(1).long()).squeeze(1) 260 | 261 | # probability of tokens copied from source 262 | copy_ix = flat_align.unsqueeze(1) + self.vocab_size 263 | copy_tok_probs = scores.gather(1, copy_ix.long()).squeeze(1) 264 | # Set scores for unk to 0 and add eps 265 | copy_tok_probs[flat_align == self.unk_index] = 0 266 | copy_tok_probs = copy_tok_probs + self.eps # to avoid -inf logs 267 | 268 | # find the indices in which you do not use the copy mechanism 269 | non_copy = flat_align == self.unk_index 270 | if not self.force_copy: 271 | non_copy = non_copy | (flat_target != self.unk_index) 272 | 273 | probs = torch.where( 274 | non_copy, copy_tok_probs + vocab_probs, copy_tok_probs 275 | ) 276 | # just NLLLoss; can the module be incorporated? 277 | loss = -(probs + self.eps).log() 278 | # Drop padding. 279 | loss[flat_target == self.ignore_index] = 0 280 | return loss 281 | 282 | 283 | class CopyGeneratorLossCompute(nn.Module): 284 | """Copy Generator Loss Computation. 285 | 286 | based on an implementation from 287 | https://github.com/OpenNMT/OpenNMT-py 288 | """ 289 | def __init__(self, criterion, generator, tgt_vocab, normalize_by_length): 290 | super(CopyGeneratorLossCompute, self).__init__() 291 | self.criterion = criterion 292 | self.generator = generator 293 | self.tgt_vocab = tgt_vocab 294 | self.normalize_by_length = normalize_by_length 295 | 296 | def _bottle(self, _v): 297 | return _v.view(-1, _v.size(2)) 298 | 299 | def _unbottle(self, _v, batch_size): 300 | return _v.view(-1, batch_size, _v.size(1)) 301 | 302 | def compute_loss(self, batch, output, target, copy_attn, align, 303 | std_attn=None, coverage_attn=None): 304 | """Compute the loss. 305 | The args must match :func:`self._make_shard_state()`. 306 | Args: 307 | batch: the current batch. 308 | output: the predict output from the model. 309 | target: the validate target to compare output with. 310 | copy_attn: the copy attention value. 311 | align: the align info. 312 | """ 313 | target = target.view(-1) 314 | align = align.view(-1) 315 | scores = self.generator( 316 | self._bottle(output), self._bottle(copy_attn), batch.src_map 317 | ) 318 | loss = self.criterion(scores, align, target) 319 | 320 | # this part looks like it belongs in CopyGeneratorLoss 321 | if self.normalize_by_length: 322 | # Compute Loss as NLL divided by seq length 323 | tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float() 324 | # Compute Total Loss per sequence in batch 325 | loss = loss.view(-1, batch.batch_size).sum(0) 326 | # Divide by length of each sequence and sum 327 | loss = torch.div(loss, tgt_lens).sum() 328 | else: 329 | loss = loss.sum() 330 | 331 | return loss 332 | 333 | 334 | def collapse_copy_scores( 335 | scores, src_tgt_vocab_map, vocab_size, 336 | keep_src_vocab_unk=True): 337 | """ 338 | Given scores from an expanded dictionary 339 | corresponeding to a batch, sums together copies, 340 | with a dictionary word when it is ambiguous. 341 | 342 | src_tgt_vocab_map: batch size x src tgt vocab map size 343 | scores: (batch size * seq len) x dynamic vocab size 344 | 345 | based on an implementation from 346 | https://github.com/OpenNMT/OpenNMT-py 347 | """ 348 | batch_size = src_tgt_vocab_map.shape[0] 349 | batch_size_by_seq_len = scores.shape[0] 350 | assert batch_size_by_seq_len % batch_size == 0, \ 351 | batch_size_by_seq_len % batch_size 352 | 353 | seq_len = batch_size_by_seq_len // batch_size 354 | offset = vocab_size 355 | 356 | fill = src_tgt_vocab_map[:, 1:].unsqueeze(1)\ 357 | .expand(-1, seq_len, -1).reshape(batch_size * seq_len, -1) 358 | pad = torch.ones(batch_size_by_seq_len, 359 | scores.shape[1] - fill.shape[1]).to(fill.device) 360 | padded_fill = torch.cat([pad, fill], 1) 361 | scores[padded_fill == -1] = 0 362 | 363 | non_neg_src_tgt_vocab_map = src_tgt_vocab_map.clone() 364 | non_neg_src_tgt_vocab_map[non_neg_src_tgt_vocab_map == -1] = 0 365 | 366 | blank = (offset + torch.arange(1, non_neg_src_tgt_vocab_map.shape[1]) 367 | .unsqueeze(0).expand(batch_size_by_seq_len, -1)).long() 368 | blank = blank.to(scores.device) 369 | fill = non_neg_src_tgt_vocab_map[:, 1:].long().unsqueeze(1)\ 370 | .expand(-1, seq_len, -1).reshape(batch_size * seq_len, -1) 371 | 372 | add_scores = torch.zeros_like(scores) 373 | indexed_scores = scores.gather(1, blank) 374 | add_scores.scatter_(1, fill, indexed_scores) 375 | if keep_src_vocab_unk: 376 | add_scores[:, 0] = 0 377 | scores = scores + add_scores 378 | 379 | scores_mask = torch.ones_like(scores) 380 | scores_mask.scatter_(1, blank, 0.0) 381 | 382 | if keep_src_vocab_unk: 383 | scores_mask[padded_fill == 0] = 1 384 | scores = scores * scores_mask 385 | 386 | return scores, scores_mask 387 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | import model_utils 8 | import encoders 9 | import decoders 10 | 11 | from decorators import auto_init_args, auto_init_pytorch 12 | 13 | 14 | class Base(nn.Module): 15 | def __init__(self, iter_per_epoch, experiment): 16 | super(Base, self).__init__() 17 | self.expe = experiment 18 | self.iter_per_epoch = iter_per_epoch 19 | self.eps = self.expe.config.eps 20 | self.expe.log.info("use_entmax: {}" 21 | .format(self.expe.config.use_entmax)) 22 | if torch.cuda.is_available(): 23 | self.device = torch.device('cuda') 24 | else: 25 | self.device = torch.device('cpu') 26 | 27 | def init_weights(self, module): 28 | """ Initialize the weights """ 29 | if isinstance(module, (nn.Linear, nn.Embedding)): 30 | # Slightly different from the TF version 31 | # which uses truncated_normal for initialization 32 | # cf https://github.com/pytorch/pytorch/pull/5617 33 | module.weight.data.normal_(mean=0.0, std=0.02) 34 | if isinstance(module, nn.Linear) and module.bias is not None: 35 | module.bias.data.zero_() 36 | 37 | def to_tensor(self, inputs): 38 | if torch.is_tensor(inputs): 39 | return inputs.clone().detach().to(self.device) 40 | else: 41 | return torch.tensor(inputs, device=self.device) 42 | 43 | def to_tensors(self, *inputs): 44 | return [self.to_tensor(inputs_) if inputs_ is not None and inputs_.size 45 | else None for inputs_ in inputs] 46 | 47 | def count_trainable_parameters(self): 48 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 49 | 50 | def count_all_parameters(self): 51 | return sum(p.numel() for p in self.parameters()) 52 | 53 | def optimize(self, loss, update_param): 54 | loss.backward() 55 | if update_param: 56 | if self.expe.config.gclip is not None: 57 | torch.nn.utils.clip_grad_norm_( 58 | self.parameters(), self.expe.config.gclip) 59 | self.opt.step() 60 | if self.expe.config.wstep: 61 | self.scheduler.step() 62 | self.opt.zero_grad() 63 | 64 | def init_optimizer(self, opt_type, learning_rate, weight_decay): 65 | if opt_type.lower() == "adam": 66 | optimizer = torch.optim.Adam 67 | elif opt_type.lower() == "rmsprop": 68 | optimizer = torch.optim.RMSprop 69 | elif opt_type.lower() == "sgd": 70 | optimizer = torch.optim.SGD 71 | else: 72 | raise NotImplementedError("invalid optimizer: {}".format(opt_type)) 73 | 74 | opt = optimizer( 75 | params=filter( 76 | lambda p: p.requires_grad, self.parameters() 77 | ), 78 | weight_decay=weight_decay, 79 | lr=learning_rate) 80 | 81 | if self.expe.config.wstep: 82 | self.scheduler = \ 83 | model_utils.get_linear_schedule_with_warmup( 84 | opt, self.expe.config.wstep, 85 | self.expe.config.n_epoch * self.iter_per_epoch) 86 | self.expe.log.info( 87 | "training with learning rate scheduler - " 88 | "iterations per epoch: {}, total epochs: {}" 89 | .format(self.iter_per_epoch, self.expe.config.n_epoch)) 90 | return opt 91 | 92 | def save(self, dev_bleu, test_bleu, epoch, iteration=None, name="best"): 93 | save_path = os.path.join(self.expe.experiment_dir, name + ".ckpt") 94 | checkpoint = { 95 | "dev_bleu": dev_bleu, 96 | "test_bleu": test_bleu, 97 | "epoch": epoch, 98 | "iteration": iteration, 99 | "state_dict": self.state_dict(), 100 | "opt_state_dict": self.opt.state_dict(), 101 | "config": self.expe.config 102 | } 103 | if self.expe.config.wstep: 104 | checkpoint["lr_scheduler_state_dict"] = self.scheduler.state_dict() 105 | torch.save(checkpoint, save_path) 106 | self.expe.log.info("model saved to {}".format(save_path)) 107 | 108 | def load(self, checkpointed_state_dict=None, name="best", path=None): 109 | if checkpointed_state_dict is None: 110 | base_path = self.expe.experiment_dir if path is None else path 111 | save_path = os.path.join(base_path, name + ".ckpt") 112 | checkpoint = torch.load(save_path, 113 | map_location=lambda storage, loc: storage) 114 | self.load_state_dict(checkpoint['state_dict']) 115 | self.opt.load_state_dict(checkpoint.get("opt_state_dict")) 116 | if self.expe.config.wstep: 117 | self.scheduler.load_state_dict( 118 | checkpoint["lr_scheduler_state_dict"]) 119 | self.expe.log.info("model loaded from {}".format(save_path)) 120 | self.to(self.device) 121 | for state in self.opt.state.values(): 122 | for k, v in state.items(): 123 | if isinstance(v, torch.Tensor): 124 | state[k] = v.to(self.device) 125 | self.expe.log.info("transferred model to {}".format(self.device)) 126 | return checkpoint.get('epoch', 0), \ 127 | checkpoint.get('iteration', 0), \ 128 | checkpoint.get('dev_bleu', 0), \ 129 | checkpoint.get('test_bleu', 0) 130 | else: 131 | self.load_state_dict(checkpointed_state_dict) 132 | self.expe.log.info("model loaded from checkpoint.") 133 | self.to(self.device) 134 | self.expe.log.info("transferred model to {}".format(self.device)) 135 | 136 | 137 | class BasicCyclicAttnSplitMask(Base): 138 | @auto_init_pytorch 139 | @auto_init_args 140 | def __init__( 141 | self, vocab_size, type_vocab_size, 142 | embed_dim, iter_per_epoch, use_entmax, experiment): 143 | super(BasicCyclicAttnSplitMask, self).__init__( 144 | iter_per_epoch, experiment) 145 | self.encode = getattr(encoders, self.expe.config.encoder_type)( 146 | embed_dim=embed_dim, 147 | nlayer=self.expe.config.elayer, 148 | nhead=self.expe.config.nhead, 149 | hidden_size=self.expe.config.ensize, 150 | dropout=self.expe.config.dp, 151 | vocab_size=vocab_size, 152 | type_vocab_size=type_vocab_size, 153 | max_pos_len=self.expe.config.max_encoder_len, 154 | act_fn=self.expe.config.act_fn) 155 | 156 | self.decode = getattr(decoders, self.expe.config.decoder_type)( 157 | embed_dim=embed_dim, 158 | encoder_size=self.expe.config.ensize, 159 | decoder_size=self.expe.config.desize, 160 | dropout=self.expe.config.dp, 161 | nlayer=self.expe.config.dlayer, 162 | nhead=self.expe.config.nhead, 163 | type_vocab_size=type_vocab_size, 164 | vocab_size=vocab_size, 165 | max_len=self.expe.config.max_decoder_len, 166 | act_fn=self.expe.config.act_fn, 167 | use_copy=self.expe.config.use_copy, 168 | use_entmax=use_entmax, 169 | share_embedding=self.expe.config.share_decoder_embedding) 170 | 171 | self.cyclic_encode = encoders.ie_transformer( 172 | embed_dim=embed_dim, 173 | nlayer=self.expe.config.bwdelayer, 174 | nhead=self.expe.config.bwdnhead, 175 | hidden_size=self.expe.config.ensize, 176 | dropout=self.expe.config.dp, 177 | vocab_size=vocab_size, 178 | type_vocab_size=type_vocab_size, 179 | max_len=self.expe.config.max_decoder_len, 180 | act_fn=self.expe.config.act_fn) 181 | 182 | self.cyclic_decode = decoders.ie_mask_transformer( 183 | embed_dim=embed_dim, 184 | encoder_size=self.expe.config.ensize, 185 | decoder_size=self.expe.config.desize, 186 | attn_size=self.expe.config.asize, 187 | dropout=self.expe.config.dp, 188 | nlayer=self.expe.config.bwddlayer, 189 | nhead=self.expe.config.bwdnhead, 190 | max_len=self.expe.config.max_encoder_len, 191 | vocab_size=vocab_size, 192 | act_fn=self.expe.config.act_fn) 193 | 194 | if self.expe.config.use_copy: 195 | self.loss_fn = model_utils.CopyGeneratorLoss( 196 | vocab_size=vocab_size, 197 | force_copy=self.expe.config.force_copy) 198 | 199 | def forward( 200 | self, data, data_mask, data_pos, data_type, data_if_hyp, 201 | data_src_vocab, data_src_tgt_vocab_map, 202 | tgt_input_data, tgt_input_data_mask, tgt_input_data_pos, 203 | tgt_input_data_type, tgt_input_data_if_hyp, 204 | tgt_output_data, tgt_output_mask, 205 | tgt_input, tgt_label, tgt_mask, tgt_src_vocab): 206 | 207 | data, data_mask, data_pos, data_type, data_if_hyp, \ 208 | data_src_vocab, data_src_tgt_vocab_map, \ 209 | tgt_input_data, tgt_input_data_mask, tgt_input_data_pos, \ 210 | tgt_input_data_type, tgt_input_if_hyp, \ 211 | tgt_output_data, tgt_output_mask, \ 212 | tgt_input, tgt_label, tgt_mask, tgt_src_vocab = \ 213 | self.to_tensors(data, data_mask, data_pos, data_type, 214 | data_if_hyp, data_src_vocab, 215 | data_src_tgt_vocab_map, 216 | tgt_input_data, tgt_input_data_mask, 217 | tgt_input_data_pos, tgt_input_data_type, 218 | tgt_input_data_if_hyp, 219 | tgt_output_data, tgt_output_mask, 220 | tgt_input, tgt_label, tgt_mask, tgt_src_vocab) 221 | data_vec = self.encode( 222 | data, data_mask, data_pos, data_type, data_if_hyp) 223 | 224 | pred_probs, _ = self.decode( 225 | encoder_outputs=data_vec, 226 | encoder_mask=data_mask, 227 | tgts=tgt_input, 228 | src_map=data_src_vocab) 229 | 230 | if self.expe.config.use_copy: 231 | loss = self.loss_fn( 232 | scores=pred_probs, 233 | align=tgt_src_vocab, 234 | target=tgt_label, 235 | src_tgt_map=data_src_tgt_vocab_map, 236 | label_smoothing=self.expe.config.lm) 237 | batch_size, seq_len = tgt_mask.shape 238 | flat_tgt_mask = tgt_mask.reshape(-1) 239 | loss = loss * flat_tgt_mask 240 | gloss = loss.reshape(batch_size, seq_len).sum(1) / \ 241 | (tgt_mask.reshape(batch_size, seq_len)).sum(1) 242 | 243 | elif self.expe.config.lm: 244 | loss_fn = model_utils.LabelSmoothingLoss( 245 | classes=self.vocab_size, 246 | smoothing=self.expe.config.lm, 247 | dim=-1) 248 | loss = loss_fn(pred_probs, tgt_label.long()) 249 | loss = loss * tgt_mask 250 | gloss = loss.sum(1) / tgt_mask.sum(1) 251 | else: 252 | batch_size, seq_len, vocab_size = pred_probs.shape 253 | flat_tgt_mask = tgt_mask.reshape(-1) 254 | flat_pred_probs = pred_probs.reshape( 255 | batch_size * seq_len, vocab_size) 256 | tgt = tgt_label.reshape(-1) 257 | gloss = F.cross_entropy( 258 | flat_pred_probs, tgt.long(), 259 | reduction="none") 260 | gloss = gloss * flat_tgt_mask 261 | gloss = gloss.reshape(batch_size, seq_len).sum(1) / tgt_mask.sum(1) 262 | gloss = gloss.mean(0) 263 | 264 | if self.expe.config.floss: 265 | if self.expe.config.use_copy: 266 | floss_tgt_vec = self.cyclic_encode.softmax_forward( 267 | model_utils.collapse_copy_scores( 268 | scores=torch.cat(pred_probs, 1), 269 | src_tgt_vocab_map=data_src_tgt_vocab_map, 270 | vocab_size=self.vocab_size, 271 | keep_src_vocab_unk=False)[0] 272 | [:, :self.vocab_size].reshape( 273 | batch_size, seq_len, self.vocab_size), 274 | tgt_mask) 275 | else: 276 | floss_tgt_vec = self.cyclic_encode.mix_forward( 277 | pred_probs, tgt_mask) 278 | 279 | floss_pred_probs = self.cyclic_decode( 280 | encoder_outputs=floss_tgt_vec, 281 | encoder_mask=tgt_mask, 282 | inp=tgt_input_data, 283 | inp_mask=tgt_input_data_mask, 284 | inp_pos=tgt_input_data_pos, 285 | inp_type=tgt_input_data_type, 286 | inp_if_hyp=tgt_input_if_hyp) 287 | 288 | batch_size, seq_len, vocab_size = floss_pred_probs.shape 289 | if self.expe.config.bwd_lm: 290 | loss_fn = model_utils.LabelSmoothingLoss( 291 | classes=self.vocab_size, 292 | smoothing=self.expe.config.bwd_lm, 293 | dim=-1) 294 | loss = loss_fn(floss_pred_probs, tgt_output_data.long()) 295 | loss = loss * tgt_output_mask 296 | floss = loss.sum(1) / tgt_output_mask.sum(1) 297 | else: 298 | floss_tgt_mask = tgt_output_mask.reshape(-1) 299 | floss_pred_probs = floss_pred_probs.reshape( 300 | batch_size * seq_len, vocab_size) 301 | floss_tgt = tgt_output_data.reshape(-1) 302 | floss = F.cross_entropy( 303 | floss_pred_probs, floss_tgt.long(), 304 | reduction="none") 305 | floss = floss * floss_tgt_mask 306 | floss = floss.reshape(batch_size, seq_len).sum(1) / \ 307 | (floss_tgt_mask.reshape(batch_size, seq_len)).sum(1) 308 | floss = floss.mean(0) 309 | else: 310 | floss = torch.zeros_like(gloss) 311 | 312 | if self.expe.config.tloss: 313 | tloss_tgt_vec = self.cyclic_encode(tgt_input, tgt_mask) 314 | 315 | tloss_pred_probs = self.cyclic_decode( 316 | encoder_outputs=tloss_tgt_vec, 317 | encoder_mask=tgt_mask, 318 | inp=tgt_input_data, 319 | inp_mask=tgt_input_data_mask, 320 | inp_pos=tgt_input_data_pos, 321 | inp_type=tgt_input_data_type, 322 | inp_if_hyp=tgt_input_if_hyp) 323 | 324 | batch_size, seq_len, vocab_size = tloss_pred_probs.shape 325 | if self.expe.config.bwd_lm: 326 | loss_fn = model_utils.LabelSmoothingLoss( 327 | classes=self.vocab_size, 328 | smoothing=self.expe.config.bwd_lm, dim=-1) 329 | loss = loss_fn(tloss_pred_probs, tgt_output_data.long()) 330 | loss = loss * tgt_output_mask 331 | tloss = loss.sum(1) / tgt_output_mask.sum(1) 332 | else: 333 | tloss_tgt_mask = tgt_output_mask.reshape(-1) 334 | tloss_pred_probs = tloss_pred_probs.reshape( 335 | batch_size * seq_len, vocab_size) 336 | tloss_tgt = tgt_output_data.reshape(-1) 337 | tloss = F.cross_entropy( 338 | tloss_pred_probs, tloss_tgt.long(), 339 | reduction="none") 340 | tloss = tloss * tloss_tgt_mask 341 | tloss = tloss.reshape(batch_size, seq_len).sum(1) / \ 342 | (tloss_tgt_mask.reshape(batch_size, seq_len)).sum(1) 343 | tloss = tloss.mean(0) 344 | else: 345 | tloss = torch.zeros_like(gloss) 346 | loss = gloss + self.expe.config.floss * floss + \ 347 | self.expe.config.tloss * tloss 348 | return loss, gloss, floss, tloss 349 | 350 | def greedy_decode( 351 | self, data, data_mask, data_pos, 352 | data_type, data_if_hyp, data_src_vocab, 353 | data_src_tgt_vocab_map, 354 | max_len, min_len, top_p, top_k): 355 | self.eval() 356 | data, data_mask, data_pos, data_type, \ 357 | data_if_hyp, data_src_vocab, \ 358 | data_src_tgt_vocab_map = \ 359 | self.to_tensors( 360 | data, data_mask, data_pos, 361 | data_type, data_if_hyp, data_src_vocab, data_src_tgt_vocab_map) 362 | data_vec = self.encode( 363 | data, data_mask, data_pos, data_type, data_if_hyp) 364 | 365 | return self.decode.generate( 366 | encoder_outputs=data_vec, 367 | encoder_mask=data_mask, 368 | max_len=max_len, 369 | min_len=min_len, 370 | top_p=top_p, 371 | top_k=top_k, 372 | src_map=data_src_vocab, 373 | src_tgt_vocab_map=data_src_tgt_vocab_map) 374 | -------------------------------------------------------------------------------- /scripts/convert2parent_dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python convert_to_parent_reference.py \ 4 | dev-trim-shuf.json \ 5 | infobox.json.devtest \ 6 | wikidata.json.devtest \ 7 | parent_dev.txt 8 | -------------------------------------------------------------------------------- /scripts/eval_dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python eval_rouge_meteor_rep.py --references $1 --generations $2 4 | python eval_parent.py --tables parent_references/parent_dev.txt --references $1 --generations $2 5 | -------------------------------------------------------------------------------- /scripts/generate_beam_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python generate_beam_search.py \ 4 | --gen_dir generations \ 5 | --gen_prefix gen_dev_beam_search_size5 \ 6 | --model_file large+copy+cyc.ckpt \ 7 | --train_path dev-trim-shuf.json \ 8 | --dev_path dev-trim-shuf.json \ 9 | --wikidata_path wikidata.json.devtest \ 10 | --infobox_path infobox.json.devtest \ 11 | --bpe_vocab cased_30k.vocab \ 12 | --bpe_codes cased_30k.codes \ 13 | --eval_batch_size 10 \ 14 | --max_gen_len 300 \ 15 | --min_gen_len 100 \ 16 | --beam_size 5 17 | -------------------------------------------------------------------------------- /scripts/train_large_copy_cyc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py \ 4 | --debug 0 \ 5 | --auto_disconnect 1 \ 6 | --save_prefix large_copy_cyc \ 7 | --decoder_type ctransformer \ 8 | --encoder_type transformer \ 9 | --n_epoch 5 \ 10 | --train_path train-trim-shuf.json \ 11 | --dev_path dev-trim-shuf.json \ 12 | --wikidata_path wikidata.json \ 13 | --infobox_path infobox.json \ 14 | --bpe_vocab cased_30k.vocab \ 15 | --bpe_codes cased_30k.codes \ 16 | --batch_size 4 \ 17 | --max_num_value 10 \ 18 | --eval_batch_size 20 \ 19 | --use_copy 1 \ 20 | --max_gen_len 150 \ 21 | --min_gen_len 50 \ 22 | --warmup_steps 2000 \ 23 | --learning_rate 3e-4 \ 24 | --gradient_accumulation_steps 50 \ 25 | --encoder_num_layer 1 \ 26 | --decoder_num_layer 12 \ 27 | --num_head 8 \ 28 | --bwd_num_head 8 \ 29 | --true_cyclic_loss 1.0 \ 30 | --fake_cyclic_loss 0.1 \ 31 | --dropout 0.1 \ 32 | --l2 0.0 \ 33 | --input_hyperlink 1 \ 34 | --input_wikidata 1 \ 35 | --embed_dim 1024 \ 36 | --encoder_size 1024 \ 37 | --decoder_size 1024 \ 38 | --print_every 100 \ 39 | --save_every 1000 \ 40 | --eval_every 2000 41 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import train_helper 4 | import data_utils 5 | import config 6 | 7 | import models 8 | 9 | BEST_DEV_BLEU = TEST_BLEU = 0 10 | 11 | 12 | def run(e): 13 | global BEST_DEV_BLEU, TEST_BLEU 14 | 15 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 16 | data_processor = data_utils.DataProcessor( 17 | train_path=e.config.train_path, 18 | dev_path=e.config.dev_path, 19 | test_path=e.config.test_path, 20 | wikidata_path=e.config.wikidata_path, 21 | infobox_path=e.config.infobox_path, 22 | bpe_vocab=e.config.bpe_vocab, 23 | bpe_codes=e.config.bpe_codes, 24 | experiment=e) 25 | data = data_processor.process() 26 | 27 | train_batch = data_utils.Minibatcher( 28 | data=data.train_data, 29 | batch_size=e.config.batch_size, 30 | save_dir=e.experiment_dir, 31 | filename="minibatcher.ckpt", 32 | log=e.log, 33 | is_eval=False, 34 | vocab_size=len(data.vocab), 35 | vocab=data.vocab, 36 | return_wikidata=e.config.return_wikidata, 37 | return_hyperlink=e.config.return_hyperlink, 38 | input_wikidata=e.config.input_wikidata, 39 | input_hyperlink=e.config.input_hyperlink, 40 | verbose=True) 41 | 42 | e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25) 43 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 44 | 45 | model = models.BasicCyclicAttnSplitMask( 46 | vocab_size=len(data.vocab), 47 | type_vocab_size=500, 48 | embed_dim=e.config.edim, 49 | iter_per_epoch=len(train_batch.idx_pool) // e.config.gcs, 50 | use_entmax=False, 51 | experiment=e) 52 | 53 | start_epoch = true_it = 0 54 | if e.config.resume: 55 | start_epoch, _, BEST_DEV_BLEU, TEST_BLEU = \ 56 | model.load(name="latest") 57 | e.log.info( 58 | "resumed from previous checkpoint: start epoch: {}, " 59 | "iteration: {}, best dev bleu: {:.3f}, test bleu: {:.3f}." 60 | .format(start_epoch, true_it, BEST_DEV_BLEU, TEST_BLEU)) 61 | 62 | e.log.info(model) 63 | e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) 64 | 65 | dev_eval = train_helper.Evaluator( 66 | model=model, 67 | data=data.dev_data, 68 | inv_vocab=data.inv_vocab, 69 | vocab=data.vocab, 70 | eval_batch_size=e.config.eval_batch_size, 71 | return_wikidata=e.config.return_wikidata, 72 | return_hyperlink=e.config.return_hyperlink, 73 | input_wikidata=e.config.input_wikidata, 74 | input_hyperlink=e.config.input_hyperlink, 75 | experiment=e) 76 | 77 | e.log.info("Training start ...") 78 | train_stats = train_helper.Tracker( 79 | ["loss", "gen_loss", "cyclic_loss", "reference_loss"]) 80 | 81 | for epoch in range(start_epoch, e.config.n_epoch): 82 | for it, (input_data, input_data_mask, input_data_pos, 83 | input_data_type, input_if_hyp, 84 | input_data_src_vocab, input_data_src_tgt_vocab_map, 85 | tgt_inp_data, tgt_inp_data_mask, 86 | tgt_inp_data_pos, tgt_inp_data_type, tgt_inp_data_if_hyp, 87 | tgt_out_data, tgt_out_data_mask, 88 | tgt_input, tgt_label, tgt_mask, tgt_src_vocab, _) in \ 89 | enumerate(train_batch): 90 | model.train() 91 | curr_it = train_batch.init_pointer + it + \ 92 | 1 + epoch * len(train_batch.idx_pool) 93 | true_it = curr_it // e.config.gcs 94 | full_division = ((curr_it % e.config.gcs) == 0) or \ 95 | (curr_it % len(train_batch.idx_pool) == 0) 96 | 97 | loss, gloss, floss, tloss = \ 98 | model(input_data, input_data_mask, input_data_pos, 99 | input_data_type, input_if_hyp, 100 | input_data_src_vocab, input_data_src_tgt_vocab_map, 101 | tgt_inp_data, tgt_inp_data_mask, 102 | tgt_inp_data_pos, tgt_inp_data_type, tgt_inp_data_if_hyp, 103 | tgt_out_data, tgt_out_data_mask, 104 | tgt_input, tgt_label, tgt_mask, tgt_src_vocab) 105 | model.optimize(loss / e.config.gcs, update_param=full_division) 106 | train_stats.update( 107 | {"loss": loss, "gen_loss": gloss, "cyclic_loss": floss, 108 | "reference_loss": tloss}, len(input_data)) 109 | 110 | if e.config.auto_disconnect and full_division: 111 | if e.elapsed_time > 3.5: 112 | e.log.info("elapsed time: {:.3}(h), " 113 | "automatically exiting the program..." 114 | .format(e.elapsed_time)) 115 | train_batch.save() 116 | model.save( 117 | dev_bleu=BEST_DEV_BLEU, 118 | test_bleu=TEST_BLEU, 119 | iteration=true_it, 120 | epoch=epoch, 121 | name="latest") 122 | sys.exit() 123 | if ((true_it + 1) % e.config.print_every == 0 or 124 | (curr_it + 1) % len(train_batch.idx_pool) == 0) \ 125 | and full_division: 126 | curr_lr = model.scheduler.get_last_lr()[0] if e.config.wstep \ 127 | else e.config.lr 128 | summarization = train_stats.summarize( 129 | "epoch: {}, it: {} (max: {}), lr: {:.4e}" 130 | .format(epoch, it, len(train_batch), curr_lr)) 131 | e.log.info(summarization) 132 | train_stats.reset() 133 | 134 | if ((true_it + 1) % e.config.eval_every == 0 or 135 | curr_it % len(train_batch.idx_pool) == 0) \ 136 | and full_division: 137 | 138 | train_batch.save() 139 | model.save( 140 | dev_bleu=BEST_DEV_BLEU, 141 | test_bleu=TEST_BLEU, 142 | iteration=true_it, 143 | epoch=epoch, 144 | name="latest") 145 | 146 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 147 | 148 | dev_bleu = dev_eval.evaluate("gen_dev") 149 | 150 | e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25) 151 | 152 | if BEST_DEV_BLEU < dev_bleu: 153 | BEST_DEV_BLEU = dev_bleu 154 | 155 | model.save( 156 | dev_bleu=BEST_DEV_BLEU, 157 | test_bleu=TEST_BLEU, 158 | iteration=true_it, 159 | epoch=epoch) 160 | e.log.info("best dev bleu: {:.4f}, test bleu: {:.4f}" 161 | .format(BEST_DEV_BLEU, TEST_BLEU)) 162 | train_stats.reset() 163 | if ((true_it + 1) % e.config.save_every == 0 or \ 164 | curr_it % len(train_batch.idx_pool) == 0) \ 165 | and full_division: 166 | train_batch.save() 167 | model.save( 168 | dev_bleu=BEST_DEV_BLEU, 169 | test_bleu=TEST_BLEU, 170 | iteration=true_it, 171 | epoch=epoch, 172 | name="latest") 173 | 174 | train_batch.save() 175 | model.save( 176 | dev_bleu=BEST_DEV_BLEU, 177 | test_bleu=TEST_BLEU, 178 | iteration=true_it, 179 | epoch=epoch + 1, 180 | name="latest") 181 | 182 | time_per_epoch = (e.elapsed_time / (epoch - start_epoch + 1)) 183 | time_in_need = time_per_epoch * (e.config.n_epoch - epoch - 1) 184 | e.log.info("elapsed time: {:.2f}(h), " 185 | "time per epoch: {:.2f}(h), " 186 | "time needed to finish: {:.2f}(h)" 187 | .format(e.elapsed_time, time_per_epoch, time_in_need)) 188 | train_stats.reset() 189 | 190 | 191 | if __name__ == '__main__': 192 | 193 | PARSED_CONFIG = config.get_base_parser().parse_args() 194 | 195 | def exit_handler(*args): 196 | print(PARSED_CONFIG) 197 | print("best dev bleu: {:.4f}, test bleu: {:.4f}" 198 | .format(BEST_DEV_BLEU, TEST_BLEU)) 199 | sys.exit() 200 | 201 | train_helper.register_exit_handler(exit_handler) 202 | 203 | with train_helper.Experiment(PARSED_CONFIG, 204 | PARSED_CONFIG.save_prefix) as exp: 205 | 206 | exp.log.info("*" * 25 + " ARGS " + "*" * 25) 207 | exp.log.info(PARSED_CONFIG) 208 | exp.log.info("*" * 25 + " ARGS " + "*" * 25) 209 | 210 | run(exp) 211 | -------------------------------------------------------------------------------- /train_helper.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-member 2 | import subprocess 3 | import data_utils 4 | import threading 5 | import argparse 6 | import logging 7 | import signal 8 | import torch 9 | import time 10 | import nltk 11 | import os 12 | 13 | from config import get_base_parser, MULTI_BLEU_PERL, \ 14 | EOS_IDX, RESOURCE_LINK, METEOR_DATA, METEOR_JAR 15 | from decorators import auto_init_args 16 | 17 | 18 | def register_exit_handler(exit_handler): 19 | import atexit 20 | 21 | atexit.register(exit_handler) 22 | signal.signal(signal.SIGTERM, exit_handler) 23 | signal.signal(signal.SIGINT, exit_handler) 24 | 25 | 26 | def run_multi_bleu(input_file, reference_file): 27 | bleu_output = subprocess.check_output( 28 | "./{} {} < {}".format(MULTI_BLEU_PERL, reference_file, input_file), 29 | stderr=subprocess.STDOUT, shell=True).decode('utf-8') 30 | bleu = float( 31 | bleu_output.strip().split("\n")[-1] 32 | .split(",")[0].split("=")[1][1:]) 33 | return bleu 34 | 35 | 36 | class Tracker: 37 | @auto_init_args 38 | def __init__(self, names): 39 | assert len(names) > 0 40 | self.reset() 41 | 42 | def __getitem__(self, name): 43 | return self.values.get(name, 0) / self.counter if self.counter else 0 44 | 45 | def __len__(self): 46 | return len(self.names) 47 | 48 | def reset(self): 49 | self.values = dict({name: 0. for name in self.names}) 50 | self.counter = 0 51 | self.create_time = time.time() 52 | 53 | def update(self, named_values, count): 54 | """ 55 | named_values: dictionary with each item as name: value 56 | """ 57 | self.counter += count 58 | for name, value in named_values.items(): 59 | self.values[name] += value.item() * count 60 | 61 | def summarize(self, output=""): 62 | if output: 63 | output += ", " 64 | for name in self.names: 65 | output += "{}: {:.3f}, ".format( 66 | name, self.values[name] / self.counter if self.counter else 0) 67 | output += "elapsed time: {:.1f}(s)".format( 68 | time.time() - self.create_time) 69 | return output 70 | 71 | @property 72 | def stats(self): 73 | return {n: v / self.counter if self.counter else 0 74 | for n, v in self.values.items()} 75 | 76 | 77 | class Experiment: 78 | @auto_init_args 79 | def __init__(self, config, experiments_prefix, 80 | forced_debug=False, logfile_name="log"): 81 | """Create a new Experiment instance. 82 | 83 | Modified based on: https://github.com/ex4sperans/mag 84 | 85 | Args: 86 | logfile_name: str, naming for log file. This can be useful to 87 | separate logs for different runs on the same experiment 88 | experiments_prefix: str, a prefix to the path where 89 | experiment will be saved 90 | """ 91 | 92 | # get all defaults 93 | all_defaults = {} 94 | for key in vars(config): 95 | all_defaults[key] = get_base_parser().get_default(key) 96 | 97 | self.default_config = all_defaults 98 | 99 | config.resume = False 100 | if not config.debug and not forced_debug: 101 | if os.path.isdir(self.experiment_dir): 102 | print("log exists: {}".format(self.experiment_dir)) 103 | config.resume = True 104 | 105 | print(config) 106 | self._makedir() 107 | 108 | # self._make_misc_dir() 109 | 110 | def _makedir(self): 111 | os.makedirs(self.experiment_dir, exist_ok=True) 112 | 113 | def _make_misc_dir(self): 114 | os.makedirs(self.config.vocab_file, exist_ok=True) 115 | 116 | @property 117 | def log_file(self): 118 | return os.path.join(self.experiment_dir, self.logfile_name) 119 | 120 | @property 121 | def experiment_dir(self): 122 | if self.config.debug or self.forced_debug: 123 | return "./" 124 | else: 125 | # get namespace for each group of args 126 | arg_g = dict() 127 | for group in get_base_parser()._action_groups: 128 | group_d = {a.dest: self.default_config.get(a.dest, None) 129 | for a in group._group_actions} 130 | arg_g[group.title] = argparse.Namespace(**group_d) 131 | 132 | # skip default value 133 | identifier = "" 134 | for key, value in sorted(vars(arg_g["model_configs"]).items()): 135 | if getattr(self.config, key) != value: 136 | identifier += key + str(getattr(self.config, key)) 137 | return os.path.join(self.experiments_prefix, identifier) 138 | 139 | def register_directory(self, dirname): 140 | directory = os.path.join(self.experiment_dir, dirname) 141 | os.makedirs(directory, exist_ok=True) 142 | setattr(self, dirname, directory) 143 | 144 | def _register_existing_directories(self): 145 | for item in os.listdir(self.experiment_dir): 146 | fullpath = os.path.join(self.experiment_dir, item) 147 | if os.path.isdir(fullpath): 148 | setattr(self, item, fullpath) 149 | 150 | def __enter__(self): 151 | for handler in logging.root.handlers[:]: 152 | logging.root.removeHandler(handler) 153 | if self.config.debug or self.forced_debug: 154 | logging.basicConfig( 155 | level=logging.DEBUG, 156 | format='%(asctime)s %(levelname)s: %(message)s', 157 | datefmt='%m-%d %H:%M') 158 | else: 159 | print("log saving to", self.log_file) 160 | logging.basicConfig( 161 | filename=self.log_file, 162 | filemode='a+', level=logging.INFO, 163 | format='%(asctime)s %(levelname)s: %(message)s', 164 | datefmt='%m-%d %H:%M') 165 | 166 | self.log = logging.getLogger() 167 | self.start_time = time.time() 168 | return self 169 | 170 | def __exit__(self, *args): 171 | logging.shutdown() 172 | 173 | @property 174 | def elapsed_time(self): 175 | return (time.time() - self.start_time) / 3600 176 | 177 | 178 | class Evaluator: 179 | def __init__(self, model, eval_batch_size, data, inv_vocab, vocab, 180 | return_wikidata, return_hyperlink, 181 | input_wikidata, input_hyperlink, experiment): 182 | self.model = model 183 | self.expe = experiment 184 | self.inv_vocab = inv_vocab 185 | 186 | self.data_iterator = data_utils.Minibatcher( 187 | batch_size=eval_batch_size, 188 | data=data, 189 | is_eval=True, 190 | save_dir=None, 191 | verbose=False, 192 | vocab_size=len(inv_vocab), 193 | vocab=vocab, 194 | return_wikidata=return_wikidata, 195 | return_hyperlink=return_hyperlink, 196 | input_wikidata=input_wikidata, 197 | input_hyperlink=input_hyperlink, 198 | filename="devtesteval_minibatcher.ckpt", 199 | log=self.expe.log) 200 | self.eval_stats = Tracker( 201 | ["loss", "gen_loss", "fake_cyclic_loss", "true_cyclic_loss"]) 202 | 203 | def evaluate(self, gen_fn): 204 | self.model.eval() 205 | all_gen = {} 206 | self.expe.log.info("max gen len: {}, min gen len: {}, top p: {}" 207 | .format(self.expe.config.max_gen_len, 208 | self.expe.config.min_gen_len, 209 | self.expe.config.top_p)) 210 | for nbatch, (input_data, input_data_mask, input_data_pos, 211 | input_data_type, input_if_hyp, input_data_src_vocab, 212 | input_data_src_tgt_vocab_map, 213 | tgt_inp_data, tgt_inp_data_mask, tgt_inp_data_pos, 214 | tgt_inp_data_type, tgt_inp_data_if_hyp, 215 | tgt_out_data, tgt_out_data_mask, 216 | tgt_input, tgt_label, tgt_mask, 217 | tgt_src_vocab, batch_idx) in \ 218 | enumerate(self.data_iterator): 219 | if nbatch and nbatch % (len(self.data_iterator) // 10 + 1) == 0: 220 | self.expe.log.info( 221 | "evaluating progress: {}/{} = {:.2f} %" 222 | .format(nbatch, len(self.data_iterator), 223 | nbatch / (len(self.data_iterator) + 1) * 100) 224 | ) 225 | with torch.no_grad(): 226 | batch_gen, _ = self.model.greedy_decode( 227 | input_data, input_data_mask, input_data_pos, 228 | input_data_type, input_if_hyp, 229 | input_data_src_vocab, input_data_src_tgt_vocab_map, 230 | self.expe.config.max_gen_len, self.expe.config.min_gen_len, 231 | self.expe.config.top_p, self.expe.config.top_k) 232 | 233 | assert len(batch_gen) == len(batch_idx), \ 234 | "len(batch_gen)={} != len(batch_idx)={}"\ 235 | .format(len(batch_gen), len(batch_idx)) 236 | for gen, idx in zip(batch_gen, batch_idx): 237 | curr_gen = [] 238 | for i in gen: 239 | if i == EOS_IDX: 240 | break 241 | if int(i) >= len(self.inv_vocab): 242 | curr_gen.append( 243 | self.data_iterator.data[idx]["inv_src_vocab"] 244 | [int(i) - len(self.inv_vocab)]) 245 | else: 246 | curr_gen.append(self.inv_vocab[int(i)]) 247 | all_gen[idx] = " ".join(curr_gen)\ 248 | .replace("@@ ", "").replace("@@", "") 249 | assert len(all_gen) == len(self.data_iterator.data), \ 250 | "{} != {}".format(len(all_gen), len(self.data_iterator.data)) 251 | file_name = os.path.join(self.expe.experiment_dir, gen_fn + ".txt") 252 | ref_file_name = \ 253 | os.path.join(self.expe.experiment_dir, gen_fn + "_ref.txt") 254 | 255 | file_name_untok = \ 256 | os.path.join(self.expe.experiment_dir, gen_fn + "_untok.txt") 257 | ref_file_name_untok = \ 258 | os.path.join(self.expe.experiment_dir, gen_fn + "_untok_ref.txt") 259 | 260 | gen_len_list = [] 261 | with open(file_name, "w") as fp, open(ref_file_name, "w") as fp2, \ 262 | open(file_name_untok, "w") as fpu, \ 263 | open(ref_file_name_untok, "w") as fpu2: 264 | for hyp_idx, ref in zip(sorted(all_gen), 265 | sorted(self.data_iterator.data, 266 | key=lambda x: x["idx"])): 267 | assert hyp_idx == ref["idx"], \ 268 | "hyp_idx={} != ref[\"idx\"]={}".format(hyp_idx, ref["idx"]) 269 | hyp = all_gen[hyp_idx] 270 | fp2.write(ref["tok_text"] + "\n") 271 | fpu2.write(ref["untok_text"] + "\n") 272 | if hyp: 273 | tok_hyp = nltk.word_tokenize(hyp) 274 | gen_len_list.append(len(tok_hyp)) 275 | tok_hyp = " ".join(tok_hyp) 276 | fp.write(tok_hyp + "\n") 277 | fpu.write(hyp + "\n") 278 | else: 279 | gen_len_list.append(0) 280 | fp.write("\n") 281 | fpu.write("\n") 282 | 283 | bleu_score = run_multi_bleu(file_name, ref_file_name) 284 | self.expe.log.info( 285 | "generated sentences saved to: {}".format(file_name)) 286 | 287 | self.expe.log.info( 288 | "#Data: {}, bleu: {:.3f}, loss: {:.3f}, " 289 | "gloss: {:.3f}, floss: {:.3f}, tloss: {:.3f}, " 290 | "avg gen len: {:.1f}" 291 | .format(len(all_gen), bleu_score, self.eval_stats["loss"], 292 | self.eval_stats["gen_loss"], 293 | self.eval_stats["fake_cyclic_loss"], 294 | self.eval_stats["true_cyclic_loss"], 295 | sum(gen_len_list) / len(gen_len_list)) 296 | ) 297 | self.eval_stats.reset() 298 | return bleu_score 299 | 300 | 301 | def enc(s): 302 | return s.encode('utf-8') 303 | 304 | 305 | def dec(s): 306 | return s.decode('utf-8') 307 | 308 | 309 | class Meteor: 310 | def __init__(self): 311 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, 312 | '-', '-', '-stdio', '-l', 'en', '-norm', '-a', 313 | METEOR_DATA] 314 | for file in [METEOR_JAR, METEOR_DATA]: 315 | assert os.path.isfile(file), \ 316 | "{} not exsit! Please download it from {}"\ 317 | .format(file, RESOURCE_LINK) 318 | self.meteor_p = subprocess.Popen( 319 | self.meteor_cmd, 320 | cwd=os.path.dirname(os.path.abspath(__file__)), 321 | stdin=subprocess.PIPE, 322 | stdout=subprocess.PIPE, 323 | stderr=subprocess.PIPE) 324 | # Used to guarantee thread safety 325 | self.lock = threading.Lock() 326 | 327 | def compute_score(self, gts, res): 328 | assert(gts.keys() == res.keys()) 329 | imgIds = gts.keys() 330 | scores = [] 331 | 332 | eval_line = 'EVAL' 333 | self.lock.acquire() 334 | for i in imgIds: 335 | assert(len(res[i]) == 1) 336 | stat = self._stat(res[i][0], gts[i]) 337 | eval_line += ' ||| {}'.format(stat) 338 | 339 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 340 | self.meteor_p.stdin.flush() 341 | for i in range(0, len(imgIds)): 342 | scores.append(dec(float(self.meteor_p.stdout.readline().strip()))) 343 | score = float(dec(self.meteor_p.stdout.readline().strip())) 344 | self.lock.release() 345 | 346 | return score, scores 347 | 348 | def _stat(self, hypothesis_str, reference_list): 349 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 350 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 351 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 352 | self.meteor_p.stdin.write(enc(score_line + "\n")) 353 | self.meteor_p.stdin.flush() 354 | return dec(self.meteor_p.stdout.readline()).strip() 355 | 356 | def _score(self, hypothesis_str, reference_list): 357 | # self.lock.acquire() 358 | with self.lock: 359 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 360 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 361 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 362 | self.meteor_p.stdin.write(enc(score_line + "\n")) 363 | self.meteor_p.stdin.flush() 364 | stats = dec(self.meteor_p.stdout.readline().strip()) 365 | eval_line = 'EVAL ||| {}'.format(stats) 366 | # EVAL ||| stats 367 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 368 | self.meteor_p.stdin.flush() 369 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 370 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 371 | # thanks for Andrej for pointing this out 372 | score = float(dec(self.meteor_p.stdout.readline().strip())) 373 | # self.lock.release() 374 | return score 375 | 376 | def __del__(self): 377 | with self.lock: 378 | if self.meteor_p: 379 | self.meteor_p.stdin.close() 380 | self.meteor_p.kill() 381 | self.meteor_p.wait() 382 | self.meteor_p = None 383 | -------------------------------------------------------------------------------- /transformer_xlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | import itertools 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import model_utils 14 | from copy import deepcopy 15 | from entmax import entmax_bisect 16 | 17 | 18 | N_MAX_POSITIONS = 512 # maximum input sequence length 19 | 20 | 21 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 22 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 23 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 24 | if padding_idx is not None: 25 | nn.init.constant_(m.weight[padding_idx], 0) 26 | return m 27 | 28 | 29 | def Linear(in_features, out_features, bias=True): 30 | m = nn.Linear(in_features, out_features, bias) 31 | # nn.init.normal_(m.weight, mean=0, std=1) 32 | # nn.init.xavier_uniform_(m.weight) 33 | # nn.init.constant_(m.bias, 0.) 34 | return m 35 | 36 | 37 | class PredLayer(nn.Module): 38 | """ 39 | Prediction layer (cross_entropy or adaptive_softmax). 40 | """ 41 | def __init__(self, emb_dim, n_words): 42 | super().__init__() 43 | self.n_words = n_words 44 | self.proj = Linear(emb_dim, n_words, bias=True) 45 | 46 | def forward(self, x, y, get_scores=False): 47 | """ 48 | Compute the loss, and optionally the scores. 49 | """ 50 | scores = self.proj(x).view(-1, self.n_words) 51 | loss = F.cross_entropy(scores, y, reduction='mean') 52 | 53 | return scores, loss 54 | 55 | def get_scores(self, x): 56 | """ 57 | Compute scores. 58 | """ 59 | return self.proj(x) 60 | 61 | 62 | class MultiHeadAttention(nn.Module): 63 | 64 | NEW_ID = itertools.count() 65 | 66 | def __init__(self, n_heads, dim, dropout): 67 | super().__init__() 68 | self.layer_id = next(MultiHeadAttention.NEW_ID) 69 | self.dim = dim 70 | self.n_heads = n_heads 71 | self.dropout = dropout 72 | assert self.dim % self.n_heads == 0 73 | 74 | self.q_lin = Linear(dim, dim) 75 | self.k_lin = Linear(dim, dim) 76 | self.v_lin = Linear(dim, dim) 77 | self.out_lin = Linear(dim, dim) 78 | 79 | def forward(self, input, mask, kv=None, cache=None): 80 | """ 81 | Self-attention (if kv is None) or attention over source sentence (provided by kv). 82 | """ 83 | # Input is (bs, qlen, dim) 84 | # Mask is (bs, klen) (non-causal) or (bs, klen, klen) 85 | bs, qlen, dim = input.size() 86 | if kv is None: 87 | klen = qlen if cache is None else cache['slen'] + qlen 88 | else: 89 | klen = kv.size(1) 90 | assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) 91 | n_heads = self.n_heads 92 | dim_per_head = dim // n_heads 93 | mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) 94 | 95 | def shape(x): 96 | """ projection """ 97 | return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) 98 | 99 | def unshape(x): 100 | """ compute context """ 101 | return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) 102 | 103 | q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) 104 | if kv is None: 105 | k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) 106 | v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) 107 | elif cache is None or self.layer_id not in cache: 108 | k = v = kv 109 | k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) 110 | v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) 111 | 112 | if cache is not None: 113 | if self.layer_id in cache: 114 | if kv is None: 115 | k_, v_ = cache[self.layer_id] 116 | k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) 117 | v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) 118 | else: 119 | k, v = cache[self.layer_id] 120 | cache[self.layer_id] = (k, v) 121 | 122 | q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) 123 | scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) 124 | mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) 125 | scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) 126 | 127 | weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) 128 | drop_weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) 129 | context = torch.matmul(drop_weights, v) # (bs, n_heads, qlen, dim_per_head) 130 | context = unshape(context) # (bs, qlen, dim) 131 | 132 | return self.out_lin(context), weights 133 | 134 | 135 | class TransformerFFN(nn.Module): 136 | 137 | def __init__(self, in_dim, dim_hidden, out_dim, dropout, gelu_activation): 138 | super().__init__() 139 | self.dropout = dropout 140 | self.lin1 = Linear(in_dim, dim_hidden) 141 | self.lin2 = Linear(dim_hidden, out_dim) 142 | self.act = F.gelu 143 | # 144 | # def forward(self, input): 145 | # x = self.lin1(input) 146 | # x = self.act(x) 147 | # x = self.lin2(x) 148 | # x = F.dropout(x, p=self.dropout, training=self.training) 149 | # return x 150 | 151 | def forward(self, input): 152 | x = self.lin1(input) 153 | x = F.dropout(self.act(x), p=self.dropout, training=self.training) 154 | x = self.lin2(x) 155 | x = F.dropout(x, p=self.dropout, training=self.training) 156 | return x 157 | 158 | 159 | class CacheTransformer(nn.Module): 160 | def __init__(self, n_words, bos_index, eos_index, 161 | pad_index, emb_dim, n_heads, n_layers, 162 | dropout, share_embedding, attention_dropout, 163 | max_len, use_copy, use_entmax): 164 | """ 165 | Transformer model decoder. 166 | """ 167 | super().__init__() 168 | 169 | # encoder / decoder, output layer 170 | self.is_encoder = False 171 | self.is_decoder = not self.is_encoder 172 | 173 | self.n_words = n_words 174 | self.bos_index = bos_index 175 | self.eos_index = eos_index 176 | self.pad_index = pad_index 177 | self.use_copy = use_copy 178 | self.use_entmax = use_entmax 179 | 180 | # model parameters 181 | self.dim = emb_dim # 512 by default 182 | self.hidden_dim = self.dim * 4 # 2048 by default 183 | self.n_heads = n_heads # 8 by default 184 | self.n_layers = n_layers 185 | self.dropout = dropout 186 | self.attention_dropout = attention_dropout 187 | assert self.dim % self.n_heads == 0, \ 188 | 'transformer dim must be a multiple of n_heads' 189 | 190 | # embeddings 191 | self.position_embeddings = nn.Embedding(max_len, self.dim) 192 | self.embeddings = nn.Embedding(self.n_words, self.dim) 193 | 194 | # transformer layers 195 | self.attentions = nn.ModuleList() 196 | self.layer_norm1 = nn.ModuleList() 197 | self.ffns = nn.ModuleList() 198 | self.layer_norm2 = nn.ModuleList() 199 | if self.is_decoder: 200 | self.layer_norm15 = nn.ModuleList() 201 | self.encoder_attn = nn.ModuleList() 202 | 203 | for layer_id in range(self.n_layers): 204 | self.attentions.append(MultiHeadAttention( 205 | self.n_heads, self.dim, dropout=self.attention_dropout)) 206 | self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12)) 207 | if self.is_decoder: 208 | self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) 209 | self.encoder_attn.append( 210 | MultiHeadAttention(self.n_heads, self.dim, 211 | dropout=self.attention_dropout)) 212 | 213 | self.ffns.append( 214 | TransformerFFN(self.dim, self.hidden_dim, 215 | self.dim, dropout=self.dropout, 216 | gelu_activation=F.gelu)) 217 | self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) 218 | 219 | # output layer 220 | # if self.with_output: 221 | self.pred_layer = PredLayer(emb_dim, n_words) 222 | if share_embedding: 223 | self.pred_layer.proj.weight = self.embeddings.weight 224 | if use_copy: 225 | self.copy_generator = \ 226 | model_utils.CopyGenerator(self.dim, use_entmax) 227 | 228 | def _generate_square_subsequent_mask(self, seq): 229 | mask = (torch.triu(torch.ones(seq, seq)) == 1).transpose(0, 1) 230 | return mask 231 | 232 | def fwd(self, x, src_enc=None, src_mask=None, 233 | positions=None, cache=None, src_map=None): 234 | """ 235 | Inputs: 236 | `x` LongTensor(slen, bs), containing word indices 237 | `lengths` LongTensor(bs), containing the length of each sentence 238 | `causal` Boolean, if True, the attention is only done over previous hidden states 239 | `positions` LongTensor(slen, bs), containing word positions 240 | `langs` LongTensor(slen, bs), containing language IDs 241 | 242 | src_enc: bs x seq len x dim 243 | x: bs x seq len 244 | """ 245 | bs, slen = x.size() 246 | if src_enc is not None: 247 | assert self.is_decoder 248 | assert src_enc.size(0) == bs 249 | 250 | attn_mask = self._generate_square_subsequent_mask(slen).to(x.device) 251 | attn_mask = attn_mask.unsqueeze(0).expand(bs, slen, slen) 252 | 253 | # positions 254 | if positions is None: 255 | # positions = x.new(slen).long() 256 | positions = torch.arange(slen).unsqueeze(0).to(x.device).long() 257 | else: 258 | assert positions.size() == (slen, bs) 259 | positions = positions.transpose(0, 1) 260 | 261 | # do not recompute cached elements 262 | if cache is not None: 263 | _slen = slen - cache['slen'] 264 | x = x[:, -_slen:] 265 | positions = positions[:, -_slen:] 266 | attn_mask = attn_mask[:, -_slen:] 267 | 268 | # embeddings 269 | tensor = self.embeddings(x.long()) 270 | tensor = tensor + self.position_embeddings(positions).expand_as(tensor) 271 | tensor = F.dropout(tensor, p=self.dropout, training=self.training) 272 | 273 | # transformer layers 274 | for i in range(self.n_layers): 275 | 276 | # self attention 277 | attn, _ = self.attentions[i](tensor, attn_mask, cache=cache) 278 | attn = F.dropout(attn, p=self.dropout, training=self.training) 279 | tensor = tensor + attn 280 | tensor = self.layer_norm1[i](tensor) 281 | 282 | # encoder attention (for decoder only) 283 | if self.is_decoder and src_enc is not None: 284 | attn, attn_weight = \ 285 | self.encoder_attn[i](tensor, src_mask, 286 | kv=src_enc, cache=cache) 287 | attn = F.dropout(attn, p=self.dropout, training=self.training) 288 | tensor = tensor + attn 289 | tensor = self.layer_norm15[i](tensor) 290 | 291 | tensor = tensor + self.ffns[i](tensor) 292 | tensor = self.layer_norm2[i](tensor) 293 | 294 | # update cache length 295 | if cache is not None: 296 | cache['slen'] += tensor.size(1) 297 | 298 | pred_prob = self.pred_layer.get_scores(tensor) 299 | 300 | if self.use_copy: 301 | assert src_map is not None 302 | copy_attn = attn_weight[:, 0, :, :].contiguous() 303 | pred_prob = self.copy_generator( 304 | hidden=self.copy_generator._bottle(tensor), 305 | orig_prob=self.copy_generator._bottle(pred_prob), 306 | attn=self.copy_generator._bottle(copy_attn), 307 | src_map=src_map) 308 | 309 | return pred_prob, tensor 310 | 311 | def _generate(self, src_enc, src_mask, max_len=200, min_len=0, top_p=None, src_map=None, src_tgt_vocab_map=None): 312 | """ 313 | Decode a sentence given initial start. 314 | `x`: 315 | - LongTensor(bs, slen) 316 | W1 W2 W3 317 | W1 W2 W3 W4 318 | `lengths`: 319 | - LongTensor(bs) [5, 6] 320 | `positions`: 321 | - False, for regular "arange" positions (LM) 322 | - True, to reset positions from the new generation (MT) 323 | `langs`: 324 | - must be None if the model only supports one language 325 | - lang_id if only one language is involved (LM) 326 | - (lang_id1, lang_id2) if two languages are involved (MT) 327 | """ 328 | 329 | # input batch 330 | bs = len(src_mask) 331 | assert src_enc.size(0) == bs 332 | 333 | # generated sentences 334 | generated = src_mask.new(bs, max_len) # upcoming output 335 | generated.fill_(self.pad_index) # fill upcoming ouput with 336 | generated[:, 0].fill_(self.bos_index) # we use for everywhere 337 | 338 | # current position / max lengths / length of generated sentences / unfinished sentences 339 | cur_len = 1 340 | # gen_len = torch.ones(bs).to(src_mask.device).long() 341 | unfinished_sents = torch.ones(bs).to(src_mask.device).long() #src_len.clone().fill_(1) 342 | all_scores = torch.zeros(bs).to(src_mask.device) 343 | # cache compute states 344 | cache = {'slen': 0} 345 | 346 | while cur_len < max_len: 347 | 348 | # compute word scores 349 | tensor, _ = self.fwd( 350 | x=generated[:, :cur_len] if not self.use_copy else generated[:, :cur_len].masked_fill(generated[:, :cur_len].gt(self.n_words - 1), 0), 351 | src_enc=src_enc, 352 | src_mask=src_mask, 353 | cache=cache, 354 | src_map=src_map 355 | ) 356 | if self.use_copy: 357 | tensor = torch.cat(tensor, 1) 358 | scores, _ = model_utils.collapse_copy_scores( 359 | scores=tensor, 360 | src_tgt_vocab_map=src_tgt_vocab_map, 361 | vocab_size=self.n_words) 362 | scores[:, self.n_words] = 0.0 363 | else: 364 | assert tensor.size() == (bs, 1, self.n_words), (cur_len, max_len, src_enc.size(), tensor.size(), (1, bs, self.n_words)) 365 | scores = tensor[:, -1, :] # (bs, dim) 366 | # scores = self.pred_layer.get_scores(tensor) # (bs, n_words) 367 | 368 | scores[:, 0] = -float('Inf') if not self.use_copy else 0 369 | scores[:, self.pad_index] = -float('Inf') if not self.use_copy else 0 370 | scores[:, self.bos_index] = -float('Inf') if not self.use_copy else 0 371 | 372 | if cur_len < min_len: 373 | scores[:, self.eos_index] = -float('Inf') if not self.use_copy else 0 374 | 375 | # select next words: sample or greedy 376 | if top_p: 377 | if self.use_copy: 378 | next_words = torch.multinomial(model_utils.top_k_top_p_filtering(scores, top_k=0.0, top_p=top_p if top_p else 0.0, filter_value=0.0, need_softmax=False), 1).squeeze(1) 379 | next_scores = (scores + 1e-10).log().gather(1, next_words.unsqueeze(1)).squeeze(1) 380 | else: 381 | next_words = torch.multinomial(F.softmax(model_utils.top_k_top_p_filtering(scores, top_k=0.0, top_p=top_p if top_p else 0.0), dim=1), 1).squeeze(1) 382 | next_scores = scores.log_softmax(1).gather(1, next_words.unsqueeze(1)).squeeze(1) 383 | else: 384 | if self.use_copy: 385 | next_scores, next_words = (scores + 1e-10).log().max(1) 386 | elif self.use_entmax: 387 | next_scores, next_words = (entmax_bisect(scores, 1.2) + 1e-10).log().max(1) 388 | else: 389 | next_scores, next_words = scores.log_softmax(1).max(1) 390 | assert next_words.size() == (bs,) 391 | 392 | # update generations / lengths / finished sentences / current length 393 | generated[:, cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents) 394 | all_scores = all_scores + next_scores * unfinished_sents.float() 395 | # gen_len.add_(unfinished_sents) 396 | unfinished_sents.mul_(next_words.ne(self.eos_index).long()) 397 | cur_len = cur_len + 1 398 | 399 | # stop when there is a in each sentence, or if we exceed the maximul length 400 | if unfinished_sents.max() == 0: 401 | break 402 | 403 | # add to unfinished sentences 404 | if cur_len == max_len: 405 | generated[:, -1].masked_fill_(unfinished_sents.bool(), self.eos_index) 406 | 407 | # sanity check 408 | assert (generated == self.eos_index).sum() == bs 409 | 410 | return generated[:, 1:cur_len].cpu().numpy(), all_scores.cpu().numpy() #, gen_len 411 | 412 | 413 | def _generate_beam(self, src_enc, src_mask, beam_size, length_penalty=0.0, early_stopping=False, min_len=0, max_len=200, trigram_blocking=False, return_all=False, src_map=None, src_tgt_vocab_map=None): 414 | """ 415 | Decode a sentence given initial start. 416 | `x`: 417 | - LongTensor(bs, slen) 418 | W1 W2 W3 419 | W1 W2 W3 W4 420 | `lengths`: 421 | - LongTensor(bs) [5, 6] 422 | `positions`: 423 | - False, for regular "arange" positions (LM) 424 | - True, to reset positions from the new generation (MT) 425 | `langs`: 426 | - must be None if the model only supports one language 427 | - lang_id if only one language is involved (LM) 428 | - (lang_id1, lang_id2) if two languages are involved (MT) 429 | """ 430 | 431 | # check inputs 432 | assert src_enc.size(0) == src_mask.size(0) 433 | assert beam_size >= 1 434 | 435 | # batch size / number of words 436 | bs = len(src_mask) 437 | n_words = self.n_words if not self.use_copy else self.n_words + src_tgt_vocab_map.shape[1] 438 | 439 | # expand to beam size the source latent representations / source lengths 440 | src_enc = src_enc.unsqueeze(1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:]) 441 | src_mask = src_mask.unsqueeze(1).expand((bs, beam_size) + src_mask.shape[1:]).contiguous().view((bs * beam_size,) + src_mask.shape[1:]) 442 | if src_tgt_vocab_map is not None: 443 | src_tgt_vocab_map = src_tgt_vocab_map.unsqueeze(1).expand((bs, beam_size) + src_tgt_vocab_map.shape[1:]).contiguous().view((bs * beam_size,) + src_tgt_vocab_map.shape[1:]) 444 | if src_map is not None: 445 | src_map = src_map.unsqueeze(1).expand((bs, beam_size) + src_map.shape[1:]).contiguous().view((bs * beam_size,) + src_map.shape[1:]) 446 | # src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) 447 | 448 | # generated sentences (batch with beam current hypotheses) 449 | generated = src_enc.new(bs * beam_size, max_len) # upcoming output 450 | generated.fill_(self.pad_index) # fill upcoming ouput with 451 | generated[:, 0].fill_(self.bos_index) # we use for everywhere 452 | 453 | # generated hypotheses 454 | generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)] 455 | trigram_set = [set() for _ in range(bs * beam_size)] 456 | 457 | # scores for each sentence in the beam 458 | beam_scores = src_enc.new(bs, beam_size).fill_(0) 459 | beam_scores[:, 1:] = -1e9 460 | beam_scores = beam_scores.view(-1) 461 | 462 | # current position 463 | cur_len = 1 464 | 465 | # cache compute states 466 | cache = {'slen': 0} 467 | 468 | # done sentences 469 | done = [False for _ in range(bs)] 470 | 471 | while cur_len < max_len: 472 | 473 | # compute word scores 474 | tensor, _ = self.fwd( 475 | x=generated[:, :cur_len] if not self.use_copy else generated[:, :cur_len].masked_fill(generated[:, :cur_len].gt(self.n_words - 1), 0), 476 | src_enc=src_enc, 477 | src_mask=src_mask, 478 | cache=cache, 479 | src_map=src_map 480 | ) 481 | if self.use_copy: 482 | tensor = torch.cat(tensor, 1) 483 | scores, _ = model_utils.collapse_copy_scores( 484 | scores=tensor, 485 | src_tgt_vocab_map=src_tgt_vocab_map, 486 | vocab_size=self.n_words) 487 | scores[:, self.n_words] = 0 488 | else: 489 | assert tensor.size() == (bs * beam_size, 1, self.n_words) 490 | scores = tensor[:, -1, :] # (bs * beam_size, dim) 491 | 492 | scores[:, 0] = -float('Inf') if not self.use_copy else 0 493 | scores[:, self.pad_index] = -float('Inf') if not self.use_copy else 0 494 | scores[:, self.bos_index] = -float('Inf') if not self.use_copy else 0 495 | 496 | if cur_len < min_len: 497 | scores[:, self.eos_index] = -float('Inf') if not self.use_copy else 0 498 | 499 | if self.use_copy: 500 | scores = (scores + 1e-10).log() 501 | elif self.use_entmax: 502 | scores = torch.log(entmax_bisect(scores, 1.2) + 1e-10) 503 | else: 504 | scores = F.log_softmax(scores, dim=-1) # (bs * beam_size, n_words) 505 | 506 | assert scores.size() == (bs * beam_size, n_words), (scores.shape, (bs * beam_size, n_words)) 507 | 508 | # select next words with scores 509 | _scores = scores + beam_scores[:, None].expand_as(scores) # (bs * beam_size, n_words) 510 | _scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words) 511 | 512 | next_scores, next_words = torch.sort(_scores, dim=1, descending=True) 513 | assert next_scores.size() == next_words.size() == (bs, n_words * beam_size) 514 | 515 | # next batch beam content 516 | # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch) 517 | next_batch_beam = [] 518 | 519 | # for each sentence 520 | for sent_id in range(bs): 521 | 522 | # if we are done with this sentence 523 | done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item()) 524 | if done[sent_id]: 525 | next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size) # pad the batch 526 | continue 527 | 528 | # next sentence beam content 529 | next_sent_beam = [] 530 | n_add = 0 531 | 532 | # next words for this sentence 533 | for idx, value in zip(next_words[sent_id], next_scores[sent_id]): 534 | 535 | # get beam and word IDs 536 | beam_id = idx // n_words 537 | word_id = idx % n_words 538 | 539 | if trigram_blocking and cur_len > 2: 540 | trigram = tuple(generated[sent_id * beam_size + beam_id, cur_len-2:cur_len].tolist() + [word_id.item()]) 541 | if trigram in trigram_set[sent_id * beam_size + beam_id]: 542 | continue 543 | # end of sentence, or next word 544 | if word_id == self.eos_index or cur_len + 1 == max_len: 545 | n_add += 1 546 | generated_hyps[sent_id].add(generated[sent_id * beam_size + beam_id, :cur_len].clone(), value.item()) 547 | else: 548 | next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id)) 549 | if trigram_blocking and cur_len > 2: 550 | trigram_set[sent_id * beam_size + beam_id].add(trigram) 551 | 552 | # the beam for next step is full 553 | if len(next_sent_beam) == beam_size or (cur_len + 1 == max_len and n_add == beam_size): 554 | break 555 | 556 | # update next beam content 557 | assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size 558 | if len(next_sent_beam) == 0: 559 | next_sent_beam = [(0, self.pad_index, 0)] * beam_size # pad the batch 560 | next_batch_beam.extend(next_sent_beam) 561 | assert len(next_batch_beam) == beam_size * (sent_id + 1) 562 | 563 | # sanity check / prepare next batch 564 | assert len(next_batch_beam) == bs * beam_size 565 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 566 | beam_words = generated.new([x[1] for x in next_batch_beam]) 567 | beam_idx = generated.new([x[2] for x in next_batch_beam]).long() 568 | 569 | # re-order batch and internal states 570 | trigram_set = [deepcopy(trigram_set[x[2]]) for x in next_batch_beam] 571 | generated = generated[beam_idx, :] 572 | generated[:, cur_len] = beam_words 573 | for k in cache.keys(): 574 | if k != 'slen': 575 | cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) 576 | 577 | # update current length 578 | cur_len = cur_len + 1 579 | 580 | # stop when we are done with each sentence 581 | if all(done): 582 | break 583 | 584 | if return_all: 585 | return generated_hyps 586 | 587 | # select the best hypotheses 588 | tgt_len = src_enc.new(bs).long() 589 | best = [] 590 | best_scores = [] 591 | 592 | for i, hypotheses in enumerate(generated_hyps): 593 | best_score, best_hyp = max(hypotheses.hyp, key=lambda x: x[0]) 594 | tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol 595 | best.append(best_hyp) 596 | best_scores.append(best_score) 597 | 598 | # generate target batch 599 | decoded = src_enc.new(tgt_len.max().item(), bs).fill_(self.pad_index) 600 | for i, hypo in enumerate(best): 601 | decoded[:tgt_len[i] - 1, i] = hypo 602 | decoded[tgt_len[i] - 1, i] = self.eos_index 603 | 604 | # sanity check 605 | assert (decoded == self.eos_index).sum() == bs 606 | 607 | return decoded.transpose(0, 1).cpu().numpy(), best_scores, tgt_len.cpu().numpy() 608 | 609 | 610 | class BeamHypotheses(object): 611 | 612 | def __init__(self, n_hyp, max_len, length_penalty, early_stopping): 613 | """ 614 | Initialize n-best list of hypotheses. 615 | """ 616 | self.max_len = max_len - 1 # ignoring 617 | self.length_penalty = length_penalty 618 | self.early_stopping = early_stopping 619 | self.n_hyp = n_hyp 620 | self.hyp = [] 621 | self.worst_score = 1e9 622 | 623 | def __len__(self): 624 | """ 625 | Number of hypotheses in the list. 626 | """ 627 | return len(self.hyp) 628 | 629 | def add(self, hyp, sum_logprobs): 630 | """ 631 | Add a new hypothesis to the list. 632 | """ 633 | score = sum_logprobs / len(hyp) ** self.length_penalty 634 | if len(self) < self.n_hyp or score > self.worst_score: 635 | self.hyp.append((score, hyp)) 636 | if len(self) > self.n_hyp: 637 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) 638 | del self.hyp[sorted_scores[0][1]] 639 | self.worst_score = sorted_scores[1][0] 640 | else: 641 | self.worst_score = min(score, self.worst_score) 642 | 643 | def is_done(self, best_sum_logprobs): 644 | """ 645 | If there are enough hypotheses and that none of the hypotheses being generated 646 | can become better than the worst one in the heap, then we are done with this sentence. 647 | """ 648 | if len(self) < self.n_hyp: 649 | return False 650 | elif self.early_stopping: 651 | return True 652 | else: 653 | return self.worst_score >= best_sum_logprobs / self.max_len ** self.length_penalty 654 | --------------------------------------------------------------------------------