├── LICENSE ├── README.md ├── config.json ├── data ├── data.py ├── utils.py └── vocab.py ├── decode.py ├── models ├── loss.py └── model.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pointer-Generator Network 2 | 3 | This repository contains the Pytorch implementation of the Pointer-Generator Network for text summarization, presented in [Get To The Point: Summarization with Pointer-Generator Networks (See et al., 2017)](https://arxiv.org/abs/1704.04368). 4 | 5 | While the original paper trains the model on an English dataset, this project aims at building a __*Korean*__ summarization model. 6 | Thus, we additionally incorporate Korean preprocessing & tokenization techniques to adapt the model to Korean. 7 | 8 | Most of the code is implemented from scratch, but we also referred to the following repositories. 9 | Any direct references are mentioned explicitly on the corresponding lines of code. 10 | * https://github.com/abisee/pointer-generator - the original author's implementation in tensorflow 11 | * https://github.com/atulkum/pointer_summarizer 12 | * https://github.com/rohithreddy024/Text-Summarizer-Pytorch 13 | 14 | Note that the overall pipeline relies on `pytorch-lightning`. 15 | 16 | 17 | ## Requirements 18 | 19 | ### Packages 20 | ``` 21 | torch==1.5.1 22 | pytorch-lightning==1.0.3 23 | fasttext==0.9.2 24 | ``` 25 | 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### Mecab tokenizer 31 | The model requires an additional installation of the the Mecab tokenizer provided by konlpy package. 32 | The guide to install Mecab can be found in this link: https://konlpy.org/en/latest/install/. 33 | 34 | ## How to run 35 | ### Prepare data 36 | 37 | Download the dataset at [this link](https://drive.google.com/drive/folders/1mHxqjg4jAVVwkGhRWRbdaUPELlj3kQSM?usp=sharing), which is a human-annotated abstractive summarization dataset published by the [National Institute of Korean Language](https://corpus.korean.go.kr/). The dataset is arbitrarily split into train, dev, and test. 38 | 39 | ``` 40 | data 41 | ├── nikl_train.pkl 42 | ├── nikl_dev.pkl 43 | └── nikl_test.pkl 44 | ``` 45 | 46 | 47 | ### Run training 48 | First, set up the desired model configurations in `config.json`. 49 | 50 | To begin training your model, run: 51 | ``` 52 | python train.py 53 | ``` 54 | 55 | Details on optional command-line arguments are specified below: 56 | ``` 57 | Pointer-generator network 58 | 59 | optional arguments: 60 | -h, --help show this help message and exit 61 | -cp CONFIG_PATH, --config-path CONFIG_PATH 62 | path to config file 63 | -m MODEL_PATH, --model-path MODEL_PATH 64 | path to load model in case of resuming training from an existing checkpoint 65 | --load-vocab whether to load pre-built vocab file 66 | --stop-with {loss,r1,r2,rl} 67 | validation evaluation metric to perform early stopping 68 | -e EXP_NAME, --exp-name EXP_NAME 69 | suffix to specify experiment name 70 | -d DEVICE, --device DEVICE 71 | gpu device number to use. if cpu, set this argument to -1 72 | -n NOTE, --note NOTE note to append to result output file name 73 | ``` 74 | Running the file will create a subdirectory in `logs` with the experiment name. 75 | All checkpoints, test set predictions, the constructed vocab file, tensorboard logs, and hyperparameter configurations will be saved in this directory. 76 | 77 | ### Run evaluation 78 | 79 | ``` 80 | python test.py --model-path $PATH-TO-CHECKPOINT 81 | ``` 82 | 83 | This will report the ROUGE scores on the command-line and save the predicted outputs in `.tsv` format in the experiment directory where you have loaded the checkpoint. 84 | 85 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PGN", 3 | "n_gpu": 1, 4 | 5 | "data": { 6 | "vocab_size": 50000, 7 | "vocab_min_freq": 1, 8 | "src_max_train": 400, 9 | "src_max_test": 400, 10 | "tgt_max_train": 100, 11 | "tgt_max_test": 100 12 | }, 13 | 14 | "data_loader": { 15 | "data_dir": "data/", 16 | "batch_size": { 17 | "train": 64, 18 | "val": 64, 19 | "test": 1 20 | }, 21 | "shuffle": true, 22 | "num_workers": 8 23 | }, 24 | 25 | "model": { 26 | "type": "PointerGenerator", 27 | "use_pretrained": true, 28 | "pretrained": "fasttext", 29 | "args": { 30 | "embed_dim": 300, 31 | "hidden_dim": 256 32 | } 33 | }, 34 | 35 | "optimizer": { 36 | "type": "Adagrad", 37 | "args":{ 38 | "lr": 0.15, 39 | "lr_init_accum": 0.1 40 | } 41 | }, 42 | 43 | "loss": { 44 | "args": { 45 | "use_coverage": true, 46 | "cov_weight": 1.0, 47 | "pad_id": 0 48 | } 49 | }, 50 | 51 | "trainer": { 52 | "epochs": 10, 53 | "max_iter": 600000, 54 | "max_grad_norm": 2.0, 55 | "save_freq": 5, 56 | "verbose": true 57 | }, 58 | 59 | "decode": { 60 | "args": { 61 | "beam_size": 2, 62 | "min_dec_steps": 100, 63 | "num_return_seq": 1 64 | } 65 | }, 66 | 67 | "path": { 68 | "vocab": "vocab.json", 69 | "train": "../data/nikl/nikl_train.pkl", 70 | "val": "../data/nikl/nikl_dev.pkl", 71 | "test": "../data/nikl/nikl_test.pkl" 72 | } 73 | } -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import gluonnlp as nlp 4 | import torch 5 | from konlpy.tag import Mecab 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import Dataset 8 | 9 | from data.utils import collate_tokens 10 | from data.utils import load_dataset 11 | from data.vocab import Vocab 12 | 13 | 14 | class TextDataset(Dataset): 15 | """ 16 | Args: 17 | txt: list of text samples 18 | max_len: max sequence length 19 | """ 20 | 21 | def __init__(self, txt, max_len): 22 | self.tokenizer = self.load_tokenizer('mecab') 23 | self.dataset = self.build_dataset(txt) 24 | self.max_len = max_len 25 | self.vocab = None 26 | 27 | def __getitem__(self, index): 28 | txt = self.dataset[index] 29 | tokens = self.tokenize(txt) 30 | tokens = self.truncate(tokens) 31 | length = len(tokens) 32 | return tokens, length 33 | 34 | def __len__(self): 35 | return len(self.dataset) 36 | 37 | def build_vocab(self, vocab_size, min_freq, specials): 38 | counter = Counter() 39 | for t in self.dataset: 40 | tokens = self.tokenize(t) 41 | counter.update(tokens) 42 | vocab = Vocab.from_counter(counter=counter, 43 | vocab_size=vocab_size, 44 | min_freq=min_freq, 45 | specials=specials) 46 | return vocab 47 | 48 | def build_dataset(self, txt): 49 | txt = list(map(self.preprocess, txt)) 50 | return txt 51 | 52 | def load_tokenizer(self, which): 53 | """Loads mecab tokenizer""" 54 | m = Mecab() 55 | tokenizer = m.morphs 56 | return tokenizer 57 | 58 | def preprocess(self, text): 59 | # TODO: add preprocessing functions relevant to korean & our dataset 60 | text = text.lower() 61 | return text 62 | 63 | def truncate(self, tokens): 64 | if len(tokens) > self.max_len: 65 | return tokens[:self.max_len] 66 | else: 67 | return tokens 68 | 69 | def tokenize(self, text): 70 | """Converts text string to list of tokens""" 71 | tokens = self.tokenizer(text) 72 | return tokens 73 | 74 | 75 | class SummDataset(Dataset): 76 | """ 77 | Args: 78 | src: (TextDataset) source dataset 79 | tgt: (TextDataset) target dataset 80 | """ 81 | 82 | def __init__(self, vocab, src, tgt=None): 83 | if tgt is not None: 84 | assert len(src) == len(tgt), "Source and target must contain the same number of examples" 85 | self.src, self.tgt = src, tgt 86 | self.vocab = vocab 87 | 88 | def __getitem__(self, index): 89 | src, src_len = self.src[index] 90 | if self.tgt is not None: 91 | tgt, tgt_len = self.tgt[index] 92 | else: 93 | tgt, tgt_len = None, None 94 | return src, src_len, tgt, tgt_len 95 | 96 | def __len__(self): 97 | """Returns size of the dataset""" 98 | return len(self.src) 99 | 100 | 101 | class Batch: 102 | def __init__(self, data, vocab, max_decode): 103 | src, src_len, tgt, tgt_len = list(zip(*data)) 104 | self.vocab = vocab 105 | self.pad_id = self.vocab.pad() 106 | self.max_decode = max_decode 107 | 108 | # Encoder info 109 | self.enc_input, self.enc_len, self.enc_pad_mask = None, None, None 110 | # Additional info for pointer-generator network 111 | self.enc_input_ext, self.max_oov_len, self.src_oovs = None, None, None 112 | # Decoder info 113 | self.dec_input, self.dec_target, self.dec_len, self.dec_pad_mask = None, None, None, None 114 | 115 | # Build batch inputs 116 | self.init_encoder_seq(src, src_len) 117 | self.init_decoder_seq(tgt, tgt_len) 118 | 119 | # Save original strings 120 | self.src_text = src 121 | self.tgt_text = tgt 122 | 123 | def init_encoder_seq(self, src, src_len): 124 | src_ids = [self.vocab.tokens2ids(s) for s in src] 125 | 126 | self.enc_input = collate_tokens(values=src_ids, 127 | pad_idx=self.pad_id) 128 | self.enc_len = torch.LongTensor(src_len) 129 | self.enc_pad_mask = (self.enc_input == self.pad_id) 130 | 131 | # Save additional info for pointer-generator 132 | # Determine max number of source text OOVs in this batch 133 | src_ids_ext, oovs = zip(*[self.vocab.source2ids_ext(s) for s in src]) 134 | # Store the version of the encoder batch that uses article OOV ids 135 | self.enc_input_ext = collate_tokens(values=src_ids_ext, 136 | pad_idx=self.pad_id) 137 | self.max_oov_len = max([len(oov) for oov in oovs]) 138 | # Store source text OOVs themselves 139 | self.src_oovs = oovs 140 | 141 | def init_decoder_seq(self, tgt, tgt_len): 142 | tgt_ids = [self.vocab.tokens2ids(t) for t in tgt] 143 | tgt_ids_ext = [self.vocab.target2ids_ext(t, oov) for t, oov in zip(tgt, self.src_oovs)] 144 | 145 | # create decoder inputs 146 | dec_input, _ = zip(*[self.get_decoder_input_target(t, self.max_decode) for t in tgt_ids]) 147 | 148 | self.dec_input = collate_tokens(values=dec_input, 149 | pad_idx=self.pad_id, 150 | pad_to_length=self.max_decode) 151 | 152 | # create decoder targets using extended vocab 153 | _, dec_target = zip(*[self.get_decoder_input_target(t, self.max_decode) for t in tgt_ids_ext]) 154 | 155 | self.dec_target = collate_tokens(values=dec_target, 156 | pad_idx=self.pad_id, 157 | pad_to_length=self.max_decode) 158 | 159 | self.dec_len = torch.LongTensor(tgt_len) 160 | self.dec_pad_mask = (self.dec_input == self.pad_id) 161 | 162 | def get_decoder_input_target(self, tgt, max_len): 163 | dec_input = [self.vocab.start()] + tgt 164 | dec_target = tgt + [self.vocab.stop()] 165 | # truncate inputs longer than max length 166 | if len(dec_input) > max_len: 167 | dec_input = dec_input[:max_len] 168 | dec_target = dec_target[:max_len] 169 | assert len(dec_input) == len(dec_target) 170 | return dec_input, dec_target 171 | 172 | def __len__(self): 173 | return self.enc_input.size(0) 174 | 175 | def __str__(self): 176 | batch_info = { 177 | 'src_text': self.src_text, 178 | 'tgt_text': self.tgt_text, 179 | 'enc_input': self.enc_input, # [B x L] 180 | 'enc_input_ext': self.enc_input_ext, # [B x L] 181 | 'enc_len': self.enc_len, # [B] 182 | 'enc_pad_mask': self.enc_pad_mask, # [B x L] 183 | 'src_oovs': self.src_oovs, # list of length B 184 | 'max_oov_len': self.max_oov_len, # single int value 185 | 'dec_input': self.dec_input, # [B x T] 186 | 'dec_target': self.dec_target, # [B x T] 187 | 'dec_len': self.dec_len, # [B] 188 | 'dec_pad_mask': self.dec_pad_mask, # [B x T] 189 | } 190 | return str(batch_info) 191 | 192 | def to(self, device): 193 | self.enc_input = self.enc_input.to(device) 194 | self.enc_input_ext = self.enc_input_ext.to(device) 195 | self.enc_len = self.enc_len.to(device) 196 | self.enc_pad_mask = self.enc_pad_mask.to(device) 197 | self.dec_input = self.dec_input.to(device) 198 | self.dec_target = self.dec_target.to(device) 199 | self.dec_len = self.dec_len.to(device) 200 | self.dec_pad_mask = self.dec_pad_mask.to(device) 201 | return self 202 | 203 | 204 | def build_dataset(data_path, config, is_train, vocab=None, load_vocab=None): 205 | args = config.data 206 | if is_train: 207 | src_txt, tgt_txt = load_dataset(data_path) 208 | src_train = TextDataset(src_txt, args.src_max_train) 209 | tgt_train = TextDataset(tgt_txt, args.tgt_max_train) 210 | if load_vocab is not None: 211 | vocab = Vocab.from_json(load_vocab) 212 | else: 213 | vocab = src_train.build_vocab(vocab_size=args.vocab_size, 214 | min_freq=args.vocab_min_freq, 215 | specials=[PAD_TOKEN, 216 | UNK_TOKEN, 217 | START_DECODING, 218 | STOP_DECODING]) 219 | dataset = SummDataset(src=src_train, 220 | tgt=tgt_train, 221 | vocab=vocab) 222 | return dataset, vocab 223 | 224 | else: 225 | assert vocab is not None 226 | src_txt, tgt_txt = load_dataset(data_path) 227 | src_test = TextDataset(src_txt, args.src_max_test) 228 | tgt_test = TextDataset(tgt_txt, args.tgt_max_test) 229 | dataset = SummDataset(src=src_test, 230 | tgt=tgt_test, 231 | vocab=vocab) 232 | return dataset 233 | 234 | 235 | def build_dataloader(dataset, vocab, batch_size, max_decode, is_train, num_workers): 236 | shuffle = True if is_train else False 237 | data_loader = DataLoader(dataset, 238 | batch_size=batch_size, 239 | shuffle=shuffle, 240 | collate_fn=lambda data, v=vocab, t=max_decode: Batch(data=data, 241 | vocab=v, 242 | max_decode=t), 243 | num_workers=num_workers) 244 | return data_loader 245 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import random 5 | import json 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | 10 | def load_dataset(data_dir): 11 | print(f"Loading dataset from {data_dir}") 12 | if data_dir.endswith('.json'): 13 | src, tgt = load_json(data_dir) 14 | elif data_dir.endswith('.txt'): 15 | with open(data_dir, 'r') as f: 16 | data = [l.split('\t') for l in f.readlines()] 17 | src, tgt = list(zip(*data)) 18 | elif data_dir.endswith('.csv'): 19 | src, tgt = load_csv(data_dir) 20 | else: 21 | src, tgt = None, None 22 | return src, tgt 23 | 24 | 25 | def split_pkl(data_path): 26 | data_dir = os.path.dirname(data_path) 27 | with open(data_path, 'rb') as f: 28 | data = pickle.load(f) 29 | news_keys = list(data.keys()) 30 | random.shuffle(news_keys) 31 | 32 | train_keys = news_keys[150:] 33 | dev_keys = news_keys[:100] 34 | test_keys = news_keys[100:150] 35 | 36 | train_dict = {k: data[k] for k in train_keys} 37 | dev_dict = {k: data[k] for k in dev_keys} 38 | test_dict = {k: data[k] for k in test_keys} 39 | 40 | data_paths = [os.path.join(data_dir, f'nikl_{d}.pkl') for d in ['train', 'dev', 'test']] 41 | data_dicts = [train_dict, dev_dict, test_dict] 42 | 43 | for p, d in zip(*(data_paths, data_dicts)): 44 | with open(p, 'wb') as f: 45 | pickle.dump(d, f) 46 | print(f'File saved as {p}') 47 | 48 | 49 | def load_json(data_dir): 50 | import json 51 | with open(data_dir, 'r', encoding='utf-8') as f: 52 | data = json.load(f) 53 | src = [ex['content'] for ex in data] 54 | tgt = [ex['bot_summary'] for ex in data] 55 | return src, tgt 56 | 57 | 58 | def load_csv(data_dir): 59 | sep = '\t' if data_dir.endswith('.tsv') else ',' 60 | import pandas as pd 61 | try: 62 | df = pd.read_csv(data_dir, sep=sep, 63 | header=0, encoding='utf-8') 64 | except: 65 | try: 66 | sep = '\t' 67 | df = pd.read_csv(data_dir, sep=sep, 68 | header=0, encoding='utf-8') 69 | except UnicodeDecodeError: 70 | df = pd.read_csv(data_dir, sep=',', 71 | header=0, encoding='ISO-8859-1') 72 | print(df.head()) 73 | if 'abstractive' in df.columns: 74 | src = list(df['contents'].values) 75 | tgt = list(df['abstractive'].values) 76 | 77 | elif 'bot_summary' in df.columns: 78 | df['content'] = df['content'].astype(str) 79 | df['bot_summary'] = df['bot_summary'].astype(str) 80 | 81 | src = list(df['content'].values) 82 | tgt = list(df['bot_summary'].values) 83 | 84 | else: 85 | raise IndexError 86 | return src, tgt 87 | 88 | 89 | def collate_tokens(values, pad_idx, left_pad=False, 90 | pad_to_length=None): 91 | # Simplified version of `collate_tokens` from fairseq.data.data_utils 92 | """Convert a list of 1d tensors into a padded 2d tensor.""" 93 | values = list(map(torch.LongTensor, values)) 94 | size = max(v.size(0) for v in values) 95 | size = size if pad_to_length is None else max(size, pad_to_length) 96 | res = values[0].new(len(values), size).fill_(pad_idx) 97 | 98 | def copy_tensor(src, dst): 99 | assert dst.numel() == src.numel() 100 | dst.copy_(src) 101 | 102 | for i, v in enumerate(values): 103 | copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) 104 | return res -------------------------------------------------------------------------------- /data/vocab.py: -------------------------------------------------------------------------------- 1 | # Most of this file is copied from 2 | # https://github.com/abisee/pointer-generator/blob/master/data.py 3 | # https://github.com/atulkum/pointer_summarizer/blob/master/data_util/data.py 4 | 5 | import json 6 | 7 | PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence 8 | UNK_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words 9 | START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence 10 | STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences 11 | 12 | 13 | class Vocab(object): 14 | """ 15 | Vocabulary class for mapping between words and ids (integers) 16 | """ 17 | 18 | def __init__(self): 19 | self._word_to_id = {} 20 | self._id_to_word = [] 21 | self._count = 0 22 | 23 | @classmethod 24 | def from_json(cls, vocab_file): 25 | vocab = cls() 26 | with open(vocab_file, 'r') as f: 27 | vocab._word_to_id = json.load(f) 28 | vocab._id_to_word = [w for w, id_ in sorted(vocab._word_to_id, 29 | key=vocab._word_to_id.get, 30 | reverse=True)] 31 | vocab._count = len(vocab._id_to_word) 32 | vocab.specials = filter(vocab._id_to_word) 33 | return vocab 34 | 35 | @classmethod 36 | def from_counter(cls, counter, vocab_size, specials, min_freq): 37 | vocab = cls() 38 | word_and_freq = sorted(counter.items(), key=lambda tup: tup[0]) 39 | word_and_freq.sort(key=lambda tup: tup[1], reverse=True) 40 | 41 | for w in specials: 42 | vocab._word_to_id[w] = vocab._count 43 | vocab._id_to_word += [w] 44 | vocab._count += 1 45 | 46 | for word, freq in word_and_freq: 47 | if freq < min_freq or vocab._count == vocab_size: 48 | break 49 | vocab._word_to_id[word] = vocab._count 50 | vocab._id_to_word += [word] 51 | vocab._count += 1 52 | return vocab 53 | 54 | def save(self, fpath): 55 | with open(fpath, 'w', encoding='utf-8') as f: 56 | json.dump(self._word_to_id, f, ensure_ascii=False, indent=4) 57 | print(f'vocab file saved as {fpath}') 58 | 59 | def __len__(self): 60 | """Returns size of the vocabulary.""" 61 | return self._count 62 | 63 | def word2id(self, word): 64 | """Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV.""" 65 | unk_id = self.unk() 66 | return self._word_to_id.get(word, unk_id) 67 | 68 | def id2word(self, word_id): 69 | """Returns the word (string) corresponding to an id (integer).""" 70 | if word_id not in self._id_to_word: 71 | raise ValueError(f'Id not found in vocab: {word_id}') 72 | return self._id_to_word[word_id] 73 | 74 | def size(self): 75 | """Returns the total size of the vocabulary.""" 76 | return self._count 77 | 78 | def pad(self): 79 | """Helper to get index of pad symbol""" 80 | return self._word_to_id[PAD_TOKEN] 81 | 82 | def unk(self): 83 | """Helper to get index of unk symbol""" 84 | return self._word_to_id[UNK_TOKEN] 85 | 86 | def start(self): 87 | return self._word_to_id[START_DECODING] 88 | 89 | def stop(self): 90 | return self._word_to_id[STOP_DECODING] 91 | 92 | def extend(self, oovs): 93 | extended_vocab = self._id_to_word + list(oovs) 94 | return extended_vocab 95 | 96 | def tokens2ids(self, tokens): 97 | ids = [self.word2id(t) for t in tokens] 98 | return ids 99 | 100 | def source2ids_ext(self, src_tokens): 101 | """Maps source tokens to ids if in vocab, extended vocab ids if oov. 102 | 103 | Args: 104 | src_tokens: list of source text tokens 105 | 106 | Returns: 107 | ids: list of source text token ids 108 | oovs: list of oovs in source text 109 | """ 110 | ids = [] 111 | oovs = [] 112 | for t in src_tokens: 113 | t_id = self.word2id(t) 114 | unk_id = self.word2id(UNK_TOKEN) 115 | if t_id == unk_id: 116 | if t not in oovs: 117 | oovs.append(t) 118 | ids.append(self.size() + oovs.index(t)) 119 | else: 120 | ids.append(t_id) 121 | return ids, oovs 122 | 123 | def target2ids_ext(self, tgt_tokens, oovs): 124 | """Maps target text to ids, using extended vocab (vocab + oovs). 125 | 126 | Args: 127 | tgt_tokens: list of target text tokens 128 | oovs: list of oovs from source text (copy mechanism) 129 | 130 | Returns: 131 | ids: list of target text token ids 132 | """ 133 | ids = [] 134 | for t in tgt_tokens: 135 | t_id = self.word2id(t) 136 | unk_id = self.word2id(UNK_TOKEN) 137 | if t_id == unk_id: 138 | if t in oovs: 139 | ids.append(self.size() + oovs.index(t)) 140 | else: 141 | ids.append(unk_id) 142 | else: 143 | ids.append(t_id) 144 | return ids 145 | 146 | def outputids2words(self, ids, src_oovs): 147 | """Maps output ids to words 148 | 149 | Args: 150 | ids: list of ids 151 | src_oovs: list of oov words 152 | 153 | Returns: 154 | words: list of words mapped from ids 155 | 156 | """ 157 | words = [] 158 | extended_vocab = self.extend(src_oovs) 159 | for i in ids: 160 | try: 161 | w = self.id2word(i) # might be oov 162 | except ValueError as e: 163 | assert src_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary." 164 | try: 165 | w = extended_vocab[i] 166 | except IndexError as e: 167 | raise ValueError(f'Error: model produced word ID {i} \ 168 | but this example only has {len(src_oovs)} article OOVs') 169 | words.append(w) 170 | return words 171 | 172 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class Hypothesis(object): 5 | def __init__(self, tokens, log_probs, hidden_state, cell_state, coverage): 6 | self.tokens = tokens 7 | self.log_probs = log_probs 8 | self.hidden_state = hidden_state 9 | self.cell_state = cell_state 10 | self.coverage = coverage 11 | 12 | def extend(self, token, log_prob, hidden_state, cell_state, coverage): 13 | return Hypothesis(tokens=self.tokens + [token], 14 | log_probs=self.log_probs + [log_prob], 15 | hidden_state=hidden_state, 16 | cell_state=cell_state, 17 | coverage=coverage) 18 | 19 | @property 20 | def latest_token(self): 21 | return self.tokens[-1] 22 | 23 | @property 24 | def avg_log_prob(self): 25 | return sum(self.log_probs) / len(self.tokens) 26 | 27 | 28 | def postprocess(tokens, 29 | skip_special_tokens=True, 30 | clean_up_tokenization_spaces=True): 31 | if skip_special_tokens: 32 | tokens = [t for t in tokens if not is_special(t)] 33 | out_string = ' '.join(tokens) 34 | if clean_up_tokenization_spaces: 35 | out_string = clean_up_tokenization(out_string) 36 | return out_string 37 | 38 | 39 | def is_special(token): 40 | res = re.search("\[[A-Z]+\]", token) 41 | if res is None: 42 | return False 43 | return token == res.group() 44 | 45 | 46 | def clean_up_tokenization(out_string): 47 | """ 48 | Reference : transformers.tokenization_utils_base 49 | Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms. 50 | 51 | Args: 52 | out_string (:obj:`str`): The text to clean up. 53 | 54 | Returns: 55 | :obj:`str`: The cleaned-up string. 56 | """ 57 | out_string = ( 58 | out_string.replace(" .", ".") 59 | .replace(" ?", "?") 60 | .replace(" !", "!") 61 | .replace(" ,", ",") 62 | .replace(" ' ", "'") 63 | .replace(" n't", "n't") 64 | .replace(" 'm", "'m") 65 | .replace(" 's", "'s") 66 | .replace(" 've", "'ve") 67 | .replace(" 're", "'re") 68 | ) 69 | return out_string -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Loss(nn.Module): 7 | """ 8 | Computes nll loss (Eq. (6)), coverage loss (Eq. (12)), 9 | and the composite loss function that combines the two (Eq. (13)). 10 | """ 11 | def __init__(self, args): 12 | super().__init__() 13 | self.use_coverage = args.use_coverage 14 | self.cov_weight = args.cov_weight # hyperparameter lambda in Eq. (13) 15 | self.pad_id = args.pad_id 16 | 17 | def nll_loss(self, output, target): 18 | """ 19 | Negative log likelihood of target word - Eq. (6) 20 | Args: 21 | output: predicted probs from each timestep [B x V_x T] 22 | target: answer ids using extended vocab [B x T] 23 | 24 | Returns: 25 | loss: nll loss value; averaged over batch & timestep 26 | """ 27 | output = torch.log(output) 28 | loss = F.nll_loss(output, target, 29 | ignore_index=self.pad_id, 30 | reduction='mean') 31 | return loss 32 | 33 | def cov_loss(self, attn_dist, coverage, dec_pad_mask, dec_len): 34 | """ 35 | Coverage loss at timestep t - Eq. (12) 36 | Args: 37 | attn_dist: attention distribution from all timesteps [B x L x T] 38 | coverage: sum of previous attn dist's from all timesteps [B x L x T] 39 | dec_pad_mask: target sequence padding masks [PAD] -> True [B x T] 40 | dec_len: target sequence lengths [B] 41 | 42 | Returns: 43 | loss: coverage loss value; averaged over batch & timestep 44 | """ 45 | min_val = torch.min(attn_dist, coverage) # [B x L x T] 46 | loss = torch.sum(min_val, dim=1) # [B x T] 47 | 48 | # ignore loss from [PAD] tokens 49 | loss = loss.masked_fill_( 50 | dec_pad_mask, 51 | 0.0 52 | ) 53 | avg_loss = torch.sum(loss) / torch.sum(dec_len) 54 | return avg_loss 55 | 56 | def forward(self, output, batch): 57 | """ 58 | Eq. (13) - Composite loss 59 | Args: 60 | output: a dictionary of model outputs with the following keys 61 | - final_dist 62 | - attn_dist 63 | - coverage 64 | batch: `Batch` instance 65 | 66 | Returns: 67 | loss: final composite loss value 68 | """ 69 | final_dist = output['final_dist'] 70 | dec_target = batch.dec_target 71 | nll_loss = self.nll_loss(output=final_dist, target=dec_target) 72 | 73 | attn_dist = output['attn_dist'] 74 | coverage = output['coverage'] 75 | dec_pad_mask = batch.dec_pad_mask 76 | dec_len = batch.dec_len 77 | cov_loss = self.cov_loss(attn_dist, coverage, dec_pad_mask, dec_len) 78 | return nll_loss, cov_loss 79 | 80 | 81 | def build_criterion(config): 82 | criterion = Loss(args=config.loss.args) 83 | return criterion -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | from torch.nn.utils.rnn import pad_packed_sequence 7 | from torch.optim import Adagrad 8 | 9 | from decode import Hypothesis 10 | from decode import postprocess 11 | from models.loss import Loss 12 | 13 | from rouge import Rouge 14 | 15 | """ 16 | : Models 17 | : 1) Encoder 18 | : 2) Attention module 19 | : 3) Decoder w/ attention 20 | : 4) Pointer-generator network 21 | """ 22 | 23 | 24 | class Encoder(nn.Module): 25 | """ 26 | Single-layer bidirectional LSTM 27 | B : batch size 28 | E : embedding size 29 | H : encoder hidden state dimension 30 | L : sequence length 31 | """ 32 | 33 | def __init__(self, input_dim, hidden_dim): 34 | super().__init__() 35 | self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, 36 | num_layers=1, bidirectional=True, batch_first=True) 37 | self.reduce_h = nn.Linear(hidden_dim * 2, hidden_dim, bias=True) 38 | self.reduce_c = nn.Linear(hidden_dim * 2, hidden_dim, bias=True) 39 | 40 | def forward(self, src, src_lens): 41 | """ 42 | Args: 43 | src: source token embeddings [B x L x E] 44 | src_lens: source text length [B] 45 | 46 | Returns: 47 | enc_hidden: sequence of encoder hidden states [B x L x 2H] 48 | (final_h, final_c): Tuple for decoder state initialization [B x L x H] 49 | """ 50 | 51 | # Pack the sequence into a PackedSequence object to feed to the LSTM. 52 | x = pack_padded_sequence(src, src_lens, batch_first=True, enforce_sorted=False) 53 | 54 | # Get outputs from the LSTM 55 | output, (h, c) = self.lstm(x) # [B x L x 2H], [2 x B x H], [2 x B x H] 56 | enc_hidden, _ = pad_packed_sequence(output, batch_first=True) 57 | 58 | # Concatenate bidirectional lstm states 59 | h = torch.cat((h[0], h[1]), dim=-1) # [B x 2H] 60 | c = torch.cat((c[0], c[1]), dim=-1) # [B x 2H] 61 | 62 | # Project to decoder hidden state size 63 | final_hidden = torch.relu(self.reduce_h(h)) # [B x H] 64 | final_cell = torch.relu(self.reduce_c(c)) # [B x H] 65 | 66 | return enc_hidden, (final_hidden, final_cell) 67 | 68 | 69 | class Attention(nn.Module): 70 | """ 71 | Attention mechanism based on Bahdanau et al. (2015) - Eq. (1)(2) 72 | augmented with Coverage mechanism - Eq. (11) 73 | B : batch size 74 | L : source text length 75 | H : encoder hidden state dimension 76 | """ 77 | 78 | def __init__(self, hidden_dim, use_coverage): 79 | super().__init__() 80 | # Eq. (1) 81 | self.v = nn.Linear(hidden_dim * 2, 1, bias=False) # v 82 | self.enc_proj = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias=False) # W_h 83 | self.dec_proj = nn.Linear(hidden_dim, hidden_dim * 2, bias=True) # W_s, b_attn 84 | 85 | self.use_coverage = use_coverage 86 | if self.use_coverage: 87 | # Additional parameter for coverage vector; w_c in Eq. (11) 88 | self.w_c = nn.Linear(1, hidden_dim * 2, bias=False) 89 | 90 | def forward(self, dec_input, coverage, enc_hidden, enc_pad_mask): 91 | """ 92 | Args: 93 | dec_input: decoder hidden state [B x H] 94 | coverage: coverage vector [B x L] 95 | enc_hidden: encoder hidden states [B x L x 2H] 96 | enc_pad_mask: encoder padding masks [B x L] 97 | 98 | Returns: 99 | attn_dist: attention dist'n over src tokens [B x L] 100 | """ 101 | 102 | # Eq. (1) 103 | enc_feature = self.enc_proj(enc_hidden) # [B x L x 2H] 104 | dec_feature = self.dec_proj(dec_input) # [B x 2H] 105 | dec_feature = dec_feature.unsqueeze(1) # [B x 1 x 2H] 106 | scores = enc_feature + dec_feature # [B x L x 2H] 107 | 108 | if self.use_coverage: 109 | # Eq. (11) 110 | coverage = coverage.unsqueeze(-1) # [B x L x 1] 111 | cov_feature = self.w_c(coverage) # [B x L x 2H] 112 | scores = scores + cov_feature 113 | 114 | scores = torch.tanh(scores) # [B x L x 2H] 115 | scores = self.v(scores) # [B x L x 1] 116 | scores = scores.squeeze(-1) # [B x L] 117 | 118 | # Don't attend over padding; fill '-inf' where enc_pad_mask == True 119 | if enc_pad_mask is not None: 120 | scores = scores.float().masked_fill_( 121 | enc_pad_mask, 122 | float('-inf') 123 | ).type_as(scores) # FP16 support: cast to float and back 124 | 125 | # Eq. (2) 126 | attn_dist = F.softmax(scores, dim=-1) # [B x L] 127 | 128 | return attn_dist 129 | 130 | 131 | class AttnDecoder(nn.Module): 132 | """ 133 | Single-layer unidirectional LSTM with attention for a single timestep - Eq. (3)(4) 134 | B : batch size 135 | E : embedding size 136 | H : decoder hidden state dimension 137 | V : vocab size 138 | """ 139 | 140 | def __init__(self, input_dim, hidden_dim, vocab_size, use_coverage): 141 | super().__init__() 142 | self.hidden_dim = hidden_dim 143 | self.lstm = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_dim) 144 | self.attention = Attention(hidden_dim, use_coverage) 145 | # Eq. (4) 146 | self.v = nn.Linear(hidden_dim * 3, hidden_dim, bias=True) # V, b 147 | self.v_out = nn.Linear(hidden_dim, vocab_size, bias=True) # V', b' 148 | 149 | def forward(self, dec_input, prev_h, prev_c, enc_hidden, enc_pad_mask, coverage): 150 | """ 151 | Args: 152 | dec_input: decoder input embedding at timestep t [B x E] 153 | prev_h: decoder hidden state from prev timestep [B x H] 154 | prev_c: decoder cell state from prev timestep [B x H] 155 | enc_hidden: encoder hidden states [B x L x 2H] 156 | enc_pad_mask: encoder masks for attn computation [B x L] 157 | coverage: coverage vector at timestep t - Eq. (10) [B x L] 158 | 159 | Returns: 160 | vocab_dist: predicted vocab dist'n at timestep t [B x V] 161 | attn_dist: attention dist'n at timestep t [B x L] 162 | context_vec: context vector at timestep t [B x 2H] 163 | hidden: hidden state at timestep t [B x H] 164 | cell: cell state at timestep t [B x H] 165 | """ 166 | 167 | # Get this step's decoder hidden state 168 | hidden, cell = self.lstm(dec_input, (prev_h, prev_c)) # [B x H], [B x H] 169 | 170 | # Compute attention distribution over enc states 171 | attn_dist = self.attention(dec_input=hidden, 172 | coverage=coverage, 173 | enc_hidden=enc_hidden, 174 | enc_pad_mask=enc_pad_mask) # [B x L] 175 | 176 | # Eq. (3) - Sum weighted enc hidden states to make context vector 177 | # The context vector is used later to compute generation probability 178 | context_vec = torch.bmm(attn_dist.unsqueeze(1), enc_hidden) # [B x 1 x 2H] 179 | context_vec = torch.sum(context_vec, dim=1) # [B x 2H] 180 | 181 | # Eq. (4) 182 | output = self.v(torch.cat([hidden, context_vec], dim=-1)) # [B x 3H] -> [B x H] 183 | output = self.v_out(output) # [B x V] 184 | vocab_dist = F.softmax(output, dim=-1) # [B x V] 185 | return vocab_dist, attn_dist, context_vec, hidden, cell 186 | 187 | 188 | class PointerGenerator(nn.Module): 189 | """ 190 | 2.2. Pointer-generator network 191 | - Computes generation probability p_gen - Eq. (8) 192 | - Computes prob dist'n over extended vocabulary - Eq. (9) 193 | 2.3. Coverage mechanism 194 | - Computes coverage vector for coverage loss - Eq. (10) 195 | 196 | B : batch size 197 | E : decoder embedding size 198 | H : encoder, decoder hidden state dimension 199 | L : source text length 200 | V : vocab size 201 | V_x : extended vocab size 202 | """ 203 | 204 | def __init__(self, config, vocab): 205 | super().__init__() 206 | self.config = config 207 | self.use_pretrained = config.model.use_pretrained 208 | self.vocab = vocab 209 | 210 | if self.use_pretrained: 211 | emb_vecs = self.load_embeddings(config.model.pretrained) 212 | self.embedding = nn.Embedding.from_pretrained(emb_vecs, 213 | freeze=False, 214 | padding_idx=self.vocab.pad()) 215 | embed_dim = self.embedding.embedding_dim 216 | else: 217 | embed_dim = config.model.args.embed_dim 218 | self.embedding = nn.Embedding(len(vocab), embed_dim, 219 | padding_idx=self.vocab.pad()) 220 | 221 | hidden_dim = config.model.args.hidden_dim 222 | self.encoder = Encoder(input_dim=embed_dim, 223 | hidden_dim=hidden_dim) 224 | self.decoder = AttnDecoder(input_dim=embed_dim, 225 | hidden_dim=hidden_dim, 226 | vocab_size=len(vocab), 227 | use_coverage=config.loss.args.use_coverage) 228 | 229 | # Parameters specific to PGN - Eq. (8) 230 | self.w_h = nn.Linear(hidden_dim * 2, 1, bias=False) 231 | self.w_s = nn.Linear(hidden_dim, 1, bias=False) 232 | self.w_x = nn.Linear(embed_dim, 1, bias=True) 233 | 234 | # Hyper-parameters used during decoding at inference time 235 | self.beam_size = config.decode.args.beam_size 236 | self.min_dec_steps = config.decode.args.min_dec_steps 237 | self.num_return_seq = config.decode.args.num_return_seq 238 | 239 | def load_embeddings(self, which='fasttext'): 240 | num_oov = 0 241 | num_in_vocab = 0 242 | emb_vecs = [] 243 | emb_size = 300 244 | import fasttext.util 245 | fasttext.util.download_model('ko', if_exists='ignore') 246 | ft = fasttext.load_model('cc.ko.300.bin') 247 | for w in self.vocab._id_to_word: 248 | if ft.get_word_id(w) == -1: # out of ft vocab 249 | w_emb = torch.rand([emb_size, 1]) 250 | nn.init.kaiming_normal_(w_emb, mode='fan_out') 251 | num_oov += 1 252 | else: 253 | w_emb = torch.tensor(ft.get_word_vector(w)) 254 | num_in_vocab += 1 255 | emb_vecs.append(w_emb) 256 | emb_vecs = list(map(lambda x: x.squeeze(), emb_vecs)) 257 | emb_vecs = torch.stack(emb_vecs) 258 | 259 | num_total = num_oov + num_in_vocab 260 | print(f"Loaded embeddings from {which}: {num_in_vocab} out of {num_total} are initialized from {which}") 261 | return emb_vecs 262 | 263 | def forward(self, enc_input, enc_input_ext, enc_pad_mask, enc_len, 264 | dec_input, max_oov_len): 265 | """ 266 | Predict summary using reference summary as decoder inputs (teacher forcing) 267 | Args: 268 | enc_input: source text id sequence [B x L] 269 | enc_input_ext: source text id seq w/ extended vocab [B x L] 270 | enc_pad_mask: source text padding mask. [PAD] -> True [B x L] 271 | enc_len: source text length [B] 272 | dec_input: target text id sequence [B x T] 273 | max_oov_len: max number of oovs in src [1] 274 | 275 | Returns: 276 | final_dists: predicted dist'n using extended vocab [B x V_x x T] 277 | attn_dists: attn dist'n from each t [B x L x T] 278 | coverages: coverage vectors from each t [B x L x T] 279 | """ 280 | 281 | # Build source text representations from encoder 282 | enc_emb = self.embedding(enc_input) # [B x L x E] 283 | enc_hidden, (h, c) = self.encoder(enc_emb, enc_len) # [B x L x 2H] 284 | 285 | # Outputs required for loss computation 286 | # 1. cross-entropy (negative log-likelihood) loss - Eq. (6) 287 | final_dists = [] 288 | 289 | # 2. coverage loss - Eq. (12) 290 | attn_dists = [] 291 | coverages = [] 292 | 293 | # Initialize decoder inputs 294 | dec_emb = self.embedding(dec_input) # [B x T x E] 295 | cov = torch.zeros_like(enc_input).float() # [B x L] 296 | 297 | for t in range(self.config.data.tgt_max_train): 298 | input_t = dec_emb[:, t, :] # Decoder input at this timestep 299 | vocab_dist, attn_dist, context_vec, h, c = self.decoder(dec_input=input_t, 300 | prev_h=h, 301 | prev_c=c, 302 | enc_hidden=enc_hidden, 303 | enc_pad_mask=enc_pad_mask, 304 | coverage=cov) 305 | # Eq. (10) - Compute coverage vector; 306 | # sum of attn dist over all prev decoder timesteps 307 | cov = cov + attn_dist 308 | 309 | # Eq. (8) - Compute generation probability p_gen 310 | context_feat = self.w_h(context_vec) # [B x 1] 311 | decoder_feat = self.w_s(h) # [B x 1] 312 | input_feat = self.w_x(input_t) # [B x 1] 313 | gen_feat = context_feat + decoder_feat + input_feat 314 | p_gen = torch.sigmoid(gen_feat) # [B x 1] 315 | 316 | # Eq. (9) - Compute prob dist'n over extended vocabulary 317 | vocab_dist = p_gen * vocab_dist # [B x V] 318 | weighted_attn_dist = (1.0 - p_gen) * attn_dist # [B x L] 319 | 320 | # Concat some zeros to each vocab dist, 321 | # to hold probs for oov words that appeared in source text 322 | batch_size = vocab_dist.size(0) 323 | extra_zeros = torch.zeros((batch_size, max_oov_len), 324 | device=vocab_dist.device) 325 | extended_vocab_dist = torch.cat([vocab_dist, extra_zeros], dim=-1) # [B x V_x] 326 | 327 | final_dist = extended_vocab_dist.scatter_add(dim=-1, 328 | index=enc_input_ext, 329 | src=weighted_attn_dist) 330 | # Save outputs for loss computation 331 | final_dists.append(final_dist) 332 | attn_dists.append(attn_dist) 333 | coverages.append(cov) 334 | 335 | final_dists = torch.stack(final_dists, dim=-1) # [B x V_x x T] 336 | attn_dists = torch.stack(attn_dists, dim=-1) # [B x L x T] 337 | coverages = torch.stack(coverages, dim=-1) # [B x L x T] 338 | 339 | return { 340 | 'final_dist': final_dists, 341 | 'attn_dist': attn_dists, 342 | 'coverage': coverages 343 | } 344 | 345 | def inference(self, enc_input, enc_input_ext, enc_pad_mask, enc_len, 346 | src_oovs, max_oov_len): 347 | """ 348 | Predict summary using previous timestep's decoder output as this step's decoder input + beam search 349 | Args: 350 | enc_input: source text id sequence [B x L] 351 | enc_input_ext: source text id seq w/ extended vocab [B x L] 352 | enc_pad_mask: source text padding mask. [PAD] -> True [B x L] 353 | enc_len: source text length [B] 354 | src_oovs: list of source text oovs of each sample [B] 355 | max_oov_len: max number of oovs in src [1] 356 | 357 | Returns: 358 | results: dictionary with 'generated_summary' 359 | """ 360 | # Build source text representation from encoder 361 | enc_emb = self.embedding(enc_input) # [B x L x E] 362 | enc_hidden, (h, c) = self.encoder(enc_emb, enc_len) # [B x L x 2H] 363 | 364 | # Initialize decoder input 365 | cov = torch.zeros_like(enc_input).float() # [B x L] 366 | 367 | # Initialize hypotheses 368 | batch_size = enc_input.size(0) 369 | # All samples start with a single hypothesis ([START]) 370 | hyps = [ 371 | Hypothesis(tokens=[self.vocab.start()], 372 | log_probs=[0.0], 373 | hidden_state=h, 374 | cell_state=c, 375 | coverage=cov) 376 | for _ in range(batch_size) 377 | ] 378 | results = [] # finished hypotheses (those that have emitted the [STOP] token) 379 | 380 | for steps in range(self.config.data.tgt_max_test): 381 | # Prepare decoder inputs (= previously generated tokens) for this step 382 | # K : number of hypotheses (we want top-K outputs) 383 | dec_input = [self.filter_unk(hyp.latest_token) for hyp in hyps] 384 | dec_input = torch.tensor(dec_input, 385 | dtype=torch.long, 386 | device=enc_input.device) # [K] 387 | dec_emb = self.embedding(dec_input) # [K x E] 388 | h = torch.cat([hyp.hidden_state for hyp in hyps], dim=0) # [1 x H] -> [K x H] 389 | c = torch.cat([hyp.cell_state for hyp in hyps], dim=0) # [1 x H] -> [K x H] 390 | coverages = torch.cat([hyp.coverage for hyp in hyps], dim=0) # [1 x L] -> [K x L] 391 | enc_hiddens = torch.cat([enc_hidden for _ in hyps], dim=0) # [1 x L x 2H] -> [K x L x 2H] 392 | enc_pad_masks = torch.cat([enc_pad_mask for _ in hyps], dim=0) 393 | 394 | vocab_dist, attn_dist, context_vec, h, c = self.decoder(dec_input=dec_emb, 395 | prev_h=h, prev_c=c, 396 | enc_hidden=enc_hiddens, 397 | enc_pad_mask=enc_pad_masks, 398 | coverage=coverages) 399 | 400 | # Eq. (10) - Compute coverage vector; 401 | # sum of attn dist over all prev decoder timesteps 402 | cov = cov + attn_dist 403 | 404 | # Eq. (8) - Compute generation probability p_gen 405 | context_feat = self.w_h(context_vec) # [K x 1] 406 | decoder_feat = self.w_s(h) # [K x 1] 407 | input_feat = self.w_x(dec_emb) # [K x 1] 408 | gen_feat = context_feat + decoder_feat + input_feat 409 | p_gen = torch.sigmoid(gen_feat) # [K x 1] 410 | 411 | # Eq. (9) - Compute prob dist'n over extended vocabulary 412 | vocab_dist = p_gen * vocab_dist # [K x V] 413 | weighted_attn_dist = (1.0 - p_gen) * attn_dist # [K x L] 414 | 415 | # Concat some zeros to each vocab dist, 416 | # to hold probs for oov words that appeared in source text 417 | batch_size = vocab_dist.size(0) 418 | extra_zeros = torch.zeros((batch_size, max_oov_len), 419 | device=vocab_dist.device) 420 | extended_vocab_dist = torch.cat([vocab_dist, extra_zeros], dim=-1) # [K x V_x] 421 | final_dist = extended_vocab_dist.scatter_add(dim=-1, 422 | index=enc_input_ext, 423 | src=weighted_attn_dist) 424 | 425 | # Find top-2k most probable token ids and update hypotheses 426 | log_probs = torch.log(final_dist) 427 | topk_probs, topk_ids = torch.topk(log_probs, 428 | k=self.beam_size * 2, 429 | dim=-1) 430 | 431 | all_hyps = [] 432 | num_orig_hyps = 1 if steps == 0 else len(hyps) 433 | for i in range(num_orig_hyps): 434 | h_i = hyps[i] 435 | hidden_state_i = h[i].unsqueeze(0) 436 | cell_state_i = c[i].unsqueeze(0) 437 | coverage_i = cov[i].unsqueeze(0) 438 | 439 | for j in range(self.beam_size * 2): 440 | # Update existing hypothesis with predicted token 441 | if topk_ids[i, j].item() == self.vocab.unk(): 442 | pass 443 | else: 444 | new_hyp = h_i.extend(token=topk_ids[i, j].item(), 445 | log_prob=log_probs[i, j].item(), 446 | hidden_state=hidden_state_i, 447 | cell_state=cell_state_i, 448 | coverage=coverage_i) 449 | all_hyps.append(new_hyp) 450 | 451 | # Find k most probable hypotheses among 2k candidates 452 | hyps = [] 453 | for h in self.sort_hyps(all_hyps): 454 | if h.latest_token == self.vocab.stop(): 455 | if steps >= self.min_dec_steps: 456 | results.append(h) 457 | else: 458 | # save for next step 459 | hyps.append(h) 460 | if len(hyps) == self.beam_size or len(results) == self.beam_size: 461 | break 462 | 463 | if len(results) == self.beam_size: 464 | break 465 | 466 | # Reached max decode steps but not enough results 467 | if len(results) < self.num_return_seq: 468 | results = results + hyps[:self.num_return_seq - len(results)] 469 | 470 | sorted_results = self.sort_hyps(results) 471 | best_hyps = sorted_results[:self.num_return_seq] 472 | 473 | # Map token ids to words 474 | hyp_words = [self.vocab.outputids2words(hyp.tokens, src_oovs[0]) for hyp in best_hyps] 475 | 476 | # Concatenate words to strings 477 | if self.config.model.use_pretrained and self.config.model.pretrained == 'kobert': 478 | bpe = True 479 | else: 480 | bpe = False 481 | hyp_results = [postprocess(words, bpe=bpe, 482 | skip_special_tokens=True, 483 | clean_up_tokenization_spaces=True) 484 | for words in hyp_words] 485 | results = {'generated_summary': hyp_results} 486 | return results 487 | 488 | def sort_hyps(self, hyps): 489 | """Sort hypotheses according to their log probability.""" 490 | return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True) 491 | 492 | def filter_unk(self, idx): 493 | return idx if idx < self.vocab.size() else self.vocab.unk() 494 | 495 | 496 | class SummarizationModel(pl.LightningModule): 497 | def __init__(self, config, vocab): 498 | super().__init__() 499 | self.config = config 500 | self.vocab = vocab 501 | self.model = PointerGenerator(config, vocab) 502 | self.criterion = Loss(args=config.loss.args) 503 | self.num_step = 0 504 | self.cov_weight = config.loss.args.cov_weight 505 | self.rouge = Rouge() 506 | 507 | def training_step(self, batch, batch_idx): 508 | output = self.model.forward(enc_input=batch.enc_input, 509 | enc_input_ext=batch.enc_input_ext, 510 | enc_pad_mask=batch.enc_pad_mask, 511 | enc_len=batch.enc_len, 512 | dec_input=batch.dec_input, 513 | max_oov_len=batch.max_oov_len) 514 | 515 | nll_loss, cov_loss = self.criterion(output=output, 516 | batch=batch) 517 | loss = nll_loss + self.cov_weight * cov_loss 518 | # self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=True) 519 | self.logger.log_metrics({'train_loss': loss, 520 | 'train/nll_loss': nll_loss}, self.num_step) 521 | if self.cov_weight > 0: 522 | self.logger.log_metrics({'train/cov_loss': cov_loss}, self.num_step) 523 | self.num_step += 1 524 | return loss 525 | 526 | def validation_step(self, batch, batch_idx): 527 | output = self.model.forward(enc_input=batch.enc_input, 528 | enc_input_ext=batch.enc_input_ext, 529 | enc_pad_mask=batch.enc_pad_mask, 530 | enc_len=batch.enc_len, 531 | dec_input=batch.dec_input, 532 | max_oov_len=batch.max_oov_len) 533 | nll_loss, cov_loss = self.criterion(output=output, 534 | batch=batch) 535 | loss = nll_loss + self.cov_weight * cov_loss 536 | self.log('val_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) 537 | self.logger.log_metrics({'val_loss': loss}, self.num_step) 538 | 539 | result = self.test_step(batch, batch_idx) 540 | scores = self.rouge.get_scores(result['generated_summary'], 541 | result['gold_summary'], avg=True) 542 | rouge_1 = scores['rouge-1']['f'] * 100.0 543 | rouge_2 = scores['rouge-2']['f'] * 100.0 544 | rouge_l = scores['rouge-l']['f'] * 100.0 545 | pred = { 546 | 'rouge_1': rouge_1, 547 | 'rouge_2': rouge_2, 548 | 'rouge_l': rouge_l, 549 | 'val_loss': loss 550 | } 551 | return pred 552 | 553 | def validation_epoch_end(self, validation_step_outputs): 554 | rouge_1, rouge_2, rouge_l, val_loss = [], [], [], [] 555 | for pred in validation_step_outputs: 556 | rouge_1.append(pred['rouge_1']) 557 | rouge_2.append(pred['rouge_2']) 558 | rouge_l.append(pred['rouge_l']) 559 | val_loss.append(pred['val_loss']) 560 | rouge_1_avg = sum(rouge_1) / len(rouge_1) 561 | rouge_2_avg = sum(rouge_2) / len(rouge_2) 562 | rouge_l_avg = sum(rouge_l) / len(rouge_l) 563 | val_loss_avg = sum(val_loss) / len(val_loss) 564 | results = { 565 | 'rouge_1_avg': rouge_1_avg, 566 | 'rouge_2_avg': rouge_2_avg, 567 | 'rouge_l_avg': rouge_l_avg, 568 | 'val_loss_avg': val_loss_avg, 569 | } 570 | return results 571 | 572 | def test_step(self, batch, batch_idx): 573 | result = self.model.inference(enc_input=batch.enc_input, 574 | enc_input_ext=batch.enc_input_ext, 575 | enc_pad_mask=batch.enc_pad_mask, 576 | enc_len=batch.enc_len, 577 | src_oovs=batch.src_oovs, 578 | max_oov_len=batch.max_oov_len) 579 | result['source'] = [''.join(w) for w in batch.src_text] 580 | result['gold_summary'] = [''.join(w) for w in batch.tgt_text] 581 | return result 582 | 583 | def configure_optimizers(self): 584 | args = self.config.optimizer.args 585 | lr = args.lr 586 | lr_init_accum = args.lr_init_accum 587 | params = self.parameters() 588 | optimizer = Adagrad( 589 | params=params, 590 | lr=lr, 591 | initial_accumulator_value=lr_init_accum, 592 | ) 593 | return optimizer 594 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gluonnlp==0.9.1 2 | fasttext==0.9.2 3 | pytorch_lightning==1.0.3 4 | scipy==1.5.1 5 | pybind11==2.5.0 6 | pyrouge==0.1.3 7 | easydict==1.9 8 | gensim==3.8.3 9 | konlpy==0.5.2 10 | requests==2.24.0 11 | setproctitle==1.1.10 12 | torch==1.5.1+cu101 13 | tqdm==4.47.0 14 | mxnet==1.6.0 15 | pandas==1.0.5 16 | numpy==1.18.5 17 | unicode==2.7 18 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import pytorch_lightning as pl 5 | 6 | from data_loader.data import Vocab 7 | from data_loader.data import build_dataloader 8 | from data_loader.data import build_dataset 9 | from eval_rouge import report_rouge 10 | from models.model import SummarizationModel 11 | from utils import * 12 | 13 | 14 | def main(config): 15 | # load vocab 16 | log_path, _ = os.path.split(config.model_path) 17 | vocab = Vocab.from_json(os.path.join(log_path, config.path.vocab)) 18 | 19 | print(f"Vocab size : {len(vocab)}") 20 | 21 | test_data = build_dataset( 22 | data_path=config.path.test, 23 | config=config, 24 | is_train=False, 25 | vocab=vocab 26 | ) 27 | 28 | test_loader = build_dataloader( 29 | dataset=test_data, 30 | vocab=vocab, 31 | batch_size=1, 32 | max_decode=config.data.tgt_max_test, 33 | is_train=False, 34 | num_workers=config.data_loader.num_workers, 35 | ) 36 | # config.model.use_pretrained = False 37 | model = SummarizationModel.load_from_checkpoint( 38 | checkpoint_path=config.model_path, 39 | config=config, 40 | vocab=vocab, 41 | ) 42 | model.freeze() 43 | model.eval() 44 | 45 | if config.device == -1: 46 | gpus = None 47 | else: 48 | gpus = [config.device] 49 | 50 | trainer = pl.Trainer( 51 | gpus=gpus, 52 | ) 53 | 54 | test_outputs = trainer.test( 55 | model=model, 56 | test_dataloaders=test_loader, 57 | ckpt_path=config.model_path, 58 | verbose=False 59 | ) 60 | 61 | output_name = generate_output_name(config) 62 | output_fname = os.path.join(log_path, output_name) 63 | write_output(test_loader=test_loader, 64 | test_outputs=test_outputs, 65 | fname=output_fname) 66 | report_rouge(output_fname) 67 | 68 | 69 | if __name__ == "__main__": 70 | args = argparse.ArgumentParser(description='Pointer-generator network') 71 | args.add_argument( 72 | '-cp', '--config-path', 73 | default='config.json', 74 | type=str, 75 | help='path to config file' 76 | ) 77 | 78 | args.add_argument( 79 | '-m', '--model-path', 80 | default=None, 81 | type=str, 82 | help='path to load model' 83 | ) 84 | 85 | args.add_argument( 86 | '-d', '--device', 87 | default=0, 88 | type=int, 89 | ) 90 | 91 | args.add_argument( 92 | '-n', '--note', 93 | default='', 94 | type=str, 95 | help='note to append to result output file name' 96 | ) 97 | 98 | sys.path.append( 99 | os.path.dirname(os.path.abspath(os.path.dirname("__file__"))) 100 | ) 101 | 102 | config = config_parser(args.parse_args()) 103 | main(config) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning import loggers as pl_loggers 7 | 8 | from data.data import Vocab 9 | from data.data import build_dataloader 10 | from data.data import build_dataset 11 | from models.model import SummarizationModel 12 | from utils import * 13 | 14 | 15 | def main(config): 16 | # fix random seeds for reproducibility 17 | SEED = 123 18 | pl.seed_everything(SEED) 19 | 20 | # generate experiment name 21 | exp_name = generate_exp_name(config) 22 | 23 | # Train 24 | # 1. Load dataset, build or load vocab file 25 | # if use pre-built vocab 26 | if os.path.exists(config.path.vocab) and config.load_vocab: 27 | load_vocab = config.path.vocab 28 | else: 29 | load_vocab = None 30 | 31 | # 1-1. Training dataset 32 | # If validation set is not specified, take 10% of training data as validation set 33 | split_dev = 0.1 if config.path.val == '' else 0.0 34 | (train_data, val_data), vocab = build_dataset( 35 | data_path=config.path.train, 36 | load_vocab=load_vocab, 37 | config=config, 38 | is_train=True, 39 | split_dev=split_dev, 40 | ) 41 | 42 | # save vocab file 43 | vocab.save(os.path.join(f'logs/{exp_name}', config.path.vocab)) 44 | 45 | train_loader = build_dataloader( 46 | dataset=train_data, 47 | vocab=vocab, 48 | batch_size=config.data_loader.batch_size.train, 49 | max_decode=config.data.tgt_max_train, 50 | is_train=True, 51 | num_workers=config.data_loader.num_workers, 52 | ) 53 | 54 | # 1-2. Load validation dataset 55 | # For validation, we use NIKL 56 | if val_data is None: 57 | val_data = build_dataset( 58 | data_path=config.path.val, 59 | config=config, 60 | is_train=False, 61 | vocab=vocab 62 | ) 63 | 64 | val_loader = build_dataloader( 65 | dataset=val_data, 66 | vocab=vocab, 67 | batch_size=config.data_loader.batch_size.val, 68 | max_decode=config.data.tgt_max_test, 69 | is_train=False, 70 | num_workers=config.data_loader.num_workers, 71 | ) 72 | 73 | # 2. Build model instance 74 | 75 | model = SummarizationModel( 76 | config=config, 77 | vocab=vocab, 78 | ) 79 | 80 | # 3. Set logger, trainer 81 | tb_logger = pl_loggers.TensorBoardLogger( 82 | save_dir='logs/', 83 | name=exp_name, 84 | ) 85 | tb_logger.log_hyperparams(config) 86 | 87 | if config.stop_with == 'loss': 88 | monitor = 'val_loss_avg' 89 | mode = 'min' 90 | else: 91 | which_rouge = config.stop_with[-1] # 1, 2 or l 92 | monitor = f'rouge_{which_rouge}_avg' 93 | mode = 'max' 94 | 95 | filepath = '{epoch}-{' + monitor + ':.2f}' 96 | checkpoint_callback = ModelCheckpoint( 97 | filepath=os.path.join(f'logs/{exp_name}', filepath), 98 | verbose=False, 99 | monitor=monitor, 100 | mode=mode, 101 | save_top_k=5, 102 | ) 103 | 104 | # early stopping 105 | early_stop_callback = EarlyStopping( 106 | monitor=monitor, 107 | min_delta=0.00, 108 | patience=10, 109 | verbose=True, 110 | mode=mode, 111 | ) 112 | 113 | if config.device == -1: 114 | gpus = None 115 | else: 116 | gpus = [config.device] 117 | 118 | trainer = pl.Trainer( 119 | logger=tb_logger, 120 | callbacks=[early_stop_callback], 121 | gpus=gpus, 122 | resume_from_checkpoint=config.model_path, 123 | max_epochs=config.trainer.epochs, 124 | checkpoint_callback=checkpoint_callback, 125 | gradient_clip_val=config.trainer.max_grad_norm, 126 | log_every_n_steps=500, 127 | ) 128 | 129 | # 4. Train! 130 | trainer.fit(model, train_loader, val_loader) 131 | 132 | # 5. Evaluation 133 | test_data = build_dataset( 134 | data_path=config.path.test, 135 | config=config, 136 | is_train=False, 137 | vocab=vocab 138 | ) 139 | 140 | test_loader = build_dataloader( 141 | dataset=test_data, 142 | vocab=vocab, 143 | batch_size=1, 144 | max_decode=config.data.tgt_max_test, 145 | is_train=False, 146 | num_workers=config.data_loader.num_workers, 147 | ) 148 | 149 | test_outputs = trainer.test(model, test_loader) 150 | output_name = generate_output_name(config) 151 | write_output(test_loader=test_loader, 152 | test_outputs=test_outputs, 153 | fname=os.path.join(f'logs/{exp_name}', output_name)) 154 | 155 | 156 | if __name__ == "__main__": 157 | args = argparse.ArgumentParser(description='Pointer-generator network') 158 | args.add_argument( 159 | '-cp', '--config-path', 160 | default='config.json', 161 | type=str, 162 | help='path to config file' 163 | ) 164 | 165 | args.add_argument( 166 | '--mds', 167 | default=None, 168 | type=str, 169 | help='multi-news labeling method to employ. if None, nikl dataset is used.' 170 | ) 171 | 172 | args.add_argument( 173 | '-m', '--model-path', 174 | default=None, 175 | type=str, 176 | help='path to load model' 177 | ) 178 | 179 | args.add_argument( 180 | '--load-vocab', 181 | action='store_true', 182 | default=False, 183 | help='whether to load pre-built vocab file' 184 | ) 185 | 186 | args.add_argument( 187 | '--stop-with', 188 | default='rl', 189 | type=str, 190 | choices=['loss', 'r1', 'r2', 'rl'], 191 | help='validation evaluation metric to perform early stopping' 192 | ) 193 | 194 | args.add_argument( 195 | '-e', '--exp-name', 196 | default='', 197 | type=str, 198 | help='suffix to specify experiment name' 199 | ) 200 | 201 | args.add_argument( 202 | '-d', '--device', 203 | default=-1, 204 | type=int, 205 | help='gpu device number to use. if cpu, set this argument to -1' 206 | ) 207 | 208 | args.add_argument( 209 | '-n', '--note', 210 | default='', 211 | type=str, 212 | help='note to append to result output file name' 213 | ) 214 | 215 | sys.path.append( 216 | os.path.dirname(os.path.abspath(os.path.dirname("__file__"))) 217 | ) 218 | 219 | config = config_parser(args.parse_args()) 220 | main(config) 221 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datetime import datetime 4 | 5 | import easydict 6 | import torch 7 | from setproctitle import setproctitle 8 | 9 | 10 | def config_parser(args=None, interpreter=None): 11 | if interpreter: 12 | with open(interpreter, 'rb') as f: 13 | config = dict(json.load(f)) 14 | config = easydict.EasyDict({**config}) 15 | else: 16 | with open(args.config_path, 'rb') as f: 17 | config = dict(json.load(f)) 18 | config = easydict.EasyDict({**config, **vars(args)}) 19 | 20 | for k, v in vars(args).items(): 21 | config[k] = v 22 | 23 | return config 24 | 25 | 26 | def write_output(test_loader, test_outputs, fname): 27 | save_dir, _ = os.path.split(fname) 28 | if not os.path.exists(save_dir): 29 | os.mkdir(save_dir) 30 | header = ['id', 'src', 'tgt', 'pred'] 31 | with open(fname, 'w') as f: 32 | print('\t'.join(header), file=f) 33 | for batch_idx, batch in enumerate(test_loader): 34 | top_k_outputs = test_outputs[batch_idx]['generated_summary'] 35 | for k, output in enumerate(top_k_outputs, 1): 36 | data = [str(batch_idx) + '_' + str(k), batch.src_text[0], batch.tgt_text[0], output] 37 | print('\t'.join(data), file=f) 38 | print(f"Predicted outputs saved as {fname}") 39 | 40 | 41 | def generate_exp_name(config): 42 | exp_name = datetime.now().strftime("%m-%d-%H:%M:%S") 43 | if len(config.exp_name): 44 | exp_name = exp_name + f'-{config.exp_name}' 45 | # set process title to exp name 46 | setproctitle(exp_name) 47 | print(f'Experiment results saved in logs/{exp_name}') 48 | return exp_name 49 | 50 | 51 | def generate_output_name(config): 52 | if config.model_path is not None: 53 | basename = os.path.basename(config.model_path) 54 | model_name, _ = os.path.splitext(basename) 55 | else: 56 | model_name = 'best' 57 | if len(config.note): 58 | model_name += f'-{config.note}' 59 | basename = os.path.basename(config.path.test) 60 | filename, _ = os.path.splitext(basename) 61 | output_name = f'pred-{filename}-{model_name}.tsv' 62 | return output_name --------------------------------------------------------------------------------