├── LICENSE ├── README.md ├── bert_data └── .gitignore ├── json_data └── cnndm_sample.train.0.json ├── logs └── .gitignore ├── models └── .gitignore ├── raw_data └── .gitignore ├── requirements.txt ├── results └── .gitignore ├── src ├── cal_rouge.py ├── distributed.py ├── models │ ├── __init__.py │ ├── adam.py │ ├── data_loader.py │ ├── decoder.py │ ├── encoder.py │ ├── loss.py │ ├── model_builder.py │ ├── neural.py │ ├── optimizers.py │ ├── predictor.py │ ├── reporter.py │ ├── reporter_ext.py │ ├── trainer.py │ └── trainer_ext.py ├── others │ ├── __init__.py │ ├── logging.py │ ├── pyrouge.py │ ├── tokenization.py │ └── utils.py ├── post_stats.py ├── prepro │ ├── __init__.py │ ├── data_builder.py │ ├── smart_common_words.txt │ └── utils.py ├── preprocess.py ├── train.py ├── train_abstractive.py ├── train_extractive.py └── translate │ ├── __init__.py │ ├── beam.py │ └── penalties.py └── urls ├── cnn_mapping_test.txt ├── cnn_mapping_train.txt ├── cnn_mapping_valid.txt ├── mapping_test.txt ├── mapping_train.txt └── mapping_valid.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PreSumm 2 | 3 | **This code is for EMNLP 2019 paper [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345)** 4 | 5 | **Updates Jan 22 2020**: Now you can **Summarize Raw Text Input!**. Swith to the dev branch, and use `-mode test_text` and use `-text_src $RAW_SRC.TXT` to input your text file. Please still use master branch for normal training and evaluation, dev branch should be only used for test_text mode. 6 | * abstractive use -task abs, extractive use -task ext 7 | * use `-test_from $PT_FILE$` to use your model checkpoint file. 8 | * Format of the source text file: 9 | * For **abstractive summarization**, each line is a document. 10 | * If you want to do **extractive summarization**, please insert ` [CLS] [SEP] ` as your sentence boundaries. 11 | * There are example input files in the [raw_data directory](https://github.com/nlpyang/PreSumm/tree/dev/raw_data) 12 | * If you also have reference summaries aligned with your source input, please use `-text_tgt $RAW_TGT.TXT` to keep the order for evaluation. 13 | 14 | 15 | Results on CNN/DailyMail (20/8/2019): 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 |
ModelsROUGE-1ROUGE-2ROUGE-L
Extractive
TransformerExt40.9018.0237.17
BertSumExt43.2320.2439.63
BertSumExt (large)43.8520.3439.90
Abstractive
TransformerAbs40.2117.7637.09
BertSumAbs41.7219.3938.76
BertSumExtAbs42.1319.6039.18
68 | 69 | **Python version**: This code is in Python3.6 70 | 71 | **Package Requirements**: torch==1.1.0 pytorch_transformers tensorboardX multiprocess pyrouge 72 | 73 | 74 | 75 | **Updates**: For encoding a text longer than 512 tokens, for example 800. Set max_pos to 800 during both preprocessing and training. 76 | 77 | 78 | Some codes are borrowed from ONMT(https://github.com/OpenNMT/OpenNMT-py) 79 | 80 | ## Trained Models 81 | [CNN/DM BertExt](https://drive.google.com/open?id=1kKWoV0QCbeIuFt85beQgJ4v0lujaXobJ) 82 | 83 | [CNN/DM BertExtAbs](https://drive.google.com/open?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr) 84 | 85 | [CNN/DM TransformerAbs](https://drive.google.com/open?id=1yLCqT__ilQ3mf5YUUCw9-UToesX5Roxy) 86 | 87 | [XSum BertExtAbs](https://drive.google.com/open?id=1H50fClyTkNprWJNh10HWdGEdDdQIkzsI) 88 | 89 | ## System Outputs 90 | 91 | [CNN/DM and XSum](https://drive.google.com/file/d/1kYA384UEAQkvmZ-yWZAfxw7htCbCwFzC) 92 | 93 | ## Data Preparation For XSum 94 | [Pre-processed data](https://drive.google.com/open?id=1BWBN1coTWGBqrWoOfRc5dhojPHhatbYs) 95 | 96 | 97 | ## Data Preparation For CNN/Dailymail 98 | ### Option 1: download the processed data 99 | 100 | [Pre-processed data](https://drive.google.com/open?id=1DN7ClZCCXsk2KegmC6t4ClBwtAf5galI) 101 | 102 | unzip the zipfile and put all `.pt` files into `bert_data` 103 | 104 | ### Option 2: process the data yourself 105 | 106 | #### Step 1 Download Stories 107 | Download and unzip the `stories` directories from [here](http://cs.nyu.edu/~kcho/DMQA/) for both CNN and Daily Mail. Put all `.story` files in one directory (e.g. `../raw_stories`) 108 | 109 | #### Step 2. Download Stanford CoreNLP 110 | We will need Stanford CoreNLP to tokenize the data. Download it [here](https://stanfordnlp.github.io/CoreNLP/) and unzip it. Then add the following command to your bash_profile: 111 | ``` 112 | export CLASSPATH=/path/to/stanford-corenlp-full-2017-06-09/stanford-corenlp-3.8.0.jar 113 | ``` 114 | replacing `/path/to/` with the path to where you saved the `stanford-corenlp-full-2017-06-09` directory. 115 | 116 | #### Step 3. Sentence Splitting and Tokenization 117 | 118 | ``` 119 | python preprocess.py -mode tokenize -raw_path RAW_PATH -save_path TOKENIZED_PATH 120 | ``` 121 | 122 | * `RAW_PATH` is the directory containing story files (`../raw_stories`), `JSON_PATH` is the target directory to save the generated json files (`../merged_stories_tokenized`) 123 | 124 | 125 | #### Step 4. Format to Simpler Json Files 126 | 127 | ``` 128 | python preprocess.py -mode format_to_lines -raw_path RAW_PATH -save_path JSON_PATH -n_cpus 1 -use_bert_basic_tokenizer false -map_path MAP_PATH 129 | ``` 130 | 131 | * `RAW_PATH` is the directory containing tokenized files (`../merged_stories_tokenized`), `JSON_PATH` is the target directory to save the generated json files (`../json_data/cnndm`), `MAP_PATH` is the directory containing the urls files (`../urls`) 132 | 133 | #### Step 5. Format to PyTorch Files 134 | ``` 135 | python preprocess.py -mode format_to_bert -raw_path JSON_PATH -save_path BERT_DATA_PATH -lower -n_cpus 1 -log_file ../logs/preprocess.log 136 | ``` 137 | 138 | * `JSON_PATH` is the directory containing json files (`../json_data`), `BERT_DATA_PATH` is the target directory to save the generated binary files (`../bert_data`) 139 | 140 | ## Model Training 141 | 142 | **First run: For the first time, you should use single-GPU, so the code can download the BERT model. Use ``-visible_gpus -1``, after downloading, you could kill the process and rerun the code with multi-GPUs.** 143 | 144 | ### Extractive Setting 145 | 146 | ``` 147 | python train.py -task ext -mode train -bert_data_path BERT_DATA_PATH -ext_dropout 0.1 -model_path MODEL_PATH -lr 2e-3 -visible_gpus 0,1,2 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -train_steps 50000 -accum_count 2 -log_file ../logs/ext_bert_cnndm -use_interval true -warmup_steps 10000 -max_pos 512 148 | ``` 149 | 150 | ### Abstractive Setting 151 | 152 | #### TransformerAbs (baseline) 153 | ``` 154 | python train.py -mode train -accum_count 5 -batch_size 300 -bert_data_path BERT_DATA_PATH -dec_dropout 0.1 -log_file ../../logs/cnndm_baseline -lr 0.05 -model_path MODEL_PATH -save_checkpoint_steps 2000 -seed 777 -sep_optim false -train_steps 200000 -use_bert_emb true -use_interval true -warmup_steps 8000 -visible_gpus 0,1,2,3 -max_pos 512 -report_every 50 -enc_hidden_size 512 -enc_layers 6 -enc_ff_size 2048 -enc_dropout 0.1 -dec_layers 6 -dec_hidden_size 512 -dec_ff_size 2048 -encoder baseline -task abs 155 | ``` 156 | #### BertAbs 157 | ``` 158 | python train.py -task abs -mode train -bert_data_path BERT_DATA_PATH -dec_dropout 0.2 -model_path MODEL_PATH -sep_optim true -lr_bert 0.002 -lr_dec 0.2 -save_checkpoint_steps 2000 -batch_size 140 -train_steps 200000 -report_every 50 -accum_count 5 -use_bert_emb true -use_interval true -warmup_steps_bert 20000 -warmup_steps_dec 10000 -max_pos 512 -visible_gpus 0,1,2,3 -log_file ../logs/abs_bert_cnndm 159 | ``` 160 | #### BertExtAbs 161 | ``` 162 | python train.py -task abs -mode train -bert_data_path BERT_DATA_PATH -dec_dropout 0.2 -model_path MODEL_PATH -sep_optim true -lr_bert 0.002 -lr_dec 0.2 -save_checkpoint_steps 2000 -batch_size 140 -train_steps 200000 -report_every 50 -accum_count 5 -use_bert_emb true -use_interval true -warmup_steps_bert 20000 -warmup_steps_dec 10000 -max_pos 512 -visible_gpus 0,1,2,3 -log_file ../logs/abs_bert_cnndm -load_from_extractive EXT_CKPT 163 | ``` 164 | * `EXT_CKPT` is the saved `.pt` checkpoint of the extractive model. 165 | 166 | 167 | 168 | 169 | ## Model Evaluation 170 | ### CNN/DM 171 | ``` 172 | python train.py -task abs -mode validate -batch_size 3000 -test_batch_size 500 -bert_data_path BERT_DATA_PATH -log_file ../logs/val_abs_bert_cnndm -model_path MODEL_PATH -sep_optim true -use_interval true -visible_gpus 1 -max_pos 512 -max_length 200 -alpha 0.95 -min_length 50 -result_path ../logs/abs_bert_cnndm 173 | ``` 174 | ### XSum 175 | ``` 176 | python train.py -task abs -mode validate -batch_size 3000 -test_batch_size 500 -bert_data_path BERT_DATA_PATH -log_file ../logs/val_abs_bert_cnndm -model_path MODEL_PATH -sep_optim true -use_interval true -visible_gpus 1 -max_pos 512 -min_length 20 -max_length 100 -alpha 0.9 -result_path ../logs/abs_bert_cnndm 177 | ``` 178 | * `-mode` can be {`validate, test`}, where `validate` will inspect the model directory and evaluate the model for each newly saved checkpoint, `test` need to be used with `-test_from`, indicating the checkpoint you want to use 179 | * `MODEL_PATH` is the directory of saved checkpoints 180 | * use `-mode valiadte` with `-test_all`, the system will load all saved checkpoints and select the top ones to generate summaries (this will take a while) 181 | 182 | -------------------------------------------------------------------------------- /bert_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /raw_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | multiprocess==0.70.9 2 | numpy==1.17.2 3 | pyrouge==0.1.3 4 | pytorch-transformers==1.2.0 5 | tensorboardX==1.9 6 | torch==1.1.0 -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /src/cal_rouge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | # from multiprocess import Pool as Pool2 5 | from multiprocessing import Pool 6 | 7 | import shutil 8 | import sys 9 | import codecs 10 | 11 | # from onmt.utils.logging import init_logger, logger 12 | from others import pyrouge 13 | 14 | 15 | def process(data): 16 | candidates, references, pool_id = data 17 | cnt = len(candidates) 18 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 19 | tmp_dir = "rouge-tmp-{}-{}".format(current_time,pool_id) 20 | if not os.path.isdir(tmp_dir): 21 | os.mkdir(tmp_dir) 22 | os.mkdir(tmp_dir + "/candidate") 23 | os.mkdir(tmp_dir + "/reference") 24 | try: 25 | 26 | for i in range(cnt): 27 | if len(references[i]) < 1: 28 | continue 29 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 30 | encoding="utf-8") as f: 31 | f.write(candidates[i]) 32 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 33 | encoding="utf-8") as f: 34 | f.write(references[i]) 35 | r = pyrouge.Rouge155() 36 | r.model_dir = tmp_dir + "/reference/" 37 | r.system_dir = tmp_dir + "/candidate/" 38 | r.model_filename_pattern = 'ref.#ID#.txt' 39 | r.system_filename_pattern = r'cand.(\d+).txt' 40 | rouge_results = r.convert_and_evaluate() 41 | print(rouge_results) 42 | results_dict = r.output_to_dict(rouge_results) 43 | finally: 44 | pass 45 | if os.path.isdir(tmp_dir): 46 | shutil.rmtree(tmp_dir) 47 | return results_dict 48 | 49 | 50 | 51 | 52 | def chunks(l, n): 53 | """Yield successive n-sized chunks from l.""" 54 | for i in range(0, len(l), n): 55 | yield l[i:i + n] 56 | 57 | def test_rouge(cand, ref,num_processes): 58 | """Calculate ROUGE scores of sequences passed as an iterator 59 | e.g. a list of str, an open file, StringIO or even sys.stdin 60 | """ 61 | candidates = [line.strip() for line in cand] 62 | references = [line.strip() for line in ref] 63 | 64 | print(len(candidates)) 65 | print(len(references)) 66 | assert len(candidates) == len(references) 67 | candidates_chunks = list(chunks(candidates, int(len(candidates)/num_processes))) 68 | references_chunks = list(chunks(references, int(len(references)/num_processes))) 69 | n_pool = len(candidates_chunks) 70 | arg_lst = [] 71 | for i in range(n_pool): 72 | arg_lst.append((candidates_chunks[i],references_chunks[i],i)) 73 | pool = Pool(n_pool) 74 | results = pool.map(process,arg_lst) 75 | final_results = {} 76 | for i,r in enumerate(results): 77 | for k in r: 78 | if(k not in final_results): 79 | final_results[k] = r[k]*len(candidates_chunks[i]) 80 | else: 81 | final_results[k] += r[k] * len(candidates_chunks[i]) 82 | for k in final_results: 83 | final_results[k] = final_results[k]/len(candidates) 84 | return final_results 85 | def rouge_results_to_str(results_dict): 86 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 87 | results_dict["rouge_1_f_score"] * 100, 88 | results_dict["rouge_2_f_score"] * 100, 89 | # results_dict["rouge_3_f_score"] * 100, 90 | results_dict["rouge_l_f_score"] * 100, 91 | results_dict["rouge_1_recall"] * 100, 92 | results_dict["rouge_2_recall"] * 100, 93 | # results_dict["rouge_3_f_score"] * 100, 94 | results_dict["rouge_l_recall"] * 100 95 | 96 | # ,results_dict["rouge_su*_f_score"] * 100 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | # init_logger('test_rouge.log') 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('-c', type=str, default="candidate.txt", 104 | help='candidate file') 105 | parser.add_argument('-r', type=str, default="reference.txt", 106 | help='reference file') 107 | parser.add_argument('-p', type=int, default=1, 108 | help='number of processes') 109 | args = parser.parse_args() 110 | print(args.c) 111 | print(args.r) 112 | print(args.p) 113 | if args.c.upper() == "STDIN": 114 | candidates = sys.stdin 115 | else: 116 | candidates = codecs.open(args.c, encoding="utf-8") 117 | references = codecs.open(args.r, encoding="utf-8") 118 | 119 | results_dict = test_rouge(candidates, references,args.p) 120 | # return 0 121 | print(time.strftime('%H:%M:%S', time.localtime()) 122 | ) 123 | print(rouge_results_to_str(results_dict)) 124 | # logger.info(rouge_results_to_str(results_dict)) -------------------------------------------------------------------------------- /src/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | 12 | import torch.distributed 13 | 14 | from others.logging import logger 15 | 16 | 17 | def is_master(gpu_ranks, device_id): 18 | return gpu_ranks[device_id] == 0 19 | 20 | 21 | def multi_init(device_id, world_size,gpu_ranks): 22 | print(gpu_ranks) 23 | dist_init_method = 'tcp://localhost:10000' 24 | dist_world_size = world_size 25 | torch.distributed.init_process_group( 26 | backend='nccl', init_method=dist_init_method, 27 | world_size=dist_world_size, rank=gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(gpu_ranks, device_id): 30 | # print('not master') 31 | logger.disabled = True 32 | 33 | return gpu_rank 34 | 35 | 36 | 37 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 38 | buffer_size=10485760): 39 | """All-reduce and rescale tensors in chunks of the specified size. 40 | 41 | Args: 42 | tensors: list of Tensors to all-reduce 43 | rescale_denom: denominator for rescaling summed Tensors 44 | buffer_size: all-reduce chunk size in bytes 45 | """ 46 | # buffer size in bytes, determine equiv. # of elements based on data type 47 | buffer_t = tensors[0].new( 48 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 49 | buffer = [] 50 | 51 | def all_reduce_buffer(): 52 | # copy tensors into buffer_t 53 | offset = 0 54 | for t in buffer: 55 | numel = t.numel() 56 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 57 | offset += numel 58 | 59 | # all-reduce and rescale 60 | torch.distributed.all_reduce(buffer_t[:offset]) 61 | buffer_t.div_(rescale_denom) 62 | 63 | # copy all-reduced buffer back into tensors 64 | offset = 0 65 | for t in buffer: 66 | numel = t.numel() 67 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 68 | offset += numel 69 | 70 | filled = 0 71 | for t in tensors: 72 | sz = t.numel() * t.element_size() 73 | if sz > buffer_size: 74 | # tensor is bigger than buffer, all-reduce and rescale directly 75 | torch.distributed.all_reduce(t) 76 | t.div_(rescale_denom) 77 | elif filled + sz > buffer_size: 78 | # buffer is full, all-reduce and replace buffer with grad 79 | all_reduce_buffer() 80 | buffer = [t] 81 | filled = sz 82 | else: 83 | # add tensor to buffer 84 | buffer.append(t) 85 | filled += sz 86 | 87 | if len(buffer) > 0: 88 | all_reduce_buffer() 89 | 90 | 91 | def all_gather_list(data, max_size=4096): 92 | """Gathers arbitrary data from all nodes into a list.""" 93 | world_size = torch.distributed.get_world_size() 94 | if not hasattr(all_gather_list, '_in_buffer') or \ 95 | max_size != all_gather_list._in_buffer.size(): 96 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 97 | all_gather_list._out_buffers = [ 98 | torch.cuda.ByteTensor(max_size) 99 | for i in range(world_size) 100 | ] 101 | in_buffer = all_gather_list._in_buffer 102 | out_buffers = all_gather_list._out_buffers 103 | 104 | enc = pickle.dumps(data) 105 | enc_size = len(enc) 106 | if enc_size + 2 > max_size: 107 | raise ValueError( 108 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 109 | assert max_size < 255*256 110 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 111 | in_buffer[1] = enc_size % 255 112 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 113 | 114 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 115 | 116 | results = [] 117 | for i in range(world_size): 118 | out_buffer = out_buffers[i] 119 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 120 | 121 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 122 | result = pickle.loads(bytes_list) 123 | results.append(result) 124 | return results 125 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class Adam(Optimizer): 7 | r"""Implements Adam algorithm. 8 | 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False) 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _On the Convergence of Adam and Beyond: 27 | https://openreview.net/forum?id=ryQu7f-RZ 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 31 | weight_decay=0, amsgrad=False): 32 | if not 0.0 <= lr: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | if not 0.0 <= eps: 35 | raise ValueError("Invalid epsilon value: {}".format(eps)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay, amsgrad=amsgrad) 42 | super(Adam, self).__init__(params, defaults) 43 | 44 | def __setstate__(self, state): 45 | super(Adam, self).__setstate__(state) 46 | for group in self.param_groups: 47 | group.setdefault('amsgrad', False) 48 | 49 | def step(self, closure=None): 50 | """Performs a single optimization step. 51 | Arguments: 52 | closure (callable, optional): A closure that reevaluates the model 53 | and returns the loss. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | grad = p.grad.data 65 | if grad.is_sparse: 66 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 67 | 68 | state = self.state[p] 69 | 70 | # State initialization 71 | if len(state) == 0: 72 | state['step'] = 0 73 | # Exponential moving average of gradient values 74 | state['next_m'] = torch.zeros_like(p.data) 75 | # Exponential moving average of squared gradient values 76 | state['next_v'] = torch.zeros_like(p.data) 77 | 78 | next_m, next_v = state['next_m'], state['next_v'] 79 | beta1, beta2 = group['betas'] 80 | 81 | # Decay the first and second moment running average coefficient 82 | # In-place operations to update the averages at the same time 83 | next_m.mul_(beta1).add_(1 - beta1, grad) 84 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 85 | update = next_m / (next_v.sqrt() + group['eps']) 86 | 87 | # Just adding the square of the weights to the loss function is *not* 88 | # the correct way of using L2 regularization/weight decay with Adam, 89 | # since that will interact with the m and v parameters in strange ways. 90 | # 91 | # Instead we want to decay the weights in a manner that doesn't interact 92 | # with the m/v parameters. This is equivalent to adding the square 93 | # of the weights to the loss with plain (non-momentum) SGD. 94 | if group['weight_decay'] > 0.0: 95 | update += group['weight_decay'] * p.data 96 | 97 | lr_scheduled = group['lr'] 98 | 99 | update_with_lr = lr_scheduled * update 100 | p.data.add_(-update_with_lr) 101 | 102 | state['step'] += 1 103 | 104 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 105 | # No bias correction 106 | # bias_correction1 = 1 - beta1 ** state['step'] 107 | # bias_correction2 = 1 - beta2 ** state['step'] 108 | 109 | return loss -------------------------------------------------------------------------------- /src/models/data_loader.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import gc 3 | import glob 4 | import random 5 | 6 | import torch 7 | 8 | from others.logging import logger 9 | 10 | 11 | 12 | class Batch(object): 13 | def _pad(self, data, pad_id, width=-1): 14 | if (width == -1): 15 | width = max(len(d) for d in data) 16 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data] 17 | return rtn_data 18 | 19 | def __init__(self, data=None, device=None, is_test=False): 20 | """Create a Batch from a list of examples.""" 21 | if data is not None: 22 | self.batch_size = len(data) 23 | pre_src = [x[0] for x in data] 24 | pre_tgt = [x[1] for x in data] 25 | pre_segs = [x[2] for x in data] 26 | pre_clss = [x[3] for x in data] 27 | pre_src_sent_labels = [x[4] for x in data] 28 | 29 | src = torch.tensor(self._pad(pre_src, 0)) 30 | tgt = torch.tensor(self._pad(pre_tgt, 0)) 31 | 32 | segs = torch.tensor(self._pad(pre_segs, 0)) 33 | mask_src = 1 - (src == 0) 34 | mask_tgt = 1 - (tgt == 0) 35 | 36 | 37 | clss = torch.tensor(self._pad(pre_clss, -1)) 38 | src_sent_labels = torch.tensor(self._pad(pre_src_sent_labels, 0)) 39 | mask_cls = 1 - (clss == -1) 40 | clss[clss == -1] = 0 41 | setattr(self, 'clss', clss.to(device)) 42 | setattr(self, 'mask_cls', mask_cls.to(device)) 43 | setattr(self, 'src_sent_labels', src_sent_labels.to(device)) 44 | 45 | 46 | setattr(self, 'src', src.to(device)) 47 | setattr(self, 'tgt', tgt.to(device)) 48 | setattr(self, 'segs', segs.to(device)) 49 | setattr(self, 'mask_src', mask_src.to(device)) 50 | setattr(self, 'mask_tgt', mask_tgt.to(device)) 51 | 52 | 53 | if (is_test): 54 | src_str = [x[-2] for x in data] 55 | setattr(self, 'src_str', src_str) 56 | tgt_str = [x[-1] for x in data] 57 | setattr(self, 'tgt_str', tgt_str) 58 | 59 | def __len__(self): 60 | return self.batch_size 61 | 62 | 63 | 64 | 65 | def load_dataset(args, corpus_type, shuffle): 66 | """ 67 | Dataset generator. Don't do extra stuff here, like printing, 68 | because they will be postponed to the first loading time. 69 | 70 | Args: 71 | corpus_type: 'train' or 'valid' 72 | Returns: 73 | A list of dataset, the dataset(s) are lazily loaded. 74 | """ 75 | assert corpus_type in ["train", "valid", "test"] 76 | 77 | def _lazy_dataset_loader(pt_file, corpus_type): 78 | dataset = torch.load(pt_file) 79 | logger.info('Loading %s dataset from %s, number of examples: %d' % 80 | (corpus_type, pt_file, len(dataset))) 81 | return dataset 82 | 83 | # Sort the glob output by file name (by increasing indexes). 84 | pts = sorted(glob.glob(args.bert_data_path + '.' + corpus_type + '.[0-9]*.pt')) 85 | if pts: 86 | if (shuffle): 87 | random.shuffle(pts) 88 | 89 | for pt in pts: 90 | yield _lazy_dataset_loader(pt, corpus_type) 91 | else: 92 | # Only one inputters.*Dataset, simple! 93 | pt = args.bert_data_path + '.' + corpus_type + '.pt' 94 | yield _lazy_dataset_loader(pt, corpus_type) 95 | 96 | 97 | def abs_batch_size_fn(new, count): 98 | src, tgt = new[0], new[1] 99 | global max_n_sents, max_n_tokens, max_size 100 | if count == 1: 101 | max_size = 0 102 | max_n_sents=0 103 | max_n_tokens=0 104 | max_n_sents = max(max_n_sents, len(tgt)) 105 | max_size = max(max_size, max_n_sents) 106 | src_elements = count * max_size 107 | if (count > 6): 108 | return src_elements + 1e3 109 | return src_elements 110 | 111 | 112 | def ext_batch_size_fn(new, count): 113 | if (len(new) == 4): 114 | pass 115 | src, labels = new[0], new[4] 116 | global max_n_sents, max_n_tokens, max_size 117 | if count == 1: 118 | max_size = 0 119 | max_n_sents = 0 120 | max_n_tokens = 0 121 | max_n_sents = max(max_n_sents, len(src)) 122 | max_size = max(max_size, max_n_sents) 123 | src_elements = count * max_size 124 | return src_elements 125 | 126 | 127 | class Dataloader(object): 128 | def __init__(self, args, datasets, batch_size, 129 | device, shuffle, is_test): 130 | self.args = args 131 | self.datasets = datasets 132 | self.batch_size = batch_size 133 | self.device = device 134 | self.shuffle = shuffle 135 | self.is_test = is_test 136 | self.cur_iter = self._next_dataset_iterator(datasets) 137 | assert self.cur_iter is not None 138 | 139 | def __iter__(self): 140 | dataset_iter = (d for d in self.datasets) 141 | while self.cur_iter is not None: 142 | for batch in self.cur_iter: 143 | yield batch 144 | self.cur_iter = self._next_dataset_iterator(dataset_iter) 145 | 146 | 147 | def _next_dataset_iterator(self, dataset_iter): 148 | try: 149 | # Drop the current dataset for decreasing memory 150 | if hasattr(self, "cur_dataset"): 151 | self.cur_dataset = None 152 | gc.collect() 153 | del self.cur_dataset 154 | gc.collect() 155 | 156 | self.cur_dataset = next(dataset_iter) 157 | except StopIteration: 158 | return None 159 | 160 | return DataIterator(args = self.args, 161 | dataset=self.cur_dataset, batch_size=self.batch_size, 162 | device=self.device, shuffle=self.shuffle, is_test=self.is_test) 163 | 164 | 165 | class DataIterator(object): 166 | def __init__(self, args, dataset, batch_size, device=None, is_test=False, 167 | shuffle=True): 168 | self.args = args 169 | self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset 170 | self.iterations = 0 171 | self.device = device 172 | self.shuffle = shuffle 173 | 174 | self.sort_key = lambda x: len(x[1]) 175 | 176 | self._iterations_this_epoch = 0 177 | if (self.args.task == 'abs'): 178 | self.batch_size_fn = abs_batch_size_fn 179 | else: 180 | self.batch_size_fn = ext_batch_size_fn 181 | 182 | def data(self): 183 | if self.shuffle: 184 | random.shuffle(self.dataset) 185 | xs = self.dataset 186 | return xs 187 | 188 | 189 | 190 | 191 | 192 | 193 | def preprocess(self, ex, is_test): 194 | src = ex['src'] 195 | tgt = ex['tgt'][:self.args.max_tgt_len][:-1]+[2] 196 | src_sent_labels = ex['src_sent_labels'] 197 | segs = ex['segs'] 198 | if(not self.args.use_interval): 199 | segs=[0]*len(segs) 200 | clss = ex['clss'] 201 | src_txt = ex['src_txt'] 202 | tgt_txt = ex['tgt_txt'] 203 | 204 | end_id = [src[-1]] 205 | src = src[:-1][:self.args.max_pos - 1] + end_id 206 | segs = segs[:self.args.max_pos] 207 | max_sent_id = bisect.bisect_left(clss, self.args.max_pos) 208 | src_sent_labels = src_sent_labels[:max_sent_id] 209 | clss = clss[:max_sent_id] 210 | # src_txt = src_txt[:max_sent_id] 211 | 212 | 213 | 214 | if(is_test): 215 | return src, tgt, segs, clss, src_sent_labels, src_txt, tgt_txt 216 | else: 217 | return src, tgt, segs, clss, src_sent_labels 218 | 219 | def batch_buffer(self, data, batch_size): 220 | minibatch, size_so_far = [], 0 221 | for ex in data: 222 | if(len(ex['src'])==0): 223 | continue 224 | ex = self.preprocess(ex, self.is_test) 225 | if(ex is None): 226 | continue 227 | minibatch.append(ex) 228 | size_so_far = self.batch_size_fn(ex, len(minibatch)) 229 | if size_so_far == batch_size: 230 | yield minibatch 231 | minibatch, size_so_far = [], 0 232 | elif size_so_far > batch_size: 233 | yield minibatch[:-1] 234 | minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1) 235 | if minibatch: 236 | yield minibatch 237 | 238 | def batch(self, data, batch_size): 239 | """Yield elements from data in chunks of batch_size.""" 240 | minibatch, size_so_far = [], 0 241 | for ex in data: 242 | minibatch.append(ex) 243 | size_so_far = self.batch_size_fn(ex, len(minibatch)) 244 | if size_so_far == batch_size: 245 | yield minibatch 246 | minibatch, size_so_far = [], 0 247 | elif size_so_far > batch_size: 248 | yield minibatch[:-1] 249 | minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1) 250 | if minibatch: 251 | yield minibatch 252 | 253 | def create_batches(self): 254 | """ Create batches """ 255 | data = self.data() 256 | for buffer in self.batch_buffer(data, self.batch_size * 300): 257 | 258 | if (self.args.task == 'abs'): 259 | p_batch = sorted(buffer, key=lambda x: len(x[2])) 260 | p_batch = sorted(p_batch, key=lambda x: len(x[1])) 261 | else: 262 | p_batch = sorted(buffer, key=lambda x: len(x[2])) 263 | 264 | p_batch = self.batch(p_batch, self.batch_size) 265 | 266 | 267 | p_batch = list(p_batch) 268 | if (self.shuffle): 269 | random.shuffle(p_batch) 270 | for b in p_batch: 271 | if(len(b)==0): 272 | continue 273 | yield b 274 | 275 | def __iter__(self): 276 | while True: 277 | self.batches = self.create_batches() 278 | for idx, minibatch in enumerate(self.batches): 279 | # fast-forward if loaded from state 280 | if self._iterations_this_epoch > idx: 281 | continue 282 | self.iterations += 1 283 | self._iterations_this_epoch += 1 284 | batch = Batch(minibatch, self.device, self.is_test) 285 | 286 | yield batch 287 | return 288 | 289 | 290 | class TextDataloader(object): 291 | def __init__(self, args, datasets, batch_size, 292 | device, shuffle, is_test): 293 | self.args = args 294 | self.batch_size = batch_size 295 | self.device = device 296 | 297 | def data(self): 298 | if self.shuffle: 299 | random.shuffle(self.dataset) 300 | xs = self.dataset 301 | return xs 302 | 303 | def preprocess(self, ex, is_test): 304 | src = ex['src'] 305 | tgt = ex['tgt'][:self.args.max_tgt_len][:-1] + [2] 306 | src_sent_labels = ex['src_sent_labels'] 307 | segs = ex['segs'] 308 | if (not self.args.use_interval): 309 | segs = [0] * len(segs) 310 | clss = ex['clss'] 311 | src_txt = ex['src_txt'] 312 | tgt_txt = ex['tgt_txt'] 313 | 314 | end_id = [src[-1]] 315 | src = src[:-1][:self.args.max_pos - 1] + end_id 316 | segs = segs[:self.args.max_pos] 317 | max_sent_id = bisect.bisect_left(clss, self.args.max_pos) 318 | src_sent_labels = src_sent_labels[:max_sent_id] 319 | clss = clss[:max_sent_id] 320 | # src_txt = src_txt[:max_sent_id] 321 | 322 | if (is_test): 323 | return src, tgt, segs, clss, src_sent_labels, src_txt, tgt_txt 324 | else: 325 | return src, tgt, segs, clss, src_sent_labels 326 | 327 | def batch_buffer(self, data, batch_size): 328 | minibatch, size_so_far = [], 0 329 | for ex in data: 330 | if (len(ex['src']) == 0): 331 | continue 332 | ex = self.preprocess(ex, self.is_test) 333 | if (ex is None): 334 | continue 335 | minibatch.append(ex) 336 | size_so_far = simple_batch_size_fn(ex, len(minibatch)) 337 | if size_so_far == batch_size: 338 | yield minibatch 339 | minibatch, size_so_far = [], 0 340 | elif size_so_far > batch_size: 341 | yield minibatch[:-1] 342 | minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1) 343 | if minibatch: 344 | yield minibatch 345 | 346 | def create_batches(self): 347 | """ Create batches """ 348 | data = self.data() 349 | for buffer in self.batch_buffer(data, self.batch_size * 300): 350 | if (self.args.task == 'abs'): 351 | p_batch = sorted(buffer, key=lambda x: len(x[2])) 352 | p_batch = sorted(p_batch, key=lambda x: len(x[1])) 353 | else: 354 | p_batch = sorted(buffer, key=lambda x: len(x[2])) 355 | p_batch = batch(p_batch, self.batch_size) 356 | 357 | p_batch = batch(p_batch, self.batch_size) 358 | 359 | p_batch = list(p_batch) 360 | if (self.shuffle): 361 | random.shuffle(p_batch) 362 | for b in p_batch: 363 | if (len(b) == 0): 364 | continue 365 | yield b 366 | 367 | def __iter__(self): 368 | while True: 369 | self.batches = self.create_batches() 370 | for idx, minibatch in enumerate(self.batches): 371 | # fast-forward if loaded from state 372 | if self._iterations_this_epoch > idx: 373 | continue 374 | self.iterations += 1 375 | self._iterations_this_epoch += 1 376 | batch = Batch(minibatch, self.device, self.is_test) 377 | 378 | yield batch 379 | return 380 | -------------------------------------------------------------------------------- /src/models/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | from models.encoder import PositionalEncoding 10 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState 11 | 12 | MAX_SIZE = 5000 13 | 14 | 15 | class TransformerDecoderLayer(nn.Module): 16 | """ 17 | Args: 18 | d_model (int): the dimension of keys/values/queries in 19 | MultiHeadedAttention, also the input size of 20 | the first-layer of the PositionwiseFeedForward. 21 | heads (int): the number of heads for MultiHeadedAttention. 22 | d_ff (int): the second-layer of the PositionwiseFeedForward. 23 | dropout (float): dropout probability(0-1.0). 24 | self_attn_type (string): type of self-attention scaled-dot, average 25 | """ 26 | 27 | def __init__(self, d_model, heads, d_ff, dropout): 28 | super(TransformerDecoderLayer, self).__init__() 29 | 30 | 31 | self.self_attn = MultiHeadedAttention( 32 | heads, d_model, dropout=dropout) 33 | 34 | self.context_attn = MultiHeadedAttention( 35 | heads, d_model, dropout=dropout) 36 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 37 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) 38 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) 39 | self.drop = nn.Dropout(dropout) 40 | mask = self._get_attn_subsequent_mask(MAX_SIZE) 41 | # Register self.mask as a buffer in TransformerDecoderLayer, so 42 | # it gets TransformerDecoderLayer's cuda behavior automatically. 43 | self.register_buffer('mask', mask) 44 | 45 | def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, 46 | previous_input=None, layer_cache=None, step=None): 47 | """ 48 | Args: 49 | inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` 50 | memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` 51 | src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` 52 | tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` 53 | 54 | Returns: 55 | (`FloatTensor`, `FloatTensor`, `FloatTensor`): 56 | 57 | * output `[batch_size x 1 x model_dim]` 58 | * attn `[batch_size x 1 x src_len]` 59 | * all_input `[batch_size x current_step x model_dim]` 60 | 61 | """ 62 | dec_mask = torch.gt(tgt_pad_mask + 63 | self.mask[:, :tgt_pad_mask.size(1), 64 | :tgt_pad_mask.size(1)], 0) 65 | input_norm = self.layer_norm_1(inputs) 66 | all_input = input_norm 67 | if previous_input is not None: 68 | all_input = torch.cat((previous_input, input_norm), dim=1) 69 | dec_mask = None 70 | 71 | query = self.self_attn(all_input, all_input, input_norm, 72 | mask=dec_mask, 73 | layer_cache=layer_cache, 74 | type="self") 75 | 76 | query = self.drop(query) + inputs 77 | 78 | query_norm = self.layer_norm_2(query) 79 | mid = self.context_attn(memory_bank, memory_bank, query_norm, 80 | mask=src_pad_mask, 81 | layer_cache=layer_cache, 82 | type="context") 83 | output = self.feed_forward(self.drop(mid) + query) 84 | 85 | return output, all_input 86 | # return output 87 | 88 | def _get_attn_subsequent_mask(self, size): 89 | """ 90 | Get an attention mask to avoid using the subsequent info. 91 | 92 | Args: 93 | size: int 94 | 95 | Returns: 96 | (`LongTensor`): 97 | 98 | * subsequent_mask `[1 x size x size]` 99 | """ 100 | attn_shape = (1, size, size) 101 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 102 | subsequent_mask = torch.from_numpy(subsequent_mask) 103 | return subsequent_mask 104 | 105 | 106 | 107 | class TransformerDecoder(nn.Module): 108 | """ 109 | The Transformer decoder from "Attention is All You Need". 110 | 111 | 112 | .. mermaid:: 113 | 114 | graph BT 115 | A[input] 116 | B[multi-head self-attn] 117 | BB[multi-head src-attn] 118 | C[feed forward] 119 | O[output] 120 | A --> B 121 | B --> BB 122 | BB --> C 123 | C --> O 124 | 125 | 126 | Args: 127 | num_layers (int): number of encoder layers. 128 | d_model (int): size of the model 129 | heads (int): number of heads 130 | d_ff (int): size of the inner FF layer 131 | dropout (float): dropout parameters 132 | embeddings (:obj:`onmt.modules.Embeddings`): 133 | embeddings to use, should have positional encodings 134 | attn_type (str): if using a seperate copy attention 135 | """ 136 | 137 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): 138 | super(TransformerDecoder, self).__init__() 139 | 140 | # Basic attributes. 141 | self.decoder_type = 'transformer' 142 | self.num_layers = num_layers 143 | self.embeddings = embeddings 144 | self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim) 145 | 146 | 147 | # Build TransformerDecoder. 148 | self.transformer_layers = nn.ModuleList( 149 | [TransformerDecoderLayer(d_model, heads, d_ff, dropout) 150 | for _ in range(num_layers)]) 151 | 152 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 153 | 154 | def forward(self, tgt, memory_bank, state, memory_lengths=None, 155 | step=None, cache=None,memory_masks=None): 156 | """ 157 | See :obj:`onmt.modules.RNNDecoderBase.forward()` 158 | """ 159 | 160 | src_words = state.src 161 | tgt_words = tgt 162 | src_batch, src_len = src_words.size() 163 | tgt_batch, tgt_len = tgt_words.size() 164 | 165 | # Run the forward pass of the TransformerDecoder. 166 | # emb = self.embeddings(tgt, step=step) 167 | emb = self.embeddings(tgt) 168 | assert emb.dim() == 3 # len x batch x embedding_dim 169 | 170 | output = self.pos_emb(emb, step) 171 | 172 | src_memory_bank = memory_bank 173 | padding_idx = self.embeddings.padding_idx 174 | tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ 175 | .expand(tgt_batch, tgt_len, tgt_len) 176 | 177 | if (not memory_masks is None): 178 | src_len = memory_masks.size(-1) 179 | src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len) 180 | 181 | else: 182 | src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ 183 | .expand(src_batch, tgt_len, src_len) 184 | 185 | if state.cache is None: 186 | saved_inputs = [] 187 | 188 | for i in range(self.num_layers): 189 | prev_layer_input = None 190 | if state.cache is None: 191 | if state.previous_input is not None: 192 | prev_layer_input = state.previous_layer_inputs[i] 193 | output, all_input \ 194 | = self.transformer_layers[i]( 195 | output, src_memory_bank, 196 | src_pad_mask, tgt_pad_mask, 197 | previous_input=prev_layer_input, 198 | layer_cache=state.cache["layer_{}".format(i)] 199 | if state.cache is not None else None, 200 | step=step) 201 | if state.cache is None: 202 | saved_inputs.append(all_input) 203 | 204 | if state.cache is None: 205 | saved_inputs = torch.stack(saved_inputs) 206 | 207 | output = self.layer_norm(output) 208 | 209 | # Process the result and update the attentions. 210 | 211 | if state.cache is None: 212 | state = state.update_state(tgt, saved_inputs) 213 | 214 | return output, state 215 | 216 | def init_decoder_state(self, src, memory_bank, 217 | with_cache=False): 218 | """ Init decoder state """ 219 | state = TransformerDecoderState(src) 220 | if with_cache: 221 | state._init_cache(memory_bank, self.num_layers) 222 | return state 223 | 224 | 225 | 226 | class TransformerDecoderState(DecoderState): 227 | """ Transformer Decoder state base class """ 228 | 229 | def __init__(self, src): 230 | """ 231 | Args: 232 | src (FloatTensor): a sequence of source words tensors 233 | with optional feature tensors, of size (len x batch). 234 | """ 235 | self.src = src 236 | self.previous_input = None 237 | self.previous_layer_inputs = None 238 | self.cache = None 239 | 240 | @property 241 | def _all(self): 242 | """ 243 | Contains attributes that need to be updated in self.beam_update(). 244 | """ 245 | if (self.previous_input is not None 246 | and self.previous_layer_inputs is not None): 247 | return (self.previous_input, 248 | self.previous_layer_inputs, 249 | self.src) 250 | else: 251 | return (self.src,) 252 | 253 | def detach(self): 254 | if self.previous_input is not None: 255 | self.previous_input = self.previous_input.detach() 256 | if self.previous_layer_inputs is not None: 257 | self.previous_layer_inputs = self.previous_layer_inputs.detach() 258 | self.src = self.src.detach() 259 | 260 | def update_state(self, new_input, previous_layer_inputs): 261 | state = TransformerDecoderState(self.src) 262 | state.previous_input = new_input 263 | state.previous_layer_inputs = previous_layer_inputs 264 | return state 265 | 266 | def _init_cache(self, memory_bank, num_layers): 267 | self.cache = {} 268 | 269 | for l in range(num_layers): 270 | layer_cache = { 271 | "memory_keys": None, 272 | "memory_values": None 273 | } 274 | layer_cache["self_keys"] = None 275 | layer_cache["self_values"] = None 276 | self.cache["layer_{}".format(l)] = layer_cache 277 | 278 | def repeat_beam_size_times(self, beam_size): 279 | """ Repeat beam_size times along batch dimension. """ 280 | self.src = self.src.data.repeat(1, beam_size, 1) 281 | 282 | def map_batch_fn(self, fn): 283 | def _recursive_map(struct, batch_dim=0): 284 | for k, v in struct.items(): 285 | if v is not None: 286 | if isinstance(v, dict): 287 | _recursive_map(v) 288 | else: 289 | struct[k] = fn(v, batch_dim) 290 | 291 | self.src = fn(self.src, 0) 292 | if self.cache is not None: 293 | _recursive_map(self.cache) 294 | 295 | 296 | 297 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward 7 | 8 | 9 | class Classifier(nn.Module): 10 | def __init__(self, hidden_size): 11 | super(Classifier, self).__init__() 12 | self.linear1 = nn.Linear(hidden_size, 1) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x, mask_cls): 16 | h = self.linear1(x).squeeze(-1) 17 | sent_scores = self.sigmoid(h) * mask_cls.float() 18 | return sent_scores 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | 23 | def __init__(self, dropout, dim, max_len=5000): 24 | pe = torch.zeros(max_len, dim) 25 | position = torch.arange(0, max_len).unsqueeze(1) 26 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 27 | -(math.log(10000.0) / dim))) 28 | pe[:, 0::2] = torch.sin(position.float() * div_term) 29 | pe[:, 1::2] = torch.cos(position.float() * div_term) 30 | pe = pe.unsqueeze(0) 31 | super(PositionalEncoding, self).__init__() 32 | self.register_buffer('pe', pe) 33 | self.dropout = nn.Dropout(p=dropout) 34 | self.dim = dim 35 | 36 | def forward(self, emb, step=None): 37 | emb = emb * math.sqrt(self.dim) 38 | if (step): 39 | emb = emb + self.pe[:, step][:, None, :] 40 | 41 | else: 42 | emb = emb + self.pe[:, :emb.size(1)] 43 | emb = self.dropout(emb) 44 | return emb 45 | 46 | def get_emb(self, emb): 47 | return self.pe[:, :emb.size(1)] 48 | 49 | 50 | class TransformerEncoderLayer(nn.Module): 51 | def __init__(self, d_model, heads, d_ff, dropout): 52 | super(TransformerEncoderLayer, self).__init__() 53 | 54 | self.self_attn = MultiHeadedAttention( 55 | heads, d_model, dropout=dropout) 56 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 57 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 58 | self.dropout = nn.Dropout(dropout) 59 | 60 | def forward(self, iter, query, inputs, mask): 61 | if (iter != 0): 62 | input_norm = self.layer_norm(inputs) 63 | else: 64 | input_norm = inputs 65 | 66 | mask = mask.unsqueeze(1) 67 | context = self.self_attn(input_norm, input_norm, input_norm, 68 | mask=mask) 69 | out = self.dropout(context) + inputs 70 | return self.feed_forward(out) 71 | 72 | 73 | class ExtTransformerEncoder(nn.Module): 74 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0): 75 | super(ExtTransformerEncoder, self).__init__() 76 | self.d_model = d_model 77 | self.num_inter_layers = num_inter_layers 78 | self.pos_emb = PositionalEncoding(dropout, d_model) 79 | self.transformer_inter = nn.ModuleList( 80 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 81 | for _ in range(num_inter_layers)]) 82 | self.dropout = nn.Dropout(dropout) 83 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 84 | self.wo = nn.Linear(d_model, 1, bias=True) 85 | self.sigmoid = nn.Sigmoid() 86 | 87 | def forward(self, top_vecs, mask): 88 | """ See :obj:`EncoderBase.forward()`""" 89 | 90 | batch_size, n_sents = top_vecs.size(0), top_vecs.size(1) 91 | pos_emb = self.pos_emb.pe[:, :n_sents] 92 | x = top_vecs * mask[:, :, None].float() 93 | x = x + pos_emb 94 | 95 | for i in range(self.num_inter_layers): 96 | x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents * max_tokens * dim 97 | 98 | x = self.layer_norm(x) 99 | sent_scores = self.sigmoid(self.wo(x)) 100 | sent_scores = sent_scores.squeeze(-1) * mask.float() 101 | 102 | return sent_scores 103 | 104 | -------------------------------------------------------------------------------- /src/models/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file handles the details of the loss function during training. 3 | 4 | This includes: LossComputeBase and the standard NMTLossCompute, and 5 | sharded loss compute stuff. 6 | """ 7 | from __future__ import division 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from models.reporter import Statistics 13 | 14 | 15 | def abs_loss(generator, symbols, vocab_size, device, train=True, label_smoothing=0.0): 16 | compute = NMTLossCompute( 17 | generator, symbols, vocab_size, 18 | label_smoothing=label_smoothing if train else 0.0) 19 | compute.to(device) 20 | return compute 21 | 22 | 23 | 24 | class LossComputeBase(nn.Module): 25 | """ 26 | Class for managing efficient loss computation. Handles 27 | sharding next step predictions and accumulating mutiple 28 | loss computations 29 | 30 | 31 | Users can implement their own loss computation strategy by making 32 | subclass of this one. Users need to implement the _compute_loss() 33 | and make_shard_state() methods. 34 | 35 | Args: 36 | generator (:obj:`nn.Module`) : 37 | module that maps the output of the decoder to a 38 | distribution over the target vocabulary. 39 | tgt_vocab (:obj:`Vocab`) : 40 | torchtext vocab object representing the target output 41 | normalzation (str): normalize by "sents" or "tokens" 42 | """ 43 | 44 | def __init__(self, generator, pad_id): 45 | super(LossComputeBase, self).__init__() 46 | self.generator = generator 47 | self.padding_idx = pad_id 48 | 49 | 50 | 51 | def _make_shard_state(self, batch, output, attns=None): 52 | """ 53 | Make shard state dictionary for shards() to return iterable 54 | shards for efficient loss computation. Subclass must define 55 | this method to match its own _compute_loss() interface. 56 | Args: 57 | batch: the current batch. 58 | output: the predict output from the model. 59 | range_: the range of examples for computing, the whole 60 | batch or a trunc of it? 61 | attns: the attns dictionary returned from the model. 62 | """ 63 | return NotImplementedError 64 | 65 | def _compute_loss(self, batch, output, target, **kwargs): 66 | """ 67 | Compute the loss. Subclass must define this method. 68 | 69 | Args: 70 | 71 | batch: the current batch. 72 | output: the predict output from the model. 73 | target: the validate target to compare output with. 74 | **kwargs(optional): additional info for computing loss. 75 | """ 76 | return NotImplementedError 77 | 78 | def monolithic_compute_loss(self, batch, output): 79 | """ 80 | Compute the forward loss for the batch. 81 | 82 | Args: 83 | batch (batch): batch of labeled examples 84 | output (:obj:`FloatTensor`): 85 | output of decoder model `[tgt_len x batch x hidden]` 86 | attns (dict of :obj:`FloatTensor`) : 87 | dictionary of attention distributions 88 | `[tgt_len x batch x src_len]` 89 | Returns: 90 | :obj:`onmt.utils.Statistics`: loss statistics 91 | """ 92 | shard_state = self._make_shard_state(batch, output) 93 | _, batch_stats = self._compute_loss(batch, **shard_state) 94 | 95 | return batch_stats 96 | 97 | def sharded_compute_loss(self, batch, output, 98 | shard_size, 99 | normalization): 100 | """Compute the forward loss and backpropagate. Computation is done 101 | with shards and optionally truncation for memory efficiency. 102 | 103 | Also supports truncated BPTT for long sequences by taking a 104 | range in the decoder output sequence to back propagate in. 105 | Range is from `(cur_trunc, cur_trunc + trunc_size)`. 106 | 107 | Note sharding is an exact efficiency trick to relieve memory 108 | required for the generation buffers. Truncation is an 109 | approximate efficiency trick to relieve the memory required 110 | in the RNN buffers. 111 | 112 | Args: 113 | batch (batch) : batch of labeled examples 114 | output (:obj:`FloatTensor`) : 115 | output of decoder model `[tgt_len x batch x hidden]` 116 | attns (dict) : dictionary of attention distributions 117 | `[tgt_len x batch x src_len]` 118 | cur_trunc (int) : starting position of truncation window 119 | trunc_size (int) : length of truncation window 120 | shard_size (int) : maximum number of examples in a shard 121 | normalization (int) : Loss is divided by this number 122 | 123 | Returns: 124 | :obj:`onmt.utils.Statistics`: validation loss statistics 125 | 126 | """ 127 | batch_stats = Statistics() 128 | shard_state = self._make_shard_state(batch, output) 129 | for shard in shards(shard_state, shard_size): 130 | loss, stats = self._compute_loss(batch, **shard) 131 | loss.div(float(normalization)).backward() 132 | batch_stats.update(stats) 133 | 134 | return batch_stats 135 | 136 | def _stats(self, loss, scores, target): 137 | """ 138 | Args: 139 | loss (:obj:`FloatTensor`): the loss computed by the loss criterion. 140 | scores (:obj:`FloatTensor`): a score for each possible output 141 | target (:obj:`FloatTensor`): true targets 142 | 143 | Returns: 144 | :obj:`onmt.utils.Statistics` : statistics for this batch. 145 | """ 146 | pred = scores.max(1)[1] 147 | non_padding = target.ne(self.padding_idx) 148 | num_correct = pred.eq(target) \ 149 | .masked_select(non_padding) \ 150 | .sum() \ 151 | .item() 152 | num_non_padding = non_padding.sum().item() 153 | return Statistics(loss.item(), num_non_padding, num_correct) 154 | 155 | def _bottle(self, _v): 156 | return _v.view(-1, _v.size(2)) 157 | 158 | def _unbottle(self, _v, batch_size): 159 | return _v.view(-1, batch_size, _v.size(1)) 160 | 161 | 162 | class LabelSmoothingLoss(nn.Module): 163 | """ 164 | With label smoothing, 165 | KL-divergence between q_{smoothed ground truth prob.}(w) 166 | and p_{prob. computed by model}(w) is minimized. 167 | """ 168 | def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): 169 | assert 0.0 < label_smoothing <= 1.0 170 | self.padding_idx = ignore_index 171 | super(LabelSmoothingLoss, self).__init__() 172 | 173 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 174 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 175 | one_hot[self.padding_idx] = 0 176 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 177 | self.confidence = 1.0 - label_smoothing 178 | 179 | def forward(self, output, target): 180 | """ 181 | output (FloatTensor): batch_size x n_classes 182 | target (LongTensor): batch_size 183 | """ 184 | model_prob = self.one_hot.repeat(target.size(0), 1) 185 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 186 | model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) 187 | 188 | return F.kl_div(output, model_prob, reduction='sum') 189 | 190 | 191 | class NMTLossCompute(LossComputeBase): 192 | """ 193 | Standard NMT Loss Computation. 194 | """ 195 | 196 | def __init__(self, generator, symbols, vocab_size, 197 | label_smoothing=0.0): 198 | super(NMTLossCompute, self).__init__(generator, symbols['PAD']) 199 | self.sparse = not isinstance(generator[1], nn.LogSoftmax) 200 | if label_smoothing > 0: 201 | self.criterion = LabelSmoothingLoss( 202 | label_smoothing, vocab_size, ignore_index=self.padding_idx 203 | ) 204 | else: 205 | self.criterion = nn.NLLLoss( 206 | ignore_index=self.padding_idx, reduction='sum' 207 | ) 208 | 209 | def _make_shard_state(self, batch, output): 210 | return { 211 | "output": output, 212 | "target": batch.tgt[:,1:], 213 | } 214 | 215 | def _compute_loss(self, batch, output, target): 216 | bottled_output = self._bottle(output) 217 | scores = self.generator(bottled_output) 218 | gtruth =target.contiguous().view(-1) 219 | 220 | loss = self.criterion(scores, gtruth) 221 | 222 | stats = self._stats(loss.clone(), scores, gtruth) 223 | 224 | return loss, stats 225 | 226 | 227 | def filter_shard_state(state, shard_size=None): 228 | """ ? """ 229 | for k, v in state.items(): 230 | if shard_size is None: 231 | yield k, v 232 | 233 | if v is not None: 234 | v_split = [] 235 | if isinstance(v, torch.Tensor): 236 | for v_chunk in torch.split(v, shard_size): 237 | v_chunk = v_chunk.data.clone() 238 | v_chunk.requires_grad = v.requires_grad 239 | v_split.append(v_chunk) 240 | yield k, (v, v_split) 241 | 242 | 243 | def shards(state, shard_size, eval_only=False): 244 | """ 245 | Args: 246 | state: A dictionary which corresponds to the output of 247 | *LossCompute._make_shard_state(). The values for 248 | those keys are Tensor-like or None. 249 | shard_size: The maximum size of the shards yielded by the model. 250 | eval_only: If True, only yield the state, nothing else. 251 | Otherwise, yield shards. 252 | 253 | Yields: 254 | Each yielded shard is a dict. 255 | 256 | Side effect: 257 | After the last shard, this function does back-propagation. 258 | """ 259 | if eval_only: 260 | yield filter_shard_state(state) 261 | else: 262 | # non_none: the subdict of the state dictionary where the values 263 | # are not None. 264 | non_none = dict(filter_shard_state(state, shard_size)) 265 | 266 | # Now, the iteration: 267 | # state is a dictionary of sequences of tensor-like but we 268 | # want a sequence of dictionaries of tensors. 269 | # First, unzip the dictionary into a sequence of keys and a 270 | # sequence of tensor-like sequences. 271 | keys, values = zip(*((k, [v_chunk for v_chunk in v_split]) 272 | for k, (_, v_split) in non_none.items())) 273 | 274 | # Now, yield a dictionary for each shard. The keys are always 275 | # the same. values is a sequence of length #keys where each 276 | # element is a sequence of length #shards. We want to iterate 277 | # over the shards, not over the keys: therefore, the values need 278 | # to be re-zipped by shard and then each shard can be paired 279 | # with the keys. 280 | for shard_tensors in zip(*values): 281 | yield dict(zip(keys, shard_tensors)) 282 | 283 | # Assumed backprop'd 284 | variables = [] 285 | for k, (v, v_split) in non_none.items(): 286 | if isinstance(v, torch.Tensor) and state[k].requires_grad: 287 | variables.extend(zip(torch.split(state[k], shard_size), 288 | [v_chunk.grad for v_chunk in v_split])) 289 | inputs, grads = zip(*variables) 290 | torch.autograd.backward(inputs, grads) 291 | -------------------------------------------------------------------------------- /src/models/model_builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_transformers import BertModel, BertConfig 6 | from torch.nn.init import xavier_uniform_ 7 | 8 | from models.decoder import TransformerDecoder 9 | from models.encoder import Classifier, ExtTransformerEncoder 10 | from models.optimizers import Optimizer 11 | 12 | def build_optim(args, model, checkpoint): 13 | """ Build optimizer """ 14 | 15 | if checkpoint is not None: 16 | optim = checkpoint['optim'][0] 17 | saved_optimizer_state_dict = optim.optimizer.state_dict() 18 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 19 | if args.visible_gpus != '-1': 20 | for state in optim.optimizer.state.values(): 21 | for k, v in state.items(): 22 | if torch.is_tensor(v): 23 | state[k] = v.cuda() 24 | 25 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 26 | raise RuntimeError( 27 | "Error: loaded Adam optimizer from existing model" + 28 | " but optimizer state is empty") 29 | 30 | else: 31 | optim = Optimizer( 32 | args.optim, args.lr, args.max_grad_norm, 33 | beta1=args.beta1, beta2=args.beta2, 34 | decay_method='noam', 35 | warmup_steps=args.warmup_steps) 36 | 37 | optim.set_parameters(list(model.named_parameters())) 38 | 39 | 40 | return optim 41 | 42 | def build_optim_bert(args, model, checkpoint): 43 | """ Build optimizer """ 44 | 45 | if checkpoint is not None: 46 | optim = checkpoint['optims'][0] 47 | saved_optimizer_state_dict = optim.optimizer.state_dict() 48 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 49 | if args.visible_gpus != '-1': 50 | for state in optim.optimizer.state.values(): 51 | for k, v in state.items(): 52 | if torch.is_tensor(v): 53 | state[k] = v.cuda() 54 | 55 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 56 | raise RuntimeError( 57 | "Error: loaded Adam optimizer from existing model" + 58 | " but optimizer state is empty") 59 | 60 | else: 61 | optim = Optimizer( 62 | args.optim, args.lr_bert, args.max_grad_norm, 63 | beta1=args.beta1, beta2=args.beta2, 64 | decay_method='noam', 65 | warmup_steps=args.warmup_steps_bert) 66 | 67 | params = [(n, p) for n, p in list(model.named_parameters()) if n.startswith('bert.model')] 68 | optim.set_parameters(params) 69 | 70 | 71 | return optim 72 | 73 | def build_optim_dec(args, model, checkpoint): 74 | """ Build optimizer """ 75 | 76 | if checkpoint is not None: 77 | optim = checkpoint['optims'][1] 78 | saved_optimizer_state_dict = optim.optimizer.state_dict() 79 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 80 | if args.visible_gpus != '-1': 81 | for state in optim.optimizer.state.values(): 82 | for k, v in state.items(): 83 | if torch.is_tensor(v): 84 | state[k] = v.cuda() 85 | 86 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 87 | raise RuntimeError( 88 | "Error: loaded Adam optimizer from existing model" + 89 | " but optimizer state is empty") 90 | 91 | else: 92 | optim = Optimizer( 93 | args.optim, args.lr_dec, args.max_grad_norm, 94 | beta1=args.beta1, beta2=args.beta2, 95 | decay_method='noam', 96 | warmup_steps=args.warmup_steps_dec) 97 | 98 | params = [(n, p) for n, p in list(model.named_parameters()) if not n.startswith('bert.model')] 99 | optim.set_parameters(params) 100 | 101 | 102 | return optim 103 | 104 | 105 | def get_generator(vocab_size, dec_hidden_size, device): 106 | gen_func = nn.LogSoftmax(dim=-1) 107 | generator = nn.Sequential( 108 | nn.Linear(dec_hidden_size, vocab_size), 109 | gen_func 110 | ) 111 | generator.to(device) 112 | 113 | return generator 114 | 115 | class Bert(nn.Module): 116 | def __init__(self, large, temp_dir, finetune=False): 117 | super(Bert, self).__init__() 118 | if(large): 119 | self.model = BertModel.from_pretrained('bert-large-uncased', cache_dir=temp_dir) 120 | else: 121 | self.model = BertModel.from_pretrained('bert-base-uncased', cache_dir=temp_dir) 122 | 123 | self.finetune = finetune 124 | 125 | def forward(self, x, segs, mask): 126 | if(self.finetune): 127 | top_vec, _ = self.model(x, segs, attention_mask=mask) 128 | else: 129 | self.eval() 130 | with torch.no_grad(): 131 | top_vec, _ = self.model(x, segs, attention_mask=mask) 132 | return top_vec 133 | 134 | 135 | class ExtSummarizer(nn.Module): 136 | def __init__(self, args, device, checkpoint): 137 | super(ExtSummarizer, self).__init__() 138 | self.args = args 139 | self.device = device 140 | self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) 141 | 142 | self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads, 143 | args.ext_dropout, args.ext_layers) 144 | if (args.encoder == 'baseline'): 145 | bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size, 146 | num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size) 147 | self.bert.model = BertModel(bert_config) 148 | self.ext_layer = Classifier(self.bert.model.config.hidden_size) 149 | 150 | if(args.max_pos>512): 151 | my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) 152 | my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data 153 | my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) 154 | self.bert.model.embeddings.position_embeddings = my_pos_embeddings 155 | 156 | 157 | if checkpoint is not None: 158 | self.load_state_dict(checkpoint['model'], strict=True) 159 | else: 160 | if args.param_init != 0.0: 161 | for p in self.ext_layer.parameters(): 162 | p.data.uniform_(-args.param_init, args.param_init) 163 | if args.param_init_glorot: 164 | for p in self.ext_layer.parameters(): 165 | if p.dim() > 1: 166 | xavier_uniform_(p) 167 | 168 | self.to(device) 169 | 170 | def forward(self, src, segs, clss, mask_src, mask_cls): 171 | top_vec = self.bert(src, segs, mask_src) 172 | sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss] 173 | sents_vec = sents_vec * mask_cls[:, :, None].float() 174 | sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1) 175 | return sent_scores, mask_cls 176 | 177 | 178 | class AbsSummarizer(nn.Module): 179 | def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): 180 | super(AbsSummarizer, self).__init__() 181 | self.args = args 182 | self.device = device 183 | self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) 184 | 185 | if bert_from_extractive is not None: 186 | self.bert.model.load_state_dict( 187 | dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True) 188 | 189 | if (args.encoder == 'baseline'): 190 | bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size, 191 | num_hidden_layers=args.enc_layers, num_attention_heads=8, 192 | intermediate_size=args.enc_ff_size, 193 | hidden_dropout_prob=args.enc_dropout, 194 | attention_probs_dropout_prob=args.enc_dropout) 195 | self.bert.model = BertModel(bert_config) 196 | 197 | if(args.max_pos>512): 198 | my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) 199 | my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data 200 | my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) 201 | self.bert.model.embeddings.position_embeddings = my_pos_embeddings 202 | self.vocab_size = self.bert.model.config.vocab_size 203 | tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) 204 | if (self.args.share_emb): 205 | tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) 206 | 207 | self.decoder = TransformerDecoder( 208 | self.args.dec_layers, 209 | self.args.dec_hidden_size, heads=self.args.dec_heads, 210 | d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) 211 | 212 | self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device) 213 | self.generator[0].weight = self.decoder.embeddings.weight 214 | 215 | 216 | if checkpoint is not None: 217 | self.load_state_dict(checkpoint['model'], strict=True) 218 | else: 219 | for module in self.decoder.modules(): 220 | if isinstance(module, (nn.Linear, nn.Embedding)): 221 | module.weight.data.normal_(mean=0.0, std=0.02) 222 | elif isinstance(module, nn.LayerNorm): 223 | module.bias.data.zero_() 224 | module.weight.data.fill_(1.0) 225 | if isinstance(module, nn.Linear) and module.bias is not None: 226 | module.bias.data.zero_() 227 | for p in self.generator.parameters(): 228 | if p.dim() > 1: 229 | xavier_uniform_(p) 230 | else: 231 | p.data.zero_() 232 | if(args.use_bert_emb): 233 | tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) 234 | tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) 235 | self.decoder.embeddings = tgt_embeddings 236 | self.generator[0].weight = self.decoder.embeddings.weight 237 | 238 | self.to(device) 239 | 240 | def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls): 241 | top_vec = self.bert(src, segs, mask_src) 242 | dec_state = self.decoder.init_decoder_state(src, top_vec) 243 | decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state) 244 | return decoder_outputs, None 245 | -------------------------------------------------------------------------------- /src/models/optimizers.py: -------------------------------------------------------------------------------- 1 | """ Optimizers class """ 2 | import torch 3 | import torch.optim as optim 4 | from torch.nn.utils import clip_grad_norm_ 5 | 6 | 7 | # from onmt.utils import use_gpu 8 | # from models.adam import Adam 9 | 10 | 11 | def use_gpu(opt): 12 | """ 13 | Creates a boolean if gpu used 14 | """ 15 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 16 | (hasattr(opt, 'gpu') and opt.gpu > -1) 17 | 18 | def build_optim(model, opt, checkpoint): 19 | """ Build optimizer """ 20 | saved_optimizer_state_dict = None 21 | 22 | if opt.train_from: 23 | optim = checkpoint['optim'] 24 | # We need to save a copy of optim.optimizer.state_dict() for setting 25 | # the, optimizer state later on in Stage 2 in this method, since 26 | # the method optim.set_parameters(model.parameters()) will overwrite 27 | # optim.optimizer, and with ith the values stored in 28 | # optim.optimizer.state_dict() 29 | saved_optimizer_state_dict = optim.optimizer.state_dict() 30 | else: 31 | optim = Optimizer( 32 | opt.optim, opt.learning_rate, opt.max_grad_norm, 33 | lr_decay=opt.learning_rate_decay, 34 | start_decay_steps=opt.start_decay_steps, 35 | decay_steps=opt.decay_steps, 36 | beta1=opt.adam_beta1, 37 | beta2=opt.adam_beta2, 38 | adagrad_accum=opt.adagrad_accumulator_init, 39 | decay_method=opt.decay_method, 40 | warmup_steps=opt.warmup_steps) 41 | 42 | optim.set_parameters(model.named_parameters()) 43 | 44 | if opt.train_from: 45 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 46 | if use_gpu(opt): 47 | for state in optim.optimizer.state.values(): 48 | for k, v in state.items(): 49 | if torch.is_tensor(v): 50 | state[k] = v.cuda() 51 | 52 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 53 | raise RuntimeError( 54 | "Error: loaded Adam optimizer from existing model" + 55 | " but optimizer state is empty") 56 | 57 | return optim 58 | 59 | 60 | class MultipleOptimizer(object): 61 | """ Implement multiple optimizers needed for sparse adam """ 62 | 63 | def __init__(self, op): 64 | """ ? """ 65 | self.optimizers = op 66 | 67 | def zero_grad(self): 68 | """ ? """ 69 | for op in self.optimizers: 70 | op.zero_grad() 71 | 72 | def step(self): 73 | """ ? """ 74 | for op in self.optimizers: 75 | op.step() 76 | 77 | @property 78 | def state(self): 79 | """ ? """ 80 | return {k: v for op in self.optimizers for k, v in op.state.items()} 81 | 82 | def state_dict(self): 83 | """ ? """ 84 | return [op.state_dict() for op in self.optimizers] 85 | 86 | def load_state_dict(self, state_dicts): 87 | """ ? """ 88 | assert len(state_dicts) == len(self.optimizers) 89 | for i in range(len(state_dicts)): 90 | self.optimizers[i].load_state_dict(state_dicts[i]) 91 | 92 | 93 | class Optimizer(object): 94 | """ 95 | Controller class for optimization. Mostly a thin 96 | wrapper for `optim`, but also useful for implementing 97 | rate scheduling beyond what is currently available. 98 | Also implements necessary methods for training RNNs such 99 | as grad manipulations. 100 | 101 | Args: 102 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 103 | lr (float): learning rate 104 | lr_decay (float, optional): learning rate decay multiplier 105 | start_decay_steps (int, optional): step to start learning rate decay 106 | beta1, beta2 (float, optional): parameters for adam 107 | adagrad_accum (float, optional): initialization parameter for adagrad 108 | decay_method (str, option): custom decay options 109 | warmup_steps (int, option): parameter for `noam` decay 110 | model_size (int, option): parameter for `noam` decay 111 | 112 | We use the default parameters for Adam that are suggested by 113 | the original paper https://arxiv.org/pdf/1412.6980.pdf 114 | These values are also used by other established implementations, 115 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 116 | https://keras.io/optimizers/ 117 | Recently there are slightly different values used in the paper 118 | "Attention is all you need" 119 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 120 | was used there however, beta2=0.999 is still arguably the more 121 | established value, so we use that here as well 122 | """ 123 | 124 | def __init__(self, method, learning_rate, max_grad_norm, 125 | lr_decay=1, start_decay_steps=None, decay_steps=None, 126 | beta1=0.9, beta2=0.999, 127 | adagrad_accum=0.0, 128 | decay_method=None, 129 | warmup_steps=4000, weight_decay=0): 130 | self.last_ppl = None 131 | self.learning_rate = learning_rate 132 | self.original_lr = learning_rate 133 | self.max_grad_norm = max_grad_norm 134 | self.method = method 135 | self.lr_decay = lr_decay 136 | self.start_decay_steps = start_decay_steps 137 | self.decay_steps = decay_steps 138 | self.start_decay = False 139 | self._step = 0 140 | self.betas = [beta1, beta2] 141 | self.adagrad_accum = adagrad_accum 142 | self.decay_method = decay_method 143 | self.warmup_steps = warmup_steps 144 | self.weight_decay = weight_decay 145 | 146 | def set_parameters(self, params): 147 | """ ? """ 148 | self.params = [] 149 | self.sparse_params = [] 150 | for k, p in params: 151 | if p.requires_grad: 152 | if self.method != 'sparseadam' or "embed" not in k: 153 | self.params.append(p) 154 | else: 155 | self.sparse_params.append(p) 156 | if self.method == 'sgd': 157 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate) 158 | elif self.method == 'adagrad': 159 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate) 160 | for group in self.optimizer.param_groups: 161 | for p in group['params']: 162 | self.optimizer.state[p]['sum'] = self.optimizer\ 163 | .state[p]['sum'].fill_(self.adagrad_accum) 164 | elif self.method == 'adadelta': 165 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate) 166 | elif self.method == 'adam': 167 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 168 | betas=self.betas, eps=1e-9) 169 | else: 170 | raise RuntimeError("Invalid optim method: " + self.method) 171 | 172 | def _set_rate(self, learning_rate): 173 | self.learning_rate = learning_rate 174 | if self.method != 'sparseadam': 175 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 176 | else: 177 | for op in self.optimizer.optimizers: 178 | op.param_groups[0]['lr'] = self.learning_rate 179 | 180 | def step(self): 181 | """Update the model parameters based on current gradients. 182 | 183 | Optionally, will employ gradient modification or update learning 184 | rate. 185 | """ 186 | self._step += 1 187 | 188 | # Decay method used in tensor2tensor. 189 | if self.decay_method == "noam": 190 | self._set_rate( 191 | self.original_lr * 192 | min(self._step ** (-0.5), 193 | self._step * self.warmup_steps**(-1.5))) 194 | 195 | else: 196 | if ((self.start_decay_steps is not None) and ( 197 | self._step >= self.start_decay_steps)): 198 | self.start_decay = True 199 | if self.start_decay: 200 | if ((self._step - self.start_decay_steps) 201 | % self.decay_steps == 0): 202 | self.learning_rate = self.learning_rate * self.lr_decay 203 | 204 | if self.method != 'sparseadam': 205 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 206 | 207 | if self.max_grad_norm: 208 | clip_grad_norm_(self.params, self.max_grad_norm) 209 | self.optimizer.step() 210 | 211 | 212 | -------------------------------------------------------------------------------- /src/models/reporter.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | from datetime import datetime 4 | 5 | import time 6 | import math 7 | import sys 8 | 9 | from distributed import all_gather_list 10 | from others.logging import logger 11 | 12 | 13 | def build_report_manager(opt): 14 | if opt.tensorboard: 15 | from tensorboardX import SummaryWriter 16 | writer = SummaryWriter(opt.tensorboard_log_dir 17 | + datetime.now().strftime("/%b-%d_%H-%M-%S"), 18 | comment="Unmt") 19 | else: 20 | writer = None 21 | 22 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 23 | tensorboard_writer=writer) 24 | return report_mgr 25 | 26 | 27 | class ReportMgrBase(object): 28 | """ 29 | Report Manager Base class 30 | Inherited classes should override: 31 | * `_report_training` 32 | * `_report_step` 33 | """ 34 | 35 | def __init__(self, report_every, start_time=-1.): 36 | """ 37 | Args: 38 | report_every(int): Report status every this many sentences 39 | start_time(float): manually set report start time. Negative values 40 | means that you will need to set it later or use `start()` 41 | """ 42 | self.report_every = report_every 43 | self.progress_step = 0 44 | self.start_time = start_time 45 | 46 | def start(self): 47 | self.start_time = time.time() 48 | 49 | def log(self, *args, **kwargs): 50 | logger.info(*args, **kwargs) 51 | 52 | def report_training(self, step, num_steps, learning_rate, 53 | report_stats, multigpu=False): 54 | """ 55 | This is the user-defined batch-level traing progress 56 | report function. 57 | 58 | Args: 59 | step(int): current step count. 60 | num_steps(int): total number of batches. 61 | learning_rate(float): current learning rate. 62 | report_stats(Statistics): old Statistics instance. 63 | Returns: 64 | report_stats(Statistics): updated Statistics instance. 65 | """ 66 | if self.start_time < 0: 67 | raise ValueError("""ReportMgr needs to be started 68 | (set 'start_time' or use 'start()'""") 69 | 70 | if multigpu: 71 | report_stats = Statistics.all_gather_stats(report_stats) 72 | 73 | if step % self.report_every == 0: 74 | self._report_training( 75 | step, num_steps, learning_rate, report_stats) 76 | self.progress_step += 1 77 | return Statistics() 78 | 79 | def _report_training(self, *args, **kwargs): 80 | """ To be overridden """ 81 | raise NotImplementedError() 82 | 83 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 84 | """ 85 | Report stats of a step 86 | 87 | Args: 88 | train_stats(Statistics): training stats 89 | valid_stats(Statistics): validation stats 90 | lr(float): current learning rate 91 | """ 92 | self._report_step( 93 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 94 | 95 | def _report_step(self, *args, **kwargs): 96 | raise NotImplementedError() 97 | 98 | 99 | class ReportMgr(ReportMgrBase): 100 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 101 | """ 102 | A report manager that writes statistics on standard output as well as 103 | (optionally) TensorBoard 104 | 105 | Args: 106 | report_every(int): Report status every this many sentences 107 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 108 | The TensorBoard Summary writer to use or None 109 | """ 110 | super(ReportMgr, self).__init__(report_every, start_time) 111 | self.tensorboard_writer = tensorboard_writer 112 | 113 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 114 | if self.tensorboard_writer is not None: 115 | stats.log_tensorboard( 116 | prefix, self.tensorboard_writer, learning_rate, step) 117 | 118 | def _report_training(self, step, num_steps, learning_rate, 119 | report_stats): 120 | """ 121 | See base class method `ReportMgrBase.report_training`. 122 | """ 123 | report_stats.output(step, num_steps, 124 | learning_rate, self.start_time) 125 | 126 | # Log the progress using the number of batches on the x-axis. 127 | self.maybe_log_tensorboard(report_stats, 128 | "progress", 129 | learning_rate, 130 | step) 131 | report_stats = Statistics() 132 | 133 | return report_stats 134 | 135 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 136 | """ 137 | See base class method `ReportMgrBase.report_step`. 138 | """ 139 | if train_stats is not None: 140 | self.log('Train perplexity: %g' % train_stats.ppl()) 141 | self.log('Train accuracy: %g' % train_stats.accuracy()) 142 | 143 | self.maybe_log_tensorboard(train_stats, 144 | "train", 145 | lr, 146 | step) 147 | 148 | if valid_stats is not None: 149 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 150 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 151 | 152 | self.maybe_log_tensorboard(valid_stats, 153 | "valid", 154 | lr, 155 | step) 156 | 157 | 158 | class Statistics(object): 159 | """ 160 | Accumulator for loss statistics. 161 | Currently calculates: 162 | 163 | * accuracy 164 | * perplexity 165 | * elapsed time 166 | """ 167 | 168 | def __init__(self, loss=0, n_words=0, n_correct=0): 169 | self.loss = loss 170 | self.n_words = n_words 171 | self.n_docs = 0 172 | self.n_correct = n_correct 173 | self.n_src_words = 0 174 | self.start_time = time.time() 175 | 176 | @staticmethod 177 | def all_gather_stats(stat, max_size=4096): 178 | """ 179 | Gather a `Statistics` object accross multiple process/nodes 180 | 181 | Args: 182 | stat(:obj:Statistics): the statistics object to gather 183 | accross all processes/nodes 184 | max_size(int): max buffer size to use 185 | 186 | Returns: 187 | `Statistics`, the update stats object 188 | """ 189 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 190 | return stats[0] 191 | 192 | @staticmethod 193 | def all_gather_stats_list(stat_list, max_size=4096): 194 | from torch.distributed import get_rank 195 | 196 | """ 197 | Gather a `Statistics` list accross all processes/nodes 198 | 199 | Args: 200 | stat_list(list([`Statistics`])): list of statistics objects to 201 | gather accross all processes/nodes 202 | max_size(int): max buffer size to use 203 | 204 | Returns: 205 | our_stats(list([`Statistics`])): list of updated stats 206 | """ 207 | # Get a list of world_size lists with len(stat_list) Statistics objects 208 | all_stats = all_gather_list(stat_list, max_size=max_size) 209 | 210 | our_rank = get_rank() 211 | our_stats = all_stats[our_rank] 212 | for other_rank, stats in enumerate(all_stats): 213 | if other_rank == our_rank: 214 | continue 215 | for i, stat in enumerate(stats): 216 | our_stats[i].update(stat, update_n_src_words=True) 217 | return our_stats 218 | 219 | def update(self, stat, update_n_src_words=False): 220 | """ 221 | Update statistics by suming values with another `Statistics` object 222 | 223 | Args: 224 | stat: another statistic object 225 | update_n_src_words(bool): whether to update (sum) `n_src_words` 226 | or not 227 | 228 | """ 229 | self.loss += stat.loss 230 | self.n_words += stat.n_words 231 | self.n_correct += stat.n_correct 232 | self.n_docs += stat.n_docs 233 | 234 | if update_n_src_words: 235 | self.n_src_words += stat.n_src_words 236 | 237 | def accuracy(self): 238 | """ compute accuracy """ 239 | return 100 * (self.n_correct / self.n_words) 240 | 241 | def xent(self): 242 | """ compute cross entropy """ 243 | return self.loss / self.n_words 244 | 245 | def ppl(self): 246 | """ compute perplexity """ 247 | return math.exp(min(self.loss / self.n_words, 100)) 248 | 249 | def elapsed_time(self): 250 | """ compute elapsed time """ 251 | return time.time() - self.start_time 252 | 253 | def output(self, step, num_steps, learning_rate, start): 254 | """Write out statistics to stdout. 255 | 256 | Args: 257 | step (int): current step 258 | n_batch (int): total batches 259 | start (int): start time of step. 260 | """ 261 | t = self.elapsed_time() 262 | logger.info( 263 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + 264 | "lr: %7.8f; %3.0f/%3.0f tok/s; %6.0f sec") 265 | % (step, num_steps, 266 | self.accuracy(), 267 | self.ppl(), 268 | self.xent(), 269 | learning_rate, 270 | self.n_src_words / (t + 1e-5), 271 | self.n_words / (t + 1e-5), 272 | time.time() - start)) 273 | sys.stdout.flush() 274 | 275 | def log_tensorboard(self, prefix, writer, learning_rate, step): 276 | """ display statistics to tensorboard """ 277 | t = self.elapsed_time() 278 | writer.add_scalar(prefix + "/xent", self.xent(), step) 279 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 280 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 281 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 282 | writer.add_scalar(prefix + "/lr", learning_rate, step) 283 | -------------------------------------------------------------------------------- /src/models/reporter_ext.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | 4 | import sys 5 | import time 6 | from datetime import datetime 7 | 8 | from others.logging import logger 9 | 10 | 11 | def build_report_manager(opt): 12 | if opt.tensorboard: 13 | from tensorboardX import SummaryWriter 14 | tensorboard_log_dir = opt.tensorboard_log_dir 15 | 16 | if not opt.train_from: 17 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 18 | 19 | writer = SummaryWriter(tensorboard_log_dir, 20 | comment="Unmt") 21 | else: 22 | writer = None 23 | 24 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 25 | tensorboard_writer=writer) 26 | return report_mgr 27 | 28 | 29 | class ReportMgrBase(object): 30 | """ 31 | Report Manager Base class 32 | Inherited classes should override: 33 | * `_report_training` 34 | * `_report_step` 35 | """ 36 | 37 | def __init__(self, report_every, start_time=-1.): 38 | """ 39 | Args: 40 | report_every(int): Report status every this many sentences 41 | start_time(float): manually set report start time. Negative values 42 | means that you will need to set it later or use `start()` 43 | """ 44 | self.report_every = report_every 45 | self.progress_step = 0 46 | self.start_time = start_time 47 | 48 | def start(self): 49 | self.start_time = time.time() 50 | 51 | def log(self, *args, **kwargs): 52 | logger.info(*args, **kwargs) 53 | 54 | def report_training(self, step, num_steps, learning_rate, 55 | report_stats, multigpu=False): 56 | """ 57 | This is the user-defined batch-level traing progress 58 | report function. 59 | 60 | Args: 61 | step(int): current step count. 62 | num_steps(int): total number of batches. 63 | learning_rate(float): current learning rate. 64 | report_stats(Statistics): old Statistics instance. 65 | Returns: 66 | report_stats(Statistics): updated Statistics instance. 67 | """ 68 | if self.start_time < 0: 69 | raise ValueError("""ReportMgr needs to be started 70 | (set 'start_time' or use 'start()'""") 71 | 72 | if step % self.report_every == 0: 73 | if multigpu: 74 | report_stats = \ 75 | Statistics.all_gather_stats(report_stats) 76 | self._report_training( 77 | step, num_steps, learning_rate, report_stats) 78 | self.progress_step += 1 79 | return Statistics() 80 | else: 81 | return report_stats 82 | 83 | def _report_training(self, *args, **kwargs): 84 | """ To be overridden """ 85 | raise NotImplementedError() 86 | 87 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 88 | """ 89 | Report stats of a step 90 | 91 | Args: 92 | train_stats(Statistics): training stats 93 | valid_stats(Statistics): validation stats 94 | lr(float): current learning rate 95 | """ 96 | self._report_step( 97 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 98 | 99 | def _report_step(self, *args, **kwargs): 100 | raise NotImplementedError() 101 | 102 | 103 | class ReportMgr(ReportMgrBase): 104 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 105 | """ 106 | A report manager that writes statistics on standard output as well as 107 | (optionally) TensorBoard 108 | 109 | Args: 110 | report_every(int): Report status every this many sentences 111 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 112 | The TensorBoard Summary writer to use or None 113 | """ 114 | super(ReportMgr, self).__init__(report_every, start_time) 115 | self.tensorboard_writer = tensorboard_writer 116 | 117 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 118 | if self.tensorboard_writer is not None: 119 | stats.log_tensorboard( 120 | prefix, self.tensorboard_writer, learning_rate, step) 121 | 122 | def _report_training(self, step, num_steps, learning_rate, 123 | report_stats): 124 | """ 125 | See base class method `ReportMgrBase.report_training`. 126 | """ 127 | report_stats.output(step, num_steps, 128 | learning_rate, self.start_time) 129 | 130 | # Log the progress using the number of batches on the x-axis. 131 | self.maybe_log_tensorboard(report_stats, 132 | "progress", 133 | learning_rate, 134 | self.progress_step) 135 | report_stats = Statistics() 136 | 137 | return report_stats 138 | 139 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 140 | """ 141 | See base class method `ReportMgrBase.report_step`. 142 | """ 143 | if train_stats is not None: 144 | self.log('Train xent: %g' % train_stats.xent()) 145 | 146 | self.maybe_log_tensorboard(train_stats, 147 | "train", 148 | lr, 149 | step) 150 | 151 | if valid_stats is not None: 152 | self.log('Validation xent: %g at step %d' % (valid_stats.xent(), step)) 153 | 154 | self.maybe_log_tensorboard(valid_stats, 155 | "valid", 156 | lr, 157 | step) 158 | 159 | 160 | class Statistics(object): 161 | """ 162 | Accumulator for loss statistics. 163 | Currently calculates: 164 | 165 | * accuracy 166 | * perplexity 167 | * elapsed time 168 | """ 169 | 170 | def __init__(self, loss=0, n_docs=0, n_correct=0): 171 | self.loss = loss 172 | self.n_docs = n_docs 173 | self.start_time = time.time() 174 | 175 | @staticmethod 176 | def all_gather_stats(stat, max_size=4096): 177 | """ 178 | Gather a `Statistics` object accross multiple process/nodes 179 | 180 | Args: 181 | stat(:obj:Statistics): the statistics object to gather 182 | accross all processes/nodes 183 | max_size(int): max buffer size to use 184 | 185 | Returns: 186 | `Statistics`, the update stats object 187 | """ 188 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 189 | return stats[0] 190 | 191 | @staticmethod 192 | def all_gather_stats_list(stat_list, max_size=4096): 193 | """ 194 | Gather a `Statistics` list accross all processes/nodes 195 | 196 | Args: 197 | stat_list(list([`Statistics`])): list of statistics objects to 198 | gather accross all processes/nodes 199 | max_size(int): max buffer size to use 200 | 201 | Returns: 202 | our_stats(list([`Statistics`])): list of updated stats 203 | """ 204 | from torch.distributed import get_rank 205 | from distributed import all_gather_list 206 | 207 | # Get a list of world_size lists with len(stat_list) Statistics objects 208 | all_stats = all_gather_list(stat_list, max_size=max_size) 209 | 210 | our_rank = get_rank() 211 | our_stats = all_stats[our_rank] 212 | for other_rank, stats in enumerate(all_stats): 213 | if other_rank == our_rank: 214 | continue 215 | for i, stat in enumerate(stats): 216 | our_stats[i].update(stat, update_n_src_words=True) 217 | return our_stats 218 | 219 | def update(self, stat, update_n_src_words=False): 220 | """ 221 | Update statistics by suming values with another `Statistics` object 222 | 223 | Args: 224 | stat: another statistic object 225 | update_n_src_words(bool): whether to update (sum) `n_src_words` 226 | or not 227 | 228 | """ 229 | self.loss += stat.loss 230 | 231 | self.n_docs += stat.n_docs 232 | 233 | def xent(self): 234 | """ compute cross entropy """ 235 | if (self.n_docs == 0): 236 | return 0 237 | return self.loss / self.n_docs 238 | 239 | def elapsed_time(self): 240 | """ compute elapsed time """ 241 | return time.time() - self.start_time 242 | 243 | def output(self, step, num_steps, learning_rate, start): 244 | """Write out statistics to stdout. 245 | 246 | Args: 247 | step (int): current step 248 | n_batch (int): total batches 249 | start (int): start time of step. 250 | """ 251 | t = self.elapsed_time() 252 | step_fmt = "%2d" % step 253 | if num_steps > 0: 254 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 255 | logger.info( 256 | ("Step %s; xent: %4.2f; " + 257 | "lr: %7.7f; %3.0f docs/s; %6.0f sec") 258 | % (step_fmt, 259 | self.xent(), 260 | learning_rate, 261 | self.n_docs / (t + 1e-5), 262 | time.time() - start)) 263 | sys.stdout.flush() 264 | 265 | def log_tensorboard(self, prefix, writer, learning_rate, step): 266 | """ display statistics to tensorboard """ 267 | t = self.elapsed_time() 268 | writer.add_scalar(prefix + "/xent", self.xent(), step) 269 | writer.add_scalar(prefix + "/lr", learning_rate, step) 270 | -------------------------------------------------------------------------------- /src/models/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tensorboardX import SummaryWriter 6 | 7 | import distributed 8 | from models.reporter import ReportMgr, Statistics 9 | from others.logging import logger 10 | from others.utils import test_rouge, rouge_results_to_str 11 | 12 | 13 | def _tally_parameters(model): 14 | n_params = sum([p.nelement() for p in model.parameters()]) 15 | return n_params 16 | 17 | 18 | def build_trainer(args, device_id, model, optims,loss): 19 | """ 20 | Simplify `Trainer` creation based on user `opt`s* 21 | Args: 22 | opt (:obj:`Namespace`): user options (usually from argument parsing) 23 | model (:obj:`onmt.models.NMTModel`): the model to train 24 | fields (dict): dict of fields 25 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training 26 | data_type (str): string describing the type of data 27 | e.g. "text", "img", "audio" 28 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object 29 | used to save the model 30 | """ 31 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 32 | 33 | 34 | grad_accum_count = args.accum_count 35 | n_gpu = args.world_size 36 | 37 | if device_id >= 0: 38 | gpu_rank = int(args.gpu_ranks[device_id]) 39 | else: 40 | gpu_rank = 0 41 | n_gpu = 0 42 | 43 | print('gpu_rank %d' % gpu_rank) 44 | 45 | tensorboard_log_dir = args.model_path 46 | 47 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 48 | 49 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) 50 | 51 | 52 | trainer = Trainer(args, model, optims, loss, grad_accum_count, n_gpu, gpu_rank, report_manager) 53 | 54 | # print(tr) 55 | if (model): 56 | n_params = _tally_parameters(model) 57 | logger.info('* number of parameters: %d' % n_params) 58 | 59 | return trainer 60 | 61 | 62 | class Trainer(object): 63 | """ 64 | Class that controls the training process. 65 | 66 | Args: 67 | model(:py:class:`onmt.models.model.NMTModel`): translation model 68 | to train 69 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`): 70 | training loss computation 71 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): 72 | training loss computation 73 | optim(:obj:`onmt.utils.optimizers.Optimizer`): 74 | the optimizer responsible for update 75 | trunc_size(int): length of truncated back propagation through time 76 | shard_size(int): compute loss in shards of this size for efficiency 77 | data_type(string): type of the source input: [text|img|audio] 78 | norm_method(string): normalization methods: [sents|tokens] 79 | grad_accum_count(int): accumulate gradients this many times. 80 | report_manager(:obj:`onmt.utils.ReportMgrBase`): 81 | the object that creates reports, or None 82 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is 83 | used to save a checkpoint. 84 | Thus nothing will be saved if this parameter is None 85 | """ 86 | 87 | def __init__(self, args, model, optims, loss, 88 | grad_accum_count=1, n_gpu=1, gpu_rank=1, 89 | report_manager=None): 90 | # Basic attributes. 91 | self.args = args 92 | self.save_checkpoint_steps = args.save_checkpoint_steps 93 | self.model = model 94 | self.optims = optims 95 | self.grad_accum_count = grad_accum_count 96 | self.n_gpu = n_gpu 97 | self.gpu_rank = gpu_rank 98 | self.report_manager = report_manager 99 | 100 | self.loss = loss 101 | 102 | assert grad_accum_count > 0 103 | # Set model in training mode. 104 | if (model): 105 | self.model.train() 106 | 107 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): 108 | """ 109 | The main training loops. 110 | by iterating over training data (i.e. `train_iter_fct`) 111 | and running validation (i.e. iterating over `valid_iter_fct` 112 | 113 | Args: 114 | train_iter_fct(function): a function that returns the train 115 | iterator. e.g. something like 116 | train_iter_fct = lambda: generator(*args, **kwargs) 117 | valid_iter_fct(function): same as train_iter_fct, for valid data 118 | train_steps(int): 119 | valid_steps(int): 120 | save_checkpoint_steps(int): 121 | 122 | Return: 123 | None 124 | """ 125 | logger.info('Start training...') 126 | 127 | # step = self.optim._step + 1 128 | step = self.optims[0]._step + 1 129 | 130 | true_batchs = [] 131 | accum = 0 132 | normalization = 0 133 | train_iter = train_iter_fct() 134 | 135 | total_stats = Statistics() 136 | report_stats = Statistics() 137 | self._start_report_manager(start_time=total_stats.start_time) 138 | 139 | while step <= train_steps: 140 | 141 | reduce_counter = 0 142 | for i, batch in enumerate(train_iter): 143 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank): 144 | 145 | true_batchs.append(batch) 146 | num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() 147 | normalization += num_tokens.item() 148 | accum += 1 149 | if accum == self.grad_accum_count: 150 | reduce_counter += 1 151 | if self.n_gpu > 1: 152 | normalization = sum(distributed 153 | .all_gather_list 154 | (normalization)) 155 | 156 | self._gradient_accumulation( 157 | true_batchs, normalization, total_stats, 158 | report_stats) 159 | 160 | report_stats = self._maybe_report_training( 161 | step, train_steps, 162 | self.optims[0].learning_rate, 163 | report_stats) 164 | 165 | true_batchs = [] 166 | accum = 0 167 | normalization = 0 168 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): 169 | self._save(step) 170 | 171 | step += 1 172 | if step > train_steps: 173 | break 174 | train_iter = train_iter_fct() 175 | 176 | return total_stats 177 | 178 | def validate(self, valid_iter, step=0): 179 | """ Validate model. 180 | valid_iter: validate data iterator 181 | Returns: 182 | :obj:`nmt.Statistics`: validation loss statistics 183 | """ 184 | # Set model in validating mode. 185 | self.model.eval() 186 | stats = Statistics() 187 | 188 | with torch.no_grad(): 189 | for batch in valid_iter: 190 | src = batch.src 191 | tgt = batch.tgt 192 | segs = batch.segs 193 | clss = batch.clss 194 | mask_src = batch.mask_src 195 | mask_tgt = batch.mask_tgt 196 | mask_cls = batch.mask_cls 197 | 198 | outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls) 199 | 200 | batch_stats = self.loss.monolithic_compute_loss(batch, outputs) 201 | stats.update(batch_stats) 202 | self._report_step(0, step, valid_stats=stats) 203 | return stats 204 | 205 | 206 | def _gradient_accumulation(self, true_batchs, normalization, total_stats, 207 | report_stats): 208 | if self.grad_accum_count > 1: 209 | self.model.zero_grad() 210 | 211 | for batch in true_batchs: 212 | if self.grad_accum_count == 1: 213 | self.model.zero_grad() 214 | 215 | src = batch.src 216 | tgt = batch.tgt 217 | segs = batch.segs 218 | clss = batch.clss 219 | mask_src = batch.mask_src 220 | mask_tgt = batch.mask_tgt 221 | mask_cls = batch.mask_cls 222 | 223 | outputs, scores = self.model(src, tgt,segs, clss, mask_src, mask_tgt, mask_cls) 224 | batch_stats = self.loss.sharded_compute_loss(batch, outputs, self.args.generator_shard_size, normalization) 225 | 226 | batch_stats.n_docs = int(src.size(0)) 227 | 228 | total_stats.update(batch_stats) 229 | report_stats.update(batch_stats) 230 | 231 | # 4. Update the parameters and statistics. 232 | if self.grad_accum_count == 1: 233 | # Multi GPU gradient gather 234 | if self.n_gpu > 1: 235 | grads = [p.grad.data for p in self.model.parameters() 236 | if p.requires_grad 237 | and p.grad is not None] 238 | distributed.all_reduce_and_rescale_tensors( 239 | grads, float(1)) 240 | 241 | for o in self.optims: 242 | o.step() 243 | 244 | # in case of multi step gradient accumulation, 245 | # update only after accum batches 246 | if self.grad_accum_count > 1: 247 | if self.n_gpu > 1: 248 | grads = [p.grad.data for p in self.model.parameters() 249 | if p.requires_grad 250 | and p.grad is not None] 251 | distributed.all_reduce_and_rescale_tensors( 252 | grads, float(1)) 253 | for o in self.optims: 254 | o.step() 255 | 256 | 257 | def test(self, test_iter, step, cal_lead=False, cal_oracle=False): 258 | """ Validate model. 259 | valid_iter: validate data iterator 260 | Returns: 261 | :obj:`nmt.Statistics`: validation loss statistics 262 | """ 263 | # Set model in validating mode. 264 | def _get_ngrams(n, text): 265 | ngram_set = set() 266 | text_length = len(text) 267 | max_index_ngram_start = text_length - n 268 | for i in range(max_index_ngram_start + 1): 269 | ngram_set.add(tuple(text[i:i + n])) 270 | return ngram_set 271 | 272 | def _block_tri(c, p): 273 | tri_c = _get_ngrams(3, c.split()) 274 | for s in p: 275 | tri_s = _get_ngrams(3, s.split()) 276 | if len(tri_c.intersection(tri_s))>0: 277 | return True 278 | return False 279 | 280 | if (not cal_lead and not cal_oracle): 281 | self.model.eval() 282 | stats = Statistics() 283 | 284 | can_path = '%s_step%d.candidate'%(self.args.result_path,step) 285 | gold_path = '%s_step%d.gold' % (self.args.result_path, step) 286 | with open(can_path, 'w') as save_pred: 287 | with open(gold_path, 'w') as save_gold: 288 | with torch.no_grad(): 289 | for batch in test_iter: 290 | gold = [] 291 | pred = [] 292 | if (cal_lead): 293 | selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size 294 | for i, idx in enumerate(selected_ids): 295 | _pred = [] 296 | if(len(batch.src_str[i])==0): 297 | continue 298 | for j in selected_ids[i][:len(batch.src_str[i])]: 299 | if(j>=len( batch.src_str[i])): 300 | continue 301 | candidate = batch.src_str[i][j].strip() 302 | _pred.append(candidate) 303 | 304 | if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): 305 | break 306 | 307 | _pred = ''.join(_pred) 308 | if(self.args.recall_eval): 309 | _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())]) 310 | 311 | pred.append(_pred) 312 | gold.append(batch.tgt_str[i]) 313 | 314 | for i in range(len(gold)): 315 | save_gold.write(gold[i].strip()+'\n') 316 | for i in range(len(pred)): 317 | save_pred.write(pred[i].strip()+'\n') 318 | if(step!=-1 and self.args.report_rouge): 319 | rouges = test_rouge(self.args.temp_dir, can_path, gold_path) 320 | logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) 321 | self._report_step(0, step, valid_stats=stats) 322 | 323 | return stats 324 | 325 | def _save(self, step): 326 | real_model = self.model 327 | # real_generator = (self.generator.module 328 | # if isinstance(self.generator, torch.nn.DataParallel) 329 | # else self.generator) 330 | 331 | model_state_dict = real_model.state_dict() 332 | # generator_state_dict = real_generator.state_dict() 333 | checkpoint = { 334 | 'model': model_state_dict, 335 | # 'generator': generator_state_dict, 336 | 'opt': self.args, 337 | 'optims': self.optims, 338 | } 339 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step) 340 | logger.info("Saving checkpoint %s" % checkpoint_path) 341 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step) 342 | if (not os.path.exists(checkpoint_path)): 343 | torch.save(checkpoint, checkpoint_path) 344 | return checkpoint, checkpoint_path 345 | 346 | def _start_report_manager(self, start_time=None): 347 | """ 348 | Simple function to start report manager (if any) 349 | """ 350 | if self.report_manager is not None: 351 | if start_time is None: 352 | self.report_manager.start() 353 | else: 354 | self.report_manager.start_time = start_time 355 | 356 | def _maybe_gather_stats(self, stat): 357 | """ 358 | Gather statistics in multi-processes cases 359 | 360 | Args: 361 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather 362 | or None (it returns None in this case) 363 | 364 | Returns: 365 | stat: the updated (or unchanged) stat object 366 | """ 367 | if stat is not None and self.n_gpu > 1: 368 | return Statistics.all_gather_stats(stat) 369 | return stat 370 | 371 | def _maybe_report_training(self, step, num_steps, learning_rate, 372 | report_stats): 373 | """ 374 | Simple function to report training stats (if report_manager is set) 375 | see `onmt.utils.ReportManagerBase.report_training` for doc 376 | """ 377 | if self.report_manager is not None: 378 | return self.report_manager.report_training( 379 | step, num_steps, learning_rate, report_stats, 380 | multigpu=self.n_gpu > 1) 381 | 382 | def _report_step(self, learning_rate, step, train_stats=None, 383 | valid_stats=None): 384 | """ 385 | Simple function to report stats (if report_manager is set) 386 | see `onmt.utils.ReportManagerBase.report_step` for doc 387 | """ 388 | if self.report_manager is not None: 389 | return self.report_manager.report_step( 390 | learning_rate, step, train_stats=train_stats, 391 | valid_stats=valid_stats) 392 | 393 | def _maybe_save(self, step): 394 | """ 395 | Save the model if a model saver is set 396 | """ 397 | if self.model_saver is not None: 398 | self.model_saver.maybe_save(step) 399 | -------------------------------------------------------------------------------- /src/others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/others/__init__.py -------------------------------------------------------------------------------- /src/others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /src/others/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from pytorch_transformers import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 37 | } 38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 39 | 'bert-base-uncased': 512, 40 | 'bert-large-uncased': 512, 41 | 'bert-base-cased': 512, 42 | 'bert-large-cased': 512, 43 | 'bert-base-multilingual-uncased': 512, 44 | 'bert-base-multilingual-cased': 512, 45 | 'bert-base-chinese': 512, 46 | } 47 | VOCAB_NAME = 'vocab.txt' 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | index = 0 54 | with open(vocab_file, "r", encoding="utf-8") as reader: 55 | while True: 56 | token = reader.readline() 57 | if not token: 58 | break 59 | token = token.strip() 60 | vocab[token] = index 61 | index += 1 62 | return vocab 63 | 64 | 65 | def whitespace_tokenize(text): 66 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 67 | text = text.strip() 68 | if not text: 69 | return [] 70 | tokens = text.split() 71 | return tokens 72 | 73 | 74 | class BertTokenizer(object): 75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 76 | 77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[unused0]", "[unused1]", "[unused2]", "[unused3]", "[unused4]", "[unused5]", "[unused6]")): 79 | 80 | if not os.path.isfile(vocab_file): 81 | raise ValueError( 82 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 83 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 84 | self.do_lower_case = do_lower_case 85 | self.vocab = load_vocab(vocab_file) 86 | self.ids_to_tokens = collections.OrderedDict( 87 | [(ids, tok) for tok, ids in self.vocab.items()]) 88 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 89 | never_split=never_split) 90 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 91 | self.max_len = max_len if max_len is not None else int(1e12) 92 | 93 | def tokenize(self, text, use_bert_basic_tokenizer=False): 94 | split_tokens = [] 95 | if(use_bert_basic_tokenizer): 96 | pretokens = self.basic_tokenizer.tokenize(text) 97 | else: 98 | pretokens = list(enumerate(text.split())) 99 | 100 | for i,token in pretokens: 101 | # if(self.do_lower_case): 102 | # token = token.lower() 103 | subtokens = self.wordpiece_tokenizer.tokenize(token) 104 | for sub_token in subtokens: 105 | split_tokens.append(sub_token) 106 | return split_tokens 107 | 108 | def convert_tokens_to_ids(self, tokens): 109 | """Converts a sequence of tokens into ids using the vocab.""" 110 | ids = [] 111 | for token in tokens: 112 | ids.append(self.vocab[token]) 113 | # if len(ids) > self.max_len: 114 | # raise ValueError( 115 | # "Token indices sequence length is longer than the specified maximum " 116 | # " sequence length for this BERT model ({} > {}). Running this" 117 | # " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 118 | # ) 119 | return ids 120 | 121 | def convert_ids_to_tokens(self, ids): 122 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 123 | tokens = [] 124 | for i in ids: 125 | tokens.append(self.ids_to_tokens[i]) 126 | return tokens 127 | 128 | @classmethod 129 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 130 | """ 131 | Instantiate a PreTrainedBertModel from a pre-trained model file. 132 | Download and cache the pre-trained model file if needed. 133 | """ 134 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 135 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 136 | else: 137 | vocab_file = pretrained_model_name_or_path 138 | if os.path.isdir(vocab_file): 139 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 140 | # redirect to the cache, if necessary 141 | try: 142 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 143 | except EnvironmentError: 144 | logger.error( 145 | "Model name '{}' was not found in model name list ({}). " 146 | "We assumed '{}' was a path or url but couldn't find any file " 147 | "associated to this path or url.".format( 148 | pretrained_model_name_or_path, 149 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 150 | vocab_file)) 151 | return None 152 | if resolved_vocab_file == vocab_file: 153 | logger.info("loading vocabulary file {}".format(vocab_file)) 154 | else: 155 | logger.info("loading vocabulary file {} from cache at {}".format( 156 | vocab_file, resolved_vocab_file)) 157 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 158 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 159 | # than the number of positional embeddings 160 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 161 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 162 | # Instantiate tokenizer. 163 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 164 | return tokenizer 165 | 166 | 167 | class BasicTokenizer(object): 168 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 169 | 170 | def __init__(self, 171 | do_lower_case=True, 172 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 173 | """Constructs a BasicTokenizer. 174 | 175 | Args: 176 | do_lower_case: Whether to lower case the input. 177 | """ 178 | self.do_lower_case = do_lower_case 179 | self.never_split = never_split 180 | 181 | def tokenize(self, text): 182 | """Tokenizes a piece of text.""" 183 | text = self._clean_text(text) 184 | # This was added on November 1st, 2018 for the multilingual and Chinese 185 | # models. This is also applied to the English models now, but it doesn't 186 | # matter since the English models were not trained on any Chinese data 187 | # and generally don't have any Chinese data in them (there are Chinese 188 | # characters in the vocabulary because Wikipedia does have some Chinese 189 | # words in the English Wikipedia.). 190 | text = self._tokenize_chinese_chars(text) 191 | orig_tokens = whitespace_tokenize(text) 192 | split_tokens = [] 193 | for i,token in enumerate(orig_tokens): 194 | if self.do_lower_case and token not in self.never_split: 195 | token = token.lower() 196 | token = self._run_strip_accents(token) 197 | # split_tokens.append(token) 198 | split_tokens.extend([(i,t) for t in self._run_split_on_punc(token)]) 199 | 200 | # output_tokens = whitespace_tokenize(" ".join(split_tokens)) 201 | return split_tokens 202 | 203 | def _run_strip_accents(self, text): 204 | """Strips accents from a piece of text.""" 205 | text = unicodedata.normalize("NFD", text) 206 | output = [] 207 | for char in text: 208 | cat = unicodedata.category(char) 209 | if cat == "Mn": 210 | continue 211 | output.append(char) 212 | return "".join(output) 213 | 214 | def _run_split_on_punc(self, text): 215 | """Splits punctuation on a piece of text.""" 216 | if text in self.never_split: 217 | return [text] 218 | chars = list(text) 219 | i = 0 220 | start_new_word = True 221 | output = [] 222 | while i < len(chars): 223 | char = chars[i] 224 | if _is_punctuation(char): 225 | output.append([char]) 226 | start_new_word = True 227 | else: 228 | if start_new_word: 229 | output.append([]) 230 | start_new_word = False 231 | output[-1].append(char) 232 | i += 1 233 | 234 | return ["".join(x) for x in output] 235 | 236 | def _tokenize_chinese_chars(self, text): 237 | """Adds whitespace around any CJK character.""" 238 | output = [] 239 | for char in text: 240 | cp = ord(char) 241 | if self._is_chinese_char(cp): 242 | output.append(" ") 243 | output.append(char) 244 | output.append(" ") 245 | else: 246 | output.append(char) 247 | return "".join(output) 248 | 249 | def _is_chinese_char(self, cp): 250 | """Checks whether CP is the codepoint of a CJK character.""" 251 | # This defines a "chinese character" as anything in the CJK Unicode block: 252 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 253 | # 254 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 255 | # despite its name. The modern Korean Hangul alphabet is a different block, 256 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 257 | # space-separated words, so they are not treated specially and handled 258 | # like the all of the other languages. 259 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 260 | (cp >= 0x3400 and cp <= 0x4DBF) or # 261 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 262 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 263 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 264 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 265 | (cp >= 0xF900 and cp <= 0xFAFF) or # 266 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 267 | return True 268 | 269 | return False 270 | 271 | def _clean_text(self, text): 272 | """Performs invalid character removal and whitespace cleanup on text.""" 273 | output = [] 274 | for char in text: 275 | cp = ord(char) 276 | if cp == 0 or cp == 0xfffd or _is_control(char): 277 | continue 278 | if _is_whitespace(char): 279 | output.append(" ") 280 | else: 281 | output.append(char) 282 | return "".join(output) 283 | 284 | 285 | class WordpieceTokenizer(object): 286 | """Runs WordPiece tokenization.""" 287 | 288 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 289 | self.vocab = vocab 290 | self.unk_token = unk_token 291 | self.max_input_chars_per_word = max_input_chars_per_word 292 | 293 | def tokenize(self, text): 294 | """Tokenizes a piece of text into its word pieces. 295 | 296 | This uses a greedy longest-match-first algorithm to perform tokenization 297 | using the given vocabulary. 298 | 299 | For example: 300 | input = "unaffable" 301 | output = ["un", "##aff", "##able"] 302 | 303 | Args: 304 | text: A single token or whitespace separated tokens. This should have 305 | already been passed through `BasicTokenizer`. 306 | 307 | Returns: 308 | A list of wordpiece tokens. 309 | """ 310 | 311 | output_tokens = [] 312 | for token in whitespace_tokenize(text): 313 | chars = list(token) 314 | if len(chars) > self.max_input_chars_per_word: 315 | output_tokens.append(self.unk_token) 316 | continue 317 | 318 | is_bad = False 319 | start = 0 320 | sub_tokens = [] 321 | while start < len(chars): 322 | end = len(chars) 323 | cur_substr = None 324 | while start < end: 325 | substr = "".join(chars[start:end]) 326 | if start > 0: 327 | substr = "##" + substr 328 | if substr in self.vocab: 329 | cur_substr = substr 330 | break 331 | end -= 1 332 | if cur_substr is None: 333 | is_bad = True 334 | break 335 | sub_tokens.append(cur_substr) 336 | start = end 337 | 338 | if is_bad: 339 | output_tokens.append(self.unk_token) 340 | else: 341 | output_tokens.extend(sub_tokens) 342 | return output_tokens 343 | 344 | 345 | def _is_whitespace(char): 346 | """Checks whether `chars` is a whitespace character.""" 347 | # \t, \n, and \r are technically contorl characters but we treat them 348 | # as whitespace since they are generally considered as such. 349 | if char == " " or char == "\t" or char == "\n" or char == "\r": 350 | return True 351 | cat = unicodedata.category(char) 352 | if cat == "Zs": 353 | return True 354 | return False 355 | 356 | 357 | def _is_control(char): 358 | """Checks whether `chars` is a control character.""" 359 | # These are technically control characters but we count them as whitespace 360 | # characters. 361 | if char == "\t" or char == "\n" or char == "\r": 362 | return False 363 | cat = unicodedata.category(char) 364 | if cat.startswith("C"): 365 | return True 366 | return False 367 | 368 | 369 | def _is_punctuation(char): 370 | """Checks whether `chars` is a punctuation character.""" 371 | cp = ord(char) 372 | # We treat all non-letter/number ASCII as punctuation. 373 | # Characters such as "^", "$", and "`" are not in the Unicode 374 | # Punctuation class but we treat them as punctuation anyways, for 375 | # consistency. 376 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 377 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 378 | return True 379 | cat = unicodedata.category(char) 380 | if cat.startswith("P"): 381 | return True 382 | return False 383 | -------------------------------------------------------------------------------- /src/others/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import time 5 | 6 | from others import pyrouge 7 | 8 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", 9 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} 10 | 11 | 12 | def clean(x): 13 | return re.sub( 14 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", 15 | lambda m: REMAP.get(m.group()), x) 16 | 17 | 18 | def process(params): 19 | temp_dir, data = params 20 | candidates, references, pool_id = data 21 | cnt = len(candidates) 22 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 23 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}-{}".format(current_time, pool_id)) 24 | if not os.path.isdir(tmp_dir): 25 | os.mkdir(tmp_dir) 26 | os.mkdir(tmp_dir + "/candidate") 27 | os.mkdir(tmp_dir + "/reference") 28 | try: 29 | 30 | for i in range(cnt): 31 | if len(references[i]) < 1: 32 | continue 33 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 34 | encoding="utf-8") as f: 35 | f.write(candidates[i]) 36 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 37 | encoding="utf-8") as f: 38 | f.write(references[i]) 39 | r = pyrouge.Rouge155(temp_dir=temp_dir) 40 | r.model_dir = tmp_dir + "/reference/" 41 | r.system_dir = tmp_dir + "/candidate/" 42 | r.model_filename_pattern = 'ref.#ID#.txt' 43 | r.system_filename_pattern = r'cand.(\d+).txt' 44 | rouge_results = r.convert_and_evaluate() 45 | print(rouge_results) 46 | results_dict = r.output_to_dict(rouge_results) 47 | finally: 48 | pass 49 | if os.path.isdir(tmp_dir): 50 | shutil.rmtree(tmp_dir) 51 | return results_dict 52 | 53 | 54 | def test_rouge(temp_dir, cand, ref): 55 | candidates = [line.strip() for line in open(cand, encoding='utf-8')] 56 | references = [line.strip() for line in open(ref, encoding='utf-8')] 57 | print(len(candidates)) 58 | print(len(references)) 59 | assert len(candidates) == len(references) 60 | 61 | cnt = len(candidates) 62 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 63 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 64 | if not os.path.isdir(tmp_dir): 65 | os.mkdir(tmp_dir) 66 | os.mkdir(tmp_dir + "/candidate") 67 | os.mkdir(tmp_dir + "/reference") 68 | try: 69 | 70 | for i in range(cnt): 71 | if len(references[i]) < 1: 72 | continue 73 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 74 | encoding="utf-8") as f: 75 | f.write(candidates[i]) 76 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 77 | encoding="utf-8") as f: 78 | f.write(references[i]) 79 | r = pyrouge.Rouge155(temp_dir=temp_dir) 80 | r.model_dir = tmp_dir + "/reference/" 81 | r.system_dir = tmp_dir + "/candidate/" 82 | r.model_filename_pattern = 'ref.#ID#.txt' 83 | r.system_filename_pattern = r'cand.(\d+).txt' 84 | rouge_results = r.convert_and_evaluate() 85 | print(rouge_results) 86 | results_dict = r.output_to_dict(rouge_results) 87 | finally: 88 | pass 89 | if os.path.isdir(tmp_dir): 90 | shutil.rmtree(tmp_dir) 91 | return results_dict 92 | 93 | 94 | def tile(x, count, dim=0): 95 | """ 96 | Tiles x on dimension dim count times. 97 | """ 98 | perm = list(range(len(x.size()))) 99 | if dim != 0: 100 | perm[0], perm[dim] = perm[dim], perm[0] 101 | x = x.permute(perm).contiguous() 102 | out_size = list(x.size()) 103 | out_size[0] *= count 104 | batch = x.size(0) 105 | x = x.view(batch, -1) \ 106 | .transpose(0, 1) \ 107 | .repeat(count, 1) \ 108 | .transpose(0, 1) \ 109 | .contiguous() \ 110 | .view(*out_size) 111 | if dim != 0: 112 | x = x.permute(perm).contiguous() 113 | return x 114 | 115 | def rouge_results_to_str(results_dict): 116 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 117 | results_dict["rouge_1_f_score"] * 100, 118 | results_dict["rouge_2_f_score"] * 100, 119 | # results_dict["rouge_3_f_score"] * 100, 120 | results_dict["rouge_l_f_score"] * 100, 121 | results_dict["rouge_1_recall"] * 100, 122 | results_dict["rouge_2_recall"] * 100, 123 | # results_dict["rouge_3_f_score"] * 100, 124 | results_dict["rouge_l_recall"] * 100 125 | 126 | # ,results_dict["rouge_su*_f_score"] * 100 127 | ) 128 | -------------------------------------------------------------------------------- /src/post_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path 3 | from functools import reduce 4 | import re 5 | 6 | def str2bool(v): 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | 14 | 15 | 16 | def n_grams(tokens, n): 17 | l = len(tokens) 18 | return [tuple(tokens[i:i + n]) for i in range(l) if i + n < l] 19 | 20 | def has_repeat(elements): 21 | d = set(elements) 22 | return len(d) < len(elements) 23 | 24 | def cal_self_repeat(summary): 25 | ngram_repeats = {2: 0, 4: 0, 8: 0} 26 | sents = summary.split('') 27 | for n in ngram_repeats.keys(): 28 | # Respect sentence boundary 29 | grams = reduce(lambda x, y: x + y, [n_grams(sent.split(), n) for sent in sents], []) 30 | ngram_repeats[n] += has_repeat(grams) 31 | return ngram_repeats 32 | 33 | def cal_novel(summary, gold, source, summary_ngram_novel, gold_ngram_novel): 34 | summary = summary.replace('',' ') 35 | summary = re.sub(r' +', ' ', summary).strip() 36 | gold = gold.replace('',' ') 37 | gold = re.sub(r' +', ' ', gold).strip() 38 | source = source.replace(' ##','') 39 | source = source.replace('[CLS]',' ').replace('[SEP]',' ').replace('[PAD]',' ') 40 | source = re.sub(r' +', ' ', source).strip() 41 | 42 | 43 | for n in summary_ngram_novel.keys(): 44 | summary_grams = set(n_grams(summary.split(), n)) 45 | gold_grams = set(n_grams(gold.split(), n)) 46 | source_grams = set(n_grams(source.split(), n)) 47 | joint = summary_grams.intersection(source_grams) 48 | novel = summary_grams - joint 49 | summary_ngram_novel[n][0] += 1.0*len(novel) 50 | summary_ngram_novel[n][1] += len(summary_grams) 51 | summary_ngram_novel[n][2] += 1.0 * len(novel) / (len(summary.split()) + 1e-6) 52 | joint = gold_grams.intersection(source_grams) 53 | novel = gold_grams - joint 54 | gold_ngram_novel[n][0] += 1.0*len(novel) 55 | gold_ngram_novel[n][1] += len(gold_grams) 56 | gold_ngram_novel[n][2] += 1.0 * len(novel) / (len(gold.split()) + 1e-6) 57 | 58 | 59 | def cal_repeat(args): 60 | candidate_lines = open(args.result_path+'.candidate').read().strip().split('\n') 61 | gold_lines = open(args.result_path+'.gold').read().strip().split('\n') 62 | src_lines = open(args.result_path+'.raw_src').read().strip().split('\n') 63 | lines = zip(candidate_lines,gold_lines,src_lines) 64 | 65 | summary_ngram_novel = {1: [0, 0, 0], 2: [0, 0, 0], 4: [0, 0, 0]} 66 | gold_ngram_novel = {1: [0, 0, 0], 2: [0, 0, 0], 4: [0, 0, 0]} 67 | 68 | for c,g,s in lines: 69 | # self_repeats = cal_self_repeat(c) 70 | cal_novel(c, g, s,summary_ngram_novel, gold_ngram_novel) 71 | print(summary_ngram_novel, gold_ngram_novel) 72 | 73 | for n in summary_ngram_novel.keys(): 74 | # summary_ngram_novel[n] = summary_ngram_novel[n][2]/len(src_lines) 75 | # gold_ngram_novel[n] = gold_ngram_novel[n][2]/len(src_lines) 76 | summary_ngram_novel[n] = summary_ngram_novel[n][0]/summary_ngram_novel[n][1] 77 | gold_ngram_novel[n] = gold_ngram_novel[n][0]/gold_ngram_novel[n][1] 78 | print(summary_ngram_novel, gold_ngram_novel) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("-mode", default='', type=str) 84 | parser.add_argument("-result_path", default='../../results/cnndm.0') 85 | 86 | 87 | args = parser.parse_args() 88 | eval(args.mode + '(args)') 89 | -------------------------------------------------------------------------------- /src/prepro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/prepro/__init__.py -------------------------------------------------------------------------------- /src/prepro/data_builder.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import glob 3 | import hashlib 4 | import itertools 5 | import json 6 | import os 7 | import random 8 | import re 9 | import subprocess 10 | from collections import Counter 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | from multiprocess import Pool 15 | 16 | from others.logging import logger 17 | from others.tokenization import BertTokenizer 18 | from pytorch_transformers import XLNetTokenizer 19 | 20 | from others.utils import clean 21 | from prepro.utils import _get_word_ngrams 22 | 23 | import xml.etree.ElementTree as ET 24 | 25 | nyt_remove_words = ["photo", "graph", "chart", "map", "table", "drawing"] 26 | 27 | 28 | def recover_from_corenlp(s): 29 | s = re.sub(r' \'{\w}', '\'\g<1>', s) 30 | s = re.sub(r'\'\' {\w}', '\'\'\g<1>', s) 31 | 32 | 33 | 34 | def load_json(p, lower): 35 | source = [] 36 | tgt = [] 37 | flag = False 38 | for sent in json.load(open(p))['sentences']: 39 | tokens = [t['word'] for t in sent['tokens']] 40 | if (lower): 41 | tokens = [t.lower() for t in tokens] 42 | if (tokens[0] == '@highlight'): 43 | flag = True 44 | tgt.append([]) 45 | continue 46 | if (flag): 47 | tgt[-1].extend(tokens) 48 | else: 49 | source.append(tokens) 50 | 51 | source = [clean(' '.join(sent)).split() for sent in source] 52 | tgt = [clean(' '.join(sent)).split() for sent in tgt] 53 | return source, tgt 54 | 55 | 56 | 57 | def load_xml(p): 58 | tree = ET.parse(p) 59 | root = tree.getroot() 60 | title, byline, abs, paras = [], [], [], [] 61 | title_node = list(root.iter('hedline')) 62 | if (len(title_node) > 0): 63 | try: 64 | title = [p.text.lower().split() for p in list(title_node[0].iter('hl1'))][0] 65 | except: 66 | print(p) 67 | 68 | else: 69 | return None, None 70 | byline_node = list(root.iter('byline')) 71 | byline_node = [n for n in byline_node if n.attrib['class'] == 'normalized_byline'] 72 | if (len(byline_node) > 0): 73 | byline = byline_node[0].text.lower().split() 74 | abs_node = list(root.iter('abstract')) 75 | if (len(abs_node) > 0): 76 | try: 77 | abs = [p.text.lower().split() for p in list(abs_node[0].iter('p'))][0] 78 | except: 79 | print(p) 80 | 81 | else: 82 | return None, None 83 | abs = ' '.join(abs).split(';') 84 | abs[-1] = abs[-1].replace('(m)', '') 85 | abs[-1] = abs[-1].replace('(s)', '') 86 | 87 | for ww in nyt_remove_words: 88 | abs[-1] = abs[-1].replace('(' + ww + ')', '') 89 | abs = [p.split() for p in abs] 90 | abs = [p for p in abs if len(p) > 2] 91 | 92 | for doc_node in root.iter('block'): 93 | att = doc_node.get('class') 94 | # if(att == 'abstract'): 95 | # abs = [p.text for p in list(f.iter('p'))] 96 | if (att == 'full_text'): 97 | paras = [p.text.lower().split() for p in list(doc_node.iter('p'))] 98 | break 99 | if (len(paras) > 0): 100 | if (len(byline) > 0): 101 | paras = [title + ['[unused3]'] + byline + ['[unused4]']] + paras 102 | else: 103 | paras = [title + ['[unused3]']] + paras 104 | 105 | return paras, abs 106 | else: 107 | return None, None 108 | 109 | 110 | def tokenize(args): 111 | stories_dir = os.path.abspath(args.raw_path) 112 | tokenized_stories_dir = os.path.abspath(args.save_path) 113 | 114 | print("Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir)) 115 | stories = os.listdir(stories_dir) 116 | # make IO list file 117 | print("Making list of files to tokenize...") 118 | with open("mapping_for_corenlp.txt", "w") as f: 119 | for s in stories: 120 | if (not s.endswith('story')): 121 | continue 122 | f.write("%s\n" % (os.path.join(stories_dir, s))) 123 | command = ['java', 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 'tokenize,ssplit', 124 | '-ssplit.newlineIsSentenceBreak', 'always', '-filelist', 'mapping_for_corenlp.txt', '-outputFormat', 125 | 'json', '-outputDirectory', tokenized_stories_dir] 126 | print("Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir)) 127 | subprocess.call(command) 128 | print("Stanford CoreNLP Tokenizer has finished.") 129 | os.remove("mapping_for_corenlp.txt") 130 | 131 | # Check that the tokenized stories directory contains the same number of files as the original directory 132 | num_orig = len(os.listdir(stories_dir)) 133 | num_tokenized = len(os.listdir(tokenized_stories_dir)) 134 | if num_orig != num_tokenized: 135 | raise Exception( 136 | "The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % ( 137 | tokenized_stories_dir, num_tokenized, stories_dir, num_orig)) 138 | print("Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir)) 139 | 140 | def cal_rouge(evaluated_ngrams, reference_ngrams): 141 | reference_count = len(reference_ngrams) 142 | evaluated_count = len(evaluated_ngrams) 143 | 144 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 145 | overlapping_count = len(overlapping_ngrams) 146 | 147 | if evaluated_count == 0: 148 | precision = 0.0 149 | else: 150 | precision = overlapping_count / evaluated_count 151 | 152 | if reference_count == 0: 153 | recall = 0.0 154 | else: 155 | recall = overlapping_count / reference_count 156 | 157 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 158 | return {"f": f1_score, "p": precision, "r": recall} 159 | 160 | 161 | def greedy_selection(doc_sent_list, abstract_sent_list, summary_size): 162 | def _rouge_clean(s): 163 | return re.sub(r'[^a-zA-Z0-9 ]', '', s) 164 | 165 | max_rouge = 0.0 166 | abstract = sum(abstract_sent_list, []) 167 | abstract = _rouge_clean(' '.join(abstract)).split() 168 | sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list] 169 | evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents] 170 | reference_1grams = _get_word_ngrams(1, [abstract]) 171 | evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents] 172 | reference_2grams = _get_word_ngrams(2, [abstract]) 173 | 174 | selected = [] 175 | for s in range(summary_size): 176 | cur_max_rouge = max_rouge 177 | cur_id = -1 178 | for i in range(len(sents)): 179 | if (i in selected): 180 | continue 181 | c = selected + [i] 182 | candidates_1 = [evaluated_1grams[idx] for idx in c] 183 | candidates_1 = set.union(*map(set, candidates_1)) 184 | candidates_2 = [evaluated_2grams[idx] for idx in c] 185 | candidates_2 = set.union(*map(set, candidates_2)) 186 | rouge_1 = cal_rouge(candidates_1, reference_1grams)['f'] 187 | rouge_2 = cal_rouge(candidates_2, reference_2grams)['f'] 188 | rouge_score = rouge_1 + rouge_2 189 | if rouge_score > cur_max_rouge: 190 | cur_max_rouge = rouge_score 191 | cur_id = i 192 | if (cur_id == -1): 193 | return selected 194 | selected.append(cur_id) 195 | max_rouge = cur_max_rouge 196 | 197 | return sorted(selected) 198 | 199 | 200 | def hashhex(s): 201 | """Returns a heximal formated SHA1 hash of the input string.""" 202 | h = hashlib.sha1() 203 | h.update(s.encode('utf-8')) 204 | return h.hexdigest() 205 | 206 | 207 | class BertData(): 208 | def __init__(self, args): 209 | self.args = args 210 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 211 | 212 | self.sep_token = '[SEP]' 213 | self.cls_token = '[CLS]' 214 | self.pad_token = '[PAD]' 215 | self.tgt_bos = '[unused0]' 216 | self.tgt_eos = '[unused1]' 217 | self.tgt_sent_split = '[unused2]' 218 | self.sep_vid = self.tokenizer.vocab[self.sep_token] 219 | self.cls_vid = self.tokenizer.vocab[self.cls_token] 220 | self.pad_vid = self.tokenizer.vocab[self.pad_token] 221 | 222 | def preprocess(self, src, tgt, sent_labels, use_bert_basic_tokenizer=False, is_test=False): 223 | 224 | if ((not is_test) and len(src) == 0): 225 | return None 226 | 227 | original_src_txt = [' '.join(s) for s in src] 228 | 229 | idxs = [i for i, s in enumerate(src) if (len(s) > self.args.min_src_ntokens_per_sent)] 230 | 231 | _sent_labels = [0] * len(src) 232 | for l in sent_labels: 233 | _sent_labels[l] = 1 234 | 235 | src = [src[i][:self.args.max_src_ntokens_per_sent] for i in idxs] 236 | sent_labels = [_sent_labels[i] for i in idxs] 237 | src = src[:self.args.max_src_nsents] 238 | sent_labels = sent_labels[:self.args.max_src_nsents] 239 | 240 | if ((not is_test) and len(src) < self.args.min_src_nsents): 241 | return None 242 | 243 | src_txt = [' '.join(sent) for sent in src] 244 | text = ' {} {} '.format(self.sep_token, self.cls_token).join(src_txt) 245 | 246 | src_subtokens = self.tokenizer.tokenize(text) 247 | 248 | src_subtokens = [self.cls_token] + src_subtokens + [self.sep_token] 249 | src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens) 250 | _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid] 251 | segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] 252 | segments_ids = [] 253 | for i, s in enumerate(segs): 254 | if (i % 2 == 0): 255 | segments_ids += s * [0] 256 | else: 257 | segments_ids += s * [1] 258 | cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid] 259 | sent_labels = sent_labels[:len(cls_ids)] 260 | 261 | tgt_subtokens_str = '[unused0] ' + ' [unused2] '.join( 262 | [' '.join(self.tokenizer.tokenize(' '.join(tt), use_bert_basic_tokenizer=use_bert_basic_tokenizer)) for tt in tgt]) + ' [unused1]' 263 | tgt_subtoken = tgt_subtokens_str.split()[:self.args.max_tgt_ntokens] 264 | if ((not is_test) and len(tgt_subtoken) < self.args.min_tgt_ntokens): 265 | return None 266 | 267 | tgt_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(tgt_subtoken) 268 | 269 | tgt_txt = ''.join([' '.join(tt) for tt in tgt]) 270 | src_txt = [original_src_txt[i] for i in idxs] 271 | 272 | return src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt 273 | 274 | 275 | def format_to_bert(args): 276 | if (args.dataset != ''): 277 | datasets = [args.dataset] 278 | else: 279 | datasets = ['train', 'valid', 'test'] 280 | for corpus_type in datasets: 281 | a_lst = [] 282 | for json_f in glob.glob(pjoin(args.raw_path, '*' + corpus_type + '.*.json')): 283 | real_name = json_f.split('/')[-1] 284 | a_lst.append((corpus_type, json_f, args, pjoin(args.save_path, real_name.replace('json', 'bert.pt')))) 285 | print(a_lst) 286 | pool = Pool(args.n_cpus) 287 | for d in pool.imap(_format_to_bert, a_lst): 288 | pass 289 | 290 | pool.close() 291 | pool.join() 292 | 293 | 294 | def _format_to_bert(params): 295 | corpus_type, json_file, args, save_file = params 296 | is_test = corpus_type == 'test' 297 | if (os.path.exists(save_file)): 298 | logger.info('Ignore %s' % save_file) 299 | return 300 | 301 | bert = BertData(args) 302 | 303 | logger.info('Processing %s' % json_file) 304 | jobs = json.load(open(json_file)) 305 | datasets = [] 306 | for d in jobs: 307 | source, tgt = d['src'], d['tgt'] 308 | 309 | sent_labels = greedy_selection(source[:args.max_src_nsents], tgt, 3) 310 | if (args.lower): 311 | source = [' '.join(s).lower().split() for s in source] 312 | tgt = [' '.join(s).lower().split() for s in tgt] 313 | b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args.use_bert_basic_tokenizer, 314 | is_test=is_test) 315 | # b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args.use_bert_basic_tokenizer) 316 | 317 | if (b_data is None): 318 | continue 319 | src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt = b_data 320 | b_data_dict = {"src": src_subtoken_idxs, "tgt": tgt_subtoken_idxs, 321 | "src_sent_labels": sent_labels, "segs": segments_ids, 'clss': cls_ids, 322 | 'src_txt': src_txt, "tgt_txt": tgt_txt} 323 | datasets.append(b_data_dict) 324 | logger.info('Processed instances %d' % len(datasets)) 325 | logger.info('Saving to %s' % save_file) 326 | torch.save(datasets, save_file) 327 | datasets = [] 328 | gc.collect() 329 | 330 | 331 | def format_to_lines(args): 332 | corpus_mapping = {} 333 | for corpus_type in ['valid', 'test', 'train']: 334 | temp = [] 335 | for line in open(pjoin(args.map_path, 'mapping_' + corpus_type + '.txt')): 336 | temp.append(hashhex(line.strip())) 337 | corpus_mapping[corpus_type] = {key.strip(): 1 for key in temp} 338 | train_files, valid_files, test_files = [], [], [] 339 | for f in glob.glob(pjoin(args.raw_path, '*.json')): 340 | real_name = f.split('/')[-1].split('.')[0] 341 | if (real_name in corpus_mapping['valid']): 342 | valid_files.append(f) 343 | elif (real_name in corpus_mapping['test']): 344 | test_files.append(f) 345 | elif (real_name in corpus_mapping['train']): 346 | train_files.append(f) 347 | # else: 348 | # train_files.append(f) 349 | 350 | corpora = {'train': train_files, 'valid': valid_files, 'test': test_files} 351 | for corpus_type in ['train', 'valid', 'test']: 352 | a_lst = [(f, args) for f in corpora[corpus_type]] 353 | pool = Pool(args.n_cpus) 354 | dataset = [] 355 | p_ct = 0 356 | for d in pool.imap_unordered(_format_to_lines, a_lst): 357 | dataset.append(d) 358 | if (len(dataset) > args.shard_size): 359 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct) 360 | with open(pt_file, 'w') as save: 361 | # save.write('\n'.join(dataset)) 362 | save.write(json.dumps(dataset)) 363 | p_ct += 1 364 | dataset = [] 365 | 366 | pool.close() 367 | pool.join() 368 | if (len(dataset) > 0): 369 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct) 370 | with open(pt_file, 'w') as save: 371 | # save.write('\n'.join(dataset)) 372 | save.write(json.dumps(dataset)) 373 | p_ct += 1 374 | dataset = [] 375 | 376 | 377 | def _format_to_lines(params): 378 | f, args = params 379 | print(f) 380 | source, tgt = load_json(f, args.lower) 381 | return {'src': source, 'tgt': tgt} 382 | 383 | 384 | 385 | 386 | def format_xsum_to_lines(args): 387 | if (args.dataset != ''): 388 | datasets = [args.dataset] 389 | else: 390 | datasets = ['train', 'test', 'valid'] 391 | 392 | corpus_mapping = json.load(open(pjoin(args.raw_path, 'XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json'))) 393 | 394 | for corpus_type in datasets: 395 | mapped_fnames = corpus_mapping[corpus_type] 396 | root_src = pjoin(args.raw_path, 'restbody') 397 | root_tgt = pjoin(args.raw_path, 'firstsentence') 398 | # realnames = [fname.split('.')[0] for fname in os.listdir(root_src)] 399 | realnames = mapped_fnames 400 | 401 | a_lst = [(root_src, root_tgt, n) for n in realnames] 402 | pool = Pool(args.n_cpus) 403 | dataset = [] 404 | p_ct = 0 405 | for d in pool.imap_unordered(_format_xsum_to_lines, a_lst): 406 | if (d is None): 407 | continue 408 | dataset.append(d) 409 | if (len(dataset) > args.shard_size): 410 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct) 411 | with open(pt_file, 'w') as save: 412 | save.write(json.dumps(dataset)) 413 | p_ct += 1 414 | dataset = [] 415 | 416 | pool.close() 417 | pool.join() 418 | if (len(dataset) > 0): 419 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct) 420 | with open(pt_file, 'w') as save: 421 | save.write(json.dumps(dataset)) 422 | p_ct += 1 423 | dataset = [] 424 | 425 | 426 | def _format_xsum_to_lines(params): 427 | src_path, root_tgt, name = params 428 | f_src = pjoin(src_path, name + '.restbody') 429 | f_tgt = pjoin(root_tgt, name + '.fs') 430 | if (os.path.exists(f_src) and os.path.exists(f_tgt)): 431 | print(name) 432 | source = [] 433 | for sent in open(f_src): 434 | source.append(sent.split()) 435 | tgt = [] 436 | for sent in open(f_tgt): 437 | tgt.append(sent.split()) 438 | return {'src': source, 'tgt': tgt} 439 | return None 440 | -------------------------------------------------------------------------------- /src/prepro/smart_common_words.txt: -------------------------------------------------------------------------------- 1 | rrb 2 | llb 3 | lsb 4 | rsb 5 | reuters 6 | ap 7 | jan 8 | feb 9 | mar 10 | apr 11 | may 12 | jun 13 | jul 14 | aug 15 | sep 16 | oct 17 | nov 18 | dec 19 | tech 20 | news 21 | index 22 | mon 23 | tue 24 | wed 25 | thu 26 | fri 27 | sat 28 | 's 29 | a 30 | a's 31 | able 32 | about 33 | above 34 | according 35 | accordingly 36 | across 37 | actually 38 | after 39 | afterwards 40 | again 41 | against 42 | ain't 43 | all 44 | allow 45 | allows 46 | almost 47 | alone 48 | along 49 | already 50 | also 51 | although 52 | always 53 | am 54 | amid 55 | among 56 | amongst 57 | an 58 | and 59 | another 60 | any 61 | anybody 62 | anyhow 63 | anyone 64 | anything 65 | anyway 66 | anyways 67 | anywhere 68 | apart 69 | appear 70 | appreciate 71 | appropriate 72 | are 73 | aren't 74 | around 75 | as 76 | aside 77 | ask 78 | asking 79 | associated 80 | at 81 | available 82 | away 83 | awfully 84 | b 85 | be 86 | became 87 | because 88 | become 89 | becomes 90 | becoming 91 | been 92 | before 93 | beforehand 94 | behind 95 | being 96 | believe 97 | below 98 | beside 99 | besides 100 | best 101 | better 102 | between 103 | beyond 104 | both 105 | brief 106 | but 107 | by 108 | c 109 | c'mon 110 | c's 111 | came 112 | can 113 | can't 114 | cannot 115 | cant 116 | cause 117 | causes 118 | certain 119 | certainly 120 | changes 121 | clearly 122 | co 123 | com 124 | come 125 | comes 126 | concerning 127 | consequently 128 | consider 129 | considering 130 | contain 131 | containing 132 | contains 133 | corresponding 134 | could 135 | couldn't 136 | course 137 | currently 138 | d 139 | definitely 140 | described 141 | despite 142 | did 143 | didn't 144 | different 145 | do 146 | does 147 | doesn't 148 | doing 149 | don't 150 | done 151 | down 152 | downwards 153 | during 154 | e 155 | each 156 | edu 157 | eg 158 | e.g. 159 | eight 160 | either 161 | else 162 | elsewhere 163 | enough 164 | entirely 165 | especially 166 | et 167 | etc 168 | etc. 169 | even 170 | ever 171 | every 172 | everybody 173 | everyone 174 | everything 175 | everywhere 176 | ex 177 | exactly 178 | example 179 | except 180 | f 181 | far 182 | few 183 | fifth 184 | five 185 | followed 186 | following 187 | follows 188 | for 189 | former 190 | formerly 191 | forth 192 | four 193 | from 194 | further 195 | furthermore 196 | g 197 | get 198 | gets 199 | getting 200 | given 201 | gives 202 | go 203 | goes 204 | going 205 | gone 206 | got 207 | gotten 208 | greetings 209 | h 210 | had 211 | hadn't 212 | happens 213 | hardly 214 | has 215 | hasn't 216 | have 217 | haven't 218 | having 219 | he 220 | he's 221 | hello 222 | help 223 | hence 224 | her 225 | here 226 | here's 227 | hereafter 228 | hereby 229 | herein 230 | hereupon 231 | hers 232 | herself 233 | hi 234 | him 235 | himself 236 | his 237 | hither 238 | hopefully 239 | how 240 | howbeit 241 | however 242 | i 243 | i'd 244 | i'll 245 | i'm 246 | i've 247 | ie 248 | i.e. 249 | if 250 | ignored 251 | immediate 252 | in 253 | inasmuch 254 | inc 255 | indeed 256 | indicate 257 | indicated 258 | indicates 259 | inner 260 | insofar 261 | instead 262 | into 263 | inward 264 | is 265 | isn't 266 | it 267 | it'd 268 | it'll 269 | it's 270 | its 271 | itself 272 | j 273 | just 274 | k 275 | keep 276 | keeps 277 | kept 278 | know 279 | knows 280 | known 281 | l 282 | lately 283 | later 284 | latter 285 | latterly 286 | least 287 | less 288 | lest 289 | let 290 | let's 291 | like 292 | liked 293 | likely 294 | little 295 | look 296 | looking 297 | looks 298 | ltd 299 | m 300 | mainly 301 | many 302 | may 303 | maybe 304 | me 305 | mean 306 | meanwhile 307 | merely 308 | might 309 | more 310 | moreover 311 | most 312 | mostly 313 | mr. 314 | ms. 315 | much 316 | must 317 | my 318 | myself 319 | n 320 | namely 321 | nd 322 | near 323 | nearly 324 | necessary 325 | need 326 | needs 327 | neither 328 | never 329 | nevertheless 330 | new 331 | next 332 | nine 333 | no 334 | nobody 335 | non 336 | none 337 | noone 338 | nor 339 | normally 340 | not 341 | nothing 342 | novel 343 | now 344 | nowhere 345 | o 346 | obviously 347 | of 348 | off 349 | often 350 | oh 351 | ok 352 | okay 353 | old 354 | on 355 | once 356 | one 357 | ones 358 | only 359 | onto 360 | or 361 | other 362 | others 363 | otherwise 364 | ought 365 | our 366 | ours 367 | ourselves 368 | out 369 | outside 370 | over 371 | overall 372 | own 373 | p 374 | particular 375 | particularly 376 | per 377 | perhaps 378 | placed 379 | please 380 | plus 381 | possible 382 | presumably 383 | probably 384 | provides 385 | q 386 | que 387 | quite 388 | qv 389 | r 390 | rather 391 | rd 392 | re 393 | really 394 | reasonably 395 | regarding 396 | regardless 397 | regards 398 | relatively 399 | respectively 400 | right 401 | s 402 | said 403 | same 404 | saw 405 | say 406 | saying 407 | says 408 | second 409 | secondly 410 | see 411 | seeing 412 | seem 413 | seemed 414 | seeming 415 | seems 416 | seen 417 | self 418 | selves 419 | sensible 420 | sent 421 | serious 422 | seriously 423 | seven 424 | several 425 | shall 426 | she 427 | should 428 | shouldn't 429 | since 430 | six 431 | so 432 | some 433 | somebody 434 | somehow 435 | someone 436 | something 437 | sometime 438 | sometimes 439 | somewhat 440 | somewhere 441 | soon 442 | sorry 443 | specified 444 | specify 445 | specifying 446 | still 447 | sub 448 | such 449 | sup 450 | sure 451 | t 452 | t's 453 | take 454 | taken 455 | tell 456 | tends 457 | th 458 | than 459 | thank 460 | thanks 461 | thanx 462 | that 463 | that's 464 | thats 465 | the 466 | their 467 | theirs 468 | them 469 | themselves 470 | then 471 | thence 472 | there 473 | there's 474 | thereafter 475 | thereby 476 | therefore 477 | therein 478 | theres 479 | thereupon 480 | these 481 | they 482 | they'd 483 | they'll 484 | they're 485 | they've 486 | think 487 | third 488 | this 489 | thorough 490 | thoroughly 491 | those 492 | though 493 | three 494 | through 495 | throughout 496 | thru 497 | thus 498 | to 499 | together 500 | too 501 | took 502 | toward 503 | towards 504 | tried 505 | tries 506 | truly 507 | try 508 | trying 509 | twice 510 | two 511 | u 512 | un 513 | under 514 | unfortunately 515 | unless 516 | unlikely 517 | until 518 | unto 519 | up 520 | upon 521 | us 522 | use 523 | used 524 | useful 525 | uses 526 | using 527 | usually 528 | uucp 529 | v 530 | value 531 | various 532 | very 533 | via 534 | viz 535 | vs 536 | w 537 | want 538 | wants 539 | was 540 | wasn't 541 | way 542 | we 543 | we'd 544 | we'll 545 | we're 546 | we've 547 | welcome 548 | well 549 | went 550 | were 551 | weren't 552 | what 553 | what's 554 | whatever 555 | when 556 | whence 557 | whenever 558 | where 559 | where's 560 | whereafter 561 | whereas 562 | whereby 563 | wherein 564 | whereupon 565 | wherever 566 | whether 567 | which 568 | while 569 | whither 570 | who 571 | who's 572 | whoever 573 | whole 574 | whom 575 | whose 576 | why 577 | will 578 | willing 579 | wish 580 | with 581 | within 582 | without 583 | won't 584 | wonder 585 | would 586 | would 587 | wouldn't 588 | x 589 | y 590 | yes 591 | yet 592 | you 593 | you'd 594 | you'll 595 | you're 596 | you've 597 | your 598 | yours 599 | yourself 600 | yourselves 601 | z 602 | zero 603 | -------------------------------------------------------------------------------- /src/prepro/utils.py: -------------------------------------------------------------------------------- 1 | # stopwords = pkgutil.get_data(__package__, 'smart_common_words.txt') 2 | # stopwords = stopwords.decode('ascii').split('\n') 3 | # stopwords = {key.strip(): 1 for key in stopwords} 4 | 5 | 6 | def _get_ngrams(n, text): 7 | """Calcualtes n-grams. 8 | 9 | Args: 10 | n: which n-grams to calculate 11 | text: An array of tokens 12 | 13 | Returns: 14 | A set of n-grams 15 | """ 16 | ngram_set = set() 17 | text_length = len(text) 18 | max_index_ngram_start = text_length - n 19 | for i in range(max_index_ngram_start + 1): 20 | ngram_set.add(tuple(text[i:i + n])) 21 | return ngram_set 22 | 23 | 24 | def _get_word_ngrams(n, sentences): 25 | """Calculates word n-grams for multiple sentences. 26 | """ 27 | assert len(sentences) > 0 28 | assert n > 0 29 | 30 | # words = _split_into_words(sentences) 31 | 32 | words = sum(sentences, []) 33 | # words = [w for w in words if w not in stopwords] 34 | return _get_ngrams(n, words) 35 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | 3 | 4 | import argparse 5 | import time 6 | 7 | from others.logging import init_logger 8 | from prepro import data_builder 9 | 10 | 11 | def do_format_to_lines(args): 12 | print(time.clock()) 13 | data_builder.format_to_lines(args) 14 | print(time.clock()) 15 | 16 | def do_format_to_bert(args): 17 | print(time.clock()) 18 | data_builder.format_to_bert(args) 19 | print(time.clock()) 20 | 21 | 22 | 23 | def do_format_xsum_to_lines(args): 24 | print(time.clock()) 25 | data_builder.format_xsum_to_lines(args) 26 | print(time.clock()) 27 | 28 | def do_tokenize(args): 29 | print(time.clock()) 30 | data_builder.tokenize(args) 31 | print(time.clock()) 32 | 33 | 34 | def str2bool(v): 35 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 36 | return True 37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 38 | return False 39 | else: 40 | raise argparse.ArgumentTypeError('Boolean value expected.') 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("-pretrained_model", default='bert', type=str) 46 | 47 | parser.add_argument("-mode", default='', type=str) 48 | parser.add_argument("-select_mode", default='greedy', type=str) 49 | parser.add_argument("-map_path", default='../../data/') 50 | parser.add_argument("-raw_path", default='../../line_data') 51 | parser.add_argument("-save_path", default='../../data/') 52 | 53 | parser.add_argument("-shard_size", default=2000, type=int) 54 | parser.add_argument('-min_src_nsents', default=3, type=int) 55 | parser.add_argument('-max_src_nsents', default=100, type=int) 56 | parser.add_argument('-min_src_ntokens_per_sent', default=5, type=int) 57 | parser.add_argument('-max_src_ntokens_per_sent', default=200, type=int) 58 | parser.add_argument('-min_tgt_ntokens', default=5, type=int) 59 | parser.add_argument('-max_tgt_ntokens', default=500, type=int) 60 | 61 | parser.add_argument("-lower", type=str2bool, nargs='?',const=True,default=True) 62 | parser.add_argument("-use_bert_basic_tokenizer", type=str2bool, nargs='?',const=True,default=False) 63 | 64 | parser.add_argument('-log_file', default='../../logs/cnndm.log') 65 | 66 | parser.add_argument('-dataset', default='') 67 | 68 | parser.add_argument('-n_cpus', default=2, type=int) 69 | 70 | 71 | args = parser.parse_args() 72 | init_logger(args.log_file) 73 | eval('data_builder.'+args.mode + '(args)') 74 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import os 9 | from others.logging import init_logger 10 | from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs 11 | from train_extractive import train_ext, validate_ext, test_ext 12 | 13 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size', 14 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval'] 15 | 16 | 17 | def str2bool(v): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-task", default='ext', type=str, choices=['ext', 'abs']) 31 | parser.add_argument("-encoder", default='bert', type=str, choices=['bert', 'baseline']) 32 | parser.add_argument("-mode", default='train', type=str, choices=['train', 'validate', 'test']) 33 | parser.add_argument("-bert_data_path", default='../bert_data_new/cnndm') 34 | parser.add_argument("-model_path", default='../models/') 35 | parser.add_argument("-result_path", default='../results/cnndm') 36 | parser.add_argument("-temp_dir", default='../temp') 37 | 38 | parser.add_argument("-batch_size", default=140, type=int) 39 | parser.add_argument("-test_batch_size", default=200, type=int) 40 | 41 | parser.add_argument("-max_pos", default=512, type=int) 42 | parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True) 43 | parser.add_argument("-large", type=str2bool, nargs='?',const=True,default=False) 44 | parser.add_argument("-load_from_extractive", default='', type=str) 45 | 46 | parser.add_argument("-sep_optim", type=str2bool, nargs='?',const=True,default=False) 47 | parser.add_argument("-lr_bert", default=2e-3, type=float) 48 | parser.add_argument("-lr_dec", default=2e-3, type=float) 49 | parser.add_argument("-use_bert_emb", type=str2bool, nargs='?',const=True,default=False) 50 | 51 | parser.add_argument("-share_emb", type=str2bool, nargs='?', const=True, default=False) 52 | parser.add_argument("-finetune_bert", type=str2bool, nargs='?', const=True, default=True) 53 | parser.add_argument("-dec_dropout", default=0.2, type=float) 54 | parser.add_argument("-dec_layers", default=6, type=int) 55 | parser.add_argument("-dec_hidden_size", default=768, type=int) 56 | parser.add_argument("-dec_heads", default=8, type=int) 57 | parser.add_argument("-dec_ff_size", default=2048, type=int) 58 | parser.add_argument("-enc_hidden_size", default=512, type=int) 59 | parser.add_argument("-enc_ff_size", default=512, type=int) 60 | parser.add_argument("-enc_dropout", default=0.2, type=float) 61 | parser.add_argument("-enc_layers", default=6, type=int) 62 | 63 | # params for EXT 64 | parser.add_argument("-ext_dropout", default=0.2, type=float) 65 | parser.add_argument("-ext_layers", default=2, type=int) 66 | parser.add_argument("-ext_hidden_size", default=768, type=int) 67 | parser.add_argument("-ext_heads", default=8, type=int) 68 | parser.add_argument("-ext_ff_size", default=2048, type=int) 69 | 70 | parser.add_argument("-label_smoothing", default=0.1, type=float) 71 | parser.add_argument("-generator_shard_size", default=32, type=int) 72 | parser.add_argument("-alpha", default=0.6, type=float) 73 | parser.add_argument("-beam_size", default=5, type=int) 74 | parser.add_argument("-min_length", default=15, type=int) 75 | parser.add_argument("-max_length", default=150, type=int) 76 | parser.add_argument("-max_tgt_len", default=140, type=int) 77 | 78 | 79 | 80 | parser.add_argument("-param_init", default=0, type=float) 81 | parser.add_argument("-param_init_glorot", type=str2bool, nargs='?',const=True,default=True) 82 | parser.add_argument("-optim", default='adam', type=str) 83 | parser.add_argument("-lr", default=1, type=float) 84 | parser.add_argument("-beta1", default= 0.9, type=float) 85 | parser.add_argument("-beta2", default=0.999, type=float) 86 | parser.add_argument("-warmup_steps", default=8000, type=int) 87 | parser.add_argument("-warmup_steps_bert", default=8000, type=int) 88 | parser.add_argument("-warmup_steps_dec", default=8000, type=int) 89 | parser.add_argument("-max_grad_norm", default=0, type=float) 90 | 91 | parser.add_argument("-save_checkpoint_steps", default=5, type=int) 92 | parser.add_argument("-accum_count", default=1, type=int) 93 | parser.add_argument("-report_every", default=1, type=int) 94 | parser.add_argument("-train_steps", default=1000, type=int) 95 | parser.add_argument("-recall_eval", type=str2bool, nargs='?',const=True,default=False) 96 | 97 | 98 | parser.add_argument('-visible_gpus', default='-1', type=str) 99 | parser.add_argument('-gpu_ranks', default='0', type=str) 100 | parser.add_argument('-log_file', default='../logs/cnndm.log') 101 | parser.add_argument('-seed', default=666, type=int) 102 | 103 | parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False) 104 | parser.add_argument("-test_from", default='') 105 | parser.add_argument("-test_start_from", default=-1, type=int) 106 | 107 | parser.add_argument("-train_from", default='') 108 | parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=True) 109 | parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True) 110 | 111 | args = parser.parse_args() 112 | args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))] 113 | args.world_size = len(args.gpu_ranks) 114 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus 115 | 116 | init_logger(args.log_file) 117 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 118 | device_id = 0 if device == "cuda" else -1 119 | 120 | if (args.task == 'abs'): 121 | if (args.mode == 'train'): 122 | train_abs(args, device_id) 123 | elif (args.mode == 'validate'): 124 | validate_abs(args, device_id) 125 | elif (args.mode == 'lead'): 126 | baseline(args, cal_lead=True) 127 | elif (args.mode == 'oracle'): 128 | baseline(args, cal_oracle=True) 129 | if (args.mode == 'test'): 130 | cp = args.test_from 131 | try: 132 | step = int(cp.split('.')[-2].split('_')[-1]) 133 | except: 134 | step = 0 135 | test_abs(args, device_id, cp, step) 136 | elif (args.mode == 'test_text'): 137 | cp = args.test_from 138 | try: 139 | step = int(cp.split('.')[-2].split('_')[-1]) 140 | except: 141 | step = 0 142 | test_text_abs(args, device_id, cp, step) 143 | 144 | elif (args.task == 'ext'): 145 | if (args.mode == 'train'): 146 | train_ext(args, device_id) 147 | elif (args.mode == 'validate'): 148 | validate_ext(args, device_id) 149 | if (args.mode == 'test'): 150 | cp = args.test_from 151 | try: 152 | step = int(cp.split('.')[-2].split('_')[-1]) 153 | except: 154 | step = 0 155 | test_ext(args, device_id, cp, step) 156 | elif (args.mode == 'test_text'): 157 | cp = args.test_from 158 | try: 159 | step = int(cp.split('.')[-2].split('_')[-1]) 160 | except: 161 | step = 0 162 | test_text_abs(args, device_id, cp, step) 163 | -------------------------------------------------------------------------------- /src/train_abstractive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import glob 9 | import os 10 | import random 11 | import signal 12 | import time 13 | 14 | import torch 15 | from pytorch_transformers import BertTokenizer 16 | 17 | import distributed 18 | from models import data_loader, model_builder 19 | from models.data_loader import load_dataset 20 | from models.loss import abs_loss 21 | from models.model_builder import AbsSummarizer 22 | from models.predictor import build_predictor 23 | from models.trainer import build_trainer 24 | from others.logging import logger, init_logger 25 | 26 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size', 27 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval'] 28 | 29 | 30 | def str2bool(v): 31 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 32 | return True 33 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 34 | return False 35 | else: 36 | raise argparse.ArgumentTypeError('Boolean value expected.') 37 | 38 | 39 | def train_abs_multi(args): 40 | """ Spawns 1 process per GPU """ 41 | init_logger() 42 | 43 | nb_gpu = args.world_size 44 | mp = torch.multiprocessing.get_context('spawn') 45 | 46 | # Create a thread to listen for errors in the child processes. 47 | error_queue = mp.SimpleQueue() 48 | error_handler = ErrorHandler(error_queue) 49 | 50 | # Train with multiprocessing. 51 | procs = [] 52 | for i in range(nb_gpu): 53 | device_id = i 54 | procs.append(mp.Process(target=run, args=(args, 55 | device_id, error_queue,), daemon=True)) 56 | procs[i].start() 57 | logger.info(" Starting process pid: %d " % procs[i].pid) 58 | error_handler.add_child(procs[i].pid) 59 | for p in procs: 60 | p.join() 61 | 62 | 63 | def run(args, device_id, error_queue): 64 | """ run process """ 65 | 66 | setattr(args, 'gpu_ranks', [int(i) for i in args.gpu_ranks]) 67 | 68 | try: 69 | gpu_rank = distributed.multi_init(device_id, args.world_size, args.gpu_ranks) 70 | print('gpu_rank %d' % gpu_rank) 71 | if gpu_rank != args.gpu_ranks[device_id]: 72 | raise AssertionError("An error occurred in \ 73 | Distributed initialization") 74 | 75 | train_abs_single(args, device_id) 76 | except KeyboardInterrupt: 77 | pass # killed by parent, do nothing 78 | except Exception: 79 | # propagate exception to parent process, keeping original traceback 80 | import traceback 81 | error_queue.put((args.gpu_ranks[device_id], traceback.format_exc())) 82 | 83 | 84 | class ErrorHandler(object): 85 | """A class that listens for exceptions in children processes and propagates 86 | the tracebacks to the parent process.""" 87 | 88 | def __init__(self, error_queue): 89 | """ init error handler """ 90 | import signal 91 | import threading 92 | self.error_queue = error_queue 93 | self.children_pids = [] 94 | self.error_thread = threading.Thread( 95 | target=self.error_listener, daemon=True) 96 | self.error_thread.start() 97 | signal.signal(signal.SIGUSR1, self.signal_handler) 98 | 99 | def add_child(self, pid): 100 | """ error handler """ 101 | self.children_pids.append(pid) 102 | 103 | def error_listener(self): 104 | """ error listener """ 105 | (rank, original_trace) = self.error_queue.get() 106 | self.error_queue.put((rank, original_trace)) 107 | os.kill(os.getpid(), signal.SIGUSR1) 108 | 109 | def signal_handler(self, signalnum, stackframe): 110 | """ signal handler """ 111 | for pid in self.children_pids: 112 | os.kill(pid, signal.SIGINT) # kill children processes 113 | (rank, original_trace) = self.error_queue.get() 114 | msg = """\n\n-- Tracebacks above this line can probably 115 | be ignored --\n\n""" 116 | msg += original_trace 117 | raise Exception(msg) 118 | 119 | 120 | def validate_abs(args, device_id): 121 | timestep = 0 122 | if (args.test_all): 123 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 124 | cp_files.sort(key=os.path.getmtime) 125 | xent_lst = [] 126 | for i, cp in enumerate(cp_files): 127 | step = int(cp.split('.')[-2].split('_')[-1]) 128 | if (args.test_start_from != -1 and step < args.test_start_from): 129 | xent_lst.append((1e6, cp)) 130 | continue 131 | xent = validate(args, device_id, cp, step) 132 | xent_lst.append((xent, cp)) 133 | max_step = xent_lst.index(min(xent_lst)) 134 | if (i - max_step > 10): 135 | break 136 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:5] 137 | logger.info('PPL %s' % str(xent_lst)) 138 | for xent, cp in xent_lst: 139 | step = int(cp.split('.')[-2].split('_')[-1]) 140 | test_abs(args, device_id, cp, step) 141 | else: 142 | while (True): 143 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 144 | cp_files.sort(key=os.path.getmtime) 145 | if (cp_files): 146 | cp = cp_files[-1] 147 | time_of_cp = os.path.getmtime(cp) 148 | if (not os.path.getsize(cp) > 0): 149 | time.sleep(60) 150 | continue 151 | if (time_of_cp > timestep): 152 | timestep = time_of_cp 153 | step = int(cp.split('.')[-2].split('_')[-1]) 154 | validate(args, device_id, cp, step) 155 | test_abs(args, device_id, cp, step) 156 | 157 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 158 | cp_files.sort(key=os.path.getmtime) 159 | if (cp_files): 160 | cp = cp_files[-1] 161 | time_of_cp = os.path.getmtime(cp) 162 | if (time_of_cp > timestep): 163 | continue 164 | else: 165 | time.sleep(300) 166 | 167 | 168 | def validate(args, device_id, pt, step): 169 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 170 | if (pt != ''): 171 | test_from = pt 172 | else: 173 | test_from = args.test_from 174 | logger.info('Loading checkpoint from %s' % test_from) 175 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 176 | opt = vars(checkpoint['opt']) 177 | for k in opt.keys(): 178 | if (k in model_flags): 179 | setattr(args, k, opt[k]) 180 | print(args) 181 | 182 | model = AbsSummarizer(args, device, checkpoint) 183 | model.eval() 184 | 185 | valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), 186 | args.batch_size, device, 187 | shuffle=False, is_test=False) 188 | 189 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) 190 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 191 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} 192 | 193 | valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) 194 | 195 | trainer = build_trainer(args, device_id, model, None, valid_loss) 196 | stats = trainer.validate(valid_iter, step) 197 | return stats.xent() 198 | 199 | 200 | def test_abs(args, device_id, pt, step): 201 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 202 | if (pt != ''): 203 | test_from = pt 204 | else: 205 | test_from = args.test_from 206 | logger.info('Loading checkpoint from %s' % test_from) 207 | 208 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 209 | opt = vars(checkpoint['opt']) 210 | for k in opt.keys(): 211 | if (k in model_flags): 212 | setattr(args, k, opt[k]) 213 | print(args) 214 | 215 | model = AbsSummarizer(args, device, checkpoint) 216 | model.eval() 217 | 218 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), 219 | args.test_batch_size, device, 220 | shuffle=False, is_test=True) 221 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) 222 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 223 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} 224 | predictor = build_predictor(args, tokenizer, symbols, model, logger) 225 | predictor.translate(test_iter, step) 226 | 227 | 228 | def test_text_abs(args, device_id, pt, step): 229 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 230 | if (pt != ''): 231 | test_from = pt 232 | else: 233 | test_from = args.test_from 234 | logger.info('Loading checkpoint from %s' % test_from) 235 | 236 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 237 | opt = vars(checkpoint['opt']) 238 | for k in opt.keys(): 239 | if (k in model_flags): 240 | setattr(args, k, opt[k]) 241 | print(args) 242 | 243 | model = AbsSummarizer(args, device, checkpoint) 244 | model.eval() 245 | 246 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), 247 | args.test_batch_size, device, 248 | shuffle=False, is_test=True) 249 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) 250 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 251 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} 252 | predictor = build_predictor(args, tokenizer, symbols, model, logger) 253 | predictor.translate(test_iter, step) 254 | 255 | 256 | def baseline(args, cal_lead=False, cal_oracle=False): 257 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), 258 | args.batch_size, 'cpu', 259 | shuffle=False, is_test=True) 260 | 261 | trainer = build_trainer(args, '-1', None, None, None) 262 | # 263 | if (cal_lead): 264 | trainer.test(test_iter, 0, cal_lead=True) 265 | elif (cal_oracle): 266 | trainer.test(test_iter, 0, cal_oracle=True) 267 | 268 | 269 | def train_abs(args, device_id): 270 | if (args.world_size > 1): 271 | train_abs_multi(args) 272 | else: 273 | train_abs_single(args, device_id) 274 | 275 | 276 | def train_abs_single(args, device_id): 277 | init_logger(args.log_file) 278 | logger.info(str(args)) 279 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 280 | logger.info('Device ID %d' % device_id) 281 | logger.info('Device %s' % device) 282 | torch.manual_seed(args.seed) 283 | random.seed(args.seed) 284 | torch.backends.cudnn.deterministic = True 285 | 286 | if device_id >= 0: 287 | torch.cuda.set_device(device_id) 288 | torch.cuda.manual_seed(args.seed) 289 | 290 | if args.train_from != '': 291 | logger.info('Loading checkpoint from %s' % args.train_from) 292 | checkpoint = torch.load(args.train_from, 293 | map_location=lambda storage, loc: storage) 294 | opt = vars(checkpoint['opt']) 295 | for k in opt.keys(): 296 | if (k in model_flags): 297 | setattr(args, k, opt[k]) 298 | else: 299 | checkpoint = None 300 | 301 | if (args.load_from_extractive != ''): 302 | logger.info('Loading bert from extractive model %s' % args.load_from_extractive) 303 | bert_from_extractive = torch.load(args.load_from_extractive, map_location=lambda storage, loc: storage) 304 | bert_from_extractive = bert_from_extractive['model'] 305 | else: 306 | bert_from_extractive = None 307 | torch.manual_seed(args.seed) 308 | random.seed(args.seed) 309 | torch.backends.cudnn.deterministic = True 310 | 311 | def train_iter_fct(): 312 | return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, 313 | shuffle=True, is_test=False) 314 | 315 | model = AbsSummarizer(args, device, checkpoint, bert_from_extractive) 316 | if (args.sep_optim): 317 | optim_bert = model_builder.build_optim_bert(args, model, checkpoint) 318 | optim_dec = model_builder.build_optim_dec(args, model, checkpoint) 319 | optim = [optim_bert, optim_dec] 320 | else: 321 | optim = [model_builder.build_optim(args, model, checkpoint)] 322 | 323 | logger.info(model) 324 | 325 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) 326 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 327 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} 328 | 329 | train_loss = abs_loss(model.generator, symbols, model.vocab_size, device, train=True, 330 | label_smoothing=args.label_smoothing) 331 | 332 | trainer = build_trainer(args, device_id, model, optim, train_loss) 333 | 334 | trainer.train(train_iter_fct, args.train_steps) 335 | -------------------------------------------------------------------------------- /src/train_extractive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import glob 9 | import os 10 | import random 11 | import signal 12 | import time 13 | 14 | import torch 15 | 16 | import distributed 17 | from models import data_loader, model_builder 18 | from models.data_loader import load_dataset 19 | from models.model_builder import ExtSummarizer 20 | from models.trainer_ext import build_trainer 21 | from others.logging import logger, init_logger 22 | 23 | model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', 'rnn_size'] 24 | 25 | 26 | def train_multi_ext(args): 27 | """ Spawns 1 process per GPU """ 28 | init_logger() 29 | 30 | nb_gpu = args.world_size 31 | mp = torch.multiprocessing.get_context('spawn') 32 | 33 | # Create a thread to listen for errors in the child processes. 34 | error_queue = mp.SimpleQueue() 35 | error_handler = ErrorHandler(error_queue) 36 | 37 | # Train with multiprocessing. 38 | procs = [] 39 | for i in range(nb_gpu): 40 | device_id = i 41 | procs.append(mp.Process(target=run, args=(args, 42 | device_id, error_queue,), daemon=True)) 43 | procs[i].start() 44 | logger.info(" Starting process pid: %d " % procs[i].pid) 45 | error_handler.add_child(procs[i].pid) 46 | for p in procs: 47 | p.join() 48 | 49 | 50 | def run(args, device_id, error_queue): 51 | """ run process """ 52 | setattr(args, 'gpu_ranks', [int(i) for i in args.gpu_ranks]) 53 | 54 | try: 55 | gpu_rank = distributed.multi_init(device_id, args.world_size, args.gpu_ranks) 56 | print('gpu_rank %d' % gpu_rank) 57 | if gpu_rank != args.gpu_ranks[device_id]: 58 | raise AssertionError("An error occurred in \ 59 | Distributed initialization") 60 | 61 | train_single_ext(args, device_id) 62 | except KeyboardInterrupt: 63 | pass # killed by parent, do nothing 64 | except Exception: 65 | # propagate exception to parent process, keeping original traceback 66 | import traceback 67 | error_queue.put((args.gpu_ranks[device_id], traceback.format_exc())) 68 | 69 | 70 | class ErrorHandler(object): 71 | """A class that listens for exceptions in children processes and propagates 72 | the tracebacks to the parent process.""" 73 | 74 | def __init__(self, error_queue): 75 | """ init error handler """ 76 | import signal 77 | import threading 78 | self.error_queue = error_queue 79 | self.children_pids = [] 80 | self.error_thread = threading.Thread( 81 | target=self.error_listener, daemon=True) 82 | self.error_thread.start() 83 | signal.signal(signal.SIGUSR1, self.signal_handler) 84 | 85 | def add_child(self, pid): 86 | """ error handler """ 87 | self.children_pids.append(pid) 88 | 89 | def error_listener(self): 90 | """ error listener """ 91 | (rank, original_trace) = self.error_queue.get() 92 | self.error_queue.put((rank, original_trace)) 93 | os.kill(os.getpid(), signal.SIGUSR1) 94 | 95 | def signal_handler(self, signalnum, stackframe): 96 | """ signal handler """ 97 | for pid in self.children_pids: 98 | os.kill(pid, signal.SIGINT) # kill children processes 99 | (rank, original_trace) = self.error_queue.get() 100 | msg = """\n\n-- Tracebacks above this line can probably 101 | be ignored --\n\n""" 102 | msg += original_trace 103 | raise Exception(msg) 104 | 105 | 106 | def validate_ext(args, device_id): 107 | timestep = 0 108 | if (args.test_all): 109 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 110 | cp_files.sort(key=os.path.getmtime) 111 | xent_lst = [] 112 | for i, cp in enumerate(cp_files): 113 | step = int(cp.split('.')[-2].split('_')[-1]) 114 | xent = validate(args, device_id, cp, step) 115 | xent_lst.append((xent, cp)) 116 | max_step = xent_lst.index(min(xent_lst)) 117 | if (i - max_step > 10): 118 | break 119 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:3] 120 | logger.info('PPL %s' % str(xent_lst)) 121 | for xent, cp in xent_lst: 122 | step = int(cp.split('.')[-2].split('_')[-1]) 123 | test_ext(args, device_id, cp, step) 124 | else: 125 | while (True): 126 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 127 | cp_files.sort(key=os.path.getmtime) 128 | if (cp_files): 129 | cp = cp_files[-1] 130 | time_of_cp = os.path.getmtime(cp) 131 | if (not os.path.getsize(cp) > 0): 132 | time.sleep(60) 133 | continue 134 | if (time_of_cp > timestep): 135 | timestep = time_of_cp 136 | step = int(cp.split('.')[-2].split('_')[-1]) 137 | validate(args, device_id, cp, step) 138 | test_ext(args, device_id, cp, step) 139 | 140 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt'))) 141 | cp_files.sort(key=os.path.getmtime) 142 | if (cp_files): 143 | cp = cp_files[-1] 144 | time_of_cp = os.path.getmtime(cp) 145 | if (time_of_cp > timestep): 146 | continue 147 | else: 148 | time.sleep(300) 149 | 150 | 151 | def validate(args, device_id, pt, step): 152 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 153 | if (pt != ''): 154 | test_from = pt 155 | else: 156 | test_from = args.test_from 157 | logger.info('Loading checkpoint from %s' % test_from) 158 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 159 | opt = vars(checkpoint['opt']) 160 | for k in opt.keys(): 161 | if (k in model_flags): 162 | setattr(args, k, opt[k]) 163 | print(args) 164 | 165 | model = ExtSummarizer(args, device, checkpoint) 166 | model.eval() 167 | 168 | valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), 169 | args.batch_size, device, 170 | shuffle=False, is_test=False) 171 | trainer = build_trainer(args, device_id, model, None) 172 | stats = trainer.validate(valid_iter, step) 173 | return stats.xent() 174 | 175 | 176 | def test_ext(args, device_id, pt, step): 177 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 178 | if (pt != ''): 179 | test_from = pt 180 | else: 181 | test_from = args.test_from 182 | logger.info('Loading checkpoint from %s' % test_from) 183 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) 184 | opt = vars(checkpoint['opt']) 185 | for k in opt.keys(): 186 | if (k in model_flags): 187 | setattr(args, k, opt[k]) 188 | print(args) 189 | 190 | model = ExtSummarizer(args, device, checkpoint) 191 | model.eval() 192 | 193 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), 194 | args.test_batch_size, device, 195 | shuffle=False, is_test=True) 196 | trainer = build_trainer(args, device_id, model, None) 197 | trainer.test(test_iter, step) 198 | 199 | def train_ext(args, device_id): 200 | if (args.world_size > 1): 201 | train_multi_ext(args) 202 | else: 203 | train_single_ext(args, device_id) 204 | 205 | 206 | def train_single_ext(args, device_id): 207 | init_logger(args.log_file) 208 | 209 | device = "cpu" if args.visible_gpus == '-1' else "cuda" 210 | logger.info('Device ID %d' % device_id) 211 | logger.info('Device %s' % device) 212 | torch.manual_seed(args.seed) 213 | random.seed(args.seed) 214 | torch.backends.cudnn.deterministic = True 215 | 216 | if device_id >= 0: 217 | torch.cuda.set_device(device_id) 218 | torch.cuda.manual_seed(args.seed) 219 | 220 | torch.manual_seed(args.seed) 221 | random.seed(args.seed) 222 | torch.backends.cudnn.deterministic = True 223 | 224 | if args.train_from != '': 225 | logger.info('Loading checkpoint from %s' % args.train_from) 226 | checkpoint = torch.load(args.train_from, 227 | map_location=lambda storage, loc: storage) 228 | opt = vars(checkpoint['opt']) 229 | for k in opt.keys(): 230 | if (k in model_flags): 231 | setattr(args, k, opt[k]) 232 | else: 233 | checkpoint = None 234 | 235 | def train_iter_fct(): 236 | return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, 237 | shuffle=True, is_test=False) 238 | 239 | model = ExtSummarizer(args, device, checkpoint) 240 | optim = model_builder.build_optim(args, model, checkpoint) 241 | 242 | logger.info(model) 243 | 244 | trainer = build_trainer(args, device_id, model, optim) 245 | trainer.train(train_iter_fct, args.train_steps) 246 | -------------------------------------------------------------------------------- /src/translate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/translate/__init__.py -------------------------------------------------------------------------------- /src/translate/beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from translate import penalties 4 | 5 | 6 | class Beam(object): 7 | """ 8 | Class for managing the internals of the beam search process. 9 | 10 | Takes care of beams, back pointers, and scores. 11 | 12 | Args: 13 | size (int): beam size 14 | pad, bos, eos (int): indices of padding, beginning, and ending. 15 | n_best (int): nbest size to use 16 | cuda (bool): use gpu 17 | global_scorer (:obj:`GlobalScorer`) 18 | """ 19 | 20 | def __init__(self, size, pad, bos, eos, 21 | n_best=1, cuda=False, 22 | global_scorer=None, 23 | min_length=0, 24 | stepwise_penalty=False, 25 | block_ngram_repeat=0, 26 | exclusion_tokens=set()): 27 | 28 | self.size = size 29 | self.tt = torch.cuda if cuda else torch 30 | 31 | # The score for each translation on the beam. 32 | self.scores = self.tt.FloatTensor(size).zero_() 33 | self.all_scores = [] 34 | 35 | # The backpointers at each time-step. 36 | self.prev_ks = [] 37 | 38 | # The outputs at each time-step. 39 | self.next_ys = [self.tt.LongTensor(size) 40 | .fill_(pad)] 41 | self.next_ys[0][0] = bos 42 | 43 | # Has EOS topped the beam yet. 44 | self._eos = eos 45 | self.eos_top = False 46 | 47 | # The attentions (matrix) for each time. 48 | self.attn = [] 49 | 50 | # Time and k pair for finished. 51 | self.finished = [] 52 | self.n_best = n_best 53 | 54 | # Information for global scoring. 55 | self.global_scorer = global_scorer 56 | self.global_state = {} 57 | 58 | # Minimum prediction length 59 | self.min_length = min_length 60 | 61 | # Apply Penalty at every step 62 | self.stepwise_penalty = stepwise_penalty 63 | self.block_ngram_repeat = block_ngram_repeat 64 | self.exclusion_tokens = exclusion_tokens 65 | 66 | def get_current_state(self): 67 | "Get the outputs for the current timestep." 68 | return self.next_ys[-1] 69 | 70 | def get_current_origin(self): 71 | "Get the backpointers for the current timestep." 72 | return self.prev_ks[-1] 73 | 74 | def advance(self, word_probs, attn_out): 75 | """ 76 | Given prob over words for every last beam `wordLk` and attention 77 | `attn_out`: Compute and update the beam search. 78 | 79 | Parameters: 80 | 81 | * `word_probs`- probs of advancing from the last step (K x words) 82 | * `attn_out`- attention at the last step 83 | 84 | Returns: True if beam search is complete. 85 | """ 86 | num_words = word_probs.size(1) 87 | if self.stepwise_penalty: 88 | self.global_scorer.update_score(self, attn_out) 89 | # force the output to be longer than self.min_length 90 | cur_len = len(self.next_ys) 91 | if cur_len < self.min_length: 92 | for k in range(len(word_probs)): 93 | word_probs[k][self._eos] = -1e20 94 | # Sum the previous scores. 95 | if len(self.prev_ks) > 0: 96 | beam_scores = word_probs + \ 97 | self.scores.unsqueeze(1).expand_as(word_probs) 98 | # Don't let EOS have children. 99 | for i in range(self.next_ys[-1].size(0)): 100 | if self.next_ys[-1][i] == self._eos: 101 | beam_scores[i] = -1e20 102 | 103 | # Block ngram repeats 104 | if self.block_ngram_repeat > 0: 105 | ngrams = [] 106 | le = len(self.next_ys) 107 | for j in range(self.next_ys[-1].size(0)): 108 | hyp, _ = self.get_hyp(le - 1, j) 109 | ngrams = set() 110 | fail = False 111 | gram = [] 112 | for i in range(le - 1): 113 | # Last n tokens, n = block_ngram_repeat 114 | gram = (gram + 115 | [hyp[i].item()])[-self.block_ngram_repeat:] 116 | # Skip the blocking if it is in the exclusion list 117 | if set(gram) & self.exclusion_tokens: 118 | continue 119 | if tuple(gram) in ngrams: 120 | fail = True 121 | ngrams.add(tuple(gram)) 122 | if fail: 123 | beam_scores[j] = -10e20 124 | else: 125 | beam_scores = word_probs[0] 126 | flat_beam_scores = beam_scores.view(-1) 127 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 128 | True, True) 129 | 130 | self.all_scores.append(self.scores) 131 | self.scores = best_scores 132 | 133 | # best_scores_id is flattened beam x word array, so calculate which 134 | # word and beam each score came from 135 | prev_k = best_scores_id / num_words 136 | self.prev_ks.append(prev_k) 137 | self.next_ys.append((best_scores_id - prev_k * num_words)) 138 | self.attn.append(attn_out.index_select(0, prev_k)) 139 | self.global_scorer.update_global_state(self) 140 | 141 | for i in range(self.next_ys[-1].size(0)): 142 | if self.next_ys[-1][i] == self._eos: 143 | global_scores = self.global_scorer.score(self, self.scores) 144 | s = global_scores[i] 145 | self.finished.append((s, len(self.next_ys) - 1, i)) 146 | 147 | # End condition is when top-of-beam is EOS and no global score. 148 | if self.next_ys[-1][0] == self._eos: 149 | self.all_scores.append(self.scores) 150 | self.eos_top = True 151 | 152 | def done(self): 153 | return self.eos_top and len(self.finished) >= self.n_best 154 | 155 | def sort_finished(self, minimum=None): 156 | if minimum is not None: 157 | i = 0 158 | # Add from beam until we have minimum outputs. 159 | while len(self.finished) < minimum: 160 | global_scores = self.global_scorer.score(self, self.scores) 161 | s = global_scores[i] 162 | self.finished.append((s, len(self.next_ys) - 1, i)) 163 | i += 1 164 | 165 | self.finished.sort(key=lambda a: -a[0]) 166 | scores = [sc for sc, _, _ in self.finished] 167 | ks = [(t, k) for _, t, k in self.finished] 168 | return scores, ks 169 | 170 | def get_hyp(self, timestep, k): 171 | """ 172 | Walk back to construct the full hypothesis. 173 | """ 174 | hyp, attn = [], [] 175 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 176 | hyp.append(self.next_ys[j + 1][k]) 177 | attn.append(self.attn[j][k]) 178 | k = self.prev_ks[j][k] 179 | return hyp[::-1], torch.stack(attn[::-1]) 180 | 181 | 182 | class GNMTGlobalScorer(object): 183 | """ 184 | NMT re-ranking score from 185 | "Google's Neural Machine Translation System" :cite:`wu2016google` 186 | 187 | Args: 188 | alpha (float): length parameter 189 | beta (float): coverage parameter 190 | """ 191 | 192 | def __init__(self, alpha, length_penalty): 193 | self.alpha = alpha 194 | penalty_builder = penalties.PenaltyBuilder(length_penalty) 195 | # Term will be subtracted from probability 196 | # Probability will be divided by this 197 | self.length_penalty = penalty_builder.length_penalty() 198 | 199 | def score(self, beam, logprobs): 200 | """ 201 | Rescores a prediction based on penalty functions 202 | """ 203 | normalized_probs = self.length_penalty(beam, 204 | logprobs, 205 | self.alpha) 206 | 207 | return normalized_probs 208 | 209 | -------------------------------------------------------------------------------- /src/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | 9 | Args: 10 | length_pen (str): option name of length pen 11 | cov_pen (str): option name of cov pen 12 | """ 13 | 14 | def __init__(self, length_pen): 15 | self.length_pen = length_pen 16 | 17 | def length_penalty(self): 18 | if self.length_pen == "wu": 19 | return self.length_wu 20 | elif self.length_pen == "avg": 21 | return self.length_average 22 | else: 23 | return self.length_none 24 | 25 | """ 26 | Below are all the different penalty terms implemented so far 27 | """ 28 | 29 | 30 | def length_wu(self, beam, logprobs, alpha=0.): 31 | """ 32 | NMT length re-ranking score from 33 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 34 | """ 35 | 36 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 37 | ((5 + 1) ** alpha)) 38 | return (logprobs / modifier) 39 | 40 | def length_average(self, beam, logprobs, alpha=0.): 41 | """ 42 | Returns the average probability of tokens in a sequence. 43 | """ 44 | return logprobs / len(beam.next_ys) 45 | 46 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 47 | """ 48 | Returns unmodified scores. 49 | """ 50 | return logprobs --------------------------------------------------------------------------------