├── .gitignore ├── README.md ├── requirements.txt ├── run_experiments.sh ├── setup.py ├── src ├── abstractive │ ├── attn.py │ ├── beam.py │ ├── beam_search.py │ ├── cal_rouge.py │ ├── data_loader.py │ ├── decode_strategy.py │ ├── loss.py │ ├── model_builder.py │ ├── my_pyrouge.py │ ├── neural.py │ ├── optimizer.py │ ├── penalties.py │ ├── predictor_builder.py │ ├── trainer_builder.py │ ├── transformer_decoder.py │ └── transformer_encoder.py ├── others │ ├── __init__.py │ ├── distributed.py │ ├── logging.py │ ├── report_manager.py │ └── statistics.py └── train_abstractive.py └── test.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore following directories 2 | pyrouge 3 | results.* 4 | logs/ 5 | /data/ 6 | results/ 7 | models/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # IPython Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Augmentation for Abstractive Query-Focused Multi-Document Summarization (AAAI 2021) 2 | 3 | This is the implementation of the paper [Data Augmentation for Abstractive Query-Focused Multi-Document Summarization](https://arxiv.org/pdf/2103.01863.pdf). 4 | 5 | ## Prerequisites 6 | 7 | - Python 3.6+ 8 | - [PyTorch 1.0] (http://pytorch.org/) 9 | - Install all the required packages from requirements.txt file. 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | - Download the processed datasets (in pytorch format) and setup some folders and repos by running the following command: 14 | ``` 15 | python setup.py 16 | ``` 17 | if you face any issues in downloading the datasets with the above code (setup.py), directly download the datasets from here: [wikisum](https://drive.google.com/uc?id=1AnqeUpLkO9MR3PH0V8q32A6PEPDEZ0td), [wikisum-query](https://drive.google.com/uc?id=1RdX-t3pznnyaGyrswFubAfoo9S9w9K5d), [qmds-cnn](https://drive.google.com/uc?id=1KXsvfnK6s6cnYQzD8ZOkXPdA6r5-quPK), [qmds-cnn-query](https://drive.google.com/uc?id=12i_3dikeJLsOj-SQGPmc4w9Is7fB-hT-). Run the above code with the following argument to setup everything else except the datasets. 18 | ``` 19 | python setup.py --ignore_datasets 20 | ``` 21 | - If you face any issues with running ROUGE evaluation, checkout this [link](https://poojithansl7.wordpress.com/2018/08/04/setting-up-rouge/). 22 | 23 | - Some codes are borrowed from: [hiersumm](https://github.com/nlpyang/hiersumm) and [ONMT](https://github.com/OpenNMT/OpenNMT-py). 24 | 25 | 26 | ## Usage 27 | 28 | To train the model: 29 | ``` 30 | DATASET=[CNNDM/WIKI] MODEL_TYPE=[hier/he/order/query/heq/heo/hero] bash run_experiments.sh 31 | ``` 32 | To test the model: 33 | ``` 34 | DATASET=[CNNDM/WIKI] MODEL_TYPE=[hier/he/order/query/heq/heo/hero] bash test.sh 35 | ``` 36 | 37 | Few points to note: 38 | - Various model types (`MODEL_TYPE`): 39 | - hier: Baseline model (Hierarchical Transformers) 40 | - he: HS w/ Hierarchical Encodings 41 | - order: HS w/ Ordering Component 42 | - query: HS w/ Query Encoding 43 | - heq: HS-Joint Model (Hierachical Encodings + Query Encoding) 44 | - heo: HS-Joint Model (Hierachical Encodings + Ordering Component) 45 | - hero: HS-Joint Model (all three components combined) 46 | - We tested our models on Nvidia P-100s 16GB. Each experiments uses 4 GPUs. If you have fewer gpus or memory, set `BATCH_SIZE`, `VISIBLE_GPUS`, `ACCUM_COUNT` accordingly. 47 | - data, vocab, and model paths are set to default locations. Set these variables if you want to use different paths. 48 | 49 | ## Reference 50 | 51 | If you find this code helpful, please consider citing the following paper: 52 | 53 | @inproceedings{pasunuru2021data, 54 | title={Data Augmentation for Abstractive Query-Focused Multi-Document Summarization}, 55 | author={Pasunuru, Ramakanth and Celikyilmaz, Asli and Galley, Michel and Xiong, Chenyan and Zhang, Yizhe and Bansal, Mohit and Gao, Jianfeng}, 56 | booktitle={AAAI}, 57 | year={2021} 58 | } 59 | 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | sentencepiece==0.1.83 3 | numpy==1.17.1 4 | tensorboardX==1.8 5 | nltk==3.4.5 6 | gdown 7 | GitPython 8 | pyrouge==0.1.3 9 | rouge==0.3.2 10 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | # Required environment variables: 2 | # DATASET: choose the dataset (CNNDM / WIKI) 3 | # MODEL_TYPE: choose the type of model (hier, he, order, query, heq, heo, hero) 4 | 5 | BATCH_SIZE=8000 6 | SEED=666 7 | TRAIN_STEPS=500000 8 | SAVE_CHECKPOINT_STEPS=5000 9 | REPORT_EVERY=100 10 | VISIBLE_GPUS="0,1,2,3" 11 | GPU_RANKS="0,1,2,3" 12 | WORLD_SIZE=4 13 | ACCUM_COUNT=2 14 | DROPOUT=0.1 15 | LABEL_SMOOTHING=0.1 16 | INTER_LAYERS="6,7" 17 | INTER_HEADS=8 18 | LR=1 19 | MAX_SAMPLES=500 20 | 21 | case $MODEL_TYPE in 22 | query|heq|hero) 23 | QUERY=True 24 | ;; 25 | hier|he|order|heo) 26 | QUERY=False 27 | ;; 28 | *) 29 | echo "Invalid option: ${MODEL_TYPE}" 30 | ;; 31 | esac 32 | 33 | case $DATASET in 34 | CNNDM) 35 | TRUNC_TGT_NTOKEN=100 36 | TRUNC_SRC_NTOKEN=200 37 | TRUNC_SRC_NBLOCK=8 38 | if [ $QUERY == "False" ]; then 39 | DATA_FOLDER_NAME=pytorch_qmdscnn 40 | else 41 | DATA_FOLDER_NAME=pytorch_qmdscnn_query 42 | fi 43 | if [ -z ${DATA_PATH+x} ]; then 44 | DATA_PATH="data/qmdscnn/${DATA_FOLDER_NAME}/CNNDM" 45 | fi 46 | if [ -z ${VOCAB_PATH+x} ]; then 47 | VOCAB_PATH="data/qmdscnn/${DATA_FOLDER_NAME}/spm.model" 48 | fi 49 | ;; 50 | WIKI) 51 | TRUNC_TGT_NTOKEN=400 52 | TRUNC_SRC_NTOKEN=100 53 | TRUNC_SRC_NBLOCK=24 54 | if [ $QUERY == "False" ]; then 55 | DATA_FOLDER_NAME=ranked_wiki_b40 56 | else 57 | DATA_FOLDER_NAME=ranked_wiki_b40_query 58 | fi 59 | if [ -z ${DATA_PATH+x} ]; then 60 | DATA_PATH="data/wikisum/${DATA_FOLDER_NAME}/WIKI" 61 | fi 62 | if [ -z ${VOCAB_PATH+x} ]; then 63 | VOCAB_PATH="data/wikisum/${DATA_FOLDER_NAME}/spm9998_3.model" 64 | fi 65 | ;; 66 | *) 67 | echo "Invalid option: ${DATASET}" 68 | 69 | esac 70 | 71 | # If model path not set 72 | if [ -z ${MODEL_PATH+x} ]; then 73 | MODEL_PATH="results/model-${DATASET}-${MODEL_TYPE}" 74 | fi 75 | 76 | 77 | python src/train_abstractive.py \ 78 | -mode train \ 79 | -batch_size $BATCH_SIZE \ 80 | -seed $SEED \ 81 | -train_steps $TRAIN_STEPS \ 82 | -save_checkpoint_steps $SAVE_CHECKPOINT_STEPS \ 83 | -report_every $REPORT_EVERY \ 84 | -trunc_tgt_ntoken $TRUNC_TGT_NTOKEN \ 85 | -trunc_src_ntoken $TRUNC_SRC_NTOKEN \ 86 | -trunc_src_nblock $TRUNC_SRC_NBLOCK \ 87 | -visible_gpus $VISIBLE_GPUS \ 88 | -gpu_ranks $GPU_RANKS \ 89 | -world_size $WORLD_SIZE \ 90 | -accum_count $ACCUM_COUNT \ 91 | -lr $LR \ 92 | -dec_dropout $DROPOUT \ 93 | -enc_dropout $DROPOUT \ 94 | -label_smoothing $LABEL_SMOOTHING \ 95 | -inter_layers $INTER_LAYERS \ 96 | -inter_heads $INTER_HEADS \ 97 | -hier \ 98 | -dataset $DATASET \ 99 | -model_type $MODEL_TYPE \ 100 | -query $QUERY \ 101 | -max_samples $MAX_SAMPLES \ 102 | -data_path $DATA_PATH \ 103 | -vocab_path $VOCAB_PATH \ 104 | -model_path $MODEL_PATH \ 105 | -result_path $MODEL_PATH/outputs \ 106 | 107 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | import gdown 5 | import zipfile 6 | import git 7 | import subprocess 8 | 9 | logging.basicConfig(level = logging.INFO) 10 | 11 | def main(args): 12 | """ Module to setup the codebase """ 13 | # create the data folder 14 | if not os.path.exists("data"): 15 | logging.info("creating 'data' directory...") 16 | os.mkdir("data") 17 | if not os.path.exists("data/wikisum"): 18 | logging.info("creating 'data/wikisum' directory...") 19 | os.mkdir("data/wikisum") 20 | if not os.path.exists("data/qmdscnn"): 21 | logging.info("creating 'data/qmdscnn' directory...") 22 | os.mkdir("data/qmdscnn") 23 | # create the results folder 24 | if not os.path.exists("results"): 25 | logging.info("creating the 'results' directory...") 26 | os.mkdir("results") 27 | 28 | if not args.ignore_datasets: 29 | # download and unzip the wikisum dataset 30 | output_path = "data/wikisum/ranked_wiki_b40.zip" 31 | if not os.path.exists(output_path): 32 | logging.info("Downloading the encoded WikiSum dataset...") 33 | url = "https://drive.google.com/uc?id=1AnqeUpLkO9MR3PH0V8q32A6PEPDEZ0td&export=download" 34 | gdown.download(url, output_path, quiet=False) 35 | logging.info("Unziping the data...") 36 | with zipfile.ZipFile(output_path, "r") as zip_ref: 37 | zip_ref.extractall("data/wikisum") 38 | 39 | output_path = "data/wikisum/ranked_wiki_b40_query.zip" 40 | if not os.path.exists(output_path): 41 | logging.info("Downloading the encoded WikiSum dataset...") 42 | url = "https://drive.google.com/uc?id=1RdX-t3pznnyaGyrswFubAfoo9S9w9K5d&export=download" 43 | gdown.download(url, output_path, quiet=False) 44 | logging.info("Unziping the data...") 45 | with zipfile.ZipFile(output_path, "r") as zip_ref: 46 | zip_ref.extractall(data/wikisum) 47 | 48 | 49 | # download and unzip the QMDSCNN dataset 50 | output_path = "data/qmdscnn/pytorch_qmdscnn.zip" 51 | if not os.path.exists(output_path): 52 | logging.info("Downloading the encoded QMDSCNN dataset...") 53 | url = "https://drive.google.com/uc?id=1KXsvfnK6s6cnYQzD8ZOkXPdA6r5-quPK&export=download" 54 | gdown.download(url, output_path, quiet=False) 55 | logging.info("Unziping the data...") 56 | with zipfile.ZipFile(output_path, "r") as zip_ref: 57 | zip_ref.extractall("data/qmdscnn") 58 | 59 | output_path = "data/qmdscnn/pytorch_qmdscnn_query.zip" 60 | if not os.path.exists(output_path): 61 | url = "https://drive.google.com/uc?id=12i_3dikeJLsOj-SQGPmc4w9Is7fB-hT-&export=download" 62 | gdown.download(url, output_path, quiet=False) 63 | logging.info("Unziping the data...") 64 | with zipfile.ZipFile(output_path, "r") as zip_ref: 65 | zip_ref.extractall(data/qmdscnn) 66 | 67 | 68 | # download the pyrouge git repo 69 | if not os.path.exists("pyrouge"): 70 | repo_url = "https://github.com/andersjo/pyrouge.git" 71 | logging.info(f"Downloading repo: {repo_url}") 72 | git.Git(".").clone(repo_url) 73 | 74 | 75 | # set the ROUGE path 76 | rouge_path = os.path.join(os.getcwd(),"pyrouge/tools/ROUGE-1.5.5") 77 | subprocess.run(["pyrouge_set_rouge_path", f"{rouge_path}"]) 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--ignore_datasets", default=False, action="store_true") 90 | args = parser.parse_args() 91 | main(args) 92 | -------------------------------------------------------------------------------- /src/abstractive/attn.py: -------------------------------------------------------------------------------- 1 | """ Multi-Head Attention module 2 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 3 | """ 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | class MultiHeadedAttention(nn.Module): 9 | """ 10 | Multi-Head Attention module from 11 | "Attention is All You Need" 12 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. 13 | 14 | Similar to standard `dot` attention but uses 15 | multiple attention distributions simulataneously 16 | to select relevant items. 17 | 18 | .. mermaid:: 19 | 20 | graph BT 21 | A[key] 22 | B[value] 23 | C[query] 24 | O[output] 25 | subgraph Attn 26 | D[Attn 1] 27 | E[Attn 2] 28 | F[Attn N] 29 | end 30 | A --> D 31 | C --> D 32 | A --> E 33 | C --> E 34 | A --> F 35 | C --> F 36 | D --> O 37 | E --> O 38 | F --> O 39 | B --> O 40 | 41 | Also includes several additional tricks. 42 | 43 | Args: 44 | head_count (int): number of parallel heads 45 | model_dim (int): the dimension of keys/values/queries, 46 | must be divisible by head_count 47 | dropout (float): dropout parameter 48 | """ 49 | 50 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 51 | assert model_dim % head_count == 0 52 | self.dim_per_head = model_dim // head_count 53 | self.model_dim = model_dim 54 | 55 | super(MultiHeadedAttention, self).__init__() 56 | self.head_count = head_count 57 | 58 | self.linear_keys = nn.Linear(model_dim, 59 | head_count * self.dim_per_head) 60 | self.linear_values = nn.Linear(model_dim, 61 | head_count * self.dim_per_head) 62 | self.linear_query = nn.Linear(model_dim, 63 | head_count * self.dim_per_head) 64 | self.softmax = nn.Softmax(dim=-1) 65 | self.dropout = nn.Dropout(dropout) 66 | self.use_final_linear = use_final_linear 67 | if(self.use_final_linear): 68 | self.final_linear = nn.Linear(model_dim, model_dim) 69 | 70 | def forward(self, key, value, query, mask=None, 71 | layer_cache=None, type=None): 72 | """ 73 | Compute the context vector and the attention vectors. 74 | 75 | Args: 76 | key (`FloatTensor`): set of `key_len` 77 | key vectors `[batch, key_len, dim]` 78 | value (`FloatTensor`): set of `key_len` 79 | value vectors `[batch, key_len, dim]` 80 | query (`FloatTensor`): set of `query_len` 81 | query vectors `[batch, query_len, dim]` 82 | mask: binary mask indicating which keys have 83 | non-zero attention `[batch, query_len, key_len]` 84 | Returns: 85 | (`FloatTensor`, `FloatTensor`) : 86 | 87 | * output context vectors `[batch, query_len, dim]` 88 | * one of the attention vectors `[batch, query_len, key_len]` 89 | """ 90 | 91 | batch_size = key.size(0) 92 | dim_per_head = self.dim_per_head 93 | head_count = self.head_count 94 | key_len = key.size(1) 95 | query_len = query.size(1) 96 | 97 | def shape(x): 98 | """ projection """ 99 | return x.view(batch_size, -1, head_count, dim_per_head) \ 100 | .transpose(1, 2) 101 | 102 | def unshape(x): 103 | """ compute context """ 104 | return x.transpose(1, 2).contiguous() \ 105 | .view(batch_size, -1, head_count * dim_per_head) 106 | 107 | # 1) Project key, value, and query. 108 | if layer_cache is not None: 109 | if type == "self": 110 | query, key, value = self.linear_query(query),\ 111 | self.linear_keys(query),\ 112 | self.linear_values(query) 113 | 114 | key = shape(key) 115 | value = shape(value) 116 | 117 | if layer_cache is not None: 118 | device = key.device 119 | if layer_cache["self_keys"] is not None: 120 | key = torch.cat( 121 | (layer_cache["self_keys"].to(device), key), 122 | dim=2) 123 | if layer_cache["self_values"] is not None: 124 | value = torch.cat( 125 | (layer_cache["self_values"].to(device), value), 126 | dim=2) 127 | layer_cache["self_keys"] = key 128 | layer_cache["self_values"] = value 129 | elif type in ["context", "global_context", "local_context"]: 130 | query = self.linear_query(query) 131 | if layer_cache is not None: 132 | if type in ["context", "global_context"]: 133 | mkeys_str = "memory_keys" 134 | mvalues_str = "memory_values" 135 | elif type in ["local_context"]: 136 | mkeys_str = "local_memory_keys" 137 | mvalues_str = "local_memory_values" 138 | 139 | if layer_cache[mkeys_str] is None: 140 | key, value = self.linear_keys(key),\ 141 | self.linear_values(value) 142 | key = shape(key) 143 | value = shape(value) 144 | else: 145 | key, value = layer_cache[mkeys_str],\ 146 | layer_cache[mvalues_str] 147 | layer_cache[mkeys_str] = key 148 | layer_cache[mvalues_str] = value 149 | else: 150 | key, value = self.linear_keys(key),\ 151 | self.linear_values(value) 152 | key = shape(key) 153 | value = shape(value) 154 | else: 155 | key = self.linear_keys(key) 156 | value = self.linear_values(value) 157 | query = self.linear_query(query) 158 | key = shape(key) 159 | value = shape(value) 160 | 161 | query = shape(query) 162 | 163 | key_len = key.size(2) 164 | query_len = query.size(2) 165 | 166 | # 2) Calculate and scale scores. 167 | query = query / math.sqrt(dim_per_head) 168 | 169 | scores = torch.matmul(query, key.transpose(2, 3)) 170 | 171 | if mask is not None: 172 | mask = mask.unsqueeze(1).expand_as(scores) 173 | scores = scores.masked_fill(mask, -1e18) 174 | 175 | # 3) Apply attention dropout and compute context vectors. 176 | 177 | attn = self.softmax(scores) 178 | 179 | 180 | drop_attn = self.dropout(attn) 181 | if(self.use_final_linear): 182 | context = unshape(torch.matmul(drop_attn, value)) 183 | output = self.final_linear(context) 184 | return output, attn 185 | else: 186 | context = torch.matmul(drop_attn, value) 187 | return context, attn 188 | 189 | 190 | 191 | 192 | 193 | class MultiHeadedPooling(nn.Module): 194 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 195 | assert model_dim % head_count == 0 196 | self.dim_per_head = model_dim // head_count 197 | self.model_dim = model_dim 198 | super(MultiHeadedPooling, self).__init__() 199 | self.head_count = head_count 200 | self.linear_keys = nn.Linear(model_dim, 201 | head_count) 202 | self.linear_values = nn.Linear(model_dim, 203 | head_count * self.dim_per_head) 204 | self.softmax = nn.Softmax(dim=-1) 205 | self.dropout = nn.Dropout(dropout) 206 | if (use_final_linear): 207 | self.final_linear = nn.Linear(model_dim, model_dim) 208 | self.use_final_linear = use_final_linear 209 | 210 | def forward(self, key, value, mask=None): 211 | batch_size = key.size(0) 212 | dim_per_head = self.dim_per_head 213 | head_count = self.head_count 214 | 215 | def shape(x, dim=dim_per_head): 216 | """ projection """ 217 | return x.view(batch_size, -1, head_count, dim) \ 218 | .transpose(1, 2) 219 | 220 | def unshape(x, dim=dim_per_head): 221 | """ compute context """ 222 | return x.transpose(1, 2).contiguous() \ 223 | .view(batch_size, -1, head_count * dim) 224 | 225 | scores = self.linear_keys(key) 226 | value = self.linear_values(value) 227 | 228 | scores = shape(scores, 1).squeeze(-1) 229 | value = shape(value) 230 | 231 | 232 | if mask is not None: 233 | mask = mask.unsqueeze(1).expand_as(scores) 234 | scores = scores.masked_fill(mask, -1e18) 235 | 236 | # 3) Apply attention dropout and compute context vectors. 237 | attn = self.softmax(scores) 238 | drop_attn = self.dropout(attn) 239 | context = torch.sum((drop_attn.unsqueeze(-1) * value), -2) 240 | if (self.use_final_linear): 241 | context = unshape(context).squeeze(1) 242 | output = self.final_linear(context) 243 | return output 244 | else: 245 | return context 246 | 247 | 248 | class SelfAttention(nn.Module): 249 | 250 | def __init__(self, model_dim, dropout=0.1): 251 | super(SelfAttention, self).__init__() 252 | self.Va = nn.Linear(model_dim, 1, bias=False) 253 | self.Wa = nn.Linear(model_dim, model_dim) 254 | self.dropout = nn.Dropout(dropout) 255 | 256 | def forward(self, x, mask=None): 257 | b, t, n = x.size() 258 | 259 | proj = torch.tanh(self.Wa(x.view(b*t, n).contiguous())) 260 | scores = self.Va(proj) 261 | scores = scores.view(b,t).contiguous() 262 | 263 | if mask is not None: 264 | scores = scores.masked_fill(mask, -1e18) 265 | 266 | attn = torch.softmax(scores, -1) 267 | drop_attn = self.dropout(attn) 268 | 269 | context = torch.sum((drop_attn.unsqueeze(-1)*x), -2) 270 | 271 | return context, attn 272 | 273 | 274 | -------------------------------------------------------------------------------- /src/abstractive/beam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Majority of code borrowed from https://github.com/OpenNMT/OpenNMT-py 3 | """ 4 | import torch 5 | from abstractive.penalties import PenaltyBuilder 6 | 7 | 8 | 9 | 10 | class Beam(object): 11 | """Class for managing the internals of the beam search process. 12 | 13 | Takes care of beams, back pointers, and scores. 14 | 15 | Args: 16 | size (int): Number of beams to use. 17 | pad (int): Magic integer in output vocab. 18 | bos (int): Magic integer in output vocab. 19 | eos (int): Magic integer in output vocab. 20 | n_best (int): Don't stop until at least this many beams have 21 | reached EOS. 22 | cuda (bool): use gpu 23 | global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. 24 | min_length (int): Shortest acceptable generation, not counting 25 | begin-of-sentence or end-of-sentence. 26 | stepwise_penalty (bool): Apply coverage penalty at every step. 27 | block_ngram_repeat (int): Block beams where 28 | ``block_ngram_repeat``-grams repeat. 29 | exclusion_tokens (set[int]): If a gram contains any of these 30 | token indices, it may repeat. 31 | """ 32 | 33 | def __init__(self, size, pad, bos, eos, 34 | n_best=1, cuda=False, 35 | global_scorer=None, 36 | min_length=0, 37 | stepwise_penalty=False, 38 | block_ngram_repeat=0, 39 | exclusion_tokens=set()): 40 | 41 | self.size = size 42 | self.tt = torch.cuda if cuda else torch 43 | 44 | # The score for each translation on the beam. 45 | self.scores = self.tt.FloatTensor(size).zero_() 46 | self.all_scores = [] 47 | 48 | # The backpointers at each time-step. 49 | self.prev_ks = [] 50 | 51 | # The outputs at each time-step. 52 | self.next_ys = [self.tt.LongTensor(size) 53 | .fill_(pad)] 54 | self.next_ys[0][0] = bos 55 | 56 | # Has EOS topped the beam yet. 57 | self._eos = eos 58 | self.eos_top = False 59 | 60 | # The attentions (matrix) for each time. 61 | self.attn = [] 62 | 63 | # Time and k pair for finished. 64 | self.finished = [] 65 | self.n_best = n_best 66 | 67 | # Information for global scoring. 68 | self.global_scorer = global_scorer 69 | self.global_state = {} 70 | 71 | # Minimum prediction length 72 | self.min_length = min_length 73 | 74 | # Apply Penalty at every step 75 | self.stepwise_penalty = stepwise_penalty 76 | self.block_ngram_repeat = block_ngram_repeat 77 | self.exclusion_tokens = exclusion_tokens 78 | 79 | @property 80 | def current_predictions(self): 81 | return self.next_ys[-1] 82 | 83 | @property 84 | def current_origin(self): 85 | """Get the backpointers for the current timestep.""" 86 | return self.prev_ks[-1] 87 | 88 | def advance(self, word_probs, attn_out): 89 | """ 90 | Given prob over words for every last beam `wordLk` and attention 91 | `attn_out`: Compute and update the beam search. 92 | 93 | Args: 94 | word_probs (FloatTensor): probs of advancing from the last step 95 | ``(K, words)`` 96 | attn_out (FloatTensor): attention at the last step 97 | 98 | Returns: 99 | bool: True if beam search is complete. 100 | """ 101 | 102 | num_words = word_probs.size(1) 103 | if self.stepwise_penalty: 104 | self.global_scorer.update_score(self, attn_out) 105 | # force the output to be longer than self.min_length 106 | cur_len = len(self.next_ys) 107 | if cur_len <= self.min_length: 108 | # assumes there are len(word_probs) predictions OTHER 109 | # than EOS that are greater than -1e20 110 | for k in range(len(word_probs)): 111 | word_probs[k][self._eos] = -1e20 112 | 113 | # Sum the previous scores. 114 | if len(self.prev_ks) > 0: 115 | beam_scores = word_probs + self.scores.unsqueeze(1) 116 | # Don't let EOS have children. 117 | for i in range(self.next_ys[-1].size(0)): 118 | if self.next_ys[-1][i] == self._eos: 119 | beam_scores[i] = -1e20 120 | 121 | # Block ngram repeats 122 | if self.block_ngram_repeat > 0: 123 | le = len(self.next_ys) 124 | for j in range(self.next_ys[-1].size(0)): 125 | hyp, _ = self.get_hyp(le - 1, j) 126 | ngrams = set() 127 | fail = False 128 | gram = [] 129 | for i in range(le - 1): 130 | # Last n tokens, n = block_ngram_repeat 131 | gram = (gram + 132 | [hyp[i].item()])[-self.block_ngram_repeat:] 133 | # Skip the blocking if it is in the exclusion list 134 | if set(gram) & self.exclusion_tokens: 135 | continue 136 | if tuple(gram) in ngrams: 137 | fail = True 138 | ngrams.add(tuple(gram)) 139 | if fail: 140 | beam_scores[j] = -10e20 141 | else: 142 | beam_scores = word_probs[0] 143 | flat_beam_scores = beam_scores.view(-1) 144 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 145 | True, True) 146 | 147 | self.all_scores.append(self.scores) 148 | self.scores = best_scores 149 | 150 | # best_scores_id is flattened beam x word array, so calculate which 151 | # word and beam each score came from 152 | prev_k = best_scores_id / num_words 153 | self.prev_ks.append(prev_k) 154 | self.next_ys.append((best_scores_id - prev_k * num_words)) 155 | self.attn.append(attn_out.index_select(0, prev_k)) 156 | self.global_scorer.update_global_state(self) 157 | 158 | for i in range(self.next_ys[-1].size(0)): 159 | if self.next_ys[-1][i] == self._eos: 160 | global_scores = self.global_scorer.score(self, self.scores) 161 | s = global_scores[i] 162 | self.finished.append((s, len(self.next_ys) - 1, i)) 163 | 164 | # End condition is when top-of-beam is EOS and no global score. 165 | if self.next_ys[-1][0] == self._eos: 166 | self.all_scores.append(self.scores) 167 | self.eos_top = True 168 | 169 | @property 170 | def done(self): 171 | return self.eos_top and len(self.finished) >= self.n_best 172 | 173 | def sort_finished(self, minimum=None): 174 | if minimum is not None: 175 | i = 0 176 | # Add from beam until we have minimum outputs. 177 | while len(self.finished) < minimum: 178 | global_scores = self.global_scorer.score(self, self.scores) 179 | s = global_scores[i] 180 | self.finished.append((s, len(self.next_ys) - 1, i)) 181 | i += 1 182 | 183 | self.finished.sort(key=lambda a: -a[0]) 184 | scores = [sc for sc, _, _ in self.finished] 185 | ks = [(t, k) for _, t, k in self.finished] 186 | return scores, ks 187 | 188 | def get_hyp(self, timestep, k): 189 | """Walk back to construct the full hypothesis.""" 190 | hyp, attn = [], [] 191 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 192 | hyp.append(self.next_ys[j + 1][k]) 193 | attn.append(self.attn[j][k]) 194 | k = self.prev_ks[j][k] 195 | return hyp[::-1], torch.stack(attn[::-1]) 196 | 197 | 198 | class GNMTGlobalScorer(object): 199 | """NMT re-ranking. 200 | 201 | Args: 202 | alpha (float): Length parameter. 203 | beta (float): Coverage parameter. 204 | length_penalty (str): Length penalty strategy. 205 | coverage_penalty (str): Coverage penalty strategy. 206 | 207 | Attributes: 208 | alpha (float): See above. 209 | beta (float): See above. 210 | length_penalty (callable): See :class:`penalties.PenaltyBuilder`. 211 | coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`. 212 | has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`. 213 | has_len_pen (bool): See :class:`penalties.PenaltyBuilder`. 214 | """ 215 | 216 | 217 | def __init__(self, alpha, beta, length_penalty, coverage_penalty): 218 | 219 | self.alpha = alpha 220 | self.beta = beta 221 | penalty_builder = PenaltyBuilder(coverage_penalty, 222 | length_penalty) 223 | self.has_cov_pen = penalty_builder.has_cov_pen 224 | # Term will be subtracted from probability 225 | self.cov_penalty = penalty_builder.coverage_penalty 226 | 227 | self.has_len_pen = penalty_builder.has_len_pen 228 | # Probability will be divided by this 229 | self.length_penalty = penalty_builder.length_penalty 230 | 231 | 232 | 233 | def score(self, beam, logprobs): 234 | """Rescore a prediction based on penalty functions.""" 235 | len_pen = self.length_penalty(len(beam.next_ys), self.alpha) 236 | normalized_probs = logprobs / len_pen 237 | if not beam.stepwise_penalty: 238 | penalty = self.cov_penalty(beam.global_state["coverage"], 239 | self.beta) 240 | normalized_probs -= penalty 241 | 242 | return normalized_probs 243 | 244 | def update_score(self, beam, attn): 245 | """Update scores of a Beam that is not finished.""" 246 | if "prev_penalty" in beam.global_state.keys(): 247 | beam.scores.add_(beam.global_state["prev_penalty"]) 248 | penalty = self.cov_penalty(beam.global_state["coverage"] + attn, 249 | self.beta) 250 | beam.scores.sub_(penalty) 251 | 252 | def update_global_state(self, beam): 253 | """Keeps the coverage vector as sum of attentions.""" 254 | if len(beam.prev_ks) == 1: 255 | beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0) 256 | beam.global_state["coverage"] = beam.attn[-1] 257 | self.cov_total = beam.attn[-1].sum(1) 258 | else: 259 | self.cov_total += torch.min(beam.attn[-1], 260 | beam.global_state['coverage']).sum(1) 261 | beam.global_state["coverage"] = beam.global_state["coverage"] \ 262 | .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) 263 | 264 | prev_penalty = self.cov_penalty(beam.global_state["coverage"], 265 | self.beta) 266 | beam.global_state["prev_penalty"] = prev_penalty 267 | -------------------------------------------------------------------------------- /src/abstractive/beam_search.py: -------------------------------------------------------------------------------- 1 | """ Beam Search module 2 | Majority of code borrowed from https://github.com/OpenNMT/OpenNMT-py 3 | """ 4 | 5 | import torch 6 | 7 | from abstractive.decode_strategy import DecodeStrategy 8 | 9 | 10 | class BeamSearch(DecodeStrategy): 11 | """Generation beam search. 12 | 13 | Note that the attributes list is not exhaustive. Rather, it highlights 14 | tensors to document their shape. (Since the state variables' "batch" 15 | size decreases as beams finish, we denote this axis with a B rather than 16 | ``batch_size``). 17 | 18 | Args: 19 | beam_size (int): Number of beams to use (see base ``parallel_paths``). 20 | batch_size (int): See base. 21 | pad (int): See base. 22 | bos (int): See base. 23 | eos (int): See base. 24 | n_best (int): Don't stop until at least this many beams have 25 | reached EOS. 26 | mb_device (torch.device or str): See base ``device``. 27 | global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. 28 | min_length (int): See base. 29 | max_length (int): See base. 30 | return_attention (bool): See base. 31 | block_ngram_repeat (int): See base. 32 | exclusion_tokens (set[int]): See base. 33 | memory_lengths (LongTensor): Lengths of encodings. Used for 34 | masking attentions. 35 | 36 | Attributes: 37 | top_beam_finished (ByteTensor): Shape ``(B,)``. 38 | _batch_offset (LongTensor): Shape ``(B,)``. 39 | _beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``. 40 | alive_seq (LongTensor): See base. 41 | topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These 42 | are the scores used for the topk operation. 43 | select_indices (LongTensor or NoneType): Shape 44 | ``(B x beam_size,)``. This is just a flat view of the 45 | ``_batch_index``. 46 | topk_scores (FloatTensor): Shape 47 | ``(B, beam_size)``. These are the 48 | scores a sequence will receive if it finishes. 49 | topk_ids (LongTensor): Shape ``(B, beam_size)``. These are the 50 | word indices of the topk predictions. 51 | _batch_index (LongTensor): Shape ``(B, beam_size)``. 52 | _prev_penalty (FloatTensor or NoneType): Shape 53 | ``(B, beam_size)``. Initialized to ``None``. 54 | _coverage (FloatTensor or NoneType): Shape 55 | ``(1, B x beam_size, inp_seq_len)``. 56 | hypotheses (list[list[Tuple[Tensor]]]): Contains a tuple 57 | of score (float), sequence (long), and attention (float or None). 58 | """ 59 | 60 | def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device, 61 | global_scorer, min_length, max_length, return_attention, 62 | block_ngram_repeat, exclusion_tokens, memory_lengths, 63 | stepwise_penalty, ratio): 64 | super(BeamSearch, self).__init__( 65 | pad, bos, eos, batch_size, mb_device, beam_size, min_length, 66 | block_ngram_repeat, exclusion_tokens, return_attention, 67 | max_length) 68 | # beam parameters 69 | self.global_scorer = global_scorer 70 | self.beam_size = beam_size 71 | self.n_best = n_best 72 | self.batch_size = batch_size 73 | self.ratio = ratio 74 | 75 | # result caching 76 | self.hypotheses = [[] for _ in range(batch_size)] 77 | 78 | # beam state 79 | self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8) 80 | self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float, 81 | device=mb_device) 82 | 83 | self._batch_offset = torch.arange(batch_size, dtype=torch.long) 84 | self._beam_offset = torch.arange( 85 | 0, batch_size * beam_size, step=beam_size, dtype=torch.long, 86 | device=mb_device) 87 | self.topk_log_probs = torch.tensor( 88 | [0.0] + [float("-inf")] * (beam_size - 1), device=mb_device 89 | ).repeat(batch_size) 90 | self.select_indices = None 91 | self._memory_lengths = memory_lengths 92 | 93 | # buffers for the topk scores and 'backpointer' 94 | self.topk_scores = torch.empty((batch_size, beam_size), 95 | dtype=torch.float, device=mb_device) 96 | self.topk_ids = torch.empty((batch_size, beam_size), dtype=torch.long, 97 | device=mb_device) 98 | self._batch_index = torch.empty([batch_size, beam_size], 99 | dtype=torch.long, device=mb_device) 100 | self.done = False 101 | # "global state" of the old beam 102 | self._prev_penalty = None 103 | self._coverage = None 104 | 105 | self._stepwise_cov_pen = ( 106 | stepwise_penalty and self.global_scorer.has_cov_pen) 107 | self._vanilla_cov_pen = ( 108 | not stepwise_penalty and self.global_scorer.has_cov_pen) 109 | self._cov_pen = self.global_scorer.has_cov_pen 110 | 111 | @property 112 | def current_predictions(self): 113 | return self.alive_seq[:, -1] 114 | 115 | @property 116 | def current_origin(self): 117 | return self.select_indices 118 | 119 | @property 120 | def current_backptr(self): 121 | # for testing 122 | return self.select_indices.view(self.batch_size, self.beam_size)\ 123 | .fmod(self.beam_size) 124 | 125 | def advance(self, log_probs, attn): 126 | vocab_size = log_probs.size(-1) 127 | 128 | # using integer division to get an integer _B without casting 129 | _B = log_probs.shape[0] // self.beam_size 130 | 131 | if self._stepwise_cov_pen and self._prev_penalty is not None: 132 | self.topk_log_probs += self._prev_penalty 133 | self.topk_log_probs -= self.global_scorer.cov_penalty( 134 | self._coverage + attn, self.global_scorer.beta).view( 135 | _B, self.beam_size) 136 | 137 | # force the output to be longer than self.min_length 138 | step = len(self) 139 | self.ensure_min_length(log_probs) 140 | 141 | # Multiply probs by the beam probability. 142 | log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) 143 | 144 | self.block_ngram_repeats(log_probs) 145 | 146 | # if the sequence ends now, then the penalty is the current 147 | # length + 1, to include the EOS token 148 | length_penalty = self.global_scorer.length_penalty( 149 | step + 1, alpha=self.global_scorer.alpha) 150 | 151 | # Flatten probs into a list of possibilities. 152 | curr_scores = log_probs / length_penalty 153 | curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size) 154 | torch.topk(curr_scores, self.beam_size, dim=-1, 155 | out=(self.topk_scores, self.topk_ids)) 156 | 157 | # Recover log probs. 158 | # Length penalty is just a scalar. It doesn't matter if it's applied 159 | # before or after the topk. 160 | torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs) 161 | 162 | # Resolve beam origin and map to batch index flat representation. 163 | torch.div(self.topk_ids, vocab_size, out=self._batch_index) 164 | self._batch_index += self._beam_offset[:_B].unsqueeze(1) 165 | self.select_indices = self._batch_index.view(_B * self.beam_size) 166 | 167 | self.topk_ids.fmod_(vocab_size) # resolve true word ids 168 | 169 | # Append last prediction. 170 | self.alive_seq = torch.cat( 171 | [self.alive_seq.index_select(0, self.select_indices), 172 | self.topk_ids.view(_B * self.beam_size, 1)], -1) 173 | if self.return_attention or self._cov_pen: 174 | current_attn = attn.index_select(1, self.select_indices) 175 | if step == 1: 176 | self.alive_attn = current_attn 177 | # update global state (step == 1) 178 | if self._cov_pen: # coverage penalty 179 | self._prev_penalty = torch.zeros_like(self.topk_log_probs) 180 | self._coverage = current_attn 181 | else: 182 | self.alive_attn = self.alive_attn.index_select( 183 | 1, self.select_indices) 184 | self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) 185 | # update global state (step > 1) 186 | if self._cov_pen: 187 | self._coverage = self._coverage.index_select( 188 | 1, self.select_indices) 189 | self._coverage += current_attn 190 | self._prev_penalty = self.global_scorer.cov_penalty( 191 | self._coverage, beta=self.global_scorer.beta).view( 192 | _B, self.beam_size) 193 | 194 | if self._vanilla_cov_pen: 195 | # shape: (batch_size x beam_size, 1) 196 | cov_penalty = self.global_scorer.cov_penalty( 197 | self._coverage, 198 | beta=self.global_scorer.beta) 199 | self.topk_scores -= cov_penalty.view(_B, self.beam_size) 200 | 201 | self.is_finished = self.topk_ids.eq(self.eos) 202 | self.ensure_max_length() 203 | 204 | def update_finished(self): 205 | # Penalize beams that finished. 206 | _B_old = self.topk_log_probs.shape[0] 207 | step = self.alive_seq.shape[-1] # 1 greater than the step in advance 208 | self.topk_log_probs.masked_fill_(self.is_finished, -1e10) 209 | # on real data (newstest2017) with the pretrained transformer, 210 | # it's faster to not move this back to the original device 211 | self.is_finished = self.is_finished.to('cpu') 212 | self.top_beam_finished |= self.is_finished[:, 0].eq(1) 213 | predictions = self.alive_seq.view(_B_old, self.beam_size, step) 214 | attention = ( 215 | self.alive_attn.view( 216 | step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) 217 | if self.alive_attn is not None else None) 218 | non_finished_batch = [] 219 | for i in range(self.is_finished.size(0)): 220 | b = self._batch_offset[i] 221 | finished_hyp = self.is_finished[i].nonzero().view(-1) 222 | # Store finished hypotheses for this batch. 223 | for j in finished_hyp: 224 | if self.ratio > 0: 225 | s = self.topk_scores[i, j] / (step + 1) 226 | if self.best_scores[b] < s: 227 | self.best_scores[b] = s 228 | self.hypotheses[b].append(( 229 | self.topk_scores[i, j], 230 | predictions[i, j, 1:], # Ignore start_token. 231 | attention[:, i, j, :] 232 | if attention is not None else None)) ## changed this from original code 233 | # End condition is the top beam finished and we can return 234 | # n_best hypotheses. 235 | if self.ratio > 0: 236 | pred_len = self._memory_lengths[i] * self.ratio 237 | finish_flag = ((self.topk_scores[i, 0] / pred_len) 238 | <= self.best_scores[b]) or \ 239 | self.is_finished[i].all() 240 | else: 241 | finish_flag = self.top_beam_finished[i] != 0 242 | if finish_flag and len(self.hypotheses[b]) >= self.n_best: 243 | best_hyp = sorted( 244 | self.hypotheses[b], key=lambda x: x[0], reverse=True) 245 | for n, (score, pred, attn) in enumerate(best_hyp): 246 | if n >= self.n_best: 247 | break 248 | self.scores[b].append(score) 249 | self.predictions[b].append(pred) 250 | self.attention[b].append( 251 | attn if attn is not None else []) 252 | else: 253 | non_finished_batch.append(i) 254 | non_finished = torch.tensor(non_finished_batch) 255 | # If all sentences are translated, no need to go further. 256 | if len(non_finished) == 0: 257 | self.done = True 258 | return 259 | 260 | _B_new = non_finished.shape[0] 261 | # Remove finished batches for the next step. 262 | self.top_beam_finished = self.top_beam_finished.index_select( 263 | 0, non_finished) 264 | self._batch_offset = self._batch_offset.index_select(0, non_finished) 265 | non_finished = non_finished.to(self.topk_ids.device) 266 | self.topk_log_probs = self.topk_log_probs.index_select(0, 267 | non_finished) 268 | self._batch_index = self._batch_index.index_select(0, non_finished) 269 | self.select_indices = self._batch_index.view(_B_new * self.beam_size) 270 | self.alive_seq = predictions.index_select(0, non_finished) \ 271 | .view(-1, self.alive_seq.size(-1)) 272 | self.topk_scores = self.topk_scores.index_select(0, non_finished) 273 | self.topk_ids = self.topk_ids.index_select(0, non_finished) 274 | if self.alive_attn is not None: 275 | inp_seq_len = self.alive_attn.size(-1) 276 | self.alive_attn = attention.index_select(1, non_finished) \ 277 | .view(step - 1, _B_new * self.beam_size, inp_seq_len) 278 | if self._cov_pen: 279 | self._coverage = self._coverage \ 280 | .view(1, _B_old, self.beam_size, inp_seq_len) \ 281 | .index_select(1, non_finished) \ 282 | .view(1, _B_new * self.beam_size, inp_seq_len) 283 | if self._stepwise_cov_pen: 284 | self._prev_penalty = self._prev_penalty.index_select( 285 | 0, non_finished) 286 | -------------------------------------------------------------------------------- /src/abstractive/cal_rouge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 3 | """ 4 | import argparse 5 | import os 6 | import time 7 | # from multiprocess import Pool as Pool2 8 | from multiprocessing import Pool 9 | import abstractive.my_pyrouge as pyrouge 10 | import shutil 11 | import sys 12 | import codecs 13 | 14 | # from onmt.utils.logging import init_logger, logger 15 | 16 | def process(data): 17 | candidates, references, pool_id, rouge_dir = data 18 | cnt = len(candidates) 19 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 20 | tmp_dir = "rouge-tmp-{}-{}".format(current_time,pool_id) 21 | if not os.path.isdir(tmp_dir): 22 | os.mkdir(tmp_dir) 23 | os.mkdir(tmp_dir + "/candidate") 24 | os.mkdir(tmp_dir + "/reference") 25 | try: 26 | 27 | for i in range(cnt): 28 | if len(references[i]) < 1: 29 | continue 30 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 31 | encoding="utf-8") as f: 32 | f.write(candidates[i]) 33 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 34 | encoding="utf-8") as f: 35 | f.write(references[i]) 36 | r = pyrouge.Rouge155(rouge_dir=rouge_dir) 37 | r.model_dir = tmp_dir + "/reference/" 38 | r.system_dir = tmp_dir + "/candidate/" 39 | r.model_filename_pattern = 'ref.#ID#.txt' 40 | r.system_filename_pattern = r'cand.(\d+).txt' 41 | rouge_results = r.convert_and_evaluate(split_sentences=True) 42 | print(rouge_results) 43 | results_dict = r.output_to_dict(rouge_results) 44 | finally: 45 | pass 46 | if os.path.isdir(tmp_dir): 47 | shutil.rmtree(tmp_dir) 48 | return results_dict 49 | 50 | 51 | 52 | 53 | def chunks(l, n): 54 | """Yield successive n-sized chunks from l.""" 55 | for i in range(0, len(l), n): 56 | yield l[i:i + n] 57 | 58 | def test_rouge(cand, ref,num_processes, rouge_dir=None): 59 | """Calculate ROUGE scores of sequences passed as an iterator 60 | e.g. a list of str, an open file, StringIO or even sys.stdin 61 | """ 62 | candidates = [line.strip() for line in cand] 63 | references = [line.strip() for line in ref] 64 | 65 | print(len(candidates)) 66 | print(len(references)) 67 | assert len(candidates) == len(references) 68 | if num_processes == 0: 69 | return process((candidates, references, 0, rouge_dir)) 70 | else: 71 | candidates_chunks = list(chunks(candidates, int(len(candidates)/num_processes))) 72 | references_chunks = list(chunks(references, int(len(references)/num_processes))) 73 | n_pool = len(candidates_chunks) 74 | arg_lst = [] 75 | for i in range(n_pool): 76 | arg_lst.append((candidates_chunks[i],references_chunks[i],i,rouge_dir)) 77 | pool = Pool(n_pool) 78 | results = pool.map(process,arg_lst) 79 | final_results = {} 80 | for i,r in enumerate(results): 81 | for k in r: 82 | if(k not in final_results): 83 | final_results[k] = r[k]*len(candidates_chunks[i]) 84 | else: 85 | final_results[k] += r[k] * len(candidates_chunks[i]) 86 | for k in final_results: 87 | final_results[k] = final_results[k]/len(candidates) 88 | return final_results 89 | 90 | 91 | 92 | def rouge_results_to_str(results_dict): 93 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 94 | results_dict["rouge_1_f_score"] * 100, 95 | results_dict["rouge_2_f_score"] * 100, 96 | # results_dict["rouge_3_f_score"] * 100, 97 | results_dict["rouge_l_f_score"] * 100, 98 | results_dict["rouge_1_recall"] * 100, 99 | results_dict["rouge_2_recall"] * 100, 100 | # results_dict["rouge_3_f_score"] * 100, 101 | results_dict["rouge_l_recall"] * 100 102 | 103 | # ,results_dict["rouge_su*_f_score"] * 100 104 | ) 105 | 106 | 107 | if __name__ == "__main__": 108 | # init_logger('test_rouge.log') 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('-c', type=str, default="candidate.txt", 111 | help='candidate file') 112 | parser.add_argument('-r', type=str, default="reference.txt", 113 | help='reference file') 114 | parser.add_argument('-p', type=int, default=1, 115 | help='number of processes') 116 | args = parser.parse_args() 117 | print(args.c) 118 | print(args.r) 119 | print(args.p) 120 | if args.c.upper() == "STDIN": 121 | candidates = sys.stdin 122 | else: 123 | candidates = codecs.open(args.c, encoding="utf-8") 124 | references = codecs.open(args.r, encoding="utf-8") 125 | 126 | results_dict = test_rouge(candidates, references,args.p) 127 | # return 0 128 | print(time.strftime('%H:%M:%S', time.localtime()) 129 | ) 130 | print(rouge_results_to_str(results_dict)) 131 | # logger.info(rouge_results_to_str(results_dict)) 132 | -------------------------------------------------------------------------------- /src/abstractive/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data loader 3 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 4 | """ 5 | 6 | 7 | import gc 8 | import glob 9 | import random 10 | 11 | import torch 12 | 13 | from others.logging import logger 14 | 15 | def chunks(l, n): 16 | """Yield successive n-sized chunks from l.""" 17 | for i in range(0, len(l), n): 18 | yield l[i:i + n] 19 | 20 | 21 | class AbstractiveBatch(object): 22 | def _pad(self, data, height, width, pad_id): 23 | """ ? """ 24 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data] 25 | rtn_length = [len(d) for d in data] 26 | rtn_data = rtn_data + [[pad_id] * width] * (height - len(data)) 27 | rtn_length = rtn_length + [0] * (height - len(data)) 28 | 29 | return rtn_data, rtn_length 30 | 31 | def __init__(self, data=None, hier=False, pad_id=None, device=None, is_test=False, shuffle_order=False): 32 | """Create a Batch from a list of examples.""" 33 | if data is not None: 34 | self.batch_size = len(data) 35 | src = [x[0] for x in data] 36 | tgt = [x[1] for x in data] 37 | 38 | if (hier): 39 | max_nblock = max([len(e) for e in src]) 40 | max_ntoken = max([max([len(p) for p in e]) for e in src]) 41 | _src = [self._pad(e, max_nblock, max_ntoken, pad_id) for e in src] 42 | # Adding the order parameter and shuffling the order of the 43 | if shuffle_order: 44 | para_order = [] 45 | for ind, x in enumerate(data): 46 | order = list(range(max_ntoken)) 47 | random.shuffle(order) 48 | para_order.append(order) 49 | tmp = src[ind][0] 50 | for idx,i in enumerate(order): 51 | tmp[i] = src[ind[0][idx]] 52 | src[ind][0] = tmp 53 | 54 | para_order = torch.tensor(para_order).transpose(0,1) 55 | setattr(self, 'para_order', para_order.to(device)) 56 | 57 | src = torch.stack([torch.tensor(e[0]) for e in _src]) 58 | 59 | 60 | else: 61 | _src = self._pad(src, width=max([len(d) for d in src]), height=len(src), pad_id=pad_id) 62 | src = torch.tensor(_src[0]) # batch_size, src_len 63 | 64 | setattr(self, 'src', src.to(device)) 65 | 66 | _tgt = self._pad(tgt, width=max([len(d) for d in tgt]), height=len(tgt), pad_id=pad_id) 67 | tgt = torch.tensor(_tgt[0]).transpose(0, 1) 68 | setattr(self, 'tgt', tgt.to(device)) 69 | 70 | if (is_test): 71 | tgt_str = [x[2] for x in data] 72 | setattr(self, 'tgt_str', tgt_str) 73 | 74 | ## adding query part 75 | if len(data[0][3])!=0: 76 | query = [x[3] for x in data] 77 | _query = self._pad(query, width=max([len(d) for d in query]), height=len(query), pad_id=pad_id) 78 | query = torch.tensor(_query[0]) # batch_size, q_len 79 | setattr(self, 'query', query.to(device)) 80 | 81 | 82 | 83 | def __len__(self): 84 | return self.batch_size 85 | 86 | 87 | 88 | 89 | def load_dataset(args, corpus_type, shuffle): 90 | """ 91 | Dataset generator. Don't do extra stuff here, like printing, 92 | because they will be postponed to the first loading time. 93 | Args: 94 | corpus_type: 'train' or 'valid' 95 | Returns: 96 | A list of dataset, the dataset(s) are lazily loaded. 97 | """ 98 | assert corpus_type in ["train", "valid", "test"] 99 | 100 | def _lazy_dataset_loader(pt_file, corpus_type): 101 | dataset = torch.load(pt_file) 102 | logger.info('Loading %s dataset from %s, number of examples: %d' % 103 | (corpus_type, pt_file, len(dataset))) 104 | return dataset 105 | 106 | # Sort the glob output by file name (by increasing indexes). 107 | pts = sorted(glob.glob(args.data_path + '.' + corpus_type + '.[0-9]*.pt')) 108 | if pts: 109 | if (shuffle): 110 | random.shuffle(pts) 111 | 112 | for pt in pts: 113 | yield _lazy_dataset_loader(pt, corpus_type) 114 | else: 115 | # Only one inputters.*Dataset, simple! 116 | pt = args.data_path + '.' + corpus_type + '.pt' 117 | yield _lazy_dataset_loader(pt, corpus_type) 118 | 119 | 120 | class AbstractiveDataloader(object): 121 | def __init__(self, args, datasets, symbols, batch_size, 122 | device, shuffle, is_test): 123 | self.args = args 124 | self.datasets = datasets 125 | self.symbols = symbols 126 | self.batch_size = batch_size 127 | self.device = device 128 | self.shuffle = shuffle 129 | self.is_test = is_test 130 | self.cur_iter = self._next_dataset_iterator(datasets) 131 | assert self.cur_iter is not None 132 | 133 | def __iter__(self): 134 | dataset_iter = (d for d in self.datasets) 135 | while self.cur_iter is not None: 136 | for batch in self.cur_iter: 137 | yield batch 138 | self.cur_iter = self._next_dataset_iterator(dataset_iter) 139 | 140 | def _next_dataset_iterator(self, dataset_iter): 141 | try: 142 | # Drop the current dataset for decreasing memory 143 | if hasattr(self, "cur_dataset"): 144 | self.cur_dataset = None 145 | gc.collect() 146 | del self.cur_dataset 147 | gc.collect() 148 | 149 | self.cur_dataset = next(dataset_iter) 150 | except StopIteration: 151 | return None 152 | 153 | return AbstracticeIterator(args = self.args, 154 | dataset=self.cur_dataset, symbols=self.symbols, batch_size=self.batch_size, 155 | device=self.device, shuffle=self.shuffle, is_test=self.is_test) 156 | 157 | 158 | class AbstracticeIterator(object): 159 | def __init__(self, args, dataset, symbols, batch_size, device=None, is_test=False, 160 | shuffle=True): 161 | self.args = args 162 | self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset 163 | self.iterations = 0 164 | self.device = device 165 | self.shuffle = shuffle 166 | 167 | # self.secondary_sort_key = lambda x: len(x[0]) 168 | # self.secondary_sort_key = lambda x: sum([len(xi) for xi in x[0]]) 169 | # self.prime_sort_key = lambda x: len(x[1]) 170 | self.secondary_sort_key = lambda x: sum([len(xi) for xi in x[0]]) 171 | self.prime_sort_key = lambda x: len(x[1]) 172 | self._iterations_this_epoch = 0 173 | 174 | 175 | self.symbols = symbols 176 | 177 | def data(self): 178 | if self.shuffle: 179 | random.shuffle(self.dataset) 180 | xs = self.dataset 181 | return xs 182 | 183 | def preprocess(self, ex): 184 | 185 | sos_id = self.symbols['BOS'] 186 | eos_id = self.symbols['EOS'] 187 | eot_id = self.symbols['EOT'] 188 | eop_id = self.symbols['EOP'] 189 | eoq_id = self.symbols['EOQ'] 190 | src, tgt, tgt_str = ex['src'], ex['tgt'], ex['tgt_str'] 191 | 192 | """ adding query seperately""" 193 | """ **Test Pass**: N/A""" 194 | if self.args.query: 195 | query = ex['query'] 196 | # append the query to every paragraph 197 | if self.args.model_type in ['query', 'heq', 'hero']: 198 | src = [p for p in src] 199 | else: 200 | src = [ query + p for p in src] 201 | else: 202 | query = [] 203 | 204 | 205 | if (not self.args.hier): 206 | src = sum([p + [eop_id] for p in src], [])[:-1][:self.args.trunc_src_ntoken] + [ 207 | eos_id] 208 | 209 | return src, tgt, tgt_str, query 210 | 211 | src = [p[:self.args.trunc_src_ntoken] for p in src] 212 | 213 | return src[:self.args.trunc_src_nblock], tgt, tgt_str, query 214 | 215 | def simple_batch_size_fn(self, new, count): 216 | src, tgt = new[0], new[1] 217 | 218 | global max_src_in_batch, max_tgt_in_batch 219 | if count == 1: 220 | max_src_in_batch = 0 221 | if (self.args.hier): 222 | max_src_in_batch = max(max_src_in_batch, sum([len(p) for p in src])) 223 | else: 224 | max_src_in_batch = max(max_src_in_batch, len(src)) 225 | src_elements = count * max_src_in_batch 226 | return src_elements 227 | 228 | def get_batch(self, data, batch_size): 229 | """Yield elements from data in chunks of batch_size.""" 230 | minibatch, size_so_far = [], 0 231 | for ex in data: 232 | minibatch.append(ex) 233 | size_so_far = self.simple_batch_size_fn(ex, len(minibatch)) 234 | if size_so_far == batch_size: 235 | yield minibatch 236 | minibatch, size_so_far = [], 0 237 | elif size_so_far > batch_size: 238 | yield minibatch[:-1] 239 | minibatch, size_so_far = minibatch[-1:], self.simple_batch_size_fn(ex, 1) 240 | if minibatch: 241 | yield minibatch 242 | 243 | def batch_buffer(self, data, batch_size): 244 | minibatch, size_so_far = [], 0 245 | for ex in data: 246 | ex = self.preprocess(ex) 247 | minibatch.append(ex) 248 | size_so_far = self.simple_batch_size_fn(ex, len(minibatch)) 249 | if size_so_far == batch_size: 250 | yield minibatch 251 | minibatch, size_so_far = [], 0 252 | elif size_so_far > batch_size: 253 | yield minibatch[:-1] 254 | minibatch, size_so_far = minibatch[-1:], self.simple_batch_size_fn(ex, 1) 255 | if minibatch: 256 | yield minibatch 257 | 258 | def create_batches(self): 259 | """ Create batches """ 260 | data = self.data() 261 | for buffer in self.batch_buffer(data, self.batch_size * 100): 262 | if (self.args.mode != 'train'): 263 | p_batch = self.get_batch( 264 | sorted(sorted(buffer, key=self.prime_sort_key), key=self.secondary_sort_key), 265 | self.batch_size) 266 | else: 267 | p_batch = self.get_batch( 268 | sorted(sorted(buffer, key=self.secondary_sort_key), key=self.prime_sort_key), 269 | self.batch_size) 270 | 271 | p_batch = list(p_batch) 272 | 273 | if (self.shuffle): 274 | random.shuffle(p_batch) 275 | for b in p_batch: 276 | if(len(b)==0): 277 | continue 278 | yield b 279 | 280 | def __iter__(self): 281 | 282 | while True: 283 | self.batches = self.create_batches() 284 | for idx, minibatch in enumerate(self.batches): 285 | if self._iterations_this_epoch > idx: 286 | continue 287 | self.iterations += 1 288 | self._iterations_this_epoch += 1 289 | batch = AbstractiveBatch(minibatch, self.args.hier, self.symbols['PAD'], self.device, self.is_test, 290 | shuffle_order=False) 291 | 292 | yield batch 293 | return 294 | -------------------------------------------------------------------------------- /src/abstractive/decode_strategy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Majority of code borrowed https://github.com/OpenNMT/OpenNMT-py 3 | """ 4 | 5 | import torch 6 | 7 | 8 | class DecodeStrategy(object): 9 | """Base class for generation strategies. 10 | 11 | Args: 12 | pad (int): Magic integer in output vocab. 13 | bos (int): Magic integer in output vocab. 14 | eos (int): Magic integer in output vocab. 15 | batch_size (int): Current batch size. 16 | device (torch.device or str): Device for memory bank (encoder). 17 | parallel_paths (int): Decoding strategies like beam search 18 | use parallel paths. Each batch is repeated ``parallel_paths`` 19 | times in relevant state tensors. 20 | min_length (int): Shortest acceptable generation, not counting 21 | begin-of-sentence or end-of-sentence. 22 | max_length (int): Longest acceptable sequence, not counting 23 | begin-of-sentence (presumably there has been no EOS 24 | yet if max_length is used as a cutoff). 25 | block_ngram_repeat (int): Block beams where 26 | ``block_ngram_repeat``-grams repeat. 27 | exclusion_tokens (set[int]): If a gram contains any of these 28 | tokens, it may repeat. 29 | return_attention (bool): Whether to work with attention too. If this 30 | is true, it is assumed that the decoder is attentional. 31 | 32 | Attributes: 33 | pad (int): See above. 34 | bos (int): See above. 35 | eos (int): See above. 36 | predictions (list[list[LongTensor]]): For each batch, holds a 37 | list of beam prediction sequences. 38 | scores (list[list[FloatTensor]]): For each batch, holds a 39 | list of scores. 40 | attention (list[list[FloatTensor or list[]]]): For each 41 | batch, holds a list of attention sequence tensors 42 | (or empty lists) having shape ``(step, inp_seq_len)`` where 43 | ``inp_seq_len`` is the length of the sample (not the max 44 | length of all inp seqs). 45 | alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``. 46 | This sequence grows in the ``step`` axis on each call to 47 | :func:`advance()`. 48 | is_finished (ByteTensor or NoneType): Shape 49 | ``(B, parallel_paths)``. Initialized to ``None``. 50 | alive_attn (FloatTensor or NoneType): If tensor, shape is 51 | ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len`` 52 | is the (max) length of the input sequence. 53 | min_length (int): See above. 54 | max_length (int): See above. 55 | block_ngram_repeat (int): See above. 56 | exclusion_tokens (set[int]): See above. 57 | return_attention (bool): See above. 58 | done (bool): See above. 59 | """ 60 | 61 | def __init__(self, pad, bos, eos, batch_size, device, parallel_paths, 62 | min_length, block_ngram_repeat, exclusion_tokens, 63 | return_attention, max_length): 64 | 65 | # magic indices 66 | self.pad = pad 67 | self.bos = bos 68 | self.eos = eos 69 | 70 | # result caching 71 | self.predictions = [[] for _ in range(batch_size)] 72 | self.scores = [[] for _ in range(batch_size)] 73 | self.attention = [[] for _ in range(batch_size)] 74 | 75 | self.alive_seq = torch.full( 76 | [batch_size * parallel_paths, 1], self.bos, 77 | dtype=torch.long, device=device) 78 | self.is_finished = torch.zeros( 79 | [batch_size, parallel_paths], 80 | dtype=torch.uint8, device=device) 81 | self.alive_attn = None 82 | 83 | self.min_length = min_length 84 | self.max_length = max_length 85 | self.block_ngram_repeat = block_ngram_repeat 86 | self.exclusion_tokens = exclusion_tokens 87 | self.return_attention = return_attention 88 | 89 | self.done = False 90 | 91 | def __len__(self): 92 | return self.alive_seq.shape[1] 93 | 94 | def ensure_min_length(self, log_probs): 95 | if len(self) <= self.min_length: 96 | log_probs[:, self.eos] = -1e20 97 | 98 | def ensure_max_length(self): 99 | # add one to account for BOS. Don't account for EOS because hitting 100 | # this implies it hasn't been found. 101 | if len(self) == self.max_length + 1: 102 | self.is_finished.fill_(1) 103 | 104 | def block_ngram_repeats(self, log_probs): 105 | cur_len = len(self) 106 | if self.block_ngram_repeat > 0 and cur_len > 1: 107 | for path_idx in range(self.alive_seq.shape[0]): 108 | # skip BOS 109 | hyp = self.alive_seq[path_idx, 1:] 110 | ngrams = set() 111 | fail = False 112 | gram = [] 113 | for i in range(cur_len - 1): 114 | # Last n tokens, n = block_ngram_repeat 115 | gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:] 116 | # skip the blocking if any token in gram is excluded 117 | if set(gram) & self.exclusion_tokens: 118 | continue 119 | if tuple(gram) in ngrams: 120 | fail = True 121 | ngrams.add(tuple(gram)) 122 | if fail: 123 | log_probs[path_idx] = -10e20 124 | 125 | def advance(self, log_probs, attn): 126 | """DecodeStrategy subclasses should override :func:`advance()`. 127 | 128 | Advance is used to update ``self.alive_seq``, ``self.is_finished``, 129 | and, when appropriate, ``self.alive_attn``. 130 | """ 131 | 132 | raise NotImplementedError() 133 | 134 | def update_finished(self): 135 | """DecodeStrategy subclasses should override :func:`update_finished()`. 136 | 137 | ``update_finished`` is used to update ``self.predictions``, 138 | ``self.scores``, and other "output" attributes. 139 | """ 140 | 141 | raise NotImplementedError() 142 | -------------------------------------------------------------------------------- /src/abstractive/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file handles the details of the loss function during training. 3 | 4 | This includes: LossComputeBase and the standard NMTLossCompute, and 5 | sharded loss compute stuff. 6 | 7 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 8 | and https://github.com/OpenNMT/OpenNMT-py 9 | 10 | """ 11 | from __future__ import division 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from others.statistics import Statistics 17 | 18 | 19 | def build_loss_compute(generator,symbols, vocab_size, device, train=True,label_smoothing=0.0): 20 | compute = NMTLossCompute( 21 | generator, symbols, vocab_size, 22 | label_smoothing=label_smoothing if train else 0.0) 23 | compute.to(device) 24 | 25 | return compute 26 | 27 | 28 | 29 | class LossComputeBase(nn.Module): 30 | """ 31 | Class for managing efficient loss computation. Handles 32 | sharding next step predictions and accumulating mutiple 33 | loss computations 34 | 35 | 36 | Users can implement their own loss computation strategy by making 37 | subclass of this one. Users need to implement the _compute_loss() 38 | and make_shard_state() methods. 39 | 40 | Args: 41 | generator (:obj:`nn.Module`) : 42 | module that maps the output of the decoder to a 43 | distribution over the target vocabulary. 44 | tgt_vocab (:obj:`Vocab`) : 45 | torchtext vocab object representing the target output 46 | normalzation (str): normalize by "sents" or "tokens" 47 | """ 48 | 49 | def __init__(self, generator, pad_id): 50 | super(LossComputeBase, self).__init__() 51 | self.generator = generator 52 | self.padding_idx = pad_id 53 | 54 | 55 | 56 | def _make_shard_state(self, batch, output, attns=None): 57 | """ 58 | Make shard state dictionary for shards() to return iterable 59 | shards for efficient loss computation. Subclass must define 60 | this method to match its own _compute_loss() interface. 61 | Args: 62 | batch: the current batch. 63 | output: the predict output from the model. 64 | range_: the range of examples for computing, the whole 65 | batch or a trunc of it? 66 | attns: the attns dictionary returned from the model. 67 | """ 68 | return NotImplementedError 69 | 70 | def _compute_loss(self, batch, output, target, **kwargs): 71 | """ 72 | Compute the loss. Subclass must define this method. 73 | 74 | Args: 75 | 76 | batch: the current batch. 77 | output: the predict output from the model. 78 | target: the validate target to compare output with. 79 | **kwargs(optional): additional info for computing loss. 80 | """ 81 | return NotImplementedError 82 | 83 | def monolithic_compute_loss(self, batch, output): 84 | """ 85 | Compute the forward loss for the batch. 86 | 87 | Args: 88 | batch (batch): batch of labeled examples 89 | output (:obj:`FloatTensor`): 90 | output of decoder model `[tgt_len x batch x hidden]` 91 | attns (dict of :obj:`FloatTensor`) : 92 | dictionary of attention distributions 93 | `[tgt_len x batch x src_len]` 94 | Returns: 95 | :obj:`onmt.utils.Statistics`: loss statistics 96 | """ 97 | shard_state = self._make_shard_state(batch, output) 98 | _, batch_stats = self._compute_loss(batch, **shard_state) 99 | 100 | return batch_stats 101 | 102 | def sharded_compute_loss(self, batch, output, 103 | shard_size, 104 | normalization): 105 | """Compute the forward loss and backpropagate. Computation is done 106 | with shards and optionally truncation for memory efficiency. 107 | 108 | Also supports truncated BPTT for long sequences by taking a 109 | range in the decoder output sequence to back propagate in. 110 | Range is from `(cur_trunc, cur_trunc + trunc_size)`. 111 | 112 | Note sharding is an exact efficiency trick to relieve memory 113 | required for the generation buffers. Truncation is an 114 | approximate efficiency trick to relieve the memory required 115 | in the RNN buffers. 116 | 117 | Args: 118 | batch (batch) : batch of labeled examples 119 | output (:obj:`FloatTensor`) : 120 | output of decoder model `[tgt_len x batch x hidden]` 121 | c_attn (dict) : dictionary of attention related losses "discourse" and "coverage" 122 | `[tgt_len x batch x src_len]` 123 | cur_trunc (int) : starting position of truncation window 124 | trunc_size (int) : length of truncation window 125 | shard_size (int) : maximum number of examples in a shard 126 | normalization (int) : Loss is divided by this number 127 | 128 | Returns: 129 | :obj:`onmt.utils.Statistics`: validation loss statistics 130 | 131 | """ 132 | batch_stats = Statistics() 133 | shard_state = self._make_shard_state(batch, output) 134 | for shard in shards(shard_state, shard_size): 135 | loss, stats = self._compute_loss(batch, **shard) 136 | loss.div(float(normalization)).backward() 137 | batch_stats.update(stats) 138 | 139 | return batch_stats 140 | 141 | def _stats(self, loss, scores, target): 142 | """ 143 | Args: 144 | loss (:obj:`FloatTensor`): the loss computed by the loss criterion. 145 | scores (:obj:`FloatTensor`): a score for each possible output 146 | target (:obj:`FloatTensor`): true targets 147 | 148 | Returns: 149 | :obj:`onmt.utils.Statistics` : statistics for this batch. 150 | """ 151 | pred = scores.max(1)[1] 152 | non_padding = target.ne(self.padding_idx) 153 | num_correct = pred.eq(target) \ 154 | .masked_select(non_padding) \ 155 | .sum() \ 156 | .item() 157 | num_non_padding = non_padding.sum().item() 158 | return Statistics(loss.item(), num_non_padding, num_correct) 159 | 160 | def _bottle(self, _v): 161 | return _v.view(-1, _v.size(2)) 162 | 163 | def _unbottle(self, _v, batch_size): 164 | return _v.view(-1, batch_size, _v.size(1)) 165 | 166 | 167 | class LabelSmoothingLoss(nn.Module): 168 | """ 169 | With label smoothing, 170 | KL-divergence between q_{smoothed ground truth prob.}(w) 171 | and p_{prob. computed by model}(w) is minimized. 172 | """ 173 | def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): 174 | assert 0.0 < label_smoothing <= 1.0 175 | self.padding_idx = ignore_index 176 | super(LabelSmoothingLoss, self).__init__() 177 | 178 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 179 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 180 | one_hot[self.padding_idx] = 0 181 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 182 | 183 | self.confidence = 1.0 - label_smoothing 184 | 185 | def forward(self, output, target): 186 | """ 187 | output (FloatTensor): batch_size x n_classes 188 | target (LongTensor): batch_size 189 | """ 190 | model_prob = self.one_hot.repeat(target.size(0), 1) 191 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 192 | model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) 193 | 194 | return F.kl_div(output, model_prob, reduction='sum') 195 | 196 | 197 | class PTLossCompute(LossComputeBase): 198 | """ 199 | Pointer Networks Loss Computation 200 | """ 201 | def __init__(self): 202 | super(PTLossCompute, self).__init__(generator=None, pad_id=-100) 203 | self.criterion = nn.NLLLoss() 204 | 205 | def _make_shard_state(self, batch, output): 206 | return { 207 | "output": output, 208 | "target": batch.para_order 209 | } 210 | 211 | def _compute_loss(self, batch, output, target): 212 | bottled_output = self._bottle(output) 213 | scores = bottled_output 214 | gtruth =target.contiguous().view(-1) 215 | loss = self.criterion(scores, gtruth) 216 | return loss, stats 217 | 218 | 219 | 220 | class NMTLossCompute(LossComputeBase): 221 | """ 222 | Standard NMT Loss Computation. 223 | """ 224 | 225 | def __init__(self, generator, symbols, vocab_size, 226 | label_smoothing=0.0): 227 | super(NMTLossCompute, self).__init__(generator, symbols['PAD']) 228 | self.sparse = not isinstance(generator[1], nn.LogSoftmax) 229 | if label_smoothing > 0: 230 | self.criterion = LabelSmoothingLoss( 231 | label_smoothing, vocab_size, ignore_index=self.padding_idx 232 | ) 233 | else: 234 | self.criterion = nn.NLLLoss( 235 | ignore_index=self.padding_idx, reduction='sum' 236 | ) 237 | 238 | 239 | def _make_shard_state(self, batch, output): 240 | 241 | return { 242 | "output": output, 243 | "target": batch.tgt[ 1: ] 244 | } 245 | 246 | 247 | def _compute_loss(self, batch, output, target): 248 | bottled_output = self._bottle(output) 249 | if self.sparse: 250 | # for sparsemax loss, the loss function operates on the raw output 251 | # vector, not a probability vector. Hence it's only necessary to 252 | # apply the first part of the generator here. 253 | scores = self.generator[0](bottled_output) 254 | else: 255 | scores = self.generator(bottled_output) 256 | gtruth =target.contiguous().view(-1) 257 | 258 | loss = self.criterion(scores, gtruth) 259 | 260 | stats = self._stats(loss.clone(), scores, gtruth) 261 | 262 | return loss, stats 263 | 264 | 265 | def filter_shard_state(state, shard_size=None): 266 | """ ? """ 267 | for k, v in state.items(): 268 | if shard_size is None: 269 | yield k, v 270 | 271 | if v is not None: 272 | v_split = [] 273 | if isinstance(v, torch.Tensor): 274 | for v_chunk in torch.split(v, shard_size): 275 | v_chunk = v_chunk.data.clone() 276 | v_chunk.requires_grad = v.requires_grad 277 | v_split.append(v_chunk) 278 | yield k, (v, v_split) 279 | 280 | 281 | def shards(state, shard_size, eval_only=False): 282 | """ 283 | Args: 284 | state: A dictionary which corresponds to the output of 285 | *LossCompute._make_shard_state(). The values for 286 | those keys are Tensor-like or None. 287 | shard_size: The maximum size of the shards yielded by the model. 288 | eval_only: If True, only yield the state, nothing else. 289 | Otherwise, yield shards. 290 | 291 | Yields: 292 | Each yielded shard is a dict. 293 | 294 | Side effect: 295 | After the last shard, this function does back-propagation. 296 | """ 297 | if eval_only: 298 | yield filter_shard_state(state) 299 | else: 300 | # non_none: the subdict of the state dictionary where the values 301 | # are not None. 302 | non_none = dict(filter_shard_state(state, shard_size)) 303 | 304 | # Now, the iteration: 305 | # state is a dictionary of sequences of tensor-like but we 306 | # want a sequence of dictionaries of tensors. 307 | # First, unzip the dictionary into a sequence of keys and a 308 | # sequence of tensor-like sequences. 309 | keys, values = zip(*((k, [v_chunk for v_chunk in v_split]) 310 | for k, (_, v_split) in non_none.items())) 311 | 312 | # Now, yield a dictionary for each shard. The keys are always 313 | # the same. values is a sequence of length #keys where each 314 | # element is a sequence of length #shards. We want to iterate 315 | # over the shards, not over the keys: therefore, the values need 316 | # to be re-zipped by shard and then each shard can be paired 317 | # with the keys. 318 | for shard_tensors in zip(*values): 319 | yield dict(zip(keys, shard_tensors)) 320 | 321 | # Assumed backprop'd 322 | variables = [] 323 | 324 | for k, (v, v_split) in non_none.items(): 325 | if isinstance(v, torch.Tensor) and state[k].requires_grad: 326 | variables.extend(zip(torch.split(state[k], shard_size), 327 | [v_chunk.grad for v_chunk in v_split])) 328 | 329 | inputs, grads = zip(*variables) 330 | torch.autograd.backward(inputs, grads) 331 | 332 | -------------------------------------------------------------------------------- /src/abstractive/model_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is for models creation, which consults options 3 | and creates each encoder and decoder accordingly. 4 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 5 | """ 6 | 7 | from abstractive.optimizer import Optimizer 8 | from abstractive.transformer_encoder import TransformerEncoder, TransformerInterEncoder, \ 9 | TransformerEncoderHE, TransformerEncoderOrder,\ 10 | TransformerEncoderQuery, TransformerEncoderHEQ, \ 11 | TransformerEncoderHEO, TransformerEncoderHERO 12 | from abstractive.transformer_decoder import TransformerDecoder 13 | 14 | 15 | 16 | import torch.nn as nn 17 | from torch.nn.init import xavier_uniform_ 18 | import torch 19 | 20 | def build_optim(args, model, checkpoint): 21 | """ Build optimizer """ 22 | optim = Optimizer( 23 | args.optim, args.lr, args.max_grad_norm, 24 | beta1=args.beta1, beta2=args.beta2, 25 | decay_method=args.decay_method, 26 | warmup_steps=args.warmup_steps, model_size=args.enc_hidden_size) 27 | 28 | optim.set_parameters(list(model.named_parameters())) 29 | 30 | if args.train_from != '' and not args.fine_tune: 31 | optim.optimizer.load_state_dict(checkpoint['optim']) 32 | if args.visible_gpus != '-1': 33 | for state in optim.optimizer.state.values(): 34 | for k, v in state.items(): 35 | if torch.is_tensor(v): 36 | state[k] = v.cuda() 37 | 38 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 39 | raise RuntimeError( 40 | "Error: loaded Adam optimizer from existing model" + 41 | " but optimizer state is empty") 42 | 43 | 44 | return optim 45 | 46 | 47 | def get_generator(dec_hidden_size, vocab_size, device): 48 | gen_func = nn.LogSoftmax(dim=-1) 49 | generator = nn.Sequential( 50 | nn.Linear(dec_hidden_size, vocab_size), 51 | gen_func 52 | ) 53 | generator.to(device) 54 | 55 | return generator 56 | 57 | 58 | class Summarizer(nn.Module): 59 | def __init__(self,args, word_padding_idx, vocab_size, device, checkpoint=None): 60 | self.args = args 61 | super(Summarizer, self).__init__() 62 | # self.spm = spm 63 | self.vocab_size = vocab_size 64 | self.device = device 65 | # src_dict = fields["src"].vocab 66 | # tgt_dict = fields["tgt"].vocab 67 | 68 | src_embeddings = torch.nn.Embedding(self.vocab_size, self.args.emb_size, padding_idx=word_padding_idx) 69 | tgt_embeddings = torch.nn.Embedding(self.vocab_size, self.args.emb_size, padding_idx=word_padding_idx) 70 | 71 | if (self.args.share_embeddings): 72 | tgt_embeddings.weight = src_embeddings.weight 73 | 74 | if self.args.model_type == 'hier': 75 | if (self.args.hier): 76 | self.encoder = TransformerInterEncoder(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 77 | self.args.ff_size, self.args.enc_dropout, src_embeddings, inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, device=device) 78 | else: 79 | self.encoder = TransformerEncoder(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 80 | self.args.ff_size, 81 | self.args.enc_dropout, src_embeddings) 82 | 83 | 84 | elif self.args.model_type == 'he': 85 | self.encoder = TransformerEncoderHE(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 86 | self.args.ff_size, self.args.enc_dropout, src_embeddings, 87 | inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, 88 | device=device) 89 | 90 | elif self.args.model_type == 'order': 91 | self.encoder = TransformerEncoderOrder(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 92 | self.args.ff_size, self.args.enc_dropout, src_embeddings, 93 | inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, 94 | device=device) 95 | 96 | 97 | elif self.args.model_type == 'query': 98 | self.encoder = TransformerEncoderQuery(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 99 | self.args.ff_size, self.args.enc_dropout, src_embeddings, 100 | inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, 101 | num_query_layers=self.args.query_layers, device=device) 102 | 103 | 104 | elif self.args.model_type == 'heq': 105 | self.encoder = TransformerEncoderHEQ(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 106 | self.args.ff_size, self.args.enc_dropout, src_embeddings, 107 | inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, 108 | device=device) 109 | 110 | elif self.args.model_type == 'heo': 111 | self.encoder = TransformerEncoderHEO(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 112 | self.args.ff_size, self.args.enc_dropout, src_embeddings, 113 | inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, 114 | device=device) 115 | 116 | 117 | elif self.args.model_type == 'hero': 118 | self.encoder = TransformerEncoderHERO(self.args.enc_layers, self.args.enc_hidden_size, self.args.heads, 119 | self.args.ff_size, self.args.enc_dropout, src_embeddings, 120 | inter_layers=self.args.inter_layers, inter_heads= self.args.inter_heads, 121 | device=device) 122 | 123 | 124 | 125 | 126 | self.decoder = TransformerDecoder( 127 | self.args.dec_layers, 128 | self.args.dec_hidden_size, heads=self.args.heads, 129 | d_ff=self.args.ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings, device=device) 130 | 131 | 132 | 133 | self.generator = get_generator(self.args.dec_hidden_size, self.vocab_size, device) 134 | if self.args.share_decoder_embeddings: 135 | self.generator[0].weight = self.decoder.embeddings.weight 136 | 137 | if checkpoint is not None: 138 | # checkpoint['model'] 139 | keys = list(checkpoint['model'].keys()) 140 | for k in keys: 141 | if ('a_2' in k): 142 | checkpoint['model'][k.replace('a_2', 'weight')] = checkpoint['model'][k] 143 | del (checkpoint['model'][k]) 144 | if ('b_2' in k): 145 | checkpoint['model'][k.replace('b_2', 'bias')] = checkpoint['model'][k] 146 | del (checkpoint['model'][k]) 147 | self.load_state_dict(checkpoint['model'], strict=True) 148 | else: 149 | for p in self.parameters(): 150 | if p.dim() > 1: 151 | xavier_uniform_(p) 152 | 153 | 154 | 155 | self.to(device) 156 | 157 | def forward(self, src, tgt, query=None, para_order=None): 158 | tgt = tgt[:-1] 159 | batch_size, n_blocks, num_tokens = src.shape 160 | 161 | if self.args.model_type in ['query', 'heq', 'hero']: 162 | src_features, mask_hier = self.encoder(src, query) 163 | else: 164 | src_features, mask_hier = self.encoder(src) 165 | dec_state = self.decoder.init_decoder_state(src, src_features) 166 | if (self.args.hier): 167 | decoder_outputs = self.decoder(tgt, src_features.view(n_blocks, num_tokens, batch_size, -1).contiguous(), dec_state, memory_masks=mask_hier) 168 | else: 169 | decoder_outputs = self.decoder(tgt, src_features.view(n_blocks, num_tokens, batch_size, -1).contiguous(), dec_state) 170 | 171 | return decoder_outputs 172 | 173 | 174 | -------------------------------------------------------------------------------- /src/abstractive/my_pyrouge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 3 | """ 4 | 5 | from __future__ import print_function, unicode_literals, division 6 | 7 | import os 8 | import re 9 | import codecs 10 | import platform 11 | 12 | from subprocess import check_output 13 | from tempfile import mkdtemp 14 | from functools import partial 15 | 16 | try: 17 | from configparser import ConfigParser 18 | except ImportError: 19 | from ConfigParser import ConfigParser 20 | 21 | from pyrouge.utils import log 22 | from pyrouge.utils.file_utils import verify_dir 23 | 24 | 25 | class DirectoryProcessor: 26 | 27 | @staticmethod 28 | def process(input_dir, output_dir, function): 29 | """ 30 | Apply function to all files in input_dir and save the resulting ouput 31 | files in output_dir. 32 | 33 | """ 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir) 36 | logger = log.get_global_console_logger() 37 | logger.info("Processing files in {}.".format(input_dir)) 38 | input_file_names = os.listdir(input_dir) 39 | for input_file_name in input_file_names: 40 | input_file = os.path.join(input_dir, input_file_name) 41 | with codecs.open(input_file, "r", encoding="UTF-8") as f: 42 | input_string = f.read() 43 | output_string = function(input_string) 44 | output_file = os.path.join(output_dir, input_file_name) 45 | with codecs.open(output_file, "w", encoding="UTF-8") as f: 46 | f.write(output_string.lower()) 47 | logger.info("Saved processed files to {}.".format(output_dir)) 48 | 49 | 50 | class Rouge155(object): 51 | """ 52 | This is a wrapper for the ROUGE 1.5.5 summary evaluation package. 53 | This class is designed to simplify the evaluation process by: 54 | 55 | 1) Converting summaries into a format ROUGE understands. 56 | 2) Generating the ROUGE configuration file automatically based 57 | on filename patterns. 58 | 59 | This class can be used within Python like this: 60 | 61 | rouge = Rouge155() 62 | rouge.system_dir = 'test/systems' 63 | rouge.model_dir = 'test/models' 64 | 65 | # The system filename pattern should contain one group that 66 | # matches the document ID. 67 | rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' 68 | 69 | # The model filename pattern has '#ID#' as a placeholder for the 70 | # document ID. If there are multiple model summaries, pyrouge 71 | # will use the provided regex to automatically match them with 72 | # the corresponding system summary. Here, [A-Z] matches 73 | # multiple model summaries for a given #ID#. 74 | rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 75 | 76 | rouge_output = rouge.evaluate() 77 | print(rouge_output) 78 | output_dict = rouge.output_to_dict(rouge_ouput) 79 | print(output_dict) 80 | -> {'rouge_1_f_score': 0.95652, 81 | 'rouge_1_f_score_cb': 0.95652, 82 | 'rouge_1_f_score_ce': 0.95652, 83 | 'rouge_1_precision': 0.95652, 84 | [...] 85 | 86 | 87 | To evaluate multiple systems: 88 | 89 | rouge = Rouge155() 90 | rouge.system_dir = '/PATH/TO/systems' 91 | rouge.model_dir = 'PATH/TO/models' 92 | for system_id in ['id1', 'id2', 'id3']: 93 | rouge.system_filename_pattern = \ 94 | 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) 95 | rouge.model_filename_pattern = \ 96 | 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 97 | rouge_output = rouge.evaluate(system_id) 98 | print(rouge_output) 99 | 100 | """ 101 | 102 | def __init__(self, rouge_dir=None, rouge_args=None): 103 | """ 104 | Create a Rouge155 object. 105 | 106 | rouge_dir: Directory containing Rouge-1.5.5.pl 107 | rouge_args: Arguments to pass through to ROUGE if you 108 | don't want to use the default pyrouge 109 | arguments. 110 | 111 | """ 112 | self.log = log.get_global_console_logger() 113 | self.__set_dir_properties() 114 | self._config_file = None 115 | self._settings_file = self.__get_config_path() 116 | self.__set_rouge_dir(rouge_dir) 117 | self.args = self.__clean_rouge_args(rouge_args) 118 | self._system_filename_pattern = None 119 | self._model_filename_pattern = None 120 | 121 | def save_home_dir(self): 122 | config = ConfigParser() 123 | section = 'pyrouge settings' 124 | config.add_section(section) 125 | config.set(section, 'home_dir', self._home_dir) 126 | with open(self._settings_file, 'w') as f: 127 | config.write(f) 128 | self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) 129 | 130 | @property 131 | def settings_file(self): 132 | """ 133 | Path of the setttings file, which stores the ROUGE home dir. 134 | 135 | """ 136 | return self._settings_file 137 | 138 | @property 139 | def bin_path(self): 140 | """ 141 | The full path of the ROUGE binary (although it's technically 142 | a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl 143 | 144 | """ 145 | if self._bin_path is None: 146 | raise Exception( 147 | "ROUGE path not set. Please set the ROUGE home directory " 148 | "and ensure that ROUGE-1.5.5.pl exists in it.") 149 | return self._bin_path 150 | 151 | @property 152 | def system_filename_pattern(self): 153 | """ 154 | The regular expression pattern for matching system summary 155 | filenames. The regex string. 156 | 157 | E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system 158 | filenames in the SPL2003/system folder of the ROUGE SPL example 159 | in the "sample-test" folder. 160 | 161 | Currently, there is no support for multiple systems. 162 | 163 | """ 164 | return self._system_filename_pattern 165 | 166 | @system_filename_pattern.setter 167 | def system_filename_pattern(self, pattern): 168 | self._system_filename_pattern = pattern 169 | 170 | @property 171 | def model_filename_pattern(self): 172 | """ 173 | The regular expression pattern for matching model summary 174 | filenames. The pattern needs to contain the string "#ID#", 175 | which is a placeholder for the document ID. 176 | 177 | E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model 178 | filenames in the SPL2003/system folder of the ROUGE SPL 179 | example in the "sample-test" folder. 180 | 181 | "#ID#" is a placeholder for the document ID which has been 182 | matched by the "(\d+)" part of the system filename pattern. 183 | The different model summaries for a given document ID are 184 | matched by the "[A-Z]" part. 185 | 186 | """ 187 | return self._model_filename_pattern 188 | 189 | @model_filename_pattern.setter 190 | def model_filename_pattern(self, pattern): 191 | self._model_filename_pattern = pattern 192 | 193 | @property 194 | def config_file(self): 195 | return self._config_file 196 | 197 | @config_file.setter 198 | def config_file(self, path): 199 | config_dir, _ = os.path.split(path) 200 | verify_dir(config_dir, "configuration file") 201 | self._config_file = path 202 | 203 | def split_sentences(self): 204 | """ 205 | ROUGE requires texts split into sentences. In case the texts 206 | are not already split, this method can be used. 207 | 208 | """ 209 | from pyrouge.utils.sentence_splitter import PunktSentenceSplitter 210 | self.log.info("Splitting sentences.") 211 | ss = PunktSentenceSplitter() 212 | sent_split_to_string = lambda s: "\n".join(ss.split(s)) 213 | process_func = partial( 214 | DirectoryProcessor.process, function=sent_split_to_string) 215 | self.__process_summaries(process_func) 216 | 217 | @staticmethod 218 | def convert_summaries_to_rouge_format(input_dir, output_dir): 219 | """ 220 | Convert all files in input_dir into a format ROUGE understands 221 | and saves the files to output_dir. The input files are assumed 222 | to be plain text with one sentence per line. 223 | 224 | input_dir: Path of directory containing the input files. 225 | output_dir: Path of directory in which the converted files 226 | will be saved. 227 | 228 | """ 229 | DirectoryProcessor.process( 230 | input_dir, output_dir, Rouge155.convert_text_to_rouge_format) 231 | 232 | @staticmethod 233 | def convert_text_to_rouge_format(text, title="dummy title"): 234 | """ 235 | Convert a text to a format ROUGE understands. The text is 236 | assumed to contain one sentence per line. 237 | 238 | text: The text to convert, containg one sentence per line. 239 | title: Optional title for the text. The title will appear 240 | in the converted file, but doesn't seem to have 241 | any other relevance. 242 | 243 | Returns: The converted text as string. 244 | 245 | """ 246 | sentences = text.split("\n") 247 | sent_elems = [ 248 | "[{i}] " 249 | "{text}".format(i=i, text=sent) 250 | for i, sent in enumerate(sentences, start=1)] 251 | html = """ 252 | 253 | {title} 254 | 255 | 256 | {elems} 257 | 258 | """.format(title=title, elems="\n".join(sent_elems)) 259 | 260 | return html 261 | 262 | @staticmethod 263 | def write_config_static(system_dir, system_filename_pattern, 264 | model_dir, model_filename_pattern, 265 | config_file_path, system_id=None): 266 | """ 267 | Write the ROUGE configuration file, which is basically a list 268 | of system summary files and their corresponding model summary 269 | files. 270 | 271 | pyrouge uses regular expressions to automatically find the 272 | matching model summary files for a given system summary file 273 | (cf. docstrings for system_filename_pattern and 274 | model_filename_pattern). 275 | 276 | system_dir: Path of directory containing 277 | system summaries. 278 | system_filename_pattern: Regex string for matching 279 | system summary filenames. 280 | model_dir: Path of directory containing 281 | model summaries. 282 | model_filename_pattern: Regex string for matching model 283 | summary filenames. 284 | config_file_path: Path of the configuration file. 285 | system_id: Optional system ID string which 286 | will appear in the ROUGE output. 287 | 288 | """ 289 | system_filenames = [f for f in os.listdir(system_dir)] 290 | system_models_tuples = [] 291 | 292 | system_filename_pattern = re.compile(system_filename_pattern) 293 | for system_filename in sorted(system_filenames): 294 | match = system_filename_pattern.match(system_filename) 295 | if match: 296 | id = match.groups(0)[0] 297 | model_filenames = [model_filename_pattern.replace('#ID#',id)] 298 | # model_filenames = Rouge155.__get_model_filenames_for_id( 299 | # id, model_dir, model_filename_pattern) 300 | system_models_tuples.append( 301 | (system_filename, sorted(model_filenames))) 302 | if not system_models_tuples: 303 | raise Exception( 304 | "Did not find any files matching the pattern {} " 305 | "in the system summaries directory {}.".format( 306 | system_filename_pattern.pattern, system_dir)) 307 | 308 | with codecs.open(config_file_path, 'w', encoding='utf-8') as f: 309 | f.write('') 310 | for task_id, (system_filename, model_filenames) in enumerate( 311 | system_models_tuples, start=1): 312 | 313 | eval_string = Rouge155.__get_eval_string( 314 | task_id, system_id, 315 | system_dir, system_filename, 316 | model_dir, model_filenames) 317 | f.write(eval_string) 318 | f.write("") 319 | 320 | def write_config(self, config_file_path=None, system_id=None): 321 | """ 322 | Write the ROUGE configuration file, which is basically a list 323 | of system summary files and their matching model summary files. 324 | 325 | This is a non-static version of write_config_file_static(). 326 | 327 | config_file_path: Path of the configuration file. 328 | system_id: Optional system ID string which will 329 | appear in the ROUGE output. 330 | 331 | """ 332 | if not system_id: 333 | system_id = 1 334 | if (not config_file_path) or (not self._config_dir): 335 | self._config_dir = mkdtemp() 336 | config_filename = "rouge_conf.xml" 337 | else: 338 | config_dir, config_filename = os.path.split(config_file_path) 339 | verify_dir(config_dir, "configuration file") 340 | self._config_file = os.path.join(self._config_dir, config_filename) 341 | Rouge155.write_config_static( 342 | self._system_dir, self._system_filename_pattern, 343 | self._model_dir, self._model_filename_pattern, 344 | self._config_file, system_id) 345 | self.log.info( 346 | "Written ROUGE configuration to {}".format(self._config_file)) 347 | 348 | def evaluate(self, system_id=1, rouge_args=None): 349 | """ 350 | Run ROUGE to evaluate the system summaries in system_dir against 351 | the model summaries in model_dir. The summaries are assumed to 352 | be in the one-sentence-per-line HTML format ROUGE understands. 353 | 354 | system_id: Optional system ID which will be printed in 355 | ROUGE's output. 356 | 357 | Returns: Rouge output as string. 358 | 359 | """ 360 | self.write_config(system_id=system_id) 361 | options = self.__get_options(rouge_args) 362 | command = [self._bin_path] + options 363 | self.log.info( 364 | "Running ROUGE with command {}".format(" ".join(command))) 365 | rouge_output = check_output(command).decode("UTF-8") 366 | return rouge_output 367 | 368 | def convert_and_evaluate(self, system_id=1, 369 | split_sentences=False, rouge_args=None): 370 | """ 371 | Convert plain text summaries to ROUGE format and run ROUGE to 372 | evaluate the system summaries in system_dir against the model 373 | summaries in model_dir. Optionally split texts into sentences 374 | in case they aren't already. 375 | 376 | This is just a convenience method combining 377 | convert_summaries_to_rouge_format() and evaluate(). 378 | 379 | split_sentences: Optional argument specifying if 380 | sentences should be split. 381 | system_id: Optional system ID which will be printed 382 | in ROUGE's output. 383 | 384 | Returns: ROUGE output as string. 385 | 386 | """ 387 | if split_sentences: 388 | self.split_sentences() 389 | self.__write_summaries() 390 | rouge_output = self.evaluate(system_id, rouge_args) 391 | return rouge_output 392 | 393 | def output_to_dict(self, output): 394 | """ 395 | Convert the ROUGE output into python dictionary for further 396 | processing. 397 | 398 | """ 399 | #0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) 400 | pattern = re.compile( 401 | r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " 402 | r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") 403 | results = {} 404 | for line in output.split("\n"): 405 | match = pattern.match(line) 406 | if match: 407 | sys_id, rouge_type, measure, result, conf_begin, conf_end = \ 408 | match.groups() 409 | measure = { 410 | 'Average_R': 'recall', 411 | 'Average_P': 'precision', 412 | 'Average_F': 'f_score' 413 | }[measure] 414 | rouge_type = rouge_type.lower().replace("-", '_') 415 | key = "{}_{}".format(rouge_type, measure) 416 | results[key] = float(result) 417 | results["{}_cb".format(key)] = float(conf_begin) 418 | results["{}_ce".format(key)] = float(conf_end) 419 | return results 420 | 421 | ################################################################### 422 | # Private methods 423 | 424 | def __set_rouge_dir(self, home_dir=None): 425 | """ 426 | Verfify presence of ROUGE-1.5.5.pl and data folder, and set 427 | those paths. 428 | 429 | """ 430 | if not home_dir: 431 | self._home_dir = self.__get_rouge_home_dir_from_settings() 432 | else: 433 | self._home_dir = home_dir 434 | self.save_home_dir() 435 | self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') 436 | self.data_dir = os.path.join(self._home_dir, 'data') 437 | if not os.path.exists(self._bin_path): 438 | raise Exception( 439 | "ROUGE binary not found at {}. Please set the " 440 | "correct path by running pyrouge_set_rouge_path " 441 | "/path/to/rouge/home.".format(self._bin_path)) 442 | 443 | def __get_rouge_home_dir_from_settings(self): 444 | config = ConfigParser() 445 | with open(self._settings_file) as f: 446 | if hasattr(config, "read_file"): 447 | config.read_file(f) 448 | else: 449 | # use deprecated python 2.x method 450 | config.readfp(f) 451 | rouge_home_dir = config.get('pyrouge settings', 'home_dir') 452 | return rouge_home_dir 453 | 454 | @staticmethod 455 | def __get_eval_string( 456 | task_id, system_id, 457 | system_dir, system_filename, 458 | model_dir, model_filenames): 459 | """ 460 | ROUGE can evaluate several system summaries for a given text 461 | against several model summaries, i.e. there is an m-to-n 462 | relation between system and model summaries. The system 463 | summaries are listed in the tag and the model summaries 464 | in the tag. pyrouge currently only supports one system 465 | summary per text, i.e. it assumes a 1-to-n relation between 466 | system and model summaries. 467 | 468 | """ 469 | peer_elems = "

{name}

".format( 470 | id=system_id, name=system_filename) 471 | 472 | model_elems = ["{name}".format( 473 | id=chr(65 + i), name=name) 474 | for i, name in enumerate(model_filenames)] 475 | 476 | model_elems = "\n\t\t\t".join(model_elems) 477 | eval_string = """ 478 | 479 | {model_root} 480 | {peer_root} 481 | 482 | 483 | 484 | {peer_elems} 485 | 486 | 487 | {model_elems} 488 | 489 | 490 | """.format( 491 | task_id=task_id, 492 | model_root=model_dir, model_elems=model_elems, 493 | peer_root=system_dir, peer_elems=peer_elems) 494 | return eval_string 495 | 496 | def __process_summaries(self, process_func): 497 | """ 498 | Helper method that applies process_func to the files in the 499 | system and model folders and saves the resulting files to new 500 | system and model folders. 501 | 502 | """ 503 | temp_dir = mkdtemp() 504 | new_system_dir = os.path.join(temp_dir, "system") 505 | os.mkdir(new_system_dir) 506 | new_model_dir = os.path.join(temp_dir, "model") 507 | os.mkdir(new_model_dir) 508 | self.log.info( 509 | "Processing summaries. Saving system files to {} and " 510 | "model files to {}.".format(new_system_dir, new_model_dir)) 511 | process_func(self._system_dir, new_system_dir) 512 | process_func(self._model_dir, new_model_dir) 513 | self._system_dir = new_system_dir 514 | self._model_dir = new_model_dir 515 | 516 | def __write_summaries(self): 517 | self.log.info("Writing summaries.") 518 | self.__process_summaries(self.convert_summaries_to_rouge_format) 519 | 520 | @staticmethod 521 | def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): 522 | pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) 523 | model_filenames = [ 524 | f for f in os.listdir(model_dir) if pattern.match(f)] 525 | if not model_filenames: 526 | raise Exception( 527 | "Could not find any model summaries for the system" 528 | " summary with ID {}. Specified model filename pattern was: " 529 | "{}".format(id, model_filenames_pattern)) 530 | return model_filenames 531 | 532 | def __get_options(self, rouge_args=None): 533 | """ 534 | Get supplied command line arguments for ROUGE or use default 535 | ones. 536 | 537 | """ 538 | if self.args: 539 | options = self.args.split() 540 | elif rouge_args: 541 | options = rouge_args.split() 542 | else: 543 | options = [ 544 | '-e', self._data_dir, 545 | '-c', 95, 546 | # '-2', 547 | # '-1', 548 | # '-U', 549 | '-m', 550 | # '-v', 551 | '-r', 1000, 552 | '-n', 2, 553 | # '-w', 1.2, 554 | '-a', 555 | ] 556 | options = list(map(str, options)) 557 | 558 | 559 | 560 | 561 | options = self.__add_config_option(options) 562 | return options 563 | 564 | def __create_dir_property(self, dir_name, docstring): 565 | """ 566 | Generate getter and setter for a directory property. 567 | 568 | """ 569 | property_name = "{}_dir".format(dir_name) 570 | private_name = "_" + property_name 571 | setattr(self, private_name, None) 572 | 573 | def fget(self): 574 | return getattr(self, private_name) 575 | 576 | def fset(self, path): 577 | verify_dir(path, dir_name) 578 | setattr(self, private_name, path) 579 | 580 | p = property(fget=fget, fset=fset, doc=docstring) 581 | setattr(self.__class__, property_name, p) 582 | 583 | def __set_dir_properties(self): 584 | """ 585 | Automatically generate the properties for directories. 586 | 587 | """ 588 | directories = [ 589 | ("home", "The ROUGE home directory."), 590 | ("data", "The path of the ROUGE 'data' directory."), 591 | ("system", "Path of the directory containing system summaries."), 592 | ("model", "Path of the directory containing model summaries."), 593 | ] 594 | for (dirname, docstring) in directories: 595 | self.__create_dir_property(dirname, docstring) 596 | 597 | def __clean_rouge_args(self, rouge_args): 598 | """ 599 | Remove enclosing quotation marks, if any. 600 | 601 | """ 602 | if not rouge_args: 603 | return 604 | quot_mark_pattern = re.compile('"(.+)"') 605 | match = quot_mark_pattern.match(rouge_args) 606 | if match: 607 | cleaned_args = match.group(1) 608 | return cleaned_args 609 | else: 610 | return rouge_args 611 | 612 | def __add_config_option(self, options): 613 | return options + [self._config_file] 614 | 615 | def __get_config_path(self): 616 | if platform.system() == "Windows": 617 | parent_dir = os.getenv("APPDATA") 618 | config_dir_name = "pyrouge" 619 | elif os.name == "posix": 620 | parent_dir = os.path.expanduser("~") 621 | config_dir_name = ".pyrouge" 622 | else: 623 | parent_dir = os.path.dirname(__file__) 624 | config_dir_name = "" 625 | config_dir = os.path.join(parent_dir, config_dir_name) 626 | if not os.path.exists(config_dir): 627 | os.makedirs(config_dir) 628 | return os.path.join(config_dir, 'settings.ini') 629 | 630 | 631 | if __name__ == "__main__": 632 | import argparse 633 | from utils.argparsers import rouge_path_parser 634 | 635 | parser = argparse.ArgumentParser(parents=[rouge_path_parser]) 636 | args = parser.parse_args() 637 | 638 | rouge = Rouge155(args.rouge_home) 639 | rouge.save_home_dir() 640 | -------------------------------------------------------------------------------- /src/abstractive/neural.py: -------------------------------------------------------------------------------- 1 | """ 2 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 3 | """ 4 | 5 | import math 6 | import torch 7 | 8 | from torch import nn 9 | 10 | def tile(x, count, dim=0): 11 | """ 12 | Tiles x on dimension dim count times. 13 | """ 14 | perm = list(range(len(x.size()))) 15 | if dim != 0: 16 | perm[0], perm[dim] = perm[dim], perm[0] 17 | x = x.permute(perm).contiguous() 18 | out_size = list(x.size()) 19 | out_size[0] *= count 20 | batch = x.size(0) 21 | x = x.view(batch, -1) \ 22 | .transpose(0, 1) \ 23 | .repeat(count, 1) \ 24 | .transpose(0, 1) \ 25 | .contiguous() \ 26 | .view(*out_size) 27 | if dim != 0: 28 | x = x.permute(perm).contiguous() 29 | return x 30 | 31 | def sequence_mask(lengths, max_len=None): 32 | """ 33 | Creates a boolean mask from sequence lengths. 34 | """ 35 | batch_size = lengths.numel() 36 | max_len = max_len or lengths.max() 37 | return (torch.arange(0, max_len) 38 | .type_as(lengths) 39 | .repeat(batch_size, 1) 40 | .lt(lengths.unsqueeze(1))) 41 | 42 | class PositionalEncoding(nn.Module): 43 | """ 44 | Implements the sinusoidal positional encoding for 45 | non-recurrent neural networks. 46 | 47 | Implementation based on "Attention Is All You Need" 48 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` 49 | 50 | Args: 51 | dropout (float): dropout parameter 52 | dim (int): embedding size 53 | """ 54 | 55 | def __init__(self, dropout, dim, max_len=5000, buffer_name='pe'): 56 | pe = torch.zeros(max_len, dim) 57 | position = torch.arange(0, max_len).unsqueeze(1) 58 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 59 | -(math.log(10000.0) / dim))) 60 | pe[:, 0::2] = torch.sin(position.float() * div_term) 61 | pe[:, 1::2] = torch.cos(position.float() * div_term) 62 | pe = pe.unsqueeze(0) 63 | super(PositionalEncoding, self).__init__() 64 | self.register_buffer(buffer_name, pe) 65 | self.dropout = nn.Dropout(p=dropout) 66 | self.dim = dim 67 | 68 | def forward(self, emb, step=None): 69 | emb = emb * math.sqrt(self.dim) 70 | if (step): 71 | emb = emb + self.pe[:, step][:, None, :] 72 | 73 | else: 74 | emb = emb + self.pe[:, :emb.size(1)] 75 | emb = self.dropout(emb) 76 | return emb 77 | 78 | def get_emb(self, emb): 79 | return self.pe[:, :emb.size(1)] 80 | 81 | class PositionwiseFeedForward(nn.Module): 82 | """ A two-layer Feed-Forward-Network with residual layer norm. 83 | 84 | Args: 85 | d_model (int): the size of input for the first-layer of the FFN. 86 | d_ff (int): the hidden layer size of the second-layer 87 | of the FNN. 88 | dropout (float): dropout probability(0-1.0). 89 | """ 90 | 91 | def __init__(self, d_model, d_ff, dropout=0.1): 92 | super(PositionwiseFeedForward, self).__init__() 93 | self.w_1 = nn.Linear(d_model, d_ff) 94 | self.w_2 = nn.Linear(d_ff, d_model) 95 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 96 | self.dropout_1 = nn.Dropout(dropout) 97 | self.relu = nn.ReLU() 98 | self.dropout_2 = nn.Dropout(dropout) 99 | 100 | def forward(self, x): 101 | """ 102 | Layer definition. 103 | 104 | Args: 105 | input: [ batch_size, input_len, model_dim ] 106 | 107 | 108 | Returns: 109 | output: [ batch_size, input_len, model_dim ] 110 | """ 111 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 112 | output = self.dropout_2(self.w_2(inter)) 113 | return output + x 114 | -------------------------------------------------------------------------------- /src/abstractive/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimizers class 3 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 4 | """ 5 | import torch 6 | import torch.optim as optim 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | 10 | 11 | def use_gpu(opt): 12 | """ 13 | Creates a boolean if gpu used 14 | """ 15 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 16 | (hasattr(opt, 'gpu') and opt.gpu > -1) 17 | 18 | def build_optim(model, opt, checkpoint): 19 | """ Build optimizer """ 20 | saved_optimizer_state_dict = None 21 | optim = Optimizer( 22 | opt.optim, opt.learning_rate, opt.max_grad_norm, 23 | lr_decay=opt.learning_rate_decay, 24 | start_decay_steps=opt.start_decay_steps, 25 | decay_steps=opt.decay_steps, 26 | beta1=opt.adam_beta1, 27 | beta2=opt.adam_beta2, 28 | adagrad_accum=opt.adagrad_accumulator_init, 29 | decay_method=opt.decay_method, 30 | warmup_steps=opt.warmup_steps, 31 | model_size=opt.rnn_size) 32 | 33 | if opt.train_from: 34 | # optim = checkpoint['optim'] 35 | # We need to save a copy of optim.optimizer.state_dict() for setting 36 | # the, optimizer state later on in Stage 2 in this method, since 37 | # the method optim.set_parameters(model.parameters()) will overwrite 38 | # optim.optimizer, and with ith the values stored in 39 | # optim.optimizer.state_dict() 40 | saved_optimizer_state_dict = checkpoint['optim'] 41 | 42 | # Stage 1: 43 | # Essentially optim.set_parameters (re-)creates and optimizer using 44 | # model.paramters() as parameters that will be stored in the 45 | # optim.optimizer.param_groups field of the torch optimizer class. 46 | # Importantly, this method does not yet load the optimizer state, as 47 | # essentially it builds a new optimizer with empty optimizer state and 48 | # parameters from the model. 49 | optim.set_parameters(model.named_parameters()) 50 | 51 | if opt.train_from: 52 | # Stage 2: In this stage, which is only performed when loading an 53 | # optimizer from a checkpoint, we load the saved_optimizer_state_dict 54 | # into the re-created optimizer, to set the optim.optimizer.state 55 | # field, which was previously empty. For this, we use the optimizer 56 | # state saved in the "saved_optimizer_state_dict" variable for 57 | # this purpose. 58 | # See also: https://github.com/pytorch/pytorch/issues/2830 59 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 60 | # Convert back the state values to cuda type if applicable 61 | if use_gpu(opt): 62 | for state in optim.optimizer.state.values(): 63 | for k, v in state.items(): 64 | if torch.is_tensor(v): 65 | state[k] = v.cuda() 66 | 67 | # We want to make sure that indeed we have a non-empty optimizer state 68 | # when we loaded an existing model. This should be at least the case 69 | # for Adam, which saves "exp_avg" and "exp_avg_sq" state 70 | # (Exponential moving average of gradient and squared gradient values) 71 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 72 | raise RuntimeError( 73 | "Error: loaded Adam optimizer from existing model" + 74 | " but optimizer state is empty") 75 | 76 | return optim 77 | 78 | 79 | class MultipleOptimizer(object): 80 | """ Implement multiple optimizers needed for sparse adam """ 81 | 82 | def __init__(self, op): 83 | """ ? """ 84 | self.optimizers = op 85 | 86 | def zero_grad(self): 87 | """ ? """ 88 | for op in self.optimizers: 89 | op.zero_grad() 90 | 91 | def step(self): 92 | """ ? """ 93 | for op in self.optimizers: 94 | op.step() 95 | 96 | @property 97 | def state(self): 98 | """ ? """ 99 | return {k: v for op in self.optimizers for k, v in op.state.items()} 100 | 101 | def state_dict(self): 102 | """ ? """ 103 | return [op.state_dict() for op in self.optimizers] 104 | 105 | def load_state_dict(self, state_dicts): 106 | """ ? """ 107 | assert len(state_dicts) == len(self.optimizers) 108 | for i in range(len(state_dicts)): 109 | self.optimizers[i].load_state_dict(state_dicts[i]) 110 | 111 | 112 | class Optimizer(object): 113 | """ 114 | Controller class for optimization. Mostly a thin 115 | wrapper for `optim`, but also useful for implementing 116 | rate scheduling beyond what is currently available. 117 | Also implements necessary methods for training RNNs such 118 | as grad manipulations. 119 | 120 | Args: 121 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 122 | lr (float): learning rate 123 | lr_decay (float, optional): learning rate decay multiplier 124 | start_decay_steps (int, optional): step to start learning rate decay 125 | beta1, beta2 (float, optional): parameters for adam 126 | adagrad_accum (float, optional): initialization parameter for adagrad 127 | decay_method (str, option): custom decay options 128 | warmup_steps (int, option): parameter for `noam` decay 129 | model_size (int, option): parameter for `noam` decay 130 | 131 | We use the default parameters for Adam that are suggested by 132 | the original paper https://arxiv.org/pdf/1412.6980.pdf 133 | These values are also used by other established implementations, 134 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 135 | https://keras.io/optimizers/ 136 | Recently there are slightly different values used in the paper 137 | "Attention is all you need" 138 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 139 | was used there however, beta2=0.999 is still arguably the more 140 | established value, so we use that here as well 141 | """ 142 | 143 | def __init__(self, method, learning_rate, max_grad_norm, 144 | lr_decay=1, start_decay_steps=None, decay_steps=None, 145 | beta1=0.9, beta2=0.999, 146 | adagrad_accum=0.0, 147 | decay_method=None, 148 | warmup_steps=4000, 149 | model_size=None): 150 | self.last_ppl = None 151 | self.learning_rate = learning_rate 152 | self.original_lr = learning_rate 153 | self.max_grad_norm = max_grad_norm 154 | self.method = method 155 | self.lr_decay = lr_decay 156 | self.start_decay_steps = start_decay_steps 157 | self.decay_steps = decay_steps 158 | self.start_decay = False 159 | self._step = 0 160 | self.betas = [beta1, beta2] 161 | self.adagrad_accum = adagrad_accum 162 | self.decay_method = decay_method 163 | self.warmup_steps = warmup_steps 164 | self.model_size = model_size 165 | 166 | def set_parameters(self, params): 167 | """ ? """ 168 | self.params = [] 169 | self.sparse_params = [] 170 | for k, p in params: 171 | if p.requires_grad: 172 | if self.method != 'sparseadam' or "embed" not in k: 173 | self.params.append(p) 174 | else: 175 | self.sparse_params.append(p) 176 | if self.method == 'sgd': 177 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate) 178 | elif self.method == 'adagrad': 179 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate) 180 | for group in self.optimizer.param_groups: 181 | for p in group['params']: 182 | self.optimizer.state[p]['sum'] = self.optimizer\ 183 | .state[p]['sum'].fill_(self.adagrad_accum) 184 | elif self.method == 'adadelta': 185 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate) 186 | elif self.method == 'adam': 187 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 188 | betas=self.betas, eps=1e-9) 189 | elif self.method == 'sparseadam': 190 | self.optimizer = MultipleOptimizer( 191 | [optim.Adam(self.params, lr=self.learning_rate, 192 | betas=self.betas, eps=1e-8), 193 | optim.SparseAdam(self.sparse_params, lr=self.learning_rate, 194 | betas=self.betas, eps=1e-8)]) 195 | else: 196 | raise RuntimeError("Invalid optim method: " + self.method) 197 | 198 | def _set_rate(self, learning_rate): 199 | self.learning_rate = learning_rate 200 | if self.method != 'sparseadam': 201 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 202 | else: 203 | for op in self.optimizer.optimizers: 204 | op.param_groups[0]['lr'] = self.learning_rate 205 | 206 | def step(self): 207 | """Update the model parameters based on current gradients. 208 | 209 | Optionally, will employ gradient modification or update learning 210 | rate. 211 | """ 212 | self._step += 1 213 | 214 | # Decay method used in tensor2tensor. 215 | if self.decay_method == "noam": 216 | self._set_rate( 217 | self.original_lr * 218 | ( self.model_size ** -0.5*min(self._step ** (-0.5), 219 | self._step * self.warmup_steps**(-1.5)))) 220 | else: 221 | if ((self.start_decay_steps is not None) and ( 222 | self._step >= self.start_decay_steps)): 223 | self.start_decay = True 224 | if self.start_decay: 225 | if ((self._step - self.start_decay_steps) 226 | % self.decay_steps == 0): 227 | self.learning_rate = self.learning_rate * self.lr_decay 228 | 229 | if self.method != 'sparseadam': 230 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 231 | 232 | if self.max_grad_norm: 233 | clip_grad_norm_(self.params, self.max_grad_norm) 234 | self.optimizer.step() 235 | -------------------------------------------------------------------------------- /src/abstractive/penalties.py: -------------------------------------------------------------------------------- 1 | """ 2 | Majority of code borrowed https://github.com/OpenNMT/OpenNMT-py 3 | """ 4 | import torch 5 | 6 | 7 | class PenaltyBuilder(object): 8 | """Returns the Length and Coverage Penalty function for Beam Search. 9 | 10 | Args: 11 | length_pen (str): option name of length pen 12 | cov_pen (str): option name of cov pen 13 | 14 | Attributes: 15 | has_cov_pen (bool): Whether coverage penalty is None (applying it 16 | is a no-op). Note that the converse isn't true. Setting beta 17 | to 0 should force coverage length to be a no-op. 18 | has_len_pen (bool): Whether length penalty is None (applying it 19 | is a no-op). Note that the converse isn't true. Setting alpha 20 | to 1 should force length penalty to be a no-op. 21 | coverage_penalty (callable[[FloatTensor, float], FloatTensor]): 22 | Calculates the coverage penalty. 23 | length_penalty (callable[[int, float], float]): Calculates 24 | the length penalty. 25 | """ 26 | 27 | def __init__(self, cov_pen, length_pen): 28 | self.has_cov_pen = not self._pen_is_none(cov_pen) 29 | self.coverage_penalty = self._coverage_penalty(cov_pen) 30 | self.has_len_pen = not self._pen_is_none(length_pen) 31 | self.length_penalty = self._length_penalty(length_pen) 32 | 33 | @staticmethod 34 | def _pen_is_none(pen): 35 | return pen == "none" or pen is None 36 | 37 | def _coverage_penalty(self, cov_pen): 38 | if cov_pen == "wu": 39 | return self.coverage_wu 40 | elif cov_pen == "summary": 41 | return self.coverage_summary 42 | elif self._pen_is_none(cov_pen): 43 | return self.coverage_none 44 | else: 45 | raise NotImplementedError("No '{:s}' coverage penalty.".format( 46 | cov_pen)) 47 | 48 | def _length_penalty(self, length_pen): 49 | if length_pen == "wu": 50 | return self.length_wu 51 | elif length_pen == "avg": 52 | return self.length_average 53 | elif self._pen_is_none(length_pen): 54 | return self.length_none 55 | else: 56 | raise NotImplementedError("No '{:s}' length penalty.".format( 57 | length_pen)) 58 | 59 | # Below are all the different penalty terms implemented so far. 60 | # Subtract coverage penalty from topk log probs. 61 | # Divide topk log probs by length penalty. 62 | 63 | def coverage_wu(self, cov, beta=0.): 64 | """GNMT coverage re-ranking score. 65 | 66 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 67 | ``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is 68 | probably ``batch_size x beam_size`` but could be several 69 | dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention, 70 | then the ``seq_len`` axis probably sums to (almost) 1. 71 | """ 72 | 73 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) 74 | return beta * penalty 75 | 76 | def coverage_summary(self, cov, beta=0.): 77 | """Our summary penalty.""" 78 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1) 79 | penalty -= cov.size(-1) 80 | return beta * penalty 81 | 82 | def coverage_none(self, cov, beta=0.): 83 | """Returns zero as penalty""" 84 | none = torch.zeros((1,), device=cov.device, 85 | dtype=torch.float) 86 | if cov.dim() == 3: 87 | none = none.unsqueeze(0) 88 | return none 89 | 90 | def length_wu(self, cur_len, alpha=0.): 91 | """GNMT length re-ranking score. 92 | 93 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 94 | """ 95 | 96 | return ((5 + cur_len) / 6.0) ** alpha 97 | 98 | def length_average(self, cur_len, alpha=0.): 99 | """Returns the current sequence length.""" 100 | return cur_len 101 | 102 | def length_none(self, cur_len, alpha=0.): 103 | """Returns unmodified scores.""" 104 | return 1.0 105 | -------------------------------------------------------------------------------- /src/abstractive/predictor_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Translator Class and builder 4 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 5 | and https://github.com/OpenNMT/OpenNMT-py 6 | """ 7 | from __future__ import print_function 8 | import codecs 9 | import os 10 | import math 11 | 12 | import torch 13 | 14 | from itertools import count 15 | 16 | from tensorboardX import SummaryWriter 17 | 18 | from abstractive.beam import GNMTGlobalScorer 19 | from abstractive.cal_rouge import test_rouge, rouge_results_to_str 20 | from abstractive.neural import tile 21 | from abstractive.beam_search import BeamSearch 22 | 23 | from rouge import FilesRouge 24 | 25 | import numpy as np 26 | 27 | 28 | def build_predictor(args, tokenizer, symbols, model, logger=None): 29 | scorer = GNMTGlobalScorer(alpha=args.alpha,beta=args.cov_beta,length_penalty='wu', coverage_penalty=args.coverage_penalty) 30 | translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) 31 | return translator 32 | 33 | 34 | class Translator(object): 35 | 36 | def __init__(self, 37 | args, 38 | model, 39 | vocab, 40 | symbols, 41 | n_best=1, 42 | global_scorer=None, 43 | logger=None, 44 | dump_beam=""): 45 | self.logger = logger 46 | self.cuda = args.visible_gpus != '-1' 47 | self.args = args 48 | 49 | self.model = model 50 | self.generator = self.model.generator 51 | self.vocab = vocab 52 | self.symbols = symbols 53 | self.start_token = symbols['BOS'] 54 | self.end_token = symbols['EOS'] 55 | self.pad_token = symbols['PAD'] 56 | 57 | 58 | self.block_ngram_repeat = self.args.block_ngram_repeat 59 | self.stepwise_penalty = self.args.stepwise_penalty 60 | 61 | self.n_best = n_best 62 | self.max_length = args.max_length 63 | self.global_scorer = global_scorer 64 | self.beam_size = args.beam_size 65 | self.min_length = args.min_length 66 | self.dump_beam = dump_beam 67 | 68 | self.beam_trace = self.dump_beam != "" 69 | self.beam_accum = None 70 | 71 | tensorboard_log_dir = self.args.model_path 72 | 73 | self.tensorboard_writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 74 | 75 | if self.beam_trace: 76 | self.beam_accum = { 77 | "predicted_ids": [], 78 | "beam_parent_ids": [], 79 | "scores": [], 80 | "log_probs": []} 81 | 82 | def _build_target_tokens(self, pred): 83 | # vocab = self.fields["tgt"].vocab 84 | tokens = [] 85 | for tok in pred: 86 | tok = int(tok) 87 | tokens.append(tok) 88 | if tokens[-1] == self.end_token: 89 | tokens = tokens[:-1] 90 | break 91 | tokens = [t for t in tokens if t', ' ').replace(r' +', ' ').replace('', 'UNK').strip() 160 | gold_str = ' '.join(gold).replace('', '').replace('', '').replace('', ' ').replace(r' +', 161 | ' ').strip() 162 | pred_str = ' '.join(pred_str.split()) # remove extra white spaces 163 | gold_str = ' '.join(gold_str.split()) # remove extra white spaces 164 | 165 | 166 | gold_str = gold_str.lower() 167 | self.raw_can_out_file.write(' '.join(pred).strip() + '\n') 168 | self.raw_gold_out_file.write(' '.join(gold).strip() + '\n') 169 | self.can_out_file.write(pred_str + '\n') 170 | self.gold_out_file.write(gold_str + '\n') 171 | self.src_out_file.write(src.strip() + '\n') 172 | ct += 1 173 | if (ct > self.args.max_samples): 174 | break 175 | 176 | self.raw_can_out_file.flush() 177 | self.raw_gold_out_file.flush() 178 | self.can_out_file.flush() 179 | self.gold_out_file.flush() 180 | self.src_out_file.flush() 181 | if (ct > self.args.max_samples): 182 | break 183 | 184 | self.raw_can_out_file.close() 185 | self.raw_gold_out_file.close() 186 | self.can_out_file.close() 187 | self.gold_out_file.close() 188 | self.src_out_file.close() 189 | 190 | if(step!=-1 and self.args.report_rouge): 191 | rouges = self._report_rouge(gold_path, can_path) 192 | self.logger.info('Rouges at step %d \n%s'%(step,rouge_results_to_str(rouges))) 193 | if self.tensorboard_writer is not None: 194 | self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) 195 | self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) 196 | self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step) 197 | 198 | ## write the results to a file 199 | res_path = self.args.result_path + '.%d.result'%step 200 | res_out_file = codecs.open(res_path, 'w', 'utf-8') 201 | res_out_file.write(rouge_results_to_str(rouges)) 202 | res_out_file.flush() 203 | res_out_file.close() 204 | 205 | return rouges 206 | 207 | 208 | def fast_rouge(self, step): 209 | self.logger.info("Calculating Rouge") 210 | 211 | gold_path = self.args.result_path + '.%d.gold'%step 212 | can_path = self.args.result_path + '.%d.candidate'%step 213 | 214 | if self.args.dataset in ["DUC2006", "DUC2007"]: 215 | ## give only one reference 216 | data = [] 217 | with open(gold_path, 'r') as f: 218 | for line in f.read().splitlines(): 219 | data.append(line.strip()) 220 | 221 | data = [d.split("")[0].strip() for d in data] 222 | with open(gold_path, 'w') as f: 223 | f.write("\n".join(data)) 224 | f.flush() 225 | 226 | print(8*"="+"DEBUG TEST FOR DUC"+8*"=") 227 | print(f"reference sample: {data[0]}") 228 | 229 | 230 | 231 | files_rouge = FilesRouge(can_path, gold_path) 232 | scores = files_rouge.get_scores(avg=True) 233 | 234 | rouges = {} 235 | rouges["rouge_l_f_score"] = scores["rouge-l"]["f"] 236 | rouges["rouge_2_f_score"] = scores["rouge-2"]["f"] 237 | rouges["rouge_1_f_score"] = scores["rouge-1"]["f"] 238 | self.logger.info(rouges) 239 | 240 | return rouges 241 | 242 | 243 | 244 | def _report_rouge(self, gold_path, can_path): 245 | self.logger.info("Calculating Rouge") 246 | candidates = codecs.open(can_path, encoding="utf-8") 247 | references = codecs.open(gold_path, encoding="utf-8") 248 | if self.args.rouge_path is None: 249 | results_dict = test_rouge(candidates, references, 1) 250 | else: 251 | results_dict = test_rouge(candidates, references, 0, rouge_dir=os.path.join(os.getcwd(),self.args.rouge_path)) 252 | return results_dict 253 | 254 | 255 | 256 | def translate_batch(self, batch, fast=False): 257 | """ 258 | Translate a batch of sentences. 259 | 260 | Mostly a wrapper around :obj:`Beam`. 261 | 262 | Args: 263 | batch (:obj:`Batch`): a batch from a dataset object 264 | data (:obj:`Dataset`): the dataset object 265 | fast (bool): enables fast beam search (may not support all features) 266 | 267 | Todo: 268 | Shouldn't need the original dataset. 269 | """ 270 | with torch.no_grad(): 271 | return self._fast_translate_batch( 272 | batch, 273 | self.max_length, 274 | min_length=self.min_length, 275 | n_best=self.n_best) 276 | 277 | def _fast_translate_batch(self, 278 | batch, 279 | max_length, 280 | min_length=0, 281 | n_best=1): 282 | assert not self.dump_beam 283 | 284 | beam_size = self.beam_size 285 | batch_size = batch.batch_size 286 | 287 | # Encoder forward. 288 | src = batch.src 289 | batch_size, n_blocks, num_tokens = src.shape 290 | 291 | if self.args.model_type in ['query', 'heq', 'hero']: 292 | src_features, mask_hier = self.model.encoder(src, batch.query) 293 | else: 294 | src_features, mask_hier = self.model.encoder(src) 295 | dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True) 296 | src_features = src_features.view(n_blocks, num_tokens, batch_size, -1).contiguous() 297 | device = src_features.device 298 | 299 | 300 | results = { 301 | "predictions": None, 302 | "scores": None, 303 | "attention": None, 304 | "batch": batch, 305 | "gold_score": [0] * batch_size} 306 | 307 | #(2) Repeat src objects `beam_size` times. 308 | # We use batch_size x beam_size 309 | 310 | dec_states.map_batch_fn( 311 | lambda state, dim: tile(state, beam_size, dim=dim)) 312 | 313 | 314 | src_features = tile(src_features, beam_size, dim=2) 315 | mask = tile(mask_hier, beam_size, dim=0) 316 | 317 | 318 | 319 | 320 | beam = BeamSearch( 321 | beam_size, 322 | n_best=self.n_best, 323 | batch_size=batch_size, 324 | global_scorer=self.global_scorer, 325 | pad=self.pad_token, 326 | eos=self.end_token, 327 | bos=self.start_token, 328 | min_length=self.min_length, 329 | ratio=0., 330 | max_length=self.max_length, 331 | mb_device=device, 332 | return_attention=False, 333 | stepwise_penalty=self.stepwise_penalty, 334 | block_ngram_repeat=self.block_ngram_repeat, 335 | exclusion_tokens=set([]), 336 | memory_lengths=mask) 337 | 338 | for step in range(max_length): 339 | decoder_input = beam.current_predictions.view(1, -1) 340 | if (self.args.hier): 341 | dec_out, dec_states, attn = self.model.decoder(decoder_input, src_features, dec_states, 342 | memory_masks=mask, 343 | step=step) 344 | else: 345 | dec_out, dec_states, attn = self.model.decoder(decoder_input, src_features, dec_states, 346 | step=step) 347 | 348 | # Generator forward. 349 | log_probs = self.generator.forward(dec_out.squeeze(0)) 350 | vocab_size = log_probs.size(-1) 351 | attn = attn.transpose(1,2).transpose(0,1) 352 | attn = attn.max(2)[0] 353 | beam.advance(log_probs, attn) 354 | any_beam_is_finished = beam.is_finished.any() 355 | if any_beam_is_finished: 356 | beam.update_finished() 357 | if beam.done: 358 | break 359 | 360 | select_indices = beam.current_origin 361 | if any_beam_is_finished: 362 | src_features = src_features.index_select(2, select_indices) 363 | mask = mask.index_select(0, select_indices) 364 | 365 | dec_states.map_batch_fn( 366 | lambda state, dim: state.index_select(dim, select_indices)) 367 | 368 | results["scores"] = beam.scores 369 | results["predictions"] = beam.predictions 370 | results["attention"] = beam.attention 371 | 372 | return results 373 | 374 | 375 | -------------------------------------------------------------------------------- /src/abstractive/trainer_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer Builder Class 3 | and https://github.com/OpenNMT/OpenNMT-py 4 | """ 5 | 6 | 7 | from datetime import datetime 8 | 9 | import torch 10 | import os 11 | from glob import glob 12 | 13 | from abstractive.loss import build_loss_compute 14 | from abstractive.loss import PTLossCompute 15 | from tensorboardX import SummaryWriter 16 | 17 | from others import distributed 18 | from others.logging import logger 19 | from others.report_manager import ReportMgr 20 | from others.statistics import Statistics 21 | 22 | def _tally_parameters(model): 23 | n_params = sum([p.nelement() for p in model.parameters()]) 24 | enc = 0 25 | dec = 0 26 | for name, param in model.named_parameters(): 27 | if 'encoder' in name: 28 | enc += param.nelement() 29 | elif 'decoder' or 'generator' in name: 30 | dec += param.nelement() 31 | return n_params, enc, dec 32 | 33 | 34 | def build_trainer(args, device_id, model, symbols, vocab_size, 35 | optim): 36 | """ 37 | Simplify `Trainer` creation based on user `opt`s* 38 | Args: 39 | opt (:obj:`Namespace`): user options (usually from argument parsing) 40 | model (:obj:`onmt.models.NMTModel`): the model to train 41 | fields (dict): dict of fields 42 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 43 | data_type (str): string describing the type of data 44 | e.g. "text", "img", "audio" 45 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 46 | used to save the model 47 | """ 48 | 49 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 50 | 51 | train_loss = build_loss_compute( 52 | model.generator, symbols, vocab_size, device, train=True, label_smoothing=args.label_smoothing) 53 | valid_loss = build_loss_compute( 54 | model.generator, symbols, vocab_size, train=False, device=device) 55 | 56 | shard_size = args.max_generator_batches 57 | grad_accum_count = args.accum_count 58 | n_gpu = args.world_size 59 | if device_id >= 0: 60 | gpu_rank = int(args.gpu_ranks[device_id]) 61 | else: 62 | gpu_rank = 0 63 | n_gpu = 0 64 | 65 | tensorboard_log_dir = args.model_path 66 | 67 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 68 | 69 | 70 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 71 | trainer = Trainer(args, model, train_loss, valid_loss, optim, 72 | shard_size, 73 | grad_accum_count, n_gpu, gpu_rank, report_manager) 74 | 75 | # print(tr) 76 | n_params, enc, dec = _tally_parameters(model) 77 | logger.info('encoder: %d' % enc) 78 | logger.info('decoder: %d' % dec) 79 | logger.info('* number of parameters: %d' % n_params) 80 | 81 | return trainer 82 | 83 | 84 | class Trainer(object): 85 | """ 86 | Class that controls the training process. 87 | 88 | Args: 89 | model(:py:class:`onmt.models.model.NMTModel`): translation model 90 | to train 91 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 92 | training loss computation 93 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 94 | training loss computation 95 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 96 | the optimizer responsible for update 97 | trunc_size(int): length of truncated back propagation through time 98 | shard_size(int): compute loss in shards of this size for efficiency 99 | data_type(string): type of the source input: [text|img|audio] 100 | norm_method(string): normalization methods: [sents|tokens] 101 | grad_accum_count(int): accumulate gradients this many times. 102 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 103 | the object that creates reports, or None 104 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 105 | used to save a checkpoint. 106 | Thus nothing will be saved if this parameter is None 107 | """ 108 | 109 | def __init__(self, args, model, train_loss, valid_loss, optim, 110 | shard_size=32, grad_accum_count=1, n_gpu=1, gpu_rank=1,report_manager=None): 111 | # Basic attributes. 112 | self.args = args 113 | self.model = model 114 | self.train_loss = train_loss 115 | self.valid_loss = valid_loss 116 | self.optim = optim 117 | self.shard_size = shard_size 118 | self.grad_accum_count = grad_accum_count 119 | self.n_gpu = n_gpu 120 | self.gpu_rank = gpu_rank 121 | self.report_manager = report_manager 122 | 123 | 124 | assert grad_accum_count > 0 125 | self.model.train() 126 | 127 | def train(self, train_iter_fct, train_steps, predictor, valid_iter_fct): 128 | logger.info('Start training...') 129 | 130 | step = self.optim._step + 1 131 | true_batchs = [] 132 | accum = 0 133 | normalization = 0 134 | train_iter = train_iter_fct() 135 | 136 | total_stats = Statistics() 137 | report_stats = Statistics() 138 | self._start_report_manager(start_time=total_stats.start_time) 139 | 140 | while step <= train_steps: 141 | 142 | reduce_counter = 0 143 | for i, batch in enumerate(train_iter): 144 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 145 | #print(batch.src.shape) 146 | true_batchs.append(batch) 147 | num_tokens = batch.tgt[1:].ne( 148 | self.train_loss.padding_idx).sum() 149 | normalization += num_tokens.item() 150 | accum += 1 151 | if accum == self.grad_accum_count: 152 | reduce_counter += 1 153 | if self.n_gpu > 1: 154 | normalization = sum(distributed 155 | .all_gather_list 156 | (normalization)) 157 | 158 | self._gradient_accumulation( 159 | true_batchs, normalization, total_stats, 160 | report_stats) 161 | 162 | report_stats = self._maybe_report_training( 163 | step, train_steps, 164 | self.optim.learning_rate, 165 | report_stats) 166 | 167 | 168 | 169 | true_batchs = [] 170 | accum = 0 171 | normalization = 0 172 | if (step % self.args.save_checkpoint_steps == 0 and self.gpu_rank == 0): 173 | predictor.translate(valid_iter_fct(), step=0) 174 | rouge_scores = predictor.fast_rouge(step=0) 175 | self._save(step, score = rouge_scores[self.args.save_criteria]) 176 | self.model.train() 177 | 178 | step += 1 179 | if step > train_steps: 180 | break 181 | train_iter = train_iter_fct() 182 | 183 | #if self.args.dataset in ['DUC2006', 'DUC2007']: 184 | # self._save(step, score=1.0) 185 | 186 | return total_stats 187 | 188 | def validate(self, valid_iter): 189 | """ Validate model. 190 | valid_iter: validate data iterator 191 | Returns: 192 | :obj:`nmt.Statistics`: validation loss statistics 193 | """ 194 | # Set model in validating mode. 195 | self.model.eval() 196 | 197 | stats = Statistics() 198 | 199 | with torch.no_grad(): 200 | for batch in valid_iter: 201 | src = batch.src 202 | tgt = batch.tgt 203 | outputs, _, attn = self.model(src, tgt) 204 | 205 | batch_stats = self.valid_loss.monolithic_compute_loss( 206 | batch, outputs, attn) 207 | #print(batch_stats.ppl()) 208 | stats.update(batch_stats) 209 | return stats 210 | 211 | def _gradient_accumulation(self, true_batchs, normalization, total_stats, 212 | report_stats): 213 | if self.grad_accum_count > 1: 214 | self.model.zero_grad() 215 | 216 | for batch in true_batchs: 217 | 218 | src = batch.src 219 | tgt = batch.tgt 220 | 221 | if self.grad_accum_count == 1: 222 | self.model.zero_grad() 223 | 224 | if self.args.model_type in ['query', 'heq', 'hero']: 225 | query = batch.query 226 | outputs, _, attn = self.model(src, tgt, query) 227 | else: 228 | outputs, _, attn = self.model(src, tgt) 229 | 230 | # 3. Compute loss in shards for memory efficiency. 231 | batch_stats = self.train_loss.sharded_compute_loss( 232 | batch, outputs, self.shard_size, normalization) 233 | 234 | 235 | report_stats.n_src_words += src.nelement() 236 | 237 | total_stats.update(batch_stats) 238 | report_stats.update(batch_stats) 239 | 240 | # 4. Update the parameters and statistics. 241 | if self.grad_accum_count == 1: 242 | # Multi GPU gradient gather 243 | if self.n_gpu > 1: 244 | grads = [p.grad.data for p in self.model.parameters() 245 | if p.requires_grad 246 | and p.grad is not None] 247 | distributed.all_reduce_and_rescale_tensors( 248 | grads, float(1)) 249 | self.optim.step() 250 | 251 | # in case of multi step gradient accumulation, 252 | # update only after accum batches 253 | if self.grad_accum_count > 1: 254 | if self.n_gpu > 1: 255 | grads = [p.grad.data for p in self.model.parameters() 256 | if p.requires_grad 257 | and p.grad is not None] 258 | distributed.all_reduce_and_rescale_tensors( 259 | grads, float(1)) 260 | self.optim.step() 261 | 262 | def _save(self, step, score): 263 | real_model = self.model 264 | 265 | model_state_dict = real_model.state_dict() 266 | 267 | checkpoint = { 268 | 'model': model_state_dict, 269 | 'opt': self.args, 270 | 'optim': self.optim.optimizer.state_dict(), 271 | } 272 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%f_%d.pt' % (score, step)) 273 | logger.info("Saving checkpoint %s" % checkpoint_path) 274 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 275 | if (not os.path.exists(checkpoint_path)): 276 | torch.save(checkpoint, checkpoint_path) 277 | # check number of checkpoints 278 | paths = glob(os.path.join(self.args.model_path, '*.pt')) 279 | if len(paths)>self.args.max_num_checkpoints: 280 | delete_step = int(paths[0].split("_")[-1].replace(".pt","")) 281 | delete_score = float(paths[0].split("_")[-2].replace(".pt","")) 282 | for path in paths[1:]: 283 | mstep = int(path.split("_")[-1].replace(".pt","")) 284 | mscore = float(path.split("_")[-2].replace(".pt","")) 285 | if mscore 1: 317 | return Statistics.all_gather_stats(stat) 318 | return stat 319 | 320 | def _maybe_report_training(self, step, num_steps, learning_rate, 321 | report_stats): 322 | """ 323 | Simple function to report training stats (if report_manager is set) 324 | see `onmt.utils.ReportManagerBase.report_training` for doc 325 | """ 326 | if self.report_manager is not None: 327 | return self.report_manager.report_training( 328 | step, num_steps, learning_rate, report_stats, 329 | multigpu=self.n_gpu > 1) 330 | 331 | def _report_step(self, learning_rate, step, train_stats=None, 332 | valid_stats=None): 333 | """ 334 | Simple function to report stats (if report_manager is set) 335 | see `onmt.utils.ReportManagerBase.report_step` for doc 336 | """ 337 | if self.report_manager is not None: 338 | return self.report_manager.report_step( 339 | learning_rate, step, train_stats=train_stats, 340 | valid_stats=valid_stats) 341 | 342 | def _maybe_save(self, step): 343 | """ 344 | Save the model if a model saver is set 345 | """ 346 | if self.model_saver is not None: 347 | self.model_saver.maybe_save(step) 348 | -------------------------------------------------------------------------------- /src/abstractive/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | Only Hierarchical Transformers modules are borrowed from https://github.com/nlpyang/hiersumm 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from abstractive.attn import MultiHeadedAttention,SelfAttention 11 | from abstractive.neural import PositionwiseFeedForward 12 | from abstractive.transformer_encoder import PositionalEncoding 13 | 14 | 15 | MAX_SIZE = 5000 16 | 17 | 18 | class DecoderState(object): 19 | """Interface for grouping together the current state of a recurrent 20 | decoder. In the simplest case just represents the hidden state of 21 | the model. But can also be used for implementing various forms of 22 | input_feeding and non-recurrent models. 23 | 24 | Modules need to implement this to utilize beam search decoding. 25 | """ 26 | def detach(self): 27 | """ Need to document this """ 28 | self.hidden = tuple([_.detach() for _ in self.hidden]) 29 | self.input_feed = self.input_feed.detach() 30 | 31 | def beam_update(self, idx, positions, beam_size): 32 | """ Need to document this """ 33 | for e in self._all: 34 | sizes = e.size() 35 | br = sizes[1] 36 | if len(sizes) == 3: 37 | sent_states = e.view(sizes[0], beam_size, br // beam_size, 38 | sizes[2])[:, :, idx] 39 | else: 40 | sent_states = e.view(sizes[0], beam_size, 41 | br // beam_size, 42 | sizes[2], 43 | sizes[3])[:, :, idx] 44 | 45 | sent_states.data.copy_( 46 | sent_states.data.index_select(1, positions)) 47 | 48 | def map_batch_fn(self, fn): 49 | raise NotImplementedError() 50 | 51 | 52 | 53 | class TransformerDecoderLayer(nn.Module): 54 | """ 55 | Args: 56 | d_model (int): the dimension of keys/values/queries in 57 | MultiHeadedAttention, also the input size of 58 | the first-layer of the PositionwiseFeedForward. 59 | heads (int): the number of heads for MultiHeadedAttention. 60 | d_ff (int): the second-layer of the PositionwiseFeedForward. 61 | dropout (float): dropout probability(0-1.0). 62 | self_attn_type (string): type of self-attention scaled-dot, average 63 | """ 64 | 65 | def __init__(self, d_model, heads, d_ff, dropout): 66 | super(TransformerDecoderLayer, self).__init__() 67 | 68 | 69 | self.self_attn = MultiHeadedAttention( 70 | heads, d_model, dropout=dropout) 71 | self.context_attn = MultiHeadedAttention( 72 | heads, d_model, dropout=dropout) 73 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 74 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) 75 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) 76 | self.drop = nn.Dropout(dropout) 77 | mask = self._get_attn_subsequent_mask(MAX_SIZE) 78 | # Register self.mask as a buffer in TransformerDecoderLayer, so 79 | # it gets TransformerDecoderLayer's cuda behavior automatically. 80 | self.register_buffer('mask', mask) 81 | 82 | def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, layer_cache=None, step=None, para_attn=None): 83 | """ 84 | Args: 85 | inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` 86 | memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` 87 | src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` 88 | tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` 89 | 90 | Returns: 91 | (`FloatTensor`, `FloatTensor`, `FloatTensor`): 92 | 93 | * output `[batch_size x 1 x model_dim]` 94 | * attn `[batch_size x 1 x src_len]` 95 | * all_input `[batch_size x current_step x model_dim]` 96 | 97 | """ 98 | dec_mask = torch.gt(tgt_pad_mask + 99 | self.mask[:, :tgt_pad_mask.size(1), 100 | :tgt_pad_mask.size(1)], 0) 101 | input_norm = self.layer_norm_1(inputs) 102 | all_input = input_norm 103 | 104 | query,_ = self.self_attn(all_input, all_input, input_norm, 105 | mask=dec_mask, 106 | layer_cache=layer_cache, 107 | type="self") 108 | 109 | query = self.drop(query) + inputs 110 | 111 | query_norm = self.layer_norm_2(query) 112 | mid,attn = self.context_attn(memory_bank, memory_bank, query_norm, 113 | mask=src_pad_mask, 114 | layer_cache=layer_cache, 115 | type="context") 116 | 117 | if para_attn is not None: 118 | # para_attn size is batch x block_size 119 | # attn size is slength x batch 120 | batch_size = memory_bank.size(0) 121 | dim_per_head = self.context_attn.dim_per_head 122 | head_count = self.context_attn.head_count 123 | 124 | def shape(x): 125 | """ projection """ 126 | return x.view(batch_size, -1, head_count, dim_per_head) \ 127 | .transpose(1, 2) 128 | 129 | def unshape(x): 130 | """ compute context """ 131 | return x.transpose(1, 2).contiguous() \ 132 | .view(batch_size, -1, head_count * dim_per_head) 133 | 134 | 135 | if layer_cache is not None: 136 | value = layer_cache['memory_values'] 137 | else: 138 | value = self.context_attn.linear_values(memory_bank) 139 | value = shape(value) 140 | 141 | attn = attn * para_attn.unsqueeze(1).repeat(1,head_count,1,1) # multiply for one step 142 | # renormalize attention 143 | attn = attn / attn.sum(-1).unsqueeze(-1) 144 | drop_attn = self.context_attn.dropout(attn) 145 | 146 | mid = unshape(torch.matmul(drop_attn, value)) 147 | mid = self.context_attn.final_linear(mid) 148 | 149 | 150 | output = self.feed_forward(self.drop(mid) + query) 151 | 152 | return output, all_input, attn 153 | # return output 154 | 155 | def _get_attn_subsequent_mask(self, size): 156 | """ 157 | Get an attention mask to avoid using the subsequent info. 158 | 159 | Args: 160 | size: int 161 | 162 | Returns: 163 | (`LongTensor`): 164 | 165 | * subsequent_mask `[1 x size x size]` 166 | """ 167 | attn_shape = (1, size, size) 168 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 169 | subsequent_mask = torch.from_numpy(subsequent_mask) 170 | return subsequent_mask 171 | 172 | 173 | 174 | 175 | class TransformerDecoder(nn.Module): 176 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, device): 177 | super(TransformerDecoder, self).__init__() 178 | self.device = device 179 | self.decoder_type = 'transformer' 180 | self.num_layers = num_layers 181 | self.embeddings = embeddings 182 | self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim) 183 | self.transformer_layers = nn.ModuleList( 184 | [TransformerDecoderLayer(d_model, heads, d_ff, dropout) 185 | for _ in range(num_layers)]) 186 | 187 | # TransformerDecoder has its own attention mechanism. 188 | # Set up a separated copy attention layer, if needed. 189 | self._copy = False 190 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 191 | 192 | def forward(self, tgt, memory_bank, state, memory_lengths=None, 193 | step=None, cache=None,memory_masks=None): 194 | """ 195 | See :obj:`onmt.modules.RNNDecoderBase.forward()` 196 | """ 197 | n_blocks, n_tokens, batch_size, _ = memory_bank.shape 198 | #print(memory_bank.shape) 199 | memory_bank = memory_bank.view(n_blocks*n_tokens, batch_size, -1).contiguous() 200 | src = state.src 201 | src_words = src.transpose(0, 1) 202 | tgt_words = tgt.transpose(0, 1) 203 | src_batch, src_len = src_words.size() 204 | tgt_batch, tgt_len = tgt_words.size() 205 | 206 | # Run the forward pass of the TransformerDecoder. 207 | # emb = self.embeddings(tgt, step=step) 208 | emb = self.embeddings(tgt) 209 | assert emb.dim() == 3 # len x batch x embedding_dim 210 | 211 | output = emb.transpose(0, 1).contiguous() 212 | output = self.pos_emb(output, step) 213 | 214 | src_memory_bank = memory_bank.transpose(0, 1).contiguous() 215 | padding_idx = self.embeddings.padding_idx 216 | tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ 217 | .expand(tgt_batch, tgt_len, tgt_len) 218 | 219 | if (not memory_masks is None): 220 | src_len = memory_masks.size(-1) 221 | src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len) 222 | 223 | else: 224 | src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ 225 | .expand(src_batch, tgt_len, src_len) 226 | 227 | #print(tgt.shape) 228 | for i in range(self.num_layers): 229 | output, all_input, attn \ 230 | = self.transformer_layers[i]( 231 | output, src_memory_bank, 232 | 1-src_pad_mask, tgt_pad_mask, 233 | layer_cache=state.cache["layer_{}".format(i)] 234 | if state.cache is not None else None, 235 | step=step) 236 | 237 | output = self.layer_norm(output) 238 | 239 | # Process the result and update the attentions. 240 | outputs = output.transpose(0, 1).contiguous() 241 | #print(attn.shape) 242 | if self.training: 243 | attn = self._discourse_coverage_attn(attn, n_blocks) 244 | 245 | return outputs, state, attn 246 | 247 | 248 | def _discourse_coverage_attn(self, attn, n_blocks): 249 | """ calculate the discouse and coverage attention tensors""" 250 | 251 | 252 | #attn["attn"] = attn.transpose(1, 2).transpose(0, 1) 253 | # calculate discourse tensor 254 | b, h, t, pp = attn.shape 255 | n_tokens = pp//n_blocks 256 | attn = attn.view(b, h, t, n_blocks, n_tokens).contiguous() 257 | attn_dis = attn.sum(-1) # b x h x t x n_blocks 258 | attn_dis = attn_dis - torch.cat((torch.zeros(b,h,1,n_blocks).to(self.device), attn_dis), 2)[:,:,:t,:] 259 | 260 | # calculate coverage tensor 261 | atnn = attn.transpose(1,2).transpose(0,1) # t x b x h x pp 262 | cov = torch.zeros_like(attn[0], requires_grad=True).to(self.device) 263 | attn_cov = [] 264 | for a in attn: 265 | attn_cov.append(torch.min(torch.cat((a.unsqueeze(-1), cov.unsqueeze(-1)),-1), -1)[0]) 266 | cov = cov + a 267 | attn_cov = torch.stack(attn_cov) 268 | 269 | attn = {} 270 | attn["dis"] = attn_dis.transpose(1,2).transpose(0,1) 271 | attn["cov"] = attn_cov.transpose(1,2).transpose(0,1) 272 | 273 | return attn 274 | 275 | 276 | 277 | def init_decoder_state(self, src, memory_bank, 278 | with_cache=False): 279 | """ Init decoder state """ 280 | if(src.dim()==3): 281 | src = src.view(src.size(0),-1).transpose(0,1) 282 | else: 283 | src = src.transpose(0,1) 284 | state = TransformerDecoderState(src) 285 | if with_cache: 286 | state._init_cache(memory_bank, self.num_layers) 287 | return state 288 | 289 | 290 | 291 | class TransformerDecoderState(DecoderState): 292 | """ Transformer Decoder state base class """ 293 | 294 | def __init__(self, src): 295 | """ 296 | Args: 297 | src (FloatTensor): a sequence of source words tensors 298 | with optional feature tensors, of size (len x batch). 299 | """ 300 | self.src = src 301 | self.previous_input = None 302 | self.previous_layer_inputs = None 303 | self.cache = None 304 | 305 | @property 306 | def _all(self): 307 | """ 308 | Contains attributes that need to be updated in self.beam_update(). 309 | """ 310 | if (self.previous_input is not None 311 | and self.previous_layer_inputs is not None): 312 | return (self.previous_input, 313 | self.previous_layer_inputs, 314 | self.src) 315 | else: 316 | return (self.src,) 317 | 318 | def detach(self): 319 | if self.previous_input is not None: 320 | self.previous_input = self.previous_input.detach() 321 | if self.previous_layer_inputs is not None: 322 | self.previous_layer_inputs = self.previous_layer_inputs.detach() 323 | self.src = self.src.detach() 324 | 325 | def update_state(self, new_input, previous_layer_inputs): 326 | state = TransformerDecoderState(self.src) 327 | state.previous_input = new_input 328 | state.previous_layer_inputs = previous_layer_inputs 329 | return state 330 | 331 | def _init_cache(self, memory_bank, num_layers): 332 | self.cache = {} 333 | batch_size = memory_bank.size(1) 334 | depth = memory_bank.size(-1) 335 | 336 | for l in range(num_layers): 337 | layer_cache = { 338 | "memory_keys": None, 339 | "memory_values": None, 340 | "local_memory_keys":None, 341 | "local_memory_values":None, 342 | } 343 | layer_cache["self_keys"] = None 344 | layer_cache["self_values"] = None 345 | self.cache["layer_{}".format(l)] = layer_cache 346 | 347 | def repeat_beam_size_times(self, beam_size): 348 | """ Repeat beam_size times along batch dimension. """ 349 | self.src = self.src.data.repeat(1, beam_size, 1) 350 | 351 | def map_batch_fn(self, fn): 352 | def _recursive_map(struct, batch_dim=0): 353 | for k, v in struct.items(): 354 | if v is not None: 355 | if isinstance(v, dict): 356 | _recursive_map(v) 357 | else: 358 | struct[k] = fn(v, batch_dim) 359 | 360 | self.src = fn(self.src, 1) 361 | if self.cache is not None: 362 | _recursive_map(self.cache) 363 | 364 | 365 | class PointerDecoder(nn.Module): 366 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, device): 367 | super(TransformerDecoder, self).__init__() 368 | self.device = device 369 | self.decoder_type = 'pointer' 370 | self.num_layers = num_layers 371 | self.embeddings = embeddings 372 | self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim) 373 | self.transformer_layers = nn.ModuleList( 374 | [TransformerDecoderLayer(d_model, heads, d_ff, dropout) 375 | for _ in range(num_layers)]) 376 | 377 | # TransformerDecoder has its own attention mechanism. 378 | # Set up a separated copy attention layer, if needed. 379 | self._copy = False 380 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 381 | 382 | def forward(self, tgt, memory_bank, state, memory_lengths=None, 383 | step=None, cache=None,memory_masks=None): 384 | """ 385 | See :obj:`onmt.modules.RNNDecoderBase.forward()` 386 | """ 387 | batch_size, n_blocks, _ = memory_bank.shape 388 | src = state.src 389 | src_words = src.transpose(0, 1) 390 | tgt_words = tgt.transpose(0, 1) 391 | src_batch, src_len = src_words.size() 392 | tgt_batch, tgt_len = tgt_words.size() 393 | 394 | # Run the forward pass of the TransformerDecoder. 395 | # emb = self.embeddings(tgt, step=step) 396 | emb = memory_bank.index_select(0, tgt) 397 | emb = torch.cat((torch.zeros(batch_size, 1, self.d_model), emb), -1) 398 | 399 | assert emb.dim() == 3 # len x batch x embedding_dim 400 | 401 | emb = emb[:,:-1,:] 402 | 403 | 404 | output = emb.transpose(0, 1).contiguous() 405 | output = self.pos_emb(output, step) 406 | 407 | src_memory_bank = memory_bank 408 | padding_idx = self.embeddings.padding_idx 409 | tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ 410 | .expand(tgt_batch, tgt_len, tgt_len) 411 | 412 | if (not memory_masks is None): 413 | src_len = memory_masks.size(-1) 414 | src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len) 415 | 416 | else: 417 | src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ 418 | .expand(src_batch, tgt_len, src_len) 419 | 420 | #print(tgt.shape) 421 | for i in range(self.num_layers): 422 | output, all_input, attn \ 423 | = self.transformer_layers[i]( 424 | output, src_memory_bank, 425 | 1-src_pad_mask, tgt_pad_mask, 426 | layer_cache=state.cache["layer_{}".format(i)] 427 | if state.cache is not None else None, 428 | step=step) 429 | 430 | output = self.layer_norm(output) 431 | 432 | # Process the result and update the attentions. 433 | outputs = output.transpose(0, 1).contiguous() 434 | 435 | return outputs, state, attn 436 | 437 | def init_decoder_state(self, src, memory_bank, 438 | with_cache=False): 439 | """ Init decoder state """ 440 | if(src.dim()==3): 441 | src = src.view(src.size(0),-1).transpose(0,1) 442 | else: 443 | src = src.transpose(0,1) 444 | state = TransformerDecoderState(src) 445 | if with_cache: 446 | state._init_cache(memory_bank, self.num_layers) 447 | return state 448 | 449 | 450 | -------------------------------------------------------------------------------- /src/others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ramakanth-pasunuru/QmdsCnnIr/81f7bcb3f3791e6a4c83b91f3fa781c8b33d3fbc/src/others/__init__.py -------------------------------------------------------------------------------- /src/others/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | Borrowed from https://github.com/nlpyang/hiersumm 5 | """ 6 | 7 | 8 | from __future__ import print_function 9 | 10 | import math 11 | import pickle 12 | import torch.distributed 13 | 14 | from others.logging import logger 15 | 16 | 17 | def is_master(gpu_ranks, device_id): 18 | return gpu_ranks[device_id] == 0 19 | 20 | 21 | def multi_init(device_id, world_size,gpu_ranks): 22 | print(gpu_ranks) 23 | dist_init_method = 'tcp://localhost:10000' 24 | dist_world_size = world_size 25 | torch.distributed.init_process_group( 26 | backend='nccl', init_method=dist_init_method, 27 | world_size=dist_world_size, rank=gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(gpu_ranks, device_id): 30 | # print('not master') 31 | logger.disabled = True 32 | 33 | return gpu_rank 34 | 35 | 36 | 37 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 38 | buffer_size=10485760): 39 | """All-reduce and rescale tensors in chunks of the specified size. 40 | 41 | Args: 42 | tensors: list of Tensors to all-reduce 43 | rescale_denom: denominator for rescaling summed Tensors 44 | buffer_size: all-reduce chunk size in bytes 45 | """ 46 | # buffer size in bytes, determine equiv. # of elements based on data type 47 | buffer_t = tensors[0].new( 48 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 49 | buffer = [] 50 | 51 | def all_reduce_buffer(): 52 | # copy tensors into buffer_t 53 | offset = 0 54 | for t in buffer: 55 | numel = t.numel() 56 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 57 | offset += numel 58 | 59 | # all-reduce and rescale 60 | torch.distributed.all_reduce(buffer_t[:offset]) 61 | buffer_t.div_(rescale_denom) 62 | 63 | # copy all-reduced buffer back into tensors 64 | offset = 0 65 | for t in buffer: 66 | numel = t.numel() 67 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 68 | offset += numel 69 | 70 | filled = 0 71 | for t in tensors: 72 | sz = t.numel() * t.element_size() 73 | if sz > buffer_size: 74 | # tensor is bigger than buffer, all-reduce and rescale directly 75 | torch.distributed.all_reduce(t) 76 | t.div_(rescale_denom) 77 | elif filled + sz > buffer_size: 78 | # buffer is full, all-reduce and replace buffer with grad 79 | all_reduce_buffer() 80 | buffer = [t] 81 | filled = sz 82 | else: 83 | # add tensor to buffer 84 | buffer.append(t) 85 | filled += sz 86 | 87 | if len(buffer) > 0: 88 | all_reduce_buffer() 89 | 90 | 91 | def all_gather_list(data, max_size=4096): 92 | """Gathers arbitrary data from all nodes into a list.""" 93 | world_size = torch.distributed.get_world_size() 94 | if not hasattr(all_gather_list, '_in_buffer') or \ 95 | max_size != all_gather_list._in_buffer.size(): 96 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 97 | all_gather_list._out_buffers = [ 98 | torch.cuda.ByteTensor(max_size) 99 | for i in range(world_size) 100 | ] 101 | in_buffer = all_gather_list._in_buffer 102 | out_buffers = all_gather_list._out_buffers 103 | 104 | enc = pickle.dumps(data) 105 | enc_size = len(enc) 106 | if enc_size + 2 > max_size: 107 | raise ValueError( 108 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 109 | assert max_size < 255*256 110 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 111 | in_buffer[1] = enc_size % 255 112 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 113 | 114 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 115 | 116 | results = [] 117 | for i in range(world_size): 118 | out_buffer = out_buffers[i] 119 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 120 | 121 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 122 | result = pickle.loads(bytes_list) 123 | results.append(result) 124 | return results 125 | -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Borrowed from https://github.com/nlpyang/hiersumm 4 | """ 5 | from __future__ import absolute_import 6 | 7 | import logging 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 13 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.INFO) 16 | 17 | console_handler = logging.StreamHandler() 18 | console_handler.setFormatter(log_format) 19 | logger.handlers = [console_handler] 20 | 21 | if log_file and log_file != '': 22 | file_handler = logging.FileHandler(log_file) 23 | file_handler.setLevel(log_file_level) 24 | file_handler.setFormatter(log_format) 25 | logger.addHandler(file_handler) 26 | 27 | return logger 28 | -------------------------------------------------------------------------------- /src/others/report_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Report manager utility 3 | Borrowed from https://github.com/nlpyang/hiersumm 4 | """ 5 | from __future__ import print_function 6 | import time 7 | from datetime import datetime 8 | 9 | from others.logging import logger 10 | from others.statistics import Statistics 11 | 12 | 13 | def build_report_manager(opt): 14 | if opt.tensorboard: 15 | from tensorboardX import SummaryWriter 16 | writer = SummaryWriter(opt.tensorboard_log_dir 17 | + datetime.now().strftime("/%b-%d_%H-%M-%S"), 18 | comment="Unmt") 19 | else: 20 | writer = None 21 | 22 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 23 | tensorboard_writer=writer) 24 | return report_mgr 25 | 26 | 27 | class ReportMgrBase(object): 28 | """ 29 | Report Manager Base class 30 | Inherited classes should override: 31 | * `_report_training` 32 | * `_report_step` 33 | """ 34 | 35 | def __init__(self, report_every, start_time=-1.): 36 | """ 37 | Args: 38 | report_every(int): Report status every this many sentences 39 | start_time(float): manually set report start time. Negative values 40 | means that you will need to set it later or use `start()` 41 | """ 42 | self.report_every = report_every 43 | self.progress_step = 0 44 | self.start_time = start_time 45 | 46 | def start(self): 47 | self.start_time = time.time() 48 | 49 | def log(self, *args, **kwargs): 50 | logger.info(*args, **kwargs) 51 | 52 | def report_training(self, step, num_steps, learning_rate, 53 | report_stats, multigpu=False): 54 | """ 55 | This is the user-defined batch-level traing progress 56 | report function. 57 | 58 | Args: 59 | step(int): current step count. 60 | num_steps(int): total number of batches. 61 | learning_rate(float): current learning rate. 62 | report_stats(Statistics): old Statistics instance. 63 | Returns: 64 | report_stats(Statistics): updated Statistics instance. 65 | """ 66 | if self.start_time < 0: 67 | raise ValueError("""ReportMgr needs to be started 68 | (set 'start_time' or use 'start()'""") 69 | 70 | if multigpu: 71 | report_stats = Statistics.all_gather_stats(report_stats) 72 | 73 | if step % self.report_every == 0: 74 | self._report_training( 75 | step, num_steps, learning_rate, report_stats) 76 | self.progress_step += 1 77 | return Statistics() 78 | 79 | def _report_training(self, *args, **kwargs): 80 | """ To be overridden """ 81 | raise NotImplementedError() 82 | 83 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 84 | """ 85 | Report stats of a step 86 | 87 | Args: 88 | train_stats(Statistics): training stats 89 | valid_stats(Statistics): validation stats 90 | lr(float): current learning rate 91 | """ 92 | self._report_step( 93 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 94 | 95 | def _report_step(self, *args, **kwargs): 96 | raise NotImplementedError() 97 | 98 | 99 | class ReportMgr(ReportMgrBase): 100 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 101 | """ 102 | A report manager that writes statistics on standard output as well as 103 | (optionally) TensorBoard 104 | 105 | Args: 106 | report_every(int): Report status every this many sentences 107 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 108 | The TensorBoard Summary writer to use or None 109 | """ 110 | super(ReportMgr, self).__init__(report_every, start_time) 111 | self.tensorboard_writer = tensorboard_writer 112 | 113 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 114 | if self.tensorboard_writer is not None: 115 | stats.log_tensorboard( 116 | prefix, self.tensorboard_writer, learning_rate, step) 117 | 118 | def _report_training(self, step, num_steps, learning_rate, 119 | report_stats): 120 | """ 121 | See base class method `ReportMgrBase.report_training`. 122 | """ 123 | report_stats.output(step, num_steps, 124 | learning_rate, self.start_time) 125 | 126 | # Log the progress using the number of batches on the x-axis. 127 | self.maybe_log_tensorboard(report_stats, 128 | "progress", 129 | learning_rate, 130 | step) 131 | report_stats = Statistics() 132 | 133 | return report_stats 134 | 135 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 136 | """ 137 | See base class method `ReportMgrBase.report_step`. 138 | """ 139 | if train_stats is not None: 140 | self.log('Train perplexity: %g' % train_stats.ppl()) 141 | self.log('Train accuracy: %g' % train_stats.accuracy()) 142 | 143 | self.maybe_log_tensorboard(train_stats, 144 | "train", 145 | lr, 146 | step) 147 | 148 | if valid_stats is not None: 149 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 150 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 151 | 152 | self.maybe_log_tensorboard(valid_stats, 153 | "valid", 154 | lr, 155 | step) 156 | -------------------------------------------------------------------------------- /src/others/statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | code borrowed from https://github.com/nlpyang/hiersumm 3 | and https://github.com/OpenNMT/OpenNMT-py 4 | """ 5 | 6 | import math 7 | import sys 8 | import time 9 | 10 | 11 | from others.distributed import all_gather_list 12 | from others.logging import logger 13 | 14 | 15 | class Statistics(object): 16 | """ 17 | Accumulator for loss statistics. 18 | Currently calculates: 19 | 20 | * accuracy 21 | * perplexity 22 | * elapsed time 23 | """ 24 | 25 | def __init__(self, loss=0, n_words=0, n_correct=0): 26 | self.loss = loss 27 | self.n_words = n_words 28 | self.n_correct = n_correct 29 | self.n_src_words = 0 30 | self.start_time = time.time() 31 | 32 | @staticmethod 33 | def all_gather_stats(stat, max_size=4096): 34 | """ 35 | Gather a `Statistics` object accross multiple process/nodes 36 | 37 | Args: 38 | stat(:obj:Statistics): the statistics object to gather 39 | accross all processes/nodes 40 | max_size(int): max buffer size to use 41 | 42 | Returns: 43 | `Statistics`, the update stats object 44 | """ 45 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 46 | return stats[0] 47 | 48 | @staticmethod 49 | def all_gather_stats_list(stat_list, max_size=4096): 50 | from torch.distributed import get_rank 51 | 52 | """ 53 | Gather a `Statistics` list accross all processes/nodes 54 | 55 | Args: 56 | stat_list(list([`Statistics`])): list of statistics objects to 57 | gather accross all processes/nodes 58 | max_size(int): max buffer size to use 59 | 60 | Returns: 61 | our_stats(list([`Statistics`])): list of updated stats 62 | """ 63 | # Get a list of world_size lists with len(stat_list) Statistics objects 64 | all_stats = all_gather_list(stat_list, max_size=max_size) 65 | 66 | our_rank = get_rank() 67 | our_stats = all_stats[our_rank] 68 | for other_rank, stats in enumerate(all_stats): 69 | if other_rank == our_rank: 70 | continue 71 | for i, stat in enumerate(stats): 72 | our_stats[i].update(stat, update_n_src_words=True) 73 | return our_stats 74 | 75 | def update(self, stat, update_n_src_words=False): 76 | """ 77 | Update statistics by suming values with another `Statistics` object 78 | 79 | Args: 80 | stat: another statistic object 81 | update_n_src_words(bool): whether to update (sum) `n_src_words` 82 | or not 83 | 84 | """ 85 | self.loss += stat.loss 86 | self.n_words += stat.n_words 87 | self.n_correct += stat.n_correct 88 | 89 | if update_n_src_words: 90 | self.n_src_words += stat.n_src_words 91 | 92 | def accuracy(self): 93 | """ compute accuracy """ 94 | return 100 * (self.n_correct / self.n_words) 95 | 96 | def xent(self): 97 | """ compute cross entropy """ 98 | return self.loss / self.n_words 99 | 100 | def ppl(self): 101 | """ compute perplexity """ 102 | return math.exp(min(self.loss / self.n_words, 100)) 103 | 104 | def elapsed_time(self): 105 | """ compute elapsed time """ 106 | return time.time() - self.start_time 107 | 108 | def output(self, step, num_steps, learning_rate, start): 109 | """Write out statistics to stdout. 110 | 111 | Args: 112 | step (int): current step 113 | n_batch (int): total batches 114 | start (int): start time of step. 115 | """ 116 | t = self.elapsed_time() 117 | logger.info( 118 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + 119 | "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") 120 | % (step, num_steps, 121 | self.accuracy(), 122 | self.ppl(), 123 | self.xent(), 124 | learning_rate, 125 | self.n_src_words / (t + 1e-5), 126 | self.n_words / (t + 1e-5), 127 | time.time() - start)) 128 | sys.stdout.flush() 129 | 130 | def log_tensorboard(self, prefix, writer, learning_rate, step): 131 | """ display statistics to tensorboard """ 132 | t = self.elapsed_time() 133 | writer.add_scalar(prefix + "/xent", self.xent(), step) 134 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 135 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 136 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 137 | writer.add_scalar(prefix + "/lr", learning_rate, step) 138 | -------------------------------------------------------------------------------- /src/train_abstractive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | Majority of code borrowed from https://github.com/nlpyang/hiersumm 5 | """ 6 | from __future__ import division 7 | 8 | import argparse 9 | import glob 10 | import os 11 | import signal 12 | import time 13 | 14 | import sentencepiece 15 | 16 | from abstractive.model_builder import Summarizer 17 | from abstractive.trainer_builder import build_trainer 18 | from abstractive.predictor_builder import build_predictor 19 | from abstractive.data_loader import load_dataset 20 | import torch 21 | import random 22 | 23 | from abstractive import data_loader, model_builder 24 | from others import distributed 25 | from others.logging import init_logger, logger 26 | 27 | model_flags = [ 'emb_size', 'enc_hidden_size', 'dec_hidden_size', 'enc_layers', 'dec_layers', 'block_size', 'heads', 'ff_size', 'hier', 28 | 'inter_layers', 'inter_heads', 'block_size', 'attn_threshold'] 29 | 30 | 31 | 32 | 33 | def str2bool(v): 34 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 35 | return True 36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 37 | return False 38 | else: 39 | raise argparse.ArgumentTypeError('Boolean value expected.') 40 | 41 | 42 | 43 | 44 | 45 | def main(args): 46 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 47 | device_id = 0 if device == "cuda" else -1 48 | init_logger(args.log_file) 49 | if (args.mode == 'train'): 50 | train(args, device_id) 51 | elif (args.mode == 'test'): 52 | if ".pt" in args.test_from: 53 | step = int(args.test_from.split('.')[-2].split('_')[-1]) 54 | # validate(args, device_id, args.test_from, step) 55 | test(args, args.test_from, step) 56 | else: ## if test_from only refers to the model path 57 | model_files = glob.glob(os.path.join(args.test_from,"*.pt")) 58 | best_path = None 59 | for filename in model_files: 60 | step = int(filename.split('.')[-2].split('_')[-1]) 61 | score = float(filename.split('.')[-2].split('_')[-2]) 62 | if best_path is None: 63 | best_path = (score, filename, step) 64 | else: 65 | if score > best_path[0]: 66 | best_path = (score, filename, step) 67 | 68 | test(args, best_path[1], best_path[2]) 69 | elif (args.mode == 'validate'): 70 | wait_and_validate(args, device_id) 71 | elif (args.mode == 'baseline'): 72 | baseline() 73 | elif (args.mode == 'print_flags'): 74 | print_flags() 75 | elif (args.mode == 'stats'): 76 | stats() 77 | 78 | 79 | def train(args,device_id): 80 | init_logger(args.log_file) 81 | logger.info(str(args)) 82 | 83 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 84 | logger.info('Device ID %d' % device_id) 85 | logger.info('Device %s' % device) 86 | torch.manual_seed(args.seed) 87 | random.seed(args.seed) 88 | torch.backends.cudnn.deterministic = True 89 | 90 | if device_id >= 0: 91 | torch.cuda.set_device(device_id) 92 | torch.cuda.manual_seed(args.seed) 93 | 94 | if args.train_from != '': 95 | logger.info('Loading checkpoint from %s' % args.train_from) 96 | checkpoint = torch.load(args.train_from, 97 | map_location=None) 98 | opt = vars(checkpoint['opt']) 99 | for k in opt.keys(): 100 | if (k in model_flags): 101 | setattr(args, k, opt[k]) 102 | 103 | else: 104 | checkpoint = None 105 | 106 | spm = sentencepiece.SentencePieceProcessor() 107 | spm.Load(args.vocab_path) 108 | word_padding_idx = spm.PieceToId('') 109 | symbols = {'BOS': spm.PieceToId(''), 'EOS': spm.PieceToId(''), 'PAD': word_padding_idx, 110 | 'EOT': spm.PieceToId(''), 'EOP': spm.PieceToId('

'), 'EOQ': spm.PieceToId('')} 111 | print(symbols) 112 | vocab_size = len(spm) 113 | vocab = spm 114 | 115 | def train_iter_fct(): 116 | return data_loader.AbstractiveDataloader(args, load_dataset(args, 'train', shuffle=True), symbols, args.batch_size, device, 117 | shuffle=True, is_test=False) 118 | 119 | model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) 120 | optim = model_builder.build_optim(args, model, checkpoint) 121 | logger.info(model) 122 | trainer = build_trainer(args, device_id, model, symbols, vocab_size, optim) 123 | 124 | ################# 125 | # for validation peformance during training 126 | 127 | def valid_iter_fct(): 128 | return data_loader.AbstractiveDataloader(args, load_dataset(args, 'valid', shuffle=False), symbols, args.valid_batch_size, device, 129 | shuffle=True, is_test=True) 130 | 131 | predictor = build_predictor(args, vocab, symbols, model, logger=logger) 132 | 133 | trainer.train(train_iter_fct, args.train_steps, predictor, valid_iter_fct) 134 | 135 | 136 | def wait_and_validate(args, device_id): 137 | timestep = 0 138 | if (args.test_all): 139 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 140 | cp_files.sort(key=os.path.getmtime) 141 | ppl_lst = [] 142 | for i, cp in enumerate(cp_files): 143 | step = int(cp.split('.')[-2].split('_')[-1]) 144 | ppl = validate(args, device_id, cp, step) 145 | ppl_lst.append((ppl, cp)) 146 | max_step = ppl_lst.index(min(ppl_lst)) 147 | if (i - max_step > 5): 148 | break 149 | ppl_lst = sorted(ppl_lst, key=lambda x: x[0])[:5] 150 | logger.info('PPL %s' % str(ppl_lst)) 151 | for pp, cp in ppl_lst: 152 | step = int(cp.split('.')[-2].split('_')[-1]) 153 | test(args, cp, step) 154 | else: 155 | while (True): 156 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 157 | cp_files.sort(key=os.path.getmtime) 158 | if (cp_files): 159 | cp = cp_files[-1] 160 | time_of_cp = os.path.getmtime(cp) 161 | if (not os.path.getsize(cp) > 0): 162 | time.sleep(60) 163 | continue 164 | if (time_of_cp > timestep): 165 | timestep = time_of_cp 166 | step = int(cp.split('.')[-2].split('_')[-1]) 167 | validate(args,device_id, cp, step) 168 | test(args,cp, step) 169 | 170 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 171 | cp_files.sort(key=os.path.getmtime) 172 | if (cp_files): 173 | cp = cp_files[-1] 174 | time_of_cp = os.path.getmtime(cp) 175 | if (time_of_cp > timestep): 176 | continue 177 | else: 178 | time.sleep(300) 179 | 180 | 181 | 182 | def validate(args, device_id, pt, step): 183 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 184 | 185 | if (pt != ''): 186 | test_from = pt 187 | else: 188 | test_from = args.test_from 189 | logger.info('Loading checkpoint from %s' % test_from) 190 | checkpoint = torch.load(test_from, map_location=None) 191 | 192 | 193 | opt = vars(checkpoint['opt']) 194 | 195 | for k in opt.keys(): 196 | if (k in model_flags): 197 | setattr(args, k, opt[k]) 198 | print(args) 199 | 200 | spm = sentencepiece.SentencePieceProcessor() 201 | spm.Load(args.vocab_path) 202 | word_padding_idx = spm.PieceToId('') 203 | symbols = {'BOS': spm.PieceToId(''), 'EOS': spm.PieceToId(''), 'PAD': word_padding_idx, 204 | 'EOT': spm.PieceToId(''), 'EOP': spm.PieceToId('

'), 'EOQ': spm.PieceToId('')} 205 | 206 | vocab_size = len(spm) 207 | model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) 208 | model.eval() 209 | 210 | valid_iter = data_loader.AbstractiveDataloader(args, load_dataset(args, 'valid', shuffle=False), symbols, 211 | args.batch_size, device, shuffle=False, is_test=False) 212 | 213 | trainer = build_trainer(args, device_id, model, symbols, vocab_size, None) 214 | stats = trainer.validate(valid_iter) 215 | trainer._report_step(0, step, valid_stats=stats) 216 | return stats.ppl() 217 | 218 | 219 | def test(args, pt, step): 220 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 221 | 222 | if (pt != ''): 223 | test_from = pt 224 | else: 225 | test_from = args.test_from 226 | logger.info('Loading checkpoint from %s' % test_from) 227 | checkpoint = torch.load(test_from, map_location=device) 228 | opt = vars(checkpoint['opt']) 229 | 230 | for k in opt.keys(): 231 | if (k in model_flags): 232 | setattr(args, k, opt[k]) 233 | print(args) 234 | 235 | spm = sentencepiece.SentencePieceProcessor() 236 | spm.Load(args.vocab_path) 237 | word_padding_idx = spm.PieceToId('') 238 | symbols = {'BOS': spm.PieceToId(''), 'EOS': spm.PieceToId(''), 'PAD': word_padding_idx, 239 | 'EOT': spm.PieceToId(''), 'EOP': spm.PieceToId('

'), 'EOQ': spm.PieceToId('')} 240 | 241 | vocab_size = len(spm) 242 | vocab = spm 243 | model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) 244 | model.eval() 245 | 246 | test_iter = data_loader.AbstractiveDataloader(args, load_dataset(args, 'test', shuffle=False), symbols, 247 | args.valid_batch_size, device, shuffle=False, is_test=True) 248 | predictor = build_predictor(args, vocab, symbols, model, logger=logger) 249 | predictor.translate(test_iter, step) 250 | 251 | 252 | 253 | 254 | def print_flags(args): 255 | checkpoint = torch.load(args.test_from, map_location=None) 256 | print(checkpoint['opt']) 257 | 258 | 259 | def run(args, device_id, error_queue): 260 | """ run process """ 261 | setattr(args, 'gpu_ranks', [int(i) for i in args.gpu_ranks]) 262 | 263 | try: 264 | gpu_rank = distributed.multi_init(device_id, args.world_size, args.gpu_ranks) 265 | print('gpu_rank %d' %gpu_rank) 266 | if gpu_rank != args.gpu_ranks[device_id]: 267 | raise AssertionError("An error occurred in \ 268 | Distributed initialization") 269 | train(args,device_id) 270 | except KeyboardInterrupt: 271 | pass # killed by parent, do nothing 272 | except Exception: 273 | # propagate exception to parent process, keeping original traceback 274 | import traceback 275 | error_queue.put((args.gpu_ranks[device_id], traceback.format_exc())) 276 | 277 | 278 | def multi_main(args): 279 | """ Spawns 1 process per GPU """ 280 | init_logger() 281 | 282 | nb_gpu = args.world_size 283 | mp = torch.multiprocessing.get_context('spawn') 284 | 285 | # Create a thread to listen for errors in the child processes. 286 | error_queue = mp.SimpleQueue() 287 | error_handler = ErrorHandler(error_queue) 288 | 289 | # Train with multiprocessing. 290 | procs = [] 291 | for i in range(nb_gpu): 292 | device_id = i 293 | 294 | procs.append(mp.Process(target=run, args=(args, 295 | device_id, error_queue), daemon=True)) 296 | procs[i].start() 297 | logger.info(" Starting process pid: %d " % procs[i].pid) 298 | error_handler.add_child(procs[i].pid) 299 | for p in procs: 300 | p.join() 301 | 302 | 303 | class ErrorHandler(object): 304 | """A class that listens for exceptions in children processes and propagates 305 | the tracebacks to the parent process.""" 306 | 307 | def __init__(self, error_queue): 308 | """ init error handler """ 309 | import signal 310 | import threading 311 | self.error_queue = error_queue 312 | self.children_pids = [] 313 | self.error_thread = threading.Thread( 314 | target=self.error_listener, daemon=True) 315 | self.error_thread.start() 316 | signal.signal(signal.SIGUSR1, self.signal_handler) 317 | 318 | def add_child(self, pid): 319 | """ error handler """ 320 | self.children_pids.append(pid) 321 | 322 | def error_listener(self): 323 | """ error listener """ 324 | (rank, original_trace) = self.error_queue.get() 325 | self.error_queue.put((rank, original_trace)) 326 | os.kill(os.getpid(), signal.SIGUSR1) 327 | 328 | def signal_handler(self, signalnum, stackframe): 329 | """ signal handler """ 330 | for pid in self.children_pids: 331 | os.kill(pid, signal.SIGINT) # kill children processes 332 | (rank, original_trace) = self.error_queue.get() 333 | msg = """\n\n-- Tracebacks above this line can probably 334 | be ignored --\n\n""" 335 | msg += original_trace 336 | raise Exception(msg) 337 | 338 | 339 | 340 | 341 | 342 | if __name__ == '__main__': 343 | parser = argparse.ArgumentParser() 344 | parser.add_argument('-log_file', default='', type=str) 345 | parser.add_argument('-mode', default='train', type=str) 346 | parser.add_argument('-visible_gpus', default='0', type=str) 347 | parser.add_argument('-data_path', default='/mnt/ram/MultiDocSumm/data/ranked_wiki_b40/WIKI', type=str) 348 | parser.add_argument('-model_path', default='models', type=str) 349 | parser.add_argument('-vocab_path', default='/mnt/ram/MultiDocSumm/data/ranked_wiki_b40/spm9998_3.model', type=str) 350 | parser.add_argument('-train_from', default='', type=str) 351 | 352 | parser.add_argument('-trunc_src_ntoken', default=500, type=int) 353 | parser.add_argument('-trunc_tgt_ntoken', default=200, type=int) 354 | 355 | parser.add_argument('-emb_size', default=256, type=int) 356 | parser.add_argument('-query_layers', default=1, type=int) 357 | parser.add_argument('-enc_layers', default=8, type=int) 358 | parser.add_argument('-dec_layers', default=1, type=int) 359 | parser.add_argument('-enc_dropout', default=0.1, type=float) 360 | parser.add_argument('-dec_dropout', default=0.1, type=float) 361 | parser.add_argument('-enc_hidden_size', default=256, type=int) 362 | parser.add_argument('-dec_hidden_size', default=256, type=int) 363 | parser.add_argument('-heads', default=8, type=int) 364 | parser.add_argument('-ff_size', default=1024, type=int) 365 | parser.add_argument("-hier", type=str2bool, nargs='?',const=True,default=True) 366 | parser.add_argument("-model_type", type=str, default="hier") 367 | parser.add_argument("-query", type=str2bool, default=False) 368 | parser.add_argument("-fine_tune", type=str2bool, default=False) 369 | 370 | 371 | parser.add_argument('-batch_size', default=10000, type=int) 372 | parser.add_argument('-valid_batch_size', default=10000, type=int) 373 | parser.add_argument('-optim', default='adam', type=str) 374 | parser.add_argument('-lr', default=3, type=float) 375 | parser.add_argument('-max_grad_norm', default=0, type=float) 376 | parser.add_argument('-seed', default=666, type=int) 377 | 378 | parser.add_argument('-train_steps', default=500000, type=int) 379 | parser.add_argument('-save_checkpoint_steps', default=5000, type=int) 380 | parser.add_argument('-max_num_checkpoints', default=3, type=int) 381 | parser.add_argument('-report_every', default=100, type=int) 382 | 383 | 384 | # multi-gpu 385 | parser.add_argument('-accum_count', default=1, type=int) 386 | parser.add_argument('-world_size', default=1, type=int) 387 | parser.add_argument('-gpu_ranks', default='0', type=str) 388 | 389 | # don't need to change flags 390 | parser.add_argument("-share_embeddings", type=str2bool, nargs='?',const=True,default=True) 391 | parser.add_argument("-share_decoder_embeddings", type=str2bool, nargs='?',const=True,default=True) 392 | parser.add_argument('-max_generator_batches', default=32, type=int) 393 | 394 | # flags for testing 395 | parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False) 396 | parser.add_argument('-test_from', default=None, type=str) 397 | parser.add_argument('-result_path', default='../../results', type=str) 398 | parser.add_argument('-alpha', default=0.4, type=float) 399 | parser.add_argument('-length_penalty', default='wu', type=str) 400 | parser.add_argument('-block_ngram_repeat', default=0, type=int) 401 | parser.add_argument('-coverage_penalty', default=None, type=str) 402 | parser.add_argument('-cov_beta', default=5, type=float) 403 | parser.add_argument('-attn_threshold', default=0, type=float) 404 | parser.add_argument('-stepwise_penalty', default=False, type=str2bool) 405 | parser.add_argument('-beam_size', default=5, type=int) 406 | parser.add_argument('-n_best', default=1, type=int) 407 | parser.add_argument('-max_length', default=250, type=int) 408 | parser.add_argument('-min_length', default=20, type=int) 409 | parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=False) 410 | parser.add_argument('-save_criteria', type=str, default='rouge_l_f_score') 411 | parser.add_argument('-rouge_path',type=str,default=None) 412 | 413 | parser.add_argument('-dataset', default='WIKI', type=str) 414 | parser.add_argument('-max_samples', default=5, type=int) 415 | 416 | # flags for hier 417 | # flags.DEFINE_boolean('old_inter_att', False, 'old_inter_att') 418 | parser.add_argument('-inter_layers', default='6,7', type=str) 419 | 420 | parser.add_argument('-inter_heads', default=8, type=int) 421 | parser.add_argument('-trunc_src_nblock', default=24, type=int) 422 | 423 | # flags for graph 424 | 425 | 426 | # flags for learning 427 | parser.add_argument('-beta1', default=0.9, type=float) 428 | parser.add_argument('-beta2', default=0.998, type=float) 429 | parser.add_argument('-warmup_steps', default=8000, type=int) 430 | parser.add_argument('-decay_method', default='noam', type=str) 431 | parser.add_argument('-label_smoothing', default=0.1, type=float) 432 | parser.add_argument('-lambda_dis', default=0.0, type=float) 433 | parser.add_argument('-lambda_cov', default=0.0, type=float) 434 | 435 | args = parser.parse_args() 436 | args.gpu_ranks = [int(i) for i in args.gpu_ranks.split(',')] 437 | args.inter_layers = [int(i) for i in args.inter_layers.split(',')] 438 | 439 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus 440 | 441 | if not os.path.exists(args.model_path): 442 | os.mkdir(args.model_path) 443 | 444 | 445 | if args.mode=="train": 446 | args.log_file = os.path.join(args.model_path, 'log.txt') 447 | 448 | 449 | if(args.world_size>1): 450 | multi_main(args) 451 | else: 452 | main(args) 453 | 454 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Required environment variables: 2 | # DATASET: choose the dataset (CNNDM / WIKI) 3 | # MODEL_TYPE: choose the type of model (hier, he, order, query, heq, heo, hero) 4 | 5 | BATCH_SIZE=8000 6 | VISIBLE_GPUS="0" 7 | GPU_RANKS="0" 8 | WORLD_SIZE=1 9 | MAX_SAMPLES=100000 10 | 11 | EXTRA="" 12 | 13 | case $MODEL_TYPE in 14 | query|heq|hero) 15 | QUERY=True 16 | ;; 17 | hier|he|order|heo) 18 | QUERY=False 19 | ;; 20 | *) 21 | echo "Invalid option: ${MODEL_TYPE}" 22 | ;; 23 | esac 24 | 25 | case $DATASET in 26 | CNNDM) 27 | TRUNC_TGT_NTOKEN=120 28 | TRUNC_SRC_NTOKEN=200 29 | TRUNC_SRC_NBLOCK=8 30 | MAX_LENGTH=250 31 | MIN_LENGTH=35 32 | EXTRA="-coverage_penalty summary -stepwise_penalty True -block_ngram_repeat 3" 33 | if [ $QUERY == "False" ]; then 34 | DATA_FOLDER_NAME=pytorch_qmdscnn 35 | else 36 | DATA_FOLDER_NAME=pytorch_qmdscnn_query 37 | fi 38 | if [ -z ${DATA_PATH+x} ]; then 39 | DATA_PATH="data/qmdscnn/${DATA_FOLDER_NAME}/CNNDM" 40 | fi 41 | if [ -z ${VOCAB_PATH+x} ]; then 42 | VOCAB_PATH="data/qmdscnn/${DATA_FOLDER_NAME}/spm.model" 43 | fi 44 | ;; 45 | WIKI) 46 | TRUNC_TGT_NTOKEN=400 47 | TRUNC_SRC_NTOKEN=100 48 | TRUNC_SRC_NBLOCK=40 49 | MAX_LENGTH=400 50 | MIN_LENGTH=200 51 | EXTRA="-alpha 0.4" 52 | if [ $QUERY == "False" ]; then 53 | DATA_FOLDER_NAME=ranked_wiki_b40 54 | else 55 | DATA_FOLDER_NAME=ranked_wiki_b40_query 56 | fi 57 | if [ -z ${DATA_PATH+x} ]; then 58 | DATA_PATH="data/wikisum/${DATA_FOLDER_NAME}/WIKI" 59 | fi 60 | if [ -z ${VOCAB_PATH+x} ]; then 61 | VOCAB_PATH="data/wikisum/${DATA_FOLDER_NAME}/spm9998_3.model" 62 | fi 63 | ;; 64 | *) 65 | echo "Invalid option: ${DATASET}" 66 | 67 | esac 68 | 69 | # If model path not set 70 | if [ -z ${MODEL_PATH+x} ]; then 71 | MODEL_PATH="results/model-${DATASET}-${MODEL_TYPE}" 72 | fi 73 | 74 | 75 | python src/train_abstractive.py \ 76 | -mode test \ 77 | -batch_size $BATCH_SIZE \ 78 | -trunc_tgt_ntoken $TRUNC_TGT_NTOKEN \ 79 | -trunc_src_ntoken $TRUNC_SRC_NTOKEN \ 80 | -trunc_src_nblock $TRUNC_SRC_NBLOCK \ 81 | -visible_gpus $VISIBLE_GPUS \ 82 | -gpu_ranks $GPU_RANKS \ 83 | -world_size $WORLD_SIZE \ 84 | -dataset $DATASET \ 85 | -model_type $MODEL_TYPE \ 86 | -query $QUERY \ 87 | -max_samples $MAX_SAMPLES \ 88 | -data_path $DATA_PATH \ 89 | -vocab_path $VOCAB_PATH \ 90 | -test_from $MODEL_PATH \ 91 | -result_path $MODEL_PATH/outputs \ 92 | -report_rouge \ 93 | -max_length $MAX_LENGTH \ 94 | -min_length $MIN_LENGTH \ 95 | $EXTRA 96 | --------------------------------------------------------------------------------