├── LICENSE
├── README.md
├── bert_data
└── .gitignore
├── json_data
└── cnndm_sample.train.0.json
├── logs
└── .gitignore
├── models
└── .gitignore
├── raw_data
└── .gitignore
├── requirements.txt
├── results
└── .gitignore
├── src
├── cal_rouge.py
├── distributed.py
├── models
│ ├── __init__.py
│ ├── adam.py
│ ├── data_loader.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── loss.py
│ ├── model_builder.py
│ ├── neural.py
│ ├── optimizers.py
│ ├── predictor.py
│ ├── reporter.py
│ ├── reporter_ext.py
│ ├── trainer.py
│ └── trainer_ext.py
├── others
│ ├── __init__.py
│ ├── logging.py
│ ├── pyrouge.py
│ ├── tokenization.py
│ └── utils.py
├── post_stats.py
├── prepro
│ ├── __init__.py
│ ├── data_builder.py
│ ├── smart_common_words.txt
│ └── utils.py
├── preprocess.py
├── train.py
├── train_abstractive.py
├── train_extractive.py
└── translate
│ ├── __init__.py
│ ├── beam.py
│ └── penalties.py
└── urls
├── cnn_mapping_test.txt
├── cnn_mapping_train.txt
├── cnn_mapping_valid.txt
├── mapping_test.txt
├── mapping_train.txt
└── mapping_valid.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PreSumm
2 |
3 | **This code is for EMNLP 2019 paper [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345)**
4 |
5 | **Updates Jan 22 2020**: Now you can **Summarize Raw Text Input!**. Swith to the dev branch, and use `-mode test_text` and use `-text_src $RAW_SRC.TXT` to input your text file. Please still use master branch for normal training and evaluation, dev branch should be only used for test_text mode.
6 | * abstractive use -task abs, extractive use -task ext
7 | * use `-test_from $PT_FILE$` to use your model checkpoint file.
8 | * Format of the source text file:
9 | * For **abstractive summarization**, each line is a document.
10 | * If you want to do **extractive summarization**, please insert ` [CLS] [SEP] ` as your sentence boundaries.
11 | * There are example input files in the [raw_data directory](https://github.com/nlpyang/PreSumm/tree/dev/raw_data)
12 | * If you also have reference summaries aligned with your source input, please use `-text_tgt $RAW_TGT.TXT` to keep the order for evaluation.
13 |
14 |
15 | Results on CNN/DailyMail (20/8/2019):
16 |
17 |
18 |
19 |
20 | Models |
21 | ROUGE-1 |
22 | ROUGE-2 |
23 | ROUGE-L |
24 |
25 |
26 | Extractive |
27 |
28 |
29 | TransformerExt |
30 | 40.90 |
31 | 18.02 |
32 | 37.17 |
33 |
34 |
35 | BertSumExt |
36 | 43.23 |
37 | 20.24 |
38 | 39.63 |
39 |
40 |
41 | BertSumExt (large) |
42 | 43.85 |
43 | 20.34 |
44 | 39.90 |
45 |
46 |
47 | Abstractive |
48 |
49 |
50 | TransformerAbs |
51 | 40.21 |
52 | 17.76 |
53 | 37.09 |
54 |
55 |
56 | BertSumAbs |
57 | 41.72 |
58 | 19.39 |
59 | 38.76 |
60 |
61 |
62 | BertSumExtAbs |
63 | 42.13 |
64 | 19.60 |
65 | 39.18 |
66 |
67 |
68 |
69 | **Python version**: This code is in Python3.6
70 |
71 | **Package Requirements**: torch==1.1.0 pytorch_transformers tensorboardX multiprocess pyrouge
72 |
73 |
74 |
75 | **Updates**: For encoding a text longer than 512 tokens, for example 800. Set max_pos to 800 during both preprocessing and training.
76 |
77 |
78 | Some codes are borrowed from ONMT(https://github.com/OpenNMT/OpenNMT-py)
79 |
80 | ## Trained Models
81 | [CNN/DM BertExt](https://drive.google.com/open?id=1kKWoV0QCbeIuFt85beQgJ4v0lujaXobJ)
82 |
83 | [CNN/DM BertExtAbs](https://drive.google.com/open?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr)
84 |
85 | [CNN/DM TransformerAbs](https://drive.google.com/open?id=1yLCqT__ilQ3mf5YUUCw9-UToesX5Roxy)
86 |
87 | [XSum BertExtAbs](https://drive.google.com/open?id=1H50fClyTkNprWJNh10HWdGEdDdQIkzsI)
88 |
89 | ## System Outputs
90 |
91 | [CNN/DM and XSum](https://drive.google.com/file/d/1kYA384UEAQkvmZ-yWZAfxw7htCbCwFzC)
92 |
93 | ## Data Preparation For XSum
94 | [Pre-processed data](https://drive.google.com/open?id=1BWBN1coTWGBqrWoOfRc5dhojPHhatbYs)
95 |
96 |
97 | ## Data Preparation For CNN/Dailymail
98 | ### Option 1: download the processed data
99 |
100 | [Pre-processed data](https://drive.google.com/open?id=1DN7ClZCCXsk2KegmC6t4ClBwtAf5galI)
101 |
102 | unzip the zipfile and put all `.pt` files into `bert_data`
103 |
104 | ### Option 2: process the data yourself
105 |
106 | #### Step 1 Download Stories
107 | Download and unzip the `stories` directories from [here](http://cs.nyu.edu/~kcho/DMQA/) for both CNN and Daily Mail. Put all `.story` files in one directory (e.g. `../raw_stories`)
108 |
109 | #### Step 2. Download Stanford CoreNLP
110 | We will need Stanford CoreNLP to tokenize the data. Download it [here](https://stanfordnlp.github.io/CoreNLP/) and unzip it. Then add the following command to your bash_profile:
111 | ```
112 | export CLASSPATH=/path/to/stanford-corenlp-full-2017-06-09/stanford-corenlp-3.8.0.jar
113 | ```
114 | replacing `/path/to/` with the path to where you saved the `stanford-corenlp-full-2017-06-09` directory.
115 |
116 | #### Step 3. Sentence Splitting and Tokenization
117 |
118 | ```
119 | python preprocess.py -mode tokenize -raw_path RAW_PATH -save_path TOKENIZED_PATH
120 | ```
121 |
122 | * `RAW_PATH` is the directory containing story files (`../raw_stories`), `JSON_PATH` is the target directory to save the generated json files (`../merged_stories_tokenized`)
123 |
124 |
125 | #### Step 4. Format to Simpler Json Files
126 |
127 | ```
128 | python preprocess.py -mode format_to_lines -raw_path RAW_PATH -save_path JSON_PATH -n_cpus 1 -use_bert_basic_tokenizer false -map_path MAP_PATH
129 | ```
130 |
131 | * `RAW_PATH` is the directory containing tokenized files (`../merged_stories_tokenized`), `JSON_PATH` is the target directory to save the generated json files (`../json_data/cnndm`), `MAP_PATH` is the directory containing the urls files (`../urls`)
132 |
133 | #### Step 5. Format to PyTorch Files
134 | ```
135 | python preprocess.py -mode format_to_bert -raw_path JSON_PATH -save_path BERT_DATA_PATH -lower -n_cpus 1 -log_file ../logs/preprocess.log
136 | ```
137 |
138 | * `JSON_PATH` is the directory containing json files (`../json_data`), `BERT_DATA_PATH` is the target directory to save the generated binary files (`../bert_data`)
139 |
140 | ## Model Training
141 |
142 | **First run: For the first time, you should use single-GPU, so the code can download the BERT model. Use ``-visible_gpus -1``, after downloading, you could kill the process and rerun the code with multi-GPUs.**
143 |
144 | ### Extractive Setting
145 |
146 | ```
147 | python train.py -task ext -mode train -bert_data_path BERT_DATA_PATH -ext_dropout 0.1 -model_path MODEL_PATH -lr 2e-3 -visible_gpus 0,1,2 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -train_steps 50000 -accum_count 2 -log_file ../logs/ext_bert_cnndm -use_interval true -warmup_steps 10000 -max_pos 512
148 | ```
149 |
150 | ### Abstractive Setting
151 |
152 | #### TransformerAbs (baseline)
153 | ```
154 | python train.py -mode train -accum_count 5 -batch_size 300 -bert_data_path BERT_DATA_PATH -dec_dropout 0.1 -log_file ../../logs/cnndm_baseline -lr 0.05 -model_path MODEL_PATH -save_checkpoint_steps 2000 -seed 777 -sep_optim false -train_steps 200000 -use_bert_emb true -use_interval true -warmup_steps 8000 -visible_gpus 0,1,2,3 -max_pos 512 -report_every 50 -enc_hidden_size 512 -enc_layers 6 -enc_ff_size 2048 -enc_dropout 0.1 -dec_layers 6 -dec_hidden_size 512 -dec_ff_size 2048 -encoder baseline -task abs
155 | ```
156 | #### BertAbs
157 | ```
158 | python train.py -task abs -mode train -bert_data_path BERT_DATA_PATH -dec_dropout 0.2 -model_path MODEL_PATH -sep_optim true -lr_bert 0.002 -lr_dec 0.2 -save_checkpoint_steps 2000 -batch_size 140 -train_steps 200000 -report_every 50 -accum_count 5 -use_bert_emb true -use_interval true -warmup_steps_bert 20000 -warmup_steps_dec 10000 -max_pos 512 -visible_gpus 0,1,2,3 -log_file ../logs/abs_bert_cnndm
159 | ```
160 | #### BertExtAbs
161 | ```
162 | python train.py -task abs -mode train -bert_data_path BERT_DATA_PATH -dec_dropout 0.2 -model_path MODEL_PATH -sep_optim true -lr_bert 0.002 -lr_dec 0.2 -save_checkpoint_steps 2000 -batch_size 140 -train_steps 200000 -report_every 50 -accum_count 5 -use_bert_emb true -use_interval true -warmup_steps_bert 20000 -warmup_steps_dec 10000 -max_pos 512 -visible_gpus 0,1,2,3 -log_file ../logs/abs_bert_cnndm -load_from_extractive EXT_CKPT
163 | ```
164 | * `EXT_CKPT` is the saved `.pt` checkpoint of the extractive model.
165 |
166 |
167 |
168 |
169 | ## Model Evaluation
170 | ### CNN/DM
171 | ```
172 | python train.py -task abs -mode validate -batch_size 3000 -test_batch_size 500 -bert_data_path BERT_DATA_PATH -log_file ../logs/val_abs_bert_cnndm -model_path MODEL_PATH -sep_optim true -use_interval true -visible_gpus 1 -max_pos 512 -max_length 200 -alpha 0.95 -min_length 50 -result_path ../logs/abs_bert_cnndm
173 | ```
174 | ### XSum
175 | ```
176 | python train.py -task abs -mode validate -batch_size 3000 -test_batch_size 500 -bert_data_path BERT_DATA_PATH -log_file ../logs/val_abs_bert_cnndm -model_path MODEL_PATH -sep_optim true -use_interval true -visible_gpus 1 -max_pos 512 -min_length 20 -max_length 100 -alpha 0.9 -result_path ../logs/abs_bert_cnndm
177 | ```
178 | * `-mode` can be {`validate, test`}, where `validate` will inspect the model directory and evaluate the model for each newly saved checkpoint, `test` need to be used with `-test_from`, indicating the checkpoint you want to use
179 | * `MODEL_PATH` is the directory of saved checkpoints
180 | * use `-mode valiadte` with `-test_all`, the system will load all saved checkpoints and select the top ones to generate summaries (this will take a while)
181 |
182 |
--------------------------------------------------------------------------------
/bert_data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/raw_data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | multiprocess==0.70.9
2 | numpy==1.17.2
3 | pyrouge==0.1.3
4 | pytorch-transformers==1.2.0
5 | tensorboardX==1.9
6 | torch==1.1.0
--------------------------------------------------------------------------------
/results/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/src/cal_rouge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | # from multiprocess import Pool as Pool2
5 | from multiprocessing import Pool
6 |
7 | import shutil
8 | import sys
9 | import codecs
10 |
11 | # from onmt.utils.logging import init_logger, logger
12 | from others import pyrouge
13 |
14 |
15 | def process(data):
16 | candidates, references, pool_id = data
17 | cnt = len(candidates)
18 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
19 | tmp_dir = "rouge-tmp-{}-{}".format(current_time,pool_id)
20 | if not os.path.isdir(tmp_dir):
21 | os.mkdir(tmp_dir)
22 | os.mkdir(tmp_dir + "/candidate")
23 | os.mkdir(tmp_dir + "/reference")
24 | try:
25 |
26 | for i in range(cnt):
27 | if len(references[i]) < 1:
28 | continue
29 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w",
30 | encoding="utf-8") as f:
31 | f.write(candidates[i])
32 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w",
33 | encoding="utf-8") as f:
34 | f.write(references[i])
35 | r = pyrouge.Rouge155()
36 | r.model_dir = tmp_dir + "/reference/"
37 | r.system_dir = tmp_dir + "/candidate/"
38 | r.model_filename_pattern = 'ref.#ID#.txt'
39 | r.system_filename_pattern = r'cand.(\d+).txt'
40 | rouge_results = r.convert_and_evaluate()
41 | print(rouge_results)
42 | results_dict = r.output_to_dict(rouge_results)
43 | finally:
44 | pass
45 | if os.path.isdir(tmp_dir):
46 | shutil.rmtree(tmp_dir)
47 | return results_dict
48 |
49 |
50 |
51 |
52 | def chunks(l, n):
53 | """Yield successive n-sized chunks from l."""
54 | for i in range(0, len(l), n):
55 | yield l[i:i + n]
56 |
57 | def test_rouge(cand, ref,num_processes):
58 | """Calculate ROUGE scores of sequences passed as an iterator
59 | e.g. a list of str, an open file, StringIO or even sys.stdin
60 | """
61 | candidates = [line.strip() for line in cand]
62 | references = [line.strip() for line in ref]
63 |
64 | print(len(candidates))
65 | print(len(references))
66 | assert len(candidates) == len(references)
67 | candidates_chunks = list(chunks(candidates, int(len(candidates)/num_processes)))
68 | references_chunks = list(chunks(references, int(len(references)/num_processes)))
69 | n_pool = len(candidates_chunks)
70 | arg_lst = []
71 | for i in range(n_pool):
72 | arg_lst.append((candidates_chunks[i],references_chunks[i],i))
73 | pool = Pool(n_pool)
74 | results = pool.map(process,arg_lst)
75 | final_results = {}
76 | for i,r in enumerate(results):
77 | for k in r:
78 | if(k not in final_results):
79 | final_results[k] = r[k]*len(candidates_chunks[i])
80 | else:
81 | final_results[k] += r[k] * len(candidates_chunks[i])
82 | for k in final_results:
83 | final_results[k] = final_results[k]/len(candidates)
84 | return final_results
85 | def rouge_results_to_str(results_dict):
86 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
87 | results_dict["rouge_1_f_score"] * 100,
88 | results_dict["rouge_2_f_score"] * 100,
89 | # results_dict["rouge_3_f_score"] * 100,
90 | results_dict["rouge_l_f_score"] * 100,
91 | results_dict["rouge_1_recall"] * 100,
92 | results_dict["rouge_2_recall"] * 100,
93 | # results_dict["rouge_3_f_score"] * 100,
94 | results_dict["rouge_l_recall"] * 100
95 |
96 | # ,results_dict["rouge_su*_f_score"] * 100
97 | )
98 |
99 |
100 | if __name__ == "__main__":
101 | # init_logger('test_rouge.log')
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument('-c', type=str, default="candidate.txt",
104 | help='candidate file')
105 | parser.add_argument('-r', type=str, default="reference.txt",
106 | help='reference file')
107 | parser.add_argument('-p', type=int, default=1,
108 | help='number of processes')
109 | args = parser.parse_args()
110 | print(args.c)
111 | print(args.r)
112 | print(args.p)
113 | if args.c.upper() == "STDIN":
114 | candidates = sys.stdin
115 | else:
116 | candidates = codecs.open(args.c, encoding="utf-8")
117 | references = codecs.open(args.r, encoding="utf-8")
118 |
119 | results_dict = test_rouge(candidates, references,args.p)
120 | # return 0
121 | print(time.strftime('%H:%M:%S', time.localtime())
122 | )
123 | print(rouge_results_to_str(results_dict))
124 | # logger.info(rouge_results_to_str(results_dict))
--------------------------------------------------------------------------------
/src/distributed.py:
--------------------------------------------------------------------------------
1 | """ Pytorch Distributed utils
2 | This piece of code was heavily inspired by the equivalent of Fairseq-py
3 | https://github.com/pytorch/fairseq
4 | """
5 |
6 |
7 | from __future__ import print_function
8 |
9 | import math
10 | import pickle
11 |
12 | import torch.distributed
13 |
14 | from others.logging import logger
15 |
16 |
17 | def is_master(gpu_ranks, device_id):
18 | return gpu_ranks[device_id] == 0
19 |
20 |
21 | def multi_init(device_id, world_size,gpu_ranks):
22 | print(gpu_ranks)
23 | dist_init_method = 'tcp://localhost:10000'
24 | dist_world_size = world_size
25 | torch.distributed.init_process_group(
26 | backend='nccl', init_method=dist_init_method,
27 | world_size=dist_world_size, rank=gpu_ranks[device_id])
28 | gpu_rank = torch.distributed.get_rank()
29 | if not is_master(gpu_ranks, device_id):
30 | # print('not master')
31 | logger.disabled = True
32 |
33 | return gpu_rank
34 |
35 |
36 |
37 | def all_reduce_and_rescale_tensors(tensors, rescale_denom,
38 | buffer_size=10485760):
39 | """All-reduce and rescale tensors in chunks of the specified size.
40 |
41 | Args:
42 | tensors: list of Tensors to all-reduce
43 | rescale_denom: denominator for rescaling summed Tensors
44 | buffer_size: all-reduce chunk size in bytes
45 | """
46 | # buffer size in bytes, determine equiv. # of elements based on data type
47 | buffer_t = tensors[0].new(
48 | math.ceil(buffer_size / tensors[0].element_size())).zero_()
49 | buffer = []
50 |
51 | def all_reduce_buffer():
52 | # copy tensors into buffer_t
53 | offset = 0
54 | for t in buffer:
55 | numel = t.numel()
56 | buffer_t[offset:offset+numel].copy_(t.view(-1))
57 | offset += numel
58 |
59 | # all-reduce and rescale
60 | torch.distributed.all_reduce(buffer_t[:offset])
61 | buffer_t.div_(rescale_denom)
62 |
63 | # copy all-reduced buffer back into tensors
64 | offset = 0
65 | for t in buffer:
66 | numel = t.numel()
67 | t.view(-1).copy_(buffer_t[offset:offset+numel])
68 | offset += numel
69 |
70 | filled = 0
71 | for t in tensors:
72 | sz = t.numel() * t.element_size()
73 | if sz > buffer_size:
74 | # tensor is bigger than buffer, all-reduce and rescale directly
75 | torch.distributed.all_reduce(t)
76 | t.div_(rescale_denom)
77 | elif filled + sz > buffer_size:
78 | # buffer is full, all-reduce and replace buffer with grad
79 | all_reduce_buffer()
80 | buffer = [t]
81 | filled = sz
82 | else:
83 | # add tensor to buffer
84 | buffer.append(t)
85 | filled += sz
86 |
87 | if len(buffer) > 0:
88 | all_reduce_buffer()
89 |
90 |
91 | def all_gather_list(data, max_size=4096):
92 | """Gathers arbitrary data from all nodes into a list."""
93 | world_size = torch.distributed.get_world_size()
94 | if not hasattr(all_gather_list, '_in_buffer') or \
95 | max_size != all_gather_list._in_buffer.size():
96 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
97 | all_gather_list._out_buffers = [
98 | torch.cuda.ByteTensor(max_size)
99 | for i in range(world_size)
100 | ]
101 | in_buffer = all_gather_list._in_buffer
102 | out_buffers = all_gather_list._out_buffers
103 |
104 | enc = pickle.dumps(data)
105 | enc_size = len(enc)
106 | if enc_size + 2 > max_size:
107 | raise ValueError(
108 | 'encoded data exceeds max_size: {}'.format(enc_size + 2))
109 | assert max_size < 255*256
110 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
111 | in_buffer[1] = enc_size % 255
112 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
113 |
114 | torch.distributed.all_gather(out_buffers, in_buffer.cuda())
115 |
116 | results = []
117 | for i in range(world_size):
118 | out_buffer = out_buffers[i]
119 | size = (255 * out_buffer[0].item()) + out_buffer[1].item()
120 |
121 | bytes_list = bytes(out_buffer[2:size+2].tolist())
122 | result = pickle.loads(bytes_list)
123 | results.append(result)
124 | return results
125 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/adam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class Adam(Optimizer):
7 | r"""Implements Adam algorithm.
8 |
9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
10 |
11 | Arguments:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float, optional): learning rate (default: 1e-3)
15 | betas (Tuple[float, float], optional): coefficients used for computing
16 | running averages of gradient and its square (default: (0.9, 0.999))
17 | eps (float, optional): term added to the denominator to improve
18 | numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
21 | algorithm from the paper `On the Convergence of Adam and Beyond`_
22 | (default: False)
23 |
24 | .. _Adam\: A Method for Stochastic Optimization:
25 | https://arxiv.org/abs/1412.6980
26 | .. _On the Convergence of Adam and Beyond:
27 | https://openreview.net/forum?id=ryQu7f-RZ
28 | """
29 |
30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
31 | weight_decay=0, amsgrad=False):
32 | if not 0.0 <= lr:
33 | raise ValueError("Invalid learning rate: {}".format(lr))
34 | if not 0.0 <= eps:
35 | raise ValueError("Invalid epsilon value: {}".format(eps))
36 | if not 0.0 <= betas[0] < 1.0:
37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
38 | if not 0.0 <= betas[1] < 1.0:
39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
40 | defaults = dict(lr=lr, betas=betas, eps=eps,
41 | weight_decay=weight_decay, amsgrad=amsgrad)
42 | super(Adam, self).__init__(params, defaults)
43 |
44 | def __setstate__(self, state):
45 | super(Adam, self).__setstate__(state)
46 | for group in self.param_groups:
47 | group.setdefault('amsgrad', False)
48 |
49 | def step(self, closure=None):
50 | """Performs a single optimization step.
51 | Arguments:
52 | closure (callable, optional): A closure that reevaluates the model
53 | and returns the loss.
54 | """
55 | loss = None
56 | if closure is not None:
57 | loss = closure()
58 |
59 |
60 | for group in self.param_groups:
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 | grad = p.grad.data
65 | if grad.is_sparse:
66 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
67 |
68 | state = self.state[p]
69 |
70 | # State initialization
71 | if len(state) == 0:
72 | state['step'] = 0
73 | # Exponential moving average of gradient values
74 | state['next_m'] = torch.zeros_like(p.data)
75 | # Exponential moving average of squared gradient values
76 | state['next_v'] = torch.zeros_like(p.data)
77 |
78 | next_m, next_v = state['next_m'], state['next_v']
79 | beta1, beta2 = group['betas']
80 |
81 | # Decay the first and second moment running average coefficient
82 | # In-place operations to update the averages at the same time
83 | next_m.mul_(beta1).add_(1 - beta1, grad)
84 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
85 | update = next_m / (next_v.sqrt() + group['eps'])
86 |
87 | # Just adding the square of the weights to the loss function is *not*
88 | # the correct way of using L2 regularization/weight decay with Adam,
89 | # since that will interact with the m and v parameters in strange ways.
90 | #
91 | # Instead we want to decay the weights in a manner that doesn't interact
92 | # with the m/v parameters. This is equivalent to adding the square
93 | # of the weights to the loss with plain (non-momentum) SGD.
94 | if group['weight_decay'] > 0.0:
95 | update += group['weight_decay'] * p.data
96 |
97 | lr_scheduled = group['lr']
98 |
99 | update_with_lr = lr_scheduled * update
100 | p.data.add_(-update_with_lr)
101 |
102 | state['step'] += 1
103 |
104 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
105 | # No bias correction
106 | # bias_correction1 = 1 - beta1 ** state['step']
107 | # bias_correction2 = 1 - beta2 ** state['step']
108 |
109 | return loss
--------------------------------------------------------------------------------
/src/models/data_loader.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import gc
3 | import glob
4 | import random
5 |
6 | import torch
7 |
8 | from others.logging import logger
9 |
10 |
11 |
12 | class Batch(object):
13 | def _pad(self, data, pad_id, width=-1):
14 | if (width == -1):
15 | width = max(len(d) for d in data)
16 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data]
17 | return rtn_data
18 |
19 | def __init__(self, data=None, device=None, is_test=False):
20 | """Create a Batch from a list of examples."""
21 | if data is not None:
22 | self.batch_size = len(data)
23 | pre_src = [x[0] for x in data]
24 | pre_tgt = [x[1] for x in data]
25 | pre_segs = [x[2] for x in data]
26 | pre_clss = [x[3] for x in data]
27 | pre_src_sent_labels = [x[4] for x in data]
28 |
29 | src = torch.tensor(self._pad(pre_src, 0))
30 | tgt = torch.tensor(self._pad(pre_tgt, 0))
31 |
32 | segs = torch.tensor(self._pad(pre_segs, 0))
33 | mask_src = 1 - (src == 0)
34 | mask_tgt = 1 - (tgt == 0)
35 |
36 |
37 | clss = torch.tensor(self._pad(pre_clss, -1))
38 | src_sent_labels = torch.tensor(self._pad(pre_src_sent_labels, 0))
39 | mask_cls = 1 - (clss == -1)
40 | clss[clss == -1] = 0
41 | setattr(self, 'clss', clss.to(device))
42 | setattr(self, 'mask_cls', mask_cls.to(device))
43 | setattr(self, 'src_sent_labels', src_sent_labels.to(device))
44 |
45 |
46 | setattr(self, 'src', src.to(device))
47 | setattr(self, 'tgt', tgt.to(device))
48 | setattr(self, 'segs', segs.to(device))
49 | setattr(self, 'mask_src', mask_src.to(device))
50 | setattr(self, 'mask_tgt', mask_tgt.to(device))
51 |
52 |
53 | if (is_test):
54 | src_str = [x[-2] for x in data]
55 | setattr(self, 'src_str', src_str)
56 | tgt_str = [x[-1] for x in data]
57 | setattr(self, 'tgt_str', tgt_str)
58 |
59 | def __len__(self):
60 | return self.batch_size
61 |
62 |
63 |
64 |
65 | def load_dataset(args, corpus_type, shuffle):
66 | """
67 | Dataset generator. Don't do extra stuff here, like printing,
68 | because they will be postponed to the first loading time.
69 |
70 | Args:
71 | corpus_type: 'train' or 'valid'
72 | Returns:
73 | A list of dataset, the dataset(s) are lazily loaded.
74 | """
75 | assert corpus_type in ["train", "valid", "test"]
76 |
77 | def _lazy_dataset_loader(pt_file, corpus_type):
78 | dataset = torch.load(pt_file)
79 | logger.info('Loading %s dataset from %s, number of examples: %d' %
80 | (corpus_type, pt_file, len(dataset)))
81 | return dataset
82 |
83 | # Sort the glob output by file name (by increasing indexes).
84 | pts = sorted(glob.glob(args.bert_data_path + '.' + corpus_type + '.[0-9]*.pt'))
85 | if pts:
86 | if (shuffle):
87 | random.shuffle(pts)
88 |
89 | for pt in pts:
90 | yield _lazy_dataset_loader(pt, corpus_type)
91 | else:
92 | # Only one inputters.*Dataset, simple!
93 | pt = args.bert_data_path + '.' + corpus_type + '.pt'
94 | yield _lazy_dataset_loader(pt, corpus_type)
95 |
96 |
97 | def abs_batch_size_fn(new, count):
98 | src, tgt = new[0], new[1]
99 | global max_n_sents, max_n_tokens, max_size
100 | if count == 1:
101 | max_size = 0
102 | max_n_sents=0
103 | max_n_tokens=0
104 | max_n_sents = max(max_n_sents, len(tgt))
105 | max_size = max(max_size, max_n_sents)
106 | src_elements = count * max_size
107 | if (count > 6):
108 | return src_elements + 1e3
109 | return src_elements
110 |
111 |
112 | def ext_batch_size_fn(new, count):
113 | if (len(new) == 4):
114 | pass
115 | src, labels = new[0], new[4]
116 | global max_n_sents, max_n_tokens, max_size
117 | if count == 1:
118 | max_size = 0
119 | max_n_sents = 0
120 | max_n_tokens = 0
121 | max_n_sents = max(max_n_sents, len(src))
122 | max_size = max(max_size, max_n_sents)
123 | src_elements = count * max_size
124 | return src_elements
125 |
126 |
127 | class Dataloader(object):
128 | def __init__(self, args, datasets, batch_size,
129 | device, shuffle, is_test):
130 | self.args = args
131 | self.datasets = datasets
132 | self.batch_size = batch_size
133 | self.device = device
134 | self.shuffle = shuffle
135 | self.is_test = is_test
136 | self.cur_iter = self._next_dataset_iterator(datasets)
137 | assert self.cur_iter is not None
138 |
139 | def __iter__(self):
140 | dataset_iter = (d for d in self.datasets)
141 | while self.cur_iter is not None:
142 | for batch in self.cur_iter:
143 | yield batch
144 | self.cur_iter = self._next_dataset_iterator(dataset_iter)
145 |
146 |
147 | def _next_dataset_iterator(self, dataset_iter):
148 | try:
149 | # Drop the current dataset for decreasing memory
150 | if hasattr(self, "cur_dataset"):
151 | self.cur_dataset = None
152 | gc.collect()
153 | del self.cur_dataset
154 | gc.collect()
155 |
156 | self.cur_dataset = next(dataset_iter)
157 | except StopIteration:
158 | return None
159 |
160 | return DataIterator(args = self.args,
161 | dataset=self.cur_dataset, batch_size=self.batch_size,
162 | device=self.device, shuffle=self.shuffle, is_test=self.is_test)
163 |
164 |
165 | class DataIterator(object):
166 | def __init__(self, args, dataset, batch_size, device=None, is_test=False,
167 | shuffle=True):
168 | self.args = args
169 | self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset
170 | self.iterations = 0
171 | self.device = device
172 | self.shuffle = shuffle
173 |
174 | self.sort_key = lambda x: len(x[1])
175 |
176 | self._iterations_this_epoch = 0
177 | if (self.args.task == 'abs'):
178 | self.batch_size_fn = abs_batch_size_fn
179 | else:
180 | self.batch_size_fn = ext_batch_size_fn
181 |
182 | def data(self):
183 | if self.shuffle:
184 | random.shuffle(self.dataset)
185 | xs = self.dataset
186 | return xs
187 |
188 |
189 |
190 |
191 |
192 |
193 | def preprocess(self, ex, is_test):
194 | src = ex['src']
195 | tgt = ex['tgt'][:self.args.max_tgt_len][:-1]+[2]
196 | src_sent_labels = ex['src_sent_labels']
197 | segs = ex['segs']
198 | if(not self.args.use_interval):
199 | segs=[0]*len(segs)
200 | clss = ex['clss']
201 | src_txt = ex['src_txt']
202 | tgt_txt = ex['tgt_txt']
203 |
204 | end_id = [src[-1]]
205 | src = src[:-1][:self.args.max_pos - 1] + end_id
206 | segs = segs[:self.args.max_pos]
207 | max_sent_id = bisect.bisect_left(clss, self.args.max_pos)
208 | src_sent_labels = src_sent_labels[:max_sent_id]
209 | clss = clss[:max_sent_id]
210 | # src_txt = src_txt[:max_sent_id]
211 |
212 |
213 |
214 | if(is_test):
215 | return src, tgt, segs, clss, src_sent_labels, src_txt, tgt_txt
216 | else:
217 | return src, tgt, segs, clss, src_sent_labels
218 |
219 | def batch_buffer(self, data, batch_size):
220 | minibatch, size_so_far = [], 0
221 | for ex in data:
222 | if(len(ex['src'])==0):
223 | continue
224 | ex = self.preprocess(ex, self.is_test)
225 | if(ex is None):
226 | continue
227 | minibatch.append(ex)
228 | size_so_far = self.batch_size_fn(ex, len(minibatch))
229 | if size_so_far == batch_size:
230 | yield minibatch
231 | minibatch, size_so_far = [], 0
232 | elif size_so_far > batch_size:
233 | yield minibatch[:-1]
234 | minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1)
235 | if minibatch:
236 | yield minibatch
237 |
238 | def batch(self, data, batch_size):
239 | """Yield elements from data in chunks of batch_size."""
240 | minibatch, size_so_far = [], 0
241 | for ex in data:
242 | minibatch.append(ex)
243 | size_so_far = self.batch_size_fn(ex, len(minibatch))
244 | if size_so_far == batch_size:
245 | yield minibatch
246 | minibatch, size_so_far = [], 0
247 | elif size_so_far > batch_size:
248 | yield minibatch[:-1]
249 | minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1)
250 | if minibatch:
251 | yield minibatch
252 |
253 | def create_batches(self):
254 | """ Create batches """
255 | data = self.data()
256 | for buffer in self.batch_buffer(data, self.batch_size * 300):
257 |
258 | if (self.args.task == 'abs'):
259 | p_batch = sorted(buffer, key=lambda x: len(x[2]))
260 | p_batch = sorted(p_batch, key=lambda x: len(x[1]))
261 | else:
262 | p_batch = sorted(buffer, key=lambda x: len(x[2]))
263 |
264 | p_batch = self.batch(p_batch, self.batch_size)
265 |
266 |
267 | p_batch = list(p_batch)
268 | if (self.shuffle):
269 | random.shuffle(p_batch)
270 | for b in p_batch:
271 | if(len(b)==0):
272 | continue
273 | yield b
274 |
275 | def __iter__(self):
276 | while True:
277 | self.batches = self.create_batches()
278 | for idx, minibatch in enumerate(self.batches):
279 | # fast-forward if loaded from state
280 | if self._iterations_this_epoch > idx:
281 | continue
282 | self.iterations += 1
283 | self._iterations_this_epoch += 1
284 | batch = Batch(minibatch, self.device, self.is_test)
285 |
286 | yield batch
287 | return
288 |
289 |
290 | class TextDataloader(object):
291 | def __init__(self, args, datasets, batch_size,
292 | device, shuffle, is_test):
293 | self.args = args
294 | self.batch_size = batch_size
295 | self.device = device
296 |
297 | def data(self):
298 | if self.shuffle:
299 | random.shuffle(self.dataset)
300 | xs = self.dataset
301 | return xs
302 |
303 | def preprocess(self, ex, is_test):
304 | src = ex['src']
305 | tgt = ex['tgt'][:self.args.max_tgt_len][:-1] + [2]
306 | src_sent_labels = ex['src_sent_labels']
307 | segs = ex['segs']
308 | if (not self.args.use_interval):
309 | segs = [0] * len(segs)
310 | clss = ex['clss']
311 | src_txt = ex['src_txt']
312 | tgt_txt = ex['tgt_txt']
313 |
314 | end_id = [src[-1]]
315 | src = src[:-1][:self.args.max_pos - 1] + end_id
316 | segs = segs[:self.args.max_pos]
317 | max_sent_id = bisect.bisect_left(clss, self.args.max_pos)
318 | src_sent_labels = src_sent_labels[:max_sent_id]
319 | clss = clss[:max_sent_id]
320 | # src_txt = src_txt[:max_sent_id]
321 |
322 | if (is_test):
323 | return src, tgt, segs, clss, src_sent_labels, src_txt, tgt_txt
324 | else:
325 | return src, tgt, segs, clss, src_sent_labels
326 |
327 | def batch_buffer(self, data, batch_size):
328 | minibatch, size_so_far = [], 0
329 | for ex in data:
330 | if (len(ex['src']) == 0):
331 | continue
332 | ex = self.preprocess(ex, self.is_test)
333 | if (ex is None):
334 | continue
335 | minibatch.append(ex)
336 | size_so_far = simple_batch_size_fn(ex, len(minibatch))
337 | if size_so_far == batch_size:
338 | yield minibatch
339 | minibatch, size_so_far = [], 0
340 | elif size_so_far > batch_size:
341 | yield minibatch[:-1]
342 | minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1)
343 | if minibatch:
344 | yield minibatch
345 |
346 | def create_batches(self):
347 | """ Create batches """
348 | data = self.data()
349 | for buffer in self.batch_buffer(data, self.batch_size * 300):
350 | if (self.args.task == 'abs'):
351 | p_batch = sorted(buffer, key=lambda x: len(x[2]))
352 | p_batch = sorted(p_batch, key=lambda x: len(x[1]))
353 | else:
354 | p_batch = sorted(buffer, key=lambda x: len(x[2]))
355 | p_batch = batch(p_batch, self.batch_size)
356 |
357 | p_batch = batch(p_batch, self.batch_size)
358 |
359 | p_batch = list(p_batch)
360 | if (self.shuffle):
361 | random.shuffle(p_batch)
362 | for b in p_batch:
363 | if (len(b) == 0):
364 | continue
365 | yield b
366 |
367 | def __iter__(self):
368 | while True:
369 | self.batches = self.create_batches()
370 | for idx, minibatch in enumerate(self.batches):
371 | # fast-forward if loaded from state
372 | if self._iterations_this_epoch > idx:
373 | continue
374 | self.iterations += 1
375 | self._iterations_this_epoch += 1
376 | batch = Batch(minibatch, self.device, self.is_test)
377 |
378 | yield batch
379 | return
380 |
--------------------------------------------------------------------------------
/src/models/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of "Attention is All You Need"
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 |
9 | from models.encoder import PositionalEncoding
10 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState
11 |
12 | MAX_SIZE = 5000
13 |
14 |
15 | class TransformerDecoderLayer(nn.Module):
16 | """
17 | Args:
18 | d_model (int): the dimension of keys/values/queries in
19 | MultiHeadedAttention, also the input size of
20 | the first-layer of the PositionwiseFeedForward.
21 | heads (int): the number of heads for MultiHeadedAttention.
22 | d_ff (int): the second-layer of the PositionwiseFeedForward.
23 | dropout (float): dropout probability(0-1.0).
24 | self_attn_type (string): type of self-attention scaled-dot, average
25 | """
26 |
27 | def __init__(self, d_model, heads, d_ff, dropout):
28 | super(TransformerDecoderLayer, self).__init__()
29 |
30 |
31 | self.self_attn = MultiHeadedAttention(
32 | heads, d_model, dropout=dropout)
33 |
34 | self.context_attn = MultiHeadedAttention(
35 | heads, d_model, dropout=dropout)
36 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
37 | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
38 | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
39 | self.drop = nn.Dropout(dropout)
40 | mask = self._get_attn_subsequent_mask(MAX_SIZE)
41 | # Register self.mask as a buffer in TransformerDecoderLayer, so
42 | # it gets TransformerDecoderLayer's cuda behavior automatically.
43 | self.register_buffer('mask', mask)
44 |
45 | def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
46 | previous_input=None, layer_cache=None, step=None):
47 | """
48 | Args:
49 | inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
50 | memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
51 | src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
52 | tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
53 |
54 | Returns:
55 | (`FloatTensor`, `FloatTensor`, `FloatTensor`):
56 |
57 | * output `[batch_size x 1 x model_dim]`
58 | * attn `[batch_size x 1 x src_len]`
59 | * all_input `[batch_size x current_step x model_dim]`
60 |
61 | """
62 | dec_mask = torch.gt(tgt_pad_mask +
63 | self.mask[:, :tgt_pad_mask.size(1),
64 | :tgt_pad_mask.size(1)], 0)
65 | input_norm = self.layer_norm_1(inputs)
66 | all_input = input_norm
67 | if previous_input is not None:
68 | all_input = torch.cat((previous_input, input_norm), dim=1)
69 | dec_mask = None
70 |
71 | query = self.self_attn(all_input, all_input, input_norm,
72 | mask=dec_mask,
73 | layer_cache=layer_cache,
74 | type="self")
75 |
76 | query = self.drop(query) + inputs
77 |
78 | query_norm = self.layer_norm_2(query)
79 | mid = self.context_attn(memory_bank, memory_bank, query_norm,
80 | mask=src_pad_mask,
81 | layer_cache=layer_cache,
82 | type="context")
83 | output = self.feed_forward(self.drop(mid) + query)
84 |
85 | return output, all_input
86 | # return output
87 |
88 | def _get_attn_subsequent_mask(self, size):
89 | """
90 | Get an attention mask to avoid using the subsequent info.
91 |
92 | Args:
93 | size: int
94 |
95 | Returns:
96 | (`LongTensor`):
97 |
98 | * subsequent_mask `[1 x size x size]`
99 | """
100 | attn_shape = (1, size, size)
101 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
102 | subsequent_mask = torch.from_numpy(subsequent_mask)
103 | return subsequent_mask
104 |
105 |
106 |
107 | class TransformerDecoder(nn.Module):
108 | """
109 | The Transformer decoder from "Attention is All You Need".
110 |
111 |
112 | .. mermaid::
113 |
114 | graph BT
115 | A[input]
116 | B[multi-head self-attn]
117 | BB[multi-head src-attn]
118 | C[feed forward]
119 | O[output]
120 | A --> B
121 | B --> BB
122 | BB --> C
123 | C --> O
124 |
125 |
126 | Args:
127 | num_layers (int): number of encoder layers.
128 | d_model (int): size of the model
129 | heads (int): number of heads
130 | d_ff (int): size of the inner FF layer
131 | dropout (float): dropout parameters
132 | embeddings (:obj:`onmt.modules.Embeddings`):
133 | embeddings to use, should have positional encodings
134 | attn_type (str): if using a seperate copy attention
135 | """
136 |
137 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
138 | super(TransformerDecoder, self).__init__()
139 |
140 | # Basic attributes.
141 | self.decoder_type = 'transformer'
142 | self.num_layers = num_layers
143 | self.embeddings = embeddings
144 | self.pos_emb = PositionalEncoding(dropout,self.embeddings.embedding_dim)
145 |
146 |
147 | # Build TransformerDecoder.
148 | self.transformer_layers = nn.ModuleList(
149 | [TransformerDecoderLayer(d_model, heads, d_ff, dropout)
150 | for _ in range(num_layers)])
151 |
152 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
153 |
154 | def forward(self, tgt, memory_bank, state, memory_lengths=None,
155 | step=None, cache=None,memory_masks=None):
156 | """
157 | See :obj:`onmt.modules.RNNDecoderBase.forward()`
158 | """
159 |
160 | src_words = state.src
161 | tgt_words = tgt
162 | src_batch, src_len = src_words.size()
163 | tgt_batch, tgt_len = tgt_words.size()
164 |
165 | # Run the forward pass of the TransformerDecoder.
166 | # emb = self.embeddings(tgt, step=step)
167 | emb = self.embeddings(tgt)
168 | assert emb.dim() == 3 # len x batch x embedding_dim
169 |
170 | output = self.pos_emb(emb, step)
171 |
172 | src_memory_bank = memory_bank
173 | padding_idx = self.embeddings.padding_idx
174 | tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
175 | .expand(tgt_batch, tgt_len, tgt_len)
176 |
177 | if (not memory_masks is None):
178 | src_len = memory_masks.size(-1)
179 | src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len)
180 |
181 | else:
182 | src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
183 | .expand(src_batch, tgt_len, src_len)
184 |
185 | if state.cache is None:
186 | saved_inputs = []
187 |
188 | for i in range(self.num_layers):
189 | prev_layer_input = None
190 | if state.cache is None:
191 | if state.previous_input is not None:
192 | prev_layer_input = state.previous_layer_inputs[i]
193 | output, all_input \
194 | = self.transformer_layers[i](
195 | output, src_memory_bank,
196 | src_pad_mask, tgt_pad_mask,
197 | previous_input=prev_layer_input,
198 | layer_cache=state.cache["layer_{}".format(i)]
199 | if state.cache is not None else None,
200 | step=step)
201 | if state.cache is None:
202 | saved_inputs.append(all_input)
203 |
204 | if state.cache is None:
205 | saved_inputs = torch.stack(saved_inputs)
206 |
207 | output = self.layer_norm(output)
208 |
209 | # Process the result and update the attentions.
210 |
211 | if state.cache is None:
212 | state = state.update_state(tgt, saved_inputs)
213 |
214 | return output, state
215 |
216 | def init_decoder_state(self, src, memory_bank,
217 | with_cache=False):
218 | """ Init decoder state """
219 | state = TransformerDecoderState(src)
220 | if with_cache:
221 | state._init_cache(memory_bank, self.num_layers)
222 | return state
223 |
224 |
225 |
226 | class TransformerDecoderState(DecoderState):
227 | """ Transformer Decoder state base class """
228 |
229 | def __init__(self, src):
230 | """
231 | Args:
232 | src (FloatTensor): a sequence of source words tensors
233 | with optional feature tensors, of size (len x batch).
234 | """
235 | self.src = src
236 | self.previous_input = None
237 | self.previous_layer_inputs = None
238 | self.cache = None
239 |
240 | @property
241 | def _all(self):
242 | """
243 | Contains attributes that need to be updated in self.beam_update().
244 | """
245 | if (self.previous_input is not None
246 | and self.previous_layer_inputs is not None):
247 | return (self.previous_input,
248 | self.previous_layer_inputs,
249 | self.src)
250 | else:
251 | return (self.src,)
252 |
253 | def detach(self):
254 | if self.previous_input is not None:
255 | self.previous_input = self.previous_input.detach()
256 | if self.previous_layer_inputs is not None:
257 | self.previous_layer_inputs = self.previous_layer_inputs.detach()
258 | self.src = self.src.detach()
259 |
260 | def update_state(self, new_input, previous_layer_inputs):
261 | state = TransformerDecoderState(self.src)
262 | state.previous_input = new_input
263 | state.previous_layer_inputs = previous_layer_inputs
264 | return state
265 |
266 | def _init_cache(self, memory_bank, num_layers):
267 | self.cache = {}
268 |
269 | for l in range(num_layers):
270 | layer_cache = {
271 | "memory_keys": None,
272 | "memory_values": None
273 | }
274 | layer_cache["self_keys"] = None
275 | layer_cache["self_values"] = None
276 | self.cache["layer_{}".format(l)] = layer_cache
277 |
278 | def repeat_beam_size_times(self, beam_size):
279 | """ Repeat beam_size times along batch dimension. """
280 | self.src = self.src.data.repeat(1, beam_size, 1)
281 |
282 | def map_batch_fn(self, fn):
283 | def _recursive_map(struct, batch_dim=0):
284 | for k, v in struct.items():
285 | if v is not None:
286 | if isinstance(v, dict):
287 | _recursive_map(v)
288 | else:
289 | struct[k] = fn(v, batch_dim)
290 |
291 | self.src = fn(self.src, 0)
292 | if self.cache is not None:
293 | _recursive_map(self.cache)
294 |
295 |
296 |
297 |
--------------------------------------------------------------------------------
/src/models/encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward
7 |
8 |
9 | class Classifier(nn.Module):
10 | def __init__(self, hidden_size):
11 | super(Classifier, self).__init__()
12 | self.linear1 = nn.Linear(hidden_size, 1)
13 | self.sigmoid = nn.Sigmoid()
14 |
15 | def forward(self, x, mask_cls):
16 | h = self.linear1(x).squeeze(-1)
17 | sent_scores = self.sigmoid(h) * mask_cls.float()
18 | return sent_scores
19 |
20 |
21 | class PositionalEncoding(nn.Module):
22 |
23 | def __init__(self, dropout, dim, max_len=5000):
24 | pe = torch.zeros(max_len, dim)
25 | position = torch.arange(0, max_len).unsqueeze(1)
26 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
27 | -(math.log(10000.0) / dim)))
28 | pe[:, 0::2] = torch.sin(position.float() * div_term)
29 | pe[:, 1::2] = torch.cos(position.float() * div_term)
30 | pe = pe.unsqueeze(0)
31 | super(PositionalEncoding, self).__init__()
32 | self.register_buffer('pe', pe)
33 | self.dropout = nn.Dropout(p=dropout)
34 | self.dim = dim
35 |
36 | def forward(self, emb, step=None):
37 | emb = emb * math.sqrt(self.dim)
38 | if (step):
39 | emb = emb + self.pe[:, step][:, None, :]
40 |
41 | else:
42 | emb = emb + self.pe[:, :emb.size(1)]
43 | emb = self.dropout(emb)
44 | return emb
45 |
46 | def get_emb(self, emb):
47 | return self.pe[:, :emb.size(1)]
48 |
49 |
50 | class TransformerEncoderLayer(nn.Module):
51 | def __init__(self, d_model, heads, d_ff, dropout):
52 | super(TransformerEncoderLayer, self).__init__()
53 |
54 | self.self_attn = MultiHeadedAttention(
55 | heads, d_model, dropout=dropout)
56 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
57 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
58 | self.dropout = nn.Dropout(dropout)
59 |
60 | def forward(self, iter, query, inputs, mask):
61 | if (iter != 0):
62 | input_norm = self.layer_norm(inputs)
63 | else:
64 | input_norm = inputs
65 |
66 | mask = mask.unsqueeze(1)
67 | context = self.self_attn(input_norm, input_norm, input_norm,
68 | mask=mask)
69 | out = self.dropout(context) + inputs
70 | return self.feed_forward(out)
71 |
72 |
73 | class ExtTransformerEncoder(nn.Module):
74 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0):
75 | super(ExtTransformerEncoder, self).__init__()
76 | self.d_model = d_model
77 | self.num_inter_layers = num_inter_layers
78 | self.pos_emb = PositionalEncoding(dropout, d_model)
79 | self.transformer_inter = nn.ModuleList(
80 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout)
81 | for _ in range(num_inter_layers)])
82 | self.dropout = nn.Dropout(dropout)
83 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
84 | self.wo = nn.Linear(d_model, 1, bias=True)
85 | self.sigmoid = nn.Sigmoid()
86 |
87 | def forward(self, top_vecs, mask):
88 | """ See :obj:`EncoderBase.forward()`"""
89 |
90 | batch_size, n_sents = top_vecs.size(0), top_vecs.size(1)
91 | pos_emb = self.pos_emb.pe[:, :n_sents]
92 | x = top_vecs * mask[:, :, None].float()
93 | x = x + pos_emb
94 |
95 | for i in range(self.num_inter_layers):
96 | x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents * max_tokens * dim
97 |
98 | x = self.layer_norm(x)
99 | sent_scores = self.sigmoid(self.wo(x))
100 | sent_scores = sent_scores.squeeze(-1) * mask.float()
101 |
102 | return sent_scores
103 |
104 |
--------------------------------------------------------------------------------
/src/models/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | This file handles the details of the loss function during training.
3 |
4 | This includes: LossComputeBase and the standard NMTLossCompute, and
5 | sharded loss compute stuff.
6 | """
7 | from __future__ import division
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from models.reporter import Statistics
13 |
14 |
15 | def abs_loss(generator, symbols, vocab_size, device, train=True, label_smoothing=0.0):
16 | compute = NMTLossCompute(
17 | generator, symbols, vocab_size,
18 | label_smoothing=label_smoothing if train else 0.0)
19 | compute.to(device)
20 | return compute
21 |
22 |
23 |
24 | class LossComputeBase(nn.Module):
25 | """
26 | Class for managing efficient loss computation. Handles
27 | sharding next step predictions and accumulating mutiple
28 | loss computations
29 |
30 |
31 | Users can implement their own loss computation strategy by making
32 | subclass of this one. Users need to implement the _compute_loss()
33 | and make_shard_state() methods.
34 |
35 | Args:
36 | generator (:obj:`nn.Module`) :
37 | module that maps the output of the decoder to a
38 | distribution over the target vocabulary.
39 | tgt_vocab (:obj:`Vocab`) :
40 | torchtext vocab object representing the target output
41 | normalzation (str): normalize by "sents" or "tokens"
42 | """
43 |
44 | def __init__(self, generator, pad_id):
45 | super(LossComputeBase, self).__init__()
46 | self.generator = generator
47 | self.padding_idx = pad_id
48 |
49 |
50 |
51 | def _make_shard_state(self, batch, output, attns=None):
52 | """
53 | Make shard state dictionary for shards() to return iterable
54 | shards for efficient loss computation. Subclass must define
55 | this method to match its own _compute_loss() interface.
56 | Args:
57 | batch: the current batch.
58 | output: the predict output from the model.
59 | range_: the range of examples for computing, the whole
60 | batch or a trunc of it?
61 | attns: the attns dictionary returned from the model.
62 | """
63 | return NotImplementedError
64 |
65 | def _compute_loss(self, batch, output, target, **kwargs):
66 | """
67 | Compute the loss. Subclass must define this method.
68 |
69 | Args:
70 |
71 | batch: the current batch.
72 | output: the predict output from the model.
73 | target: the validate target to compare output with.
74 | **kwargs(optional): additional info for computing loss.
75 | """
76 | return NotImplementedError
77 |
78 | def monolithic_compute_loss(self, batch, output):
79 | """
80 | Compute the forward loss for the batch.
81 |
82 | Args:
83 | batch (batch): batch of labeled examples
84 | output (:obj:`FloatTensor`):
85 | output of decoder model `[tgt_len x batch x hidden]`
86 | attns (dict of :obj:`FloatTensor`) :
87 | dictionary of attention distributions
88 | `[tgt_len x batch x src_len]`
89 | Returns:
90 | :obj:`onmt.utils.Statistics`: loss statistics
91 | """
92 | shard_state = self._make_shard_state(batch, output)
93 | _, batch_stats = self._compute_loss(batch, **shard_state)
94 |
95 | return batch_stats
96 |
97 | def sharded_compute_loss(self, batch, output,
98 | shard_size,
99 | normalization):
100 | """Compute the forward loss and backpropagate. Computation is done
101 | with shards and optionally truncation for memory efficiency.
102 |
103 | Also supports truncated BPTT for long sequences by taking a
104 | range in the decoder output sequence to back propagate in.
105 | Range is from `(cur_trunc, cur_trunc + trunc_size)`.
106 |
107 | Note sharding is an exact efficiency trick to relieve memory
108 | required for the generation buffers. Truncation is an
109 | approximate efficiency trick to relieve the memory required
110 | in the RNN buffers.
111 |
112 | Args:
113 | batch (batch) : batch of labeled examples
114 | output (:obj:`FloatTensor`) :
115 | output of decoder model `[tgt_len x batch x hidden]`
116 | attns (dict) : dictionary of attention distributions
117 | `[tgt_len x batch x src_len]`
118 | cur_trunc (int) : starting position of truncation window
119 | trunc_size (int) : length of truncation window
120 | shard_size (int) : maximum number of examples in a shard
121 | normalization (int) : Loss is divided by this number
122 |
123 | Returns:
124 | :obj:`onmt.utils.Statistics`: validation loss statistics
125 |
126 | """
127 | batch_stats = Statistics()
128 | shard_state = self._make_shard_state(batch, output)
129 | for shard in shards(shard_state, shard_size):
130 | loss, stats = self._compute_loss(batch, **shard)
131 | loss.div(float(normalization)).backward()
132 | batch_stats.update(stats)
133 |
134 | return batch_stats
135 |
136 | def _stats(self, loss, scores, target):
137 | """
138 | Args:
139 | loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
140 | scores (:obj:`FloatTensor`): a score for each possible output
141 | target (:obj:`FloatTensor`): true targets
142 |
143 | Returns:
144 | :obj:`onmt.utils.Statistics` : statistics for this batch.
145 | """
146 | pred = scores.max(1)[1]
147 | non_padding = target.ne(self.padding_idx)
148 | num_correct = pred.eq(target) \
149 | .masked_select(non_padding) \
150 | .sum() \
151 | .item()
152 | num_non_padding = non_padding.sum().item()
153 | return Statistics(loss.item(), num_non_padding, num_correct)
154 |
155 | def _bottle(self, _v):
156 | return _v.view(-1, _v.size(2))
157 |
158 | def _unbottle(self, _v, batch_size):
159 | return _v.view(-1, batch_size, _v.size(1))
160 |
161 |
162 | class LabelSmoothingLoss(nn.Module):
163 | """
164 | With label smoothing,
165 | KL-divergence between q_{smoothed ground truth prob.}(w)
166 | and p_{prob. computed by model}(w) is minimized.
167 | """
168 | def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
169 | assert 0.0 < label_smoothing <= 1.0
170 | self.padding_idx = ignore_index
171 | super(LabelSmoothingLoss, self).__init__()
172 |
173 | smoothing_value = label_smoothing / (tgt_vocab_size - 2)
174 | one_hot = torch.full((tgt_vocab_size,), smoothing_value)
175 | one_hot[self.padding_idx] = 0
176 | self.register_buffer('one_hot', one_hot.unsqueeze(0))
177 | self.confidence = 1.0 - label_smoothing
178 |
179 | def forward(self, output, target):
180 | """
181 | output (FloatTensor): batch_size x n_classes
182 | target (LongTensor): batch_size
183 | """
184 | model_prob = self.one_hot.repeat(target.size(0), 1)
185 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
186 | model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0)
187 |
188 | return F.kl_div(output, model_prob, reduction='sum')
189 |
190 |
191 | class NMTLossCompute(LossComputeBase):
192 | """
193 | Standard NMT Loss Computation.
194 | """
195 |
196 | def __init__(self, generator, symbols, vocab_size,
197 | label_smoothing=0.0):
198 | super(NMTLossCompute, self).__init__(generator, symbols['PAD'])
199 | self.sparse = not isinstance(generator[1], nn.LogSoftmax)
200 | if label_smoothing > 0:
201 | self.criterion = LabelSmoothingLoss(
202 | label_smoothing, vocab_size, ignore_index=self.padding_idx
203 | )
204 | else:
205 | self.criterion = nn.NLLLoss(
206 | ignore_index=self.padding_idx, reduction='sum'
207 | )
208 |
209 | def _make_shard_state(self, batch, output):
210 | return {
211 | "output": output,
212 | "target": batch.tgt[:,1:],
213 | }
214 |
215 | def _compute_loss(self, batch, output, target):
216 | bottled_output = self._bottle(output)
217 | scores = self.generator(bottled_output)
218 | gtruth =target.contiguous().view(-1)
219 |
220 | loss = self.criterion(scores, gtruth)
221 |
222 | stats = self._stats(loss.clone(), scores, gtruth)
223 |
224 | return loss, stats
225 |
226 |
227 | def filter_shard_state(state, shard_size=None):
228 | """ ? """
229 | for k, v in state.items():
230 | if shard_size is None:
231 | yield k, v
232 |
233 | if v is not None:
234 | v_split = []
235 | if isinstance(v, torch.Tensor):
236 | for v_chunk in torch.split(v, shard_size):
237 | v_chunk = v_chunk.data.clone()
238 | v_chunk.requires_grad = v.requires_grad
239 | v_split.append(v_chunk)
240 | yield k, (v, v_split)
241 |
242 |
243 | def shards(state, shard_size, eval_only=False):
244 | """
245 | Args:
246 | state: A dictionary which corresponds to the output of
247 | *LossCompute._make_shard_state(). The values for
248 | those keys are Tensor-like or None.
249 | shard_size: The maximum size of the shards yielded by the model.
250 | eval_only: If True, only yield the state, nothing else.
251 | Otherwise, yield shards.
252 |
253 | Yields:
254 | Each yielded shard is a dict.
255 |
256 | Side effect:
257 | After the last shard, this function does back-propagation.
258 | """
259 | if eval_only:
260 | yield filter_shard_state(state)
261 | else:
262 | # non_none: the subdict of the state dictionary where the values
263 | # are not None.
264 | non_none = dict(filter_shard_state(state, shard_size))
265 |
266 | # Now, the iteration:
267 | # state is a dictionary of sequences of tensor-like but we
268 | # want a sequence of dictionaries of tensors.
269 | # First, unzip the dictionary into a sequence of keys and a
270 | # sequence of tensor-like sequences.
271 | keys, values = zip(*((k, [v_chunk for v_chunk in v_split])
272 | for k, (_, v_split) in non_none.items()))
273 |
274 | # Now, yield a dictionary for each shard. The keys are always
275 | # the same. values is a sequence of length #keys where each
276 | # element is a sequence of length #shards. We want to iterate
277 | # over the shards, not over the keys: therefore, the values need
278 | # to be re-zipped by shard and then each shard can be paired
279 | # with the keys.
280 | for shard_tensors in zip(*values):
281 | yield dict(zip(keys, shard_tensors))
282 |
283 | # Assumed backprop'd
284 | variables = []
285 | for k, (v, v_split) in non_none.items():
286 | if isinstance(v, torch.Tensor) and state[k].requires_grad:
287 | variables.extend(zip(torch.split(state[k], shard_size),
288 | [v_chunk.grad for v_chunk in v_split]))
289 | inputs, grads = zip(*variables)
290 | torch.autograd.backward(inputs, grads)
291 |
--------------------------------------------------------------------------------
/src/models/model_builder.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from pytorch_transformers import BertModel, BertConfig
6 | from torch.nn.init import xavier_uniform_
7 |
8 | from models.decoder import TransformerDecoder
9 | from models.encoder import Classifier, ExtTransformerEncoder
10 | from models.optimizers import Optimizer
11 |
12 | def build_optim(args, model, checkpoint):
13 | """ Build optimizer """
14 |
15 | if checkpoint is not None:
16 | optim = checkpoint['optim'][0]
17 | saved_optimizer_state_dict = optim.optimizer.state_dict()
18 | optim.optimizer.load_state_dict(saved_optimizer_state_dict)
19 | if args.visible_gpus != '-1':
20 | for state in optim.optimizer.state.values():
21 | for k, v in state.items():
22 | if torch.is_tensor(v):
23 | state[k] = v.cuda()
24 |
25 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
26 | raise RuntimeError(
27 | "Error: loaded Adam optimizer from existing model" +
28 | " but optimizer state is empty")
29 |
30 | else:
31 | optim = Optimizer(
32 | args.optim, args.lr, args.max_grad_norm,
33 | beta1=args.beta1, beta2=args.beta2,
34 | decay_method='noam',
35 | warmup_steps=args.warmup_steps)
36 |
37 | optim.set_parameters(list(model.named_parameters()))
38 |
39 |
40 | return optim
41 |
42 | def build_optim_bert(args, model, checkpoint):
43 | """ Build optimizer """
44 |
45 | if checkpoint is not None:
46 | optim = checkpoint['optims'][0]
47 | saved_optimizer_state_dict = optim.optimizer.state_dict()
48 | optim.optimizer.load_state_dict(saved_optimizer_state_dict)
49 | if args.visible_gpus != '-1':
50 | for state in optim.optimizer.state.values():
51 | for k, v in state.items():
52 | if torch.is_tensor(v):
53 | state[k] = v.cuda()
54 |
55 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
56 | raise RuntimeError(
57 | "Error: loaded Adam optimizer from existing model" +
58 | " but optimizer state is empty")
59 |
60 | else:
61 | optim = Optimizer(
62 | args.optim, args.lr_bert, args.max_grad_norm,
63 | beta1=args.beta1, beta2=args.beta2,
64 | decay_method='noam',
65 | warmup_steps=args.warmup_steps_bert)
66 |
67 | params = [(n, p) for n, p in list(model.named_parameters()) if n.startswith('bert.model')]
68 | optim.set_parameters(params)
69 |
70 |
71 | return optim
72 |
73 | def build_optim_dec(args, model, checkpoint):
74 | """ Build optimizer """
75 |
76 | if checkpoint is not None:
77 | optim = checkpoint['optims'][1]
78 | saved_optimizer_state_dict = optim.optimizer.state_dict()
79 | optim.optimizer.load_state_dict(saved_optimizer_state_dict)
80 | if args.visible_gpus != '-1':
81 | for state in optim.optimizer.state.values():
82 | for k, v in state.items():
83 | if torch.is_tensor(v):
84 | state[k] = v.cuda()
85 |
86 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
87 | raise RuntimeError(
88 | "Error: loaded Adam optimizer from existing model" +
89 | " but optimizer state is empty")
90 |
91 | else:
92 | optim = Optimizer(
93 | args.optim, args.lr_dec, args.max_grad_norm,
94 | beta1=args.beta1, beta2=args.beta2,
95 | decay_method='noam',
96 | warmup_steps=args.warmup_steps_dec)
97 |
98 | params = [(n, p) for n, p in list(model.named_parameters()) if not n.startswith('bert.model')]
99 | optim.set_parameters(params)
100 |
101 |
102 | return optim
103 |
104 |
105 | def get_generator(vocab_size, dec_hidden_size, device):
106 | gen_func = nn.LogSoftmax(dim=-1)
107 | generator = nn.Sequential(
108 | nn.Linear(dec_hidden_size, vocab_size),
109 | gen_func
110 | )
111 | generator.to(device)
112 |
113 | return generator
114 |
115 | class Bert(nn.Module):
116 | def __init__(self, large, temp_dir, finetune=False):
117 | super(Bert, self).__init__()
118 | if(large):
119 | self.model = BertModel.from_pretrained('bert-large-uncased', cache_dir=temp_dir)
120 | else:
121 | self.model = BertModel.from_pretrained('bert-base-uncased', cache_dir=temp_dir)
122 |
123 | self.finetune = finetune
124 |
125 | def forward(self, x, segs, mask):
126 | if(self.finetune):
127 | top_vec, _ = self.model(x, segs, attention_mask=mask)
128 | else:
129 | self.eval()
130 | with torch.no_grad():
131 | top_vec, _ = self.model(x, segs, attention_mask=mask)
132 | return top_vec
133 |
134 |
135 | class ExtSummarizer(nn.Module):
136 | def __init__(self, args, device, checkpoint):
137 | super(ExtSummarizer, self).__init__()
138 | self.args = args
139 | self.device = device
140 | self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
141 |
142 | self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads,
143 | args.ext_dropout, args.ext_layers)
144 | if (args.encoder == 'baseline'):
145 | bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
146 | num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
147 | self.bert.model = BertModel(bert_config)
148 | self.ext_layer = Classifier(self.bert.model.config.hidden_size)
149 |
150 | if(args.max_pos>512):
151 | my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
152 | my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
153 | my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
154 | self.bert.model.embeddings.position_embeddings = my_pos_embeddings
155 |
156 |
157 | if checkpoint is not None:
158 | self.load_state_dict(checkpoint['model'], strict=True)
159 | else:
160 | if args.param_init != 0.0:
161 | for p in self.ext_layer.parameters():
162 | p.data.uniform_(-args.param_init, args.param_init)
163 | if args.param_init_glorot:
164 | for p in self.ext_layer.parameters():
165 | if p.dim() > 1:
166 | xavier_uniform_(p)
167 |
168 | self.to(device)
169 |
170 | def forward(self, src, segs, clss, mask_src, mask_cls):
171 | top_vec = self.bert(src, segs, mask_src)
172 | sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
173 | sents_vec = sents_vec * mask_cls[:, :, None].float()
174 | sent_scores = self.ext_layer(sents_vec, mask_cls).squeeze(-1)
175 | return sent_scores, mask_cls
176 |
177 |
178 | class AbsSummarizer(nn.Module):
179 | def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):
180 | super(AbsSummarizer, self).__init__()
181 | self.args = args
182 | self.device = device
183 | self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
184 |
185 | if bert_from_extractive is not None:
186 | self.bert.model.load_state_dict(
187 | dict([(n[11:], p) for n, p in bert_from_extractive.items() if n.startswith('bert.model')]), strict=True)
188 |
189 | if (args.encoder == 'baseline'):
190 | bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.enc_hidden_size,
191 | num_hidden_layers=args.enc_layers, num_attention_heads=8,
192 | intermediate_size=args.enc_ff_size,
193 | hidden_dropout_prob=args.enc_dropout,
194 | attention_probs_dropout_prob=args.enc_dropout)
195 | self.bert.model = BertModel(bert_config)
196 |
197 | if(args.max_pos>512):
198 | my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
199 | my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
200 | my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
201 | self.bert.model.embeddings.position_embeddings = my_pos_embeddings
202 | self.vocab_size = self.bert.model.config.vocab_size
203 | tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
204 | if (self.args.share_emb):
205 | tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
206 |
207 | self.decoder = TransformerDecoder(
208 | self.args.dec_layers,
209 | self.args.dec_hidden_size, heads=self.args.dec_heads,
210 | d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings)
211 |
212 | self.generator = get_generator(self.vocab_size, self.args.dec_hidden_size, device)
213 | self.generator[0].weight = self.decoder.embeddings.weight
214 |
215 |
216 | if checkpoint is not None:
217 | self.load_state_dict(checkpoint['model'], strict=True)
218 | else:
219 | for module in self.decoder.modules():
220 | if isinstance(module, (nn.Linear, nn.Embedding)):
221 | module.weight.data.normal_(mean=0.0, std=0.02)
222 | elif isinstance(module, nn.LayerNorm):
223 | module.bias.data.zero_()
224 | module.weight.data.fill_(1.0)
225 | if isinstance(module, nn.Linear) and module.bias is not None:
226 | module.bias.data.zero_()
227 | for p in self.generator.parameters():
228 | if p.dim() > 1:
229 | xavier_uniform_(p)
230 | else:
231 | p.data.zero_()
232 | if(args.use_bert_emb):
233 | tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
234 | tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
235 | self.decoder.embeddings = tgt_embeddings
236 | self.generator[0].weight = self.decoder.embeddings.weight
237 |
238 | self.to(device)
239 |
240 | def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):
241 | top_vec = self.bert(src, segs, mask_src)
242 | dec_state = self.decoder.init_decoder_state(src, top_vec)
243 | decoder_outputs, state = self.decoder(tgt[:, :-1], top_vec, dec_state)
244 | return decoder_outputs, None
245 |
--------------------------------------------------------------------------------
/src/models/optimizers.py:
--------------------------------------------------------------------------------
1 | """ Optimizers class """
2 | import torch
3 | import torch.optim as optim
4 | from torch.nn.utils import clip_grad_norm_
5 |
6 |
7 | # from onmt.utils import use_gpu
8 | # from models.adam import Adam
9 |
10 |
11 | def use_gpu(opt):
12 | """
13 | Creates a boolean if gpu used
14 | """
15 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \
16 | (hasattr(opt, 'gpu') and opt.gpu > -1)
17 |
18 | def build_optim(model, opt, checkpoint):
19 | """ Build optimizer """
20 | saved_optimizer_state_dict = None
21 |
22 | if opt.train_from:
23 | optim = checkpoint['optim']
24 | # We need to save a copy of optim.optimizer.state_dict() for setting
25 | # the, optimizer state later on in Stage 2 in this method, since
26 | # the method optim.set_parameters(model.parameters()) will overwrite
27 | # optim.optimizer, and with ith the values stored in
28 | # optim.optimizer.state_dict()
29 | saved_optimizer_state_dict = optim.optimizer.state_dict()
30 | else:
31 | optim = Optimizer(
32 | opt.optim, opt.learning_rate, opt.max_grad_norm,
33 | lr_decay=opt.learning_rate_decay,
34 | start_decay_steps=opt.start_decay_steps,
35 | decay_steps=opt.decay_steps,
36 | beta1=opt.adam_beta1,
37 | beta2=opt.adam_beta2,
38 | adagrad_accum=opt.adagrad_accumulator_init,
39 | decay_method=opt.decay_method,
40 | warmup_steps=opt.warmup_steps)
41 |
42 | optim.set_parameters(model.named_parameters())
43 |
44 | if opt.train_from:
45 | optim.optimizer.load_state_dict(saved_optimizer_state_dict)
46 | if use_gpu(opt):
47 | for state in optim.optimizer.state.values():
48 | for k, v in state.items():
49 | if torch.is_tensor(v):
50 | state[k] = v.cuda()
51 |
52 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
53 | raise RuntimeError(
54 | "Error: loaded Adam optimizer from existing model" +
55 | " but optimizer state is empty")
56 |
57 | return optim
58 |
59 |
60 | class MultipleOptimizer(object):
61 | """ Implement multiple optimizers needed for sparse adam """
62 |
63 | def __init__(self, op):
64 | """ ? """
65 | self.optimizers = op
66 |
67 | def zero_grad(self):
68 | """ ? """
69 | for op in self.optimizers:
70 | op.zero_grad()
71 |
72 | def step(self):
73 | """ ? """
74 | for op in self.optimizers:
75 | op.step()
76 |
77 | @property
78 | def state(self):
79 | """ ? """
80 | return {k: v for op in self.optimizers for k, v in op.state.items()}
81 |
82 | def state_dict(self):
83 | """ ? """
84 | return [op.state_dict() for op in self.optimizers]
85 |
86 | def load_state_dict(self, state_dicts):
87 | """ ? """
88 | assert len(state_dicts) == len(self.optimizers)
89 | for i in range(len(state_dicts)):
90 | self.optimizers[i].load_state_dict(state_dicts[i])
91 |
92 |
93 | class Optimizer(object):
94 | """
95 | Controller class for optimization. Mostly a thin
96 | wrapper for `optim`, but also useful for implementing
97 | rate scheduling beyond what is currently available.
98 | Also implements necessary methods for training RNNs such
99 | as grad manipulations.
100 |
101 | Args:
102 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam]
103 | lr (float): learning rate
104 | lr_decay (float, optional): learning rate decay multiplier
105 | start_decay_steps (int, optional): step to start learning rate decay
106 | beta1, beta2 (float, optional): parameters for adam
107 | adagrad_accum (float, optional): initialization parameter for adagrad
108 | decay_method (str, option): custom decay options
109 | warmup_steps (int, option): parameter for `noam` decay
110 | model_size (int, option): parameter for `noam` decay
111 |
112 | We use the default parameters for Adam that are suggested by
113 | the original paper https://arxiv.org/pdf/1412.6980.pdf
114 | These values are also used by other established implementations,
115 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
116 | https://keras.io/optimizers/
117 | Recently there are slightly different values used in the paper
118 | "Attention is all you need"
119 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
120 | was used there however, beta2=0.999 is still arguably the more
121 | established value, so we use that here as well
122 | """
123 |
124 | def __init__(self, method, learning_rate, max_grad_norm,
125 | lr_decay=1, start_decay_steps=None, decay_steps=None,
126 | beta1=0.9, beta2=0.999,
127 | adagrad_accum=0.0,
128 | decay_method=None,
129 | warmup_steps=4000, weight_decay=0):
130 | self.last_ppl = None
131 | self.learning_rate = learning_rate
132 | self.original_lr = learning_rate
133 | self.max_grad_norm = max_grad_norm
134 | self.method = method
135 | self.lr_decay = lr_decay
136 | self.start_decay_steps = start_decay_steps
137 | self.decay_steps = decay_steps
138 | self.start_decay = False
139 | self._step = 0
140 | self.betas = [beta1, beta2]
141 | self.adagrad_accum = adagrad_accum
142 | self.decay_method = decay_method
143 | self.warmup_steps = warmup_steps
144 | self.weight_decay = weight_decay
145 |
146 | def set_parameters(self, params):
147 | """ ? """
148 | self.params = []
149 | self.sparse_params = []
150 | for k, p in params:
151 | if p.requires_grad:
152 | if self.method != 'sparseadam' or "embed" not in k:
153 | self.params.append(p)
154 | else:
155 | self.sparse_params.append(p)
156 | if self.method == 'sgd':
157 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate)
158 | elif self.method == 'adagrad':
159 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate)
160 | for group in self.optimizer.param_groups:
161 | for p in group['params']:
162 | self.optimizer.state[p]['sum'] = self.optimizer\
163 | .state[p]['sum'].fill_(self.adagrad_accum)
164 | elif self.method == 'adadelta':
165 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate)
166 | elif self.method == 'adam':
167 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate,
168 | betas=self.betas, eps=1e-9)
169 | else:
170 | raise RuntimeError("Invalid optim method: " + self.method)
171 |
172 | def _set_rate(self, learning_rate):
173 | self.learning_rate = learning_rate
174 | if self.method != 'sparseadam':
175 | self.optimizer.param_groups[0]['lr'] = self.learning_rate
176 | else:
177 | for op in self.optimizer.optimizers:
178 | op.param_groups[0]['lr'] = self.learning_rate
179 |
180 | def step(self):
181 | """Update the model parameters based on current gradients.
182 |
183 | Optionally, will employ gradient modification or update learning
184 | rate.
185 | """
186 | self._step += 1
187 |
188 | # Decay method used in tensor2tensor.
189 | if self.decay_method == "noam":
190 | self._set_rate(
191 | self.original_lr *
192 | min(self._step ** (-0.5),
193 | self._step * self.warmup_steps**(-1.5)))
194 |
195 | else:
196 | if ((self.start_decay_steps is not None) and (
197 | self._step >= self.start_decay_steps)):
198 | self.start_decay = True
199 | if self.start_decay:
200 | if ((self._step - self.start_decay_steps)
201 | % self.decay_steps == 0):
202 | self.learning_rate = self.learning_rate * self.lr_decay
203 |
204 | if self.method != 'sparseadam':
205 | self.optimizer.param_groups[0]['lr'] = self.learning_rate
206 |
207 | if self.max_grad_norm:
208 | clip_grad_norm_(self.params, self.max_grad_norm)
209 | self.optimizer.step()
210 |
211 |
212 |
--------------------------------------------------------------------------------
/src/models/reporter.py:
--------------------------------------------------------------------------------
1 | """ Report manager utility """
2 | from __future__ import print_function
3 | from datetime import datetime
4 |
5 | import time
6 | import math
7 | import sys
8 |
9 | from distributed import all_gather_list
10 | from others.logging import logger
11 |
12 |
13 | def build_report_manager(opt):
14 | if opt.tensorboard:
15 | from tensorboardX import SummaryWriter
16 | writer = SummaryWriter(opt.tensorboard_log_dir
17 | + datetime.now().strftime("/%b-%d_%H-%M-%S"),
18 | comment="Unmt")
19 | else:
20 | writer = None
21 |
22 | report_mgr = ReportMgr(opt.report_every, start_time=-1,
23 | tensorboard_writer=writer)
24 | return report_mgr
25 |
26 |
27 | class ReportMgrBase(object):
28 | """
29 | Report Manager Base class
30 | Inherited classes should override:
31 | * `_report_training`
32 | * `_report_step`
33 | """
34 |
35 | def __init__(self, report_every, start_time=-1.):
36 | """
37 | Args:
38 | report_every(int): Report status every this many sentences
39 | start_time(float): manually set report start time. Negative values
40 | means that you will need to set it later or use `start()`
41 | """
42 | self.report_every = report_every
43 | self.progress_step = 0
44 | self.start_time = start_time
45 |
46 | def start(self):
47 | self.start_time = time.time()
48 |
49 | def log(self, *args, **kwargs):
50 | logger.info(*args, **kwargs)
51 |
52 | def report_training(self, step, num_steps, learning_rate,
53 | report_stats, multigpu=False):
54 | """
55 | This is the user-defined batch-level traing progress
56 | report function.
57 |
58 | Args:
59 | step(int): current step count.
60 | num_steps(int): total number of batches.
61 | learning_rate(float): current learning rate.
62 | report_stats(Statistics): old Statistics instance.
63 | Returns:
64 | report_stats(Statistics): updated Statistics instance.
65 | """
66 | if self.start_time < 0:
67 | raise ValueError("""ReportMgr needs to be started
68 | (set 'start_time' or use 'start()'""")
69 |
70 | if multigpu:
71 | report_stats = Statistics.all_gather_stats(report_stats)
72 |
73 | if step % self.report_every == 0:
74 | self._report_training(
75 | step, num_steps, learning_rate, report_stats)
76 | self.progress_step += 1
77 | return Statistics()
78 |
79 | def _report_training(self, *args, **kwargs):
80 | """ To be overridden """
81 | raise NotImplementedError()
82 |
83 | def report_step(self, lr, step, train_stats=None, valid_stats=None):
84 | """
85 | Report stats of a step
86 |
87 | Args:
88 | train_stats(Statistics): training stats
89 | valid_stats(Statistics): validation stats
90 | lr(float): current learning rate
91 | """
92 | self._report_step(
93 | lr, step, train_stats=train_stats, valid_stats=valid_stats)
94 |
95 | def _report_step(self, *args, **kwargs):
96 | raise NotImplementedError()
97 |
98 |
99 | class ReportMgr(ReportMgrBase):
100 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
101 | """
102 | A report manager that writes statistics on standard output as well as
103 | (optionally) TensorBoard
104 |
105 | Args:
106 | report_every(int): Report status every this many sentences
107 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`):
108 | The TensorBoard Summary writer to use or None
109 | """
110 | super(ReportMgr, self).__init__(report_every, start_time)
111 | self.tensorboard_writer = tensorboard_writer
112 |
113 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):
114 | if self.tensorboard_writer is not None:
115 | stats.log_tensorboard(
116 | prefix, self.tensorboard_writer, learning_rate, step)
117 |
118 | def _report_training(self, step, num_steps, learning_rate,
119 | report_stats):
120 | """
121 | See base class method `ReportMgrBase.report_training`.
122 | """
123 | report_stats.output(step, num_steps,
124 | learning_rate, self.start_time)
125 |
126 | # Log the progress using the number of batches on the x-axis.
127 | self.maybe_log_tensorboard(report_stats,
128 | "progress",
129 | learning_rate,
130 | step)
131 | report_stats = Statistics()
132 |
133 | return report_stats
134 |
135 | def _report_step(self, lr, step, train_stats=None, valid_stats=None):
136 | """
137 | See base class method `ReportMgrBase.report_step`.
138 | """
139 | if train_stats is not None:
140 | self.log('Train perplexity: %g' % train_stats.ppl())
141 | self.log('Train accuracy: %g' % train_stats.accuracy())
142 |
143 | self.maybe_log_tensorboard(train_stats,
144 | "train",
145 | lr,
146 | step)
147 |
148 | if valid_stats is not None:
149 | self.log('Validation perplexity: %g' % valid_stats.ppl())
150 | self.log('Validation accuracy: %g' % valid_stats.accuracy())
151 |
152 | self.maybe_log_tensorboard(valid_stats,
153 | "valid",
154 | lr,
155 | step)
156 |
157 |
158 | class Statistics(object):
159 | """
160 | Accumulator for loss statistics.
161 | Currently calculates:
162 |
163 | * accuracy
164 | * perplexity
165 | * elapsed time
166 | """
167 |
168 | def __init__(self, loss=0, n_words=0, n_correct=0):
169 | self.loss = loss
170 | self.n_words = n_words
171 | self.n_docs = 0
172 | self.n_correct = n_correct
173 | self.n_src_words = 0
174 | self.start_time = time.time()
175 |
176 | @staticmethod
177 | def all_gather_stats(stat, max_size=4096):
178 | """
179 | Gather a `Statistics` object accross multiple process/nodes
180 |
181 | Args:
182 | stat(:obj:Statistics): the statistics object to gather
183 | accross all processes/nodes
184 | max_size(int): max buffer size to use
185 |
186 | Returns:
187 | `Statistics`, the update stats object
188 | """
189 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size)
190 | return stats[0]
191 |
192 | @staticmethod
193 | def all_gather_stats_list(stat_list, max_size=4096):
194 | from torch.distributed import get_rank
195 |
196 | """
197 | Gather a `Statistics` list accross all processes/nodes
198 |
199 | Args:
200 | stat_list(list([`Statistics`])): list of statistics objects to
201 | gather accross all processes/nodes
202 | max_size(int): max buffer size to use
203 |
204 | Returns:
205 | our_stats(list([`Statistics`])): list of updated stats
206 | """
207 | # Get a list of world_size lists with len(stat_list) Statistics objects
208 | all_stats = all_gather_list(stat_list, max_size=max_size)
209 |
210 | our_rank = get_rank()
211 | our_stats = all_stats[our_rank]
212 | for other_rank, stats in enumerate(all_stats):
213 | if other_rank == our_rank:
214 | continue
215 | for i, stat in enumerate(stats):
216 | our_stats[i].update(stat, update_n_src_words=True)
217 | return our_stats
218 |
219 | def update(self, stat, update_n_src_words=False):
220 | """
221 | Update statistics by suming values with another `Statistics` object
222 |
223 | Args:
224 | stat: another statistic object
225 | update_n_src_words(bool): whether to update (sum) `n_src_words`
226 | or not
227 |
228 | """
229 | self.loss += stat.loss
230 | self.n_words += stat.n_words
231 | self.n_correct += stat.n_correct
232 | self.n_docs += stat.n_docs
233 |
234 | if update_n_src_words:
235 | self.n_src_words += stat.n_src_words
236 |
237 | def accuracy(self):
238 | """ compute accuracy """
239 | return 100 * (self.n_correct / self.n_words)
240 |
241 | def xent(self):
242 | """ compute cross entropy """
243 | return self.loss / self.n_words
244 |
245 | def ppl(self):
246 | """ compute perplexity """
247 | return math.exp(min(self.loss / self.n_words, 100))
248 |
249 | def elapsed_time(self):
250 | """ compute elapsed time """
251 | return time.time() - self.start_time
252 |
253 | def output(self, step, num_steps, learning_rate, start):
254 | """Write out statistics to stdout.
255 |
256 | Args:
257 | step (int): current step
258 | n_batch (int): total batches
259 | start (int): start time of step.
260 | """
261 | t = self.elapsed_time()
262 | logger.info(
263 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
264 | "lr: %7.8f; %3.0f/%3.0f tok/s; %6.0f sec")
265 | % (step, num_steps,
266 | self.accuracy(),
267 | self.ppl(),
268 | self.xent(),
269 | learning_rate,
270 | self.n_src_words / (t + 1e-5),
271 | self.n_words / (t + 1e-5),
272 | time.time() - start))
273 | sys.stdout.flush()
274 |
275 | def log_tensorboard(self, prefix, writer, learning_rate, step):
276 | """ display statistics to tensorboard """
277 | t = self.elapsed_time()
278 | writer.add_scalar(prefix + "/xent", self.xent(), step)
279 | writer.add_scalar(prefix + "/ppl", self.ppl(), step)
280 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
281 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
282 | writer.add_scalar(prefix + "/lr", learning_rate, step)
283 |
--------------------------------------------------------------------------------
/src/models/reporter_ext.py:
--------------------------------------------------------------------------------
1 | """ Report manager utility """
2 | from __future__ import print_function
3 |
4 | import sys
5 | import time
6 | from datetime import datetime
7 |
8 | from others.logging import logger
9 |
10 |
11 | def build_report_manager(opt):
12 | if opt.tensorboard:
13 | from tensorboardX import SummaryWriter
14 | tensorboard_log_dir = opt.tensorboard_log_dir
15 |
16 | if not opt.train_from:
17 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S")
18 |
19 | writer = SummaryWriter(tensorboard_log_dir,
20 | comment="Unmt")
21 | else:
22 | writer = None
23 |
24 | report_mgr = ReportMgr(opt.report_every, start_time=-1,
25 | tensorboard_writer=writer)
26 | return report_mgr
27 |
28 |
29 | class ReportMgrBase(object):
30 | """
31 | Report Manager Base class
32 | Inherited classes should override:
33 | * `_report_training`
34 | * `_report_step`
35 | """
36 |
37 | def __init__(self, report_every, start_time=-1.):
38 | """
39 | Args:
40 | report_every(int): Report status every this many sentences
41 | start_time(float): manually set report start time. Negative values
42 | means that you will need to set it later or use `start()`
43 | """
44 | self.report_every = report_every
45 | self.progress_step = 0
46 | self.start_time = start_time
47 |
48 | def start(self):
49 | self.start_time = time.time()
50 |
51 | def log(self, *args, **kwargs):
52 | logger.info(*args, **kwargs)
53 |
54 | def report_training(self, step, num_steps, learning_rate,
55 | report_stats, multigpu=False):
56 | """
57 | This is the user-defined batch-level traing progress
58 | report function.
59 |
60 | Args:
61 | step(int): current step count.
62 | num_steps(int): total number of batches.
63 | learning_rate(float): current learning rate.
64 | report_stats(Statistics): old Statistics instance.
65 | Returns:
66 | report_stats(Statistics): updated Statistics instance.
67 | """
68 | if self.start_time < 0:
69 | raise ValueError("""ReportMgr needs to be started
70 | (set 'start_time' or use 'start()'""")
71 |
72 | if step % self.report_every == 0:
73 | if multigpu:
74 | report_stats = \
75 | Statistics.all_gather_stats(report_stats)
76 | self._report_training(
77 | step, num_steps, learning_rate, report_stats)
78 | self.progress_step += 1
79 | return Statistics()
80 | else:
81 | return report_stats
82 |
83 | def _report_training(self, *args, **kwargs):
84 | """ To be overridden """
85 | raise NotImplementedError()
86 |
87 | def report_step(self, lr, step, train_stats=None, valid_stats=None):
88 | """
89 | Report stats of a step
90 |
91 | Args:
92 | train_stats(Statistics): training stats
93 | valid_stats(Statistics): validation stats
94 | lr(float): current learning rate
95 | """
96 | self._report_step(
97 | lr, step, train_stats=train_stats, valid_stats=valid_stats)
98 |
99 | def _report_step(self, *args, **kwargs):
100 | raise NotImplementedError()
101 |
102 |
103 | class ReportMgr(ReportMgrBase):
104 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
105 | """
106 | A report manager that writes statistics on standard output as well as
107 | (optionally) TensorBoard
108 |
109 | Args:
110 | report_every(int): Report status every this many sentences
111 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`):
112 | The TensorBoard Summary writer to use or None
113 | """
114 | super(ReportMgr, self).__init__(report_every, start_time)
115 | self.tensorboard_writer = tensorboard_writer
116 |
117 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):
118 | if self.tensorboard_writer is not None:
119 | stats.log_tensorboard(
120 | prefix, self.tensorboard_writer, learning_rate, step)
121 |
122 | def _report_training(self, step, num_steps, learning_rate,
123 | report_stats):
124 | """
125 | See base class method `ReportMgrBase.report_training`.
126 | """
127 | report_stats.output(step, num_steps,
128 | learning_rate, self.start_time)
129 |
130 | # Log the progress using the number of batches on the x-axis.
131 | self.maybe_log_tensorboard(report_stats,
132 | "progress",
133 | learning_rate,
134 | self.progress_step)
135 | report_stats = Statistics()
136 |
137 | return report_stats
138 |
139 | def _report_step(self, lr, step, train_stats=None, valid_stats=None):
140 | """
141 | See base class method `ReportMgrBase.report_step`.
142 | """
143 | if train_stats is not None:
144 | self.log('Train xent: %g' % train_stats.xent())
145 |
146 | self.maybe_log_tensorboard(train_stats,
147 | "train",
148 | lr,
149 | step)
150 |
151 | if valid_stats is not None:
152 | self.log('Validation xent: %g at step %d' % (valid_stats.xent(), step))
153 |
154 | self.maybe_log_tensorboard(valid_stats,
155 | "valid",
156 | lr,
157 | step)
158 |
159 |
160 | class Statistics(object):
161 | """
162 | Accumulator for loss statistics.
163 | Currently calculates:
164 |
165 | * accuracy
166 | * perplexity
167 | * elapsed time
168 | """
169 |
170 | def __init__(self, loss=0, n_docs=0, n_correct=0):
171 | self.loss = loss
172 | self.n_docs = n_docs
173 | self.start_time = time.time()
174 |
175 | @staticmethod
176 | def all_gather_stats(stat, max_size=4096):
177 | """
178 | Gather a `Statistics` object accross multiple process/nodes
179 |
180 | Args:
181 | stat(:obj:Statistics): the statistics object to gather
182 | accross all processes/nodes
183 | max_size(int): max buffer size to use
184 |
185 | Returns:
186 | `Statistics`, the update stats object
187 | """
188 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size)
189 | return stats[0]
190 |
191 | @staticmethod
192 | def all_gather_stats_list(stat_list, max_size=4096):
193 | """
194 | Gather a `Statistics` list accross all processes/nodes
195 |
196 | Args:
197 | stat_list(list([`Statistics`])): list of statistics objects to
198 | gather accross all processes/nodes
199 | max_size(int): max buffer size to use
200 |
201 | Returns:
202 | our_stats(list([`Statistics`])): list of updated stats
203 | """
204 | from torch.distributed import get_rank
205 | from distributed import all_gather_list
206 |
207 | # Get a list of world_size lists with len(stat_list) Statistics objects
208 | all_stats = all_gather_list(stat_list, max_size=max_size)
209 |
210 | our_rank = get_rank()
211 | our_stats = all_stats[our_rank]
212 | for other_rank, stats in enumerate(all_stats):
213 | if other_rank == our_rank:
214 | continue
215 | for i, stat in enumerate(stats):
216 | our_stats[i].update(stat, update_n_src_words=True)
217 | return our_stats
218 |
219 | def update(self, stat, update_n_src_words=False):
220 | """
221 | Update statistics by suming values with another `Statistics` object
222 |
223 | Args:
224 | stat: another statistic object
225 | update_n_src_words(bool): whether to update (sum) `n_src_words`
226 | or not
227 |
228 | """
229 | self.loss += stat.loss
230 |
231 | self.n_docs += stat.n_docs
232 |
233 | def xent(self):
234 | """ compute cross entropy """
235 | if (self.n_docs == 0):
236 | return 0
237 | return self.loss / self.n_docs
238 |
239 | def elapsed_time(self):
240 | """ compute elapsed time """
241 | return time.time() - self.start_time
242 |
243 | def output(self, step, num_steps, learning_rate, start):
244 | """Write out statistics to stdout.
245 |
246 | Args:
247 | step (int): current step
248 | n_batch (int): total batches
249 | start (int): start time of step.
250 | """
251 | t = self.elapsed_time()
252 | step_fmt = "%2d" % step
253 | if num_steps > 0:
254 | step_fmt = "%s/%5d" % (step_fmt, num_steps)
255 | logger.info(
256 | ("Step %s; xent: %4.2f; " +
257 | "lr: %7.7f; %3.0f docs/s; %6.0f sec")
258 | % (step_fmt,
259 | self.xent(),
260 | learning_rate,
261 | self.n_docs / (t + 1e-5),
262 | time.time() - start))
263 | sys.stdout.flush()
264 |
265 | def log_tensorboard(self, prefix, writer, learning_rate, step):
266 | """ display statistics to tensorboard """
267 | t = self.elapsed_time()
268 | writer.add_scalar(prefix + "/xent", self.xent(), step)
269 | writer.add_scalar(prefix + "/lr", learning_rate, step)
270 |
--------------------------------------------------------------------------------
/src/models/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from tensorboardX import SummaryWriter
6 |
7 | import distributed
8 | from models.reporter import ReportMgr, Statistics
9 | from others.logging import logger
10 | from others.utils import test_rouge, rouge_results_to_str
11 |
12 |
13 | def _tally_parameters(model):
14 | n_params = sum([p.nelement() for p in model.parameters()])
15 | return n_params
16 |
17 |
18 | def build_trainer(args, device_id, model, optims,loss):
19 | """
20 | Simplify `Trainer` creation based on user `opt`s*
21 | Args:
22 | opt (:obj:`Namespace`): user options (usually from argument parsing)
23 | model (:obj:`onmt.models.NMTModel`): the model to train
24 | fields (dict): dict of fields
25 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training
26 | data_type (str): string describing the type of data
27 | e.g. "text", "img", "audio"
28 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object
29 | used to save the model
30 | """
31 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
32 |
33 |
34 | grad_accum_count = args.accum_count
35 | n_gpu = args.world_size
36 |
37 | if device_id >= 0:
38 | gpu_rank = int(args.gpu_ranks[device_id])
39 | else:
40 | gpu_rank = 0
41 | n_gpu = 0
42 |
43 | print('gpu_rank %d' % gpu_rank)
44 |
45 | tensorboard_log_dir = args.model_path
46 |
47 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
48 |
49 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer)
50 |
51 |
52 | trainer = Trainer(args, model, optims, loss, grad_accum_count, n_gpu, gpu_rank, report_manager)
53 |
54 | # print(tr)
55 | if (model):
56 | n_params = _tally_parameters(model)
57 | logger.info('* number of parameters: %d' % n_params)
58 |
59 | return trainer
60 |
61 |
62 | class Trainer(object):
63 | """
64 | Class that controls the training process.
65 |
66 | Args:
67 | model(:py:class:`onmt.models.model.NMTModel`): translation model
68 | to train
69 | train_loss(:obj:`onmt.utils.loss.LossComputeBase`):
70 | training loss computation
71 | valid_loss(:obj:`onmt.utils.loss.LossComputeBase`):
72 | training loss computation
73 | optim(:obj:`onmt.utils.optimizers.Optimizer`):
74 | the optimizer responsible for update
75 | trunc_size(int): length of truncated back propagation through time
76 | shard_size(int): compute loss in shards of this size for efficiency
77 | data_type(string): type of the source input: [text|img|audio]
78 | norm_method(string): normalization methods: [sents|tokens]
79 | grad_accum_count(int): accumulate gradients this many times.
80 | report_manager(:obj:`onmt.utils.ReportMgrBase`):
81 | the object that creates reports, or None
82 | model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is
83 | used to save a checkpoint.
84 | Thus nothing will be saved if this parameter is None
85 | """
86 |
87 | def __init__(self, args, model, optims, loss,
88 | grad_accum_count=1, n_gpu=1, gpu_rank=1,
89 | report_manager=None):
90 | # Basic attributes.
91 | self.args = args
92 | self.save_checkpoint_steps = args.save_checkpoint_steps
93 | self.model = model
94 | self.optims = optims
95 | self.grad_accum_count = grad_accum_count
96 | self.n_gpu = n_gpu
97 | self.gpu_rank = gpu_rank
98 | self.report_manager = report_manager
99 |
100 | self.loss = loss
101 |
102 | assert grad_accum_count > 0
103 | # Set model in training mode.
104 | if (model):
105 | self.model.train()
106 |
107 | def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):
108 | """
109 | The main training loops.
110 | by iterating over training data (i.e. `train_iter_fct`)
111 | and running validation (i.e. iterating over `valid_iter_fct`
112 |
113 | Args:
114 | train_iter_fct(function): a function that returns the train
115 | iterator. e.g. something like
116 | train_iter_fct = lambda: generator(*args, **kwargs)
117 | valid_iter_fct(function): same as train_iter_fct, for valid data
118 | train_steps(int):
119 | valid_steps(int):
120 | save_checkpoint_steps(int):
121 |
122 | Return:
123 | None
124 | """
125 | logger.info('Start training...')
126 |
127 | # step = self.optim._step + 1
128 | step = self.optims[0]._step + 1
129 |
130 | true_batchs = []
131 | accum = 0
132 | normalization = 0
133 | train_iter = train_iter_fct()
134 |
135 | total_stats = Statistics()
136 | report_stats = Statistics()
137 | self._start_report_manager(start_time=total_stats.start_time)
138 |
139 | while step <= train_steps:
140 |
141 | reduce_counter = 0
142 | for i, batch in enumerate(train_iter):
143 | if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):
144 |
145 | true_batchs.append(batch)
146 | num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum()
147 | normalization += num_tokens.item()
148 | accum += 1
149 | if accum == self.grad_accum_count:
150 | reduce_counter += 1
151 | if self.n_gpu > 1:
152 | normalization = sum(distributed
153 | .all_gather_list
154 | (normalization))
155 |
156 | self._gradient_accumulation(
157 | true_batchs, normalization, total_stats,
158 | report_stats)
159 |
160 | report_stats = self._maybe_report_training(
161 | step, train_steps,
162 | self.optims[0].learning_rate,
163 | report_stats)
164 |
165 | true_batchs = []
166 | accum = 0
167 | normalization = 0
168 | if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0):
169 | self._save(step)
170 |
171 | step += 1
172 | if step > train_steps:
173 | break
174 | train_iter = train_iter_fct()
175 |
176 | return total_stats
177 |
178 | def validate(self, valid_iter, step=0):
179 | """ Validate model.
180 | valid_iter: validate data iterator
181 | Returns:
182 | :obj:`nmt.Statistics`: validation loss statistics
183 | """
184 | # Set model in validating mode.
185 | self.model.eval()
186 | stats = Statistics()
187 |
188 | with torch.no_grad():
189 | for batch in valid_iter:
190 | src = batch.src
191 | tgt = batch.tgt
192 | segs = batch.segs
193 | clss = batch.clss
194 | mask_src = batch.mask_src
195 | mask_tgt = batch.mask_tgt
196 | mask_cls = batch.mask_cls
197 |
198 | outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
199 |
200 | batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
201 | stats.update(batch_stats)
202 | self._report_step(0, step, valid_stats=stats)
203 | return stats
204 |
205 |
206 | def _gradient_accumulation(self, true_batchs, normalization, total_stats,
207 | report_stats):
208 | if self.grad_accum_count > 1:
209 | self.model.zero_grad()
210 |
211 | for batch in true_batchs:
212 | if self.grad_accum_count == 1:
213 | self.model.zero_grad()
214 |
215 | src = batch.src
216 | tgt = batch.tgt
217 | segs = batch.segs
218 | clss = batch.clss
219 | mask_src = batch.mask_src
220 | mask_tgt = batch.mask_tgt
221 | mask_cls = batch.mask_cls
222 |
223 | outputs, scores = self.model(src, tgt,segs, clss, mask_src, mask_tgt, mask_cls)
224 | batch_stats = self.loss.sharded_compute_loss(batch, outputs, self.args.generator_shard_size, normalization)
225 |
226 | batch_stats.n_docs = int(src.size(0))
227 |
228 | total_stats.update(batch_stats)
229 | report_stats.update(batch_stats)
230 |
231 | # 4. Update the parameters and statistics.
232 | if self.grad_accum_count == 1:
233 | # Multi GPU gradient gather
234 | if self.n_gpu > 1:
235 | grads = [p.grad.data for p in self.model.parameters()
236 | if p.requires_grad
237 | and p.grad is not None]
238 | distributed.all_reduce_and_rescale_tensors(
239 | grads, float(1))
240 |
241 | for o in self.optims:
242 | o.step()
243 |
244 | # in case of multi step gradient accumulation,
245 | # update only after accum batches
246 | if self.grad_accum_count > 1:
247 | if self.n_gpu > 1:
248 | grads = [p.grad.data for p in self.model.parameters()
249 | if p.requires_grad
250 | and p.grad is not None]
251 | distributed.all_reduce_and_rescale_tensors(
252 | grads, float(1))
253 | for o in self.optims:
254 | o.step()
255 |
256 |
257 | def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
258 | """ Validate model.
259 | valid_iter: validate data iterator
260 | Returns:
261 | :obj:`nmt.Statistics`: validation loss statistics
262 | """
263 | # Set model in validating mode.
264 | def _get_ngrams(n, text):
265 | ngram_set = set()
266 | text_length = len(text)
267 | max_index_ngram_start = text_length - n
268 | for i in range(max_index_ngram_start + 1):
269 | ngram_set.add(tuple(text[i:i + n]))
270 | return ngram_set
271 |
272 | def _block_tri(c, p):
273 | tri_c = _get_ngrams(3, c.split())
274 | for s in p:
275 | tri_s = _get_ngrams(3, s.split())
276 | if len(tri_c.intersection(tri_s))>0:
277 | return True
278 | return False
279 |
280 | if (not cal_lead and not cal_oracle):
281 | self.model.eval()
282 | stats = Statistics()
283 |
284 | can_path = '%s_step%d.candidate'%(self.args.result_path,step)
285 | gold_path = '%s_step%d.gold' % (self.args.result_path, step)
286 | with open(can_path, 'w') as save_pred:
287 | with open(gold_path, 'w') as save_gold:
288 | with torch.no_grad():
289 | for batch in test_iter:
290 | gold = []
291 | pred = []
292 | if (cal_lead):
293 | selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
294 | for i, idx in enumerate(selected_ids):
295 | _pred = []
296 | if(len(batch.src_str[i])==0):
297 | continue
298 | for j in selected_ids[i][:len(batch.src_str[i])]:
299 | if(j>=len( batch.src_str[i])):
300 | continue
301 | candidate = batch.src_str[i][j].strip()
302 | _pred.append(candidate)
303 |
304 | if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
305 | break
306 |
307 | _pred = ''.join(_pred)
308 | if(self.args.recall_eval):
309 | _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])
310 |
311 | pred.append(_pred)
312 | gold.append(batch.tgt_str[i])
313 |
314 | for i in range(len(gold)):
315 | save_gold.write(gold[i].strip()+'\n')
316 | for i in range(len(pred)):
317 | save_pred.write(pred[i].strip()+'\n')
318 | if(step!=-1 and self.args.report_rouge):
319 | rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
320 | logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges)))
321 | self._report_step(0, step, valid_stats=stats)
322 |
323 | return stats
324 |
325 | def _save(self, step):
326 | real_model = self.model
327 | # real_generator = (self.generator.module
328 | # if isinstance(self.generator, torch.nn.DataParallel)
329 | # else self.generator)
330 |
331 | model_state_dict = real_model.state_dict()
332 | # generator_state_dict = real_generator.state_dict()
333 | checkpoint = {
334 | 'model': model_state_dict,
335 | # 'generator': generator_state_dict,
336 | 'opt': self.args,
337 | 'optims': self.optims,
338 | }
339 | checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step)
340 | logger.info("Saving checkpoint %s" % checkpoint_path)
341 | # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step)
342 | if (not os.path.exists(checkpoint_path)):
343 | torch.save(checkpoint, checkpoint_path)
344 | return checkpoint, checkpoint_path
345 |
346 | def _start_report_manager(self, start_time=None):
347 | """
348 | Simple function to start report manager (if any)
349 | """
350 | if self.report_manager is not None:
351 | if start_time is None:
352 | self.report_manager.start()
353 | else:
354 | self.report_manager.start_time = start_time
355 |
356 | def _maybe_gather_stats(self, stat):
357 | """
358 | Gather statistics in multi-processes cases
359 |
360 | Args:
361 | stat(:obj:onmt.utils.Statistics): a Statistics object to gather
362 | or None (it returns None in this case)
363 |
364 | Returns:
365 | stat: the updated (or unchanged) stat object
366 | """
367 | if stat is not None and self.n_gpu > 1:
368 | return Statistics.all_gather_stats(stat)
369 | return stat
370 |
371 | def _maybe_report_training(self, step, num_steps, learning_rate,
372 | report_stats):
373 | """
374 | Simple function to report training stats (if report_manager is set)
375 | see `onmt.utils.ReportManagerBase.report_training` for doc
376 | """
377 | if self.report_manager is not None:
378 | return self.report_manager.report_training(
379 | step, num_steps, learning_rate, report_stats,
380 | multigpu=self.n_gpu > 1)
381 |
382 | def _report_step(self, learning_rate, step, train_stats=None,
383 | valid_stats=None):
384 | """
385 | Simple function to report stats (if report_manager is set)
386 | see `onmt.utils.ReportManagerBase.report_step` for doc
387 | """
388 | if self.report_manager is not None:
389 | return self.report_manager.report_step(
390 | learning_rate, step, train_stats=train_stats,
391 | valid_stats=valid_stats)
392 |
393 | def _maybe_save(self, step):
394 | """
395 | Save the model if a model saver is set
396 | """
397 | if self.model_saver is not None:
398 | self.model_saver.maybe_save(step)
399 |
--------------------------------------------------------------------------------
/src/others/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/others/__init__.py
--------------------------------------------------------------------------------
/src/others/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import
3 |
4 | import logging
5 |
6 | logger = logging.getLogger()
7 |
8 |
9 | def init_logger(log_file=None, log_file_level=logging.NOTSET):
10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
11 | logger = logging.getLogger()
12 | logger.setLevel(logging.INFO)
13 |
14 | console_handler = logging.StreamHandler()
15 | console_handler.setFormatter(log_format)
16 | logger.handlers = [console_handler]
17 |
18 | if log_file and log_file != '':
19 | file_handler = logging.FileHandler(log_file)
20 | file_handler.setLevel(log_file_level)
21 | file_handler.setFormatter(log_format)
22 | logger.addHandler(file_handler)
23 |
24 | return logger
25 |
--------------------------------------------------------------------------------
/src/others/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from pytorch_transformers import cached_path
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
37 | }
38 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
39 | 'bert-base-uncased': 512,
40 | 'bert-large-uncased': 512,
41 | 'bert-base-cased': 512,
42 | 'bert-large-cased': 512,
43 | 'bert-base-multilingual-uncased': 512,
44 | 'bert-base-multilingual-cased': 512,
45 | 'bert-base-chinese': 512,
46 | }
47 | VOCAB_NAME = 'vocab.txt'
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | index = 0
54 | with open(vocab_file, "r", encoding="utf-8") as reader:
55 | while True:
56 | token = reader.readline()
57 | if not token:
58 | break
59 | token = token.strip()
60 | vocab[token] = index
61 | index += 1
62 | return vocab
63 |
64 |
65 | def whitespace_tokenize(text):
66 | """Runs basic whitespace cleaning and splitting on a peice of text."""
67 | text = text.strip()
68 | if not text:
69 | return []
70 | tokens = text.split()
71 | return tokens
72 |
73 |
74 | class BertTokenizer(object):
75 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
76 |
77 | def __init__(self, vocab_file, do_lower_case=True, max_len=None,
78 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", "[unused0]", "[unused1]", "[unused2]", "[unused3]", "[unused4]", "[unused5]", "[unused6]")):
79 |
80 | if not os.path.isfile(vocab_file):
81 | raise ValueError(
82 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
83 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
84 | self.do_lower_case = do_lower_case
85 | self.vocab = load_vocab(vocab_file)
86 | self.ids_to_tokens = collections.OrderedDict(
87 | [(ids, tok) for tok, ids in self.vocab.items()])
88 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
89 | never_split=never_split)
90 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
91 | self.max_len = max_len if max_len is not None else int(1e12)
92 |
93 | def tokenize(self, text, use_bert_basic_tokenizer=False):
94 | split_tokens = []
95 | if(use_bert_basic_tokenizer):
96 | pretokens = self.basic_tokenizer.tokenize(text)
97 | else:
98 | pretokens = list(enumerate(text.split()))
99 |
100 | for i,token in pretokens:
101 | # if(self.do_lower_case):
102 | # token = token.lower()
103 | subtokens = self.wordpiece_tokenizer.tokenize(token)
104 | for sub_token in subtokens:
105 | split_tokens.append(sub_token)
106 | return split_tokens
107 |
108 | def convert_tokens_to_ids(self, tokens):
109 | """Converts a sequence of tokens into ids using the vocab."""
110 | ids = []
111 | for token in tokens:
112 | ids.append(self.vocab[token])
113 | # if len(ids) > self.max_len:
114 | # raise ValueError(
115 | # "Token indices sequence length is longer than the specified maximum "
116 | # " sequence length for this BERT model ({} > {}). Running this"
117 | # " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
118 | # )
119 | return ids
120 |
121 | def convert_ids_to_tokens(self, ids):
122 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
123 | tokens = []
124 | for i in ids:
125 | tokens.append(self.ids_to_tokens[i])
126 | return tokens
127 |
128 | @classmethod
129 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
130 | """
131 | Instantiate a PreTrainedBertModel from a pre-trained model file.
132 | Download and cache the pre-trained model file if needed.
133 | """
134 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
135 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
136 | else:
137 | vocab_file = pretrained_model_name_or_path
138 | if os.path.isdir(vocab_file):
139 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
140 | # redirect to the cache, if necessary
141 | try:
142 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
143 | except EnvironmentError:
144 | logger.error(
145 | "Model name '{}' was not found in model name list ({}). "
146 | "We assumed '{}' was a path or url but couldn't find any file "
147 | "associated to this path or url.".format(
148 | pretrained_model_name_or_path,
149 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
150 | vocab_file))
151 | return None
152 | if resolved_vocab_file == vocab_file:
153 | logger.info("loading vocabulary file {}".format(vocab_file))
154 | else:
155 | logger.info("loading vocabulary file {} from cache at {}".format(
156 | vocab_file, resolved_vocab_file))
157 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
158 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
159 | # than the number of positional embeddings
160 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
161 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
162 | # Instantiate tokenizer.
163 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
164 | return tokenizer
165 |
166 |
167 | class BasicTokenizer(object):
168 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
169 |
170 | def __init__(self,
171 | do_lower_case=True,
172 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
173 | """Constructs a BasicTokenizer.
174 |
175 | Args:
176 | do_lower_case: Whether to lower case the input.
177 | """
178 | self.do_lower_case = do_lower_case
179 | self.never_split = never_split
180 |
181 | def tokenize(self, text):
182 | """Tokenizes a piece of text."""
183 | text = self._clean_text(text)
184 | # This was added on November 1st, 2018 for the multilingual and Chinese
185 | # models. This is also applied to the English models now, but it doesn't
186 | # matter since the English models were not trained on any Chinese data
187 | # and generally don't have any Chinese data in them (there are Chinese
188 | # characters in the vocabulary because Wikipedia does have some Chinese
189 | # words in the English Wikipedia.).
190 | text = self._tokenize_chinese_chars(text)
191 | orig_tokens = whitespace_tokenize(text)
192 | split_tokens = []
193 | for i,token in enumerate(orig_tokens):
194 | if self.do_lower_case and token not in self.never_split:
195 | token = token.lower()
196 | token = self._run_strip_accents(token)
197 | # split_tokens.append(token)
198 | split_tokens.extend([(i,t) for t in self._run_split_on_punc(token)])
199 |
200 | # output_tokens = whitespace_tokenize(" ".join(split_tokens))
201 | return split_tokens
202 |
203 | def _run_strip_accents(self, text):
204 | """Strips accents from a piece of text."""
205 | text = unicodedata.normalize("NFD", text)
206 | output = []
207 | for char in text:
208 | cat = unicodedata.category(char)
209 | if cat == "Mn":
210 | continue
211 | output.append(char)
212 | return "".join(output)
213 |
214 | def _run_split_on_punc(self, text):
215 | """Splits punctuation on a piece of text."""
216 | if text in self.never_split:
217 | return [text]
218 | chars = list(text)
219 | i = 0
220 | start_new_word = True
221 | output = []
222 | while i < len(chars):
223 | char = chars[i]
224 | if _is_punctuation(char):
225 | output.append([char])
226 | start_new_word = True
227 | else:
228 | if start_new_word:
229 | output.append([])
230 | start_new_word = False
231 | output[-1].append(char)
232 | i += 1
233 |
234 | return ["".join(x) for x in output]
235 |
236 | def _tokenize_chinese_chars(self, text):
237 | """Adds whitespace around any CJK character."""
238 | output = []
239 | for char in text:
240 | cp = ord(char)
241 | if self._is_chinese_char(cp):
242 | output.append(" ")
243 | output.append(char)
244 | output.append(" ")
245 | else:
246 | output.append(char)
247 | return "".join(output)
248 |
249 | def _is_chinese_char(self, cp):
250 | """Checks whether CP is the codepoint of a CJK character."""
251 | # This defines a "chinese character" as anything in the CJK Unicode block:
252 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
253 | #
254 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
255 | # despite its name. The modern Korean Hangul alphabet is a different block,
256 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
257 | # space-separated words, so they are not treated specially and handled
258 | # like the all of the other languages.
259 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
260 | (cp >= 0x3400 and cp <= 0x4DBF) or #
261 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
262 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
263 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
264 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
265 | (cp >= 0xF900 and cp <= 0xFAFF) or #
266 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
267 | return True
268 |
269 | return False
270 |
271 | def _clean_text(self, text):
272 | """Performs invalid character removal and whitespace cleanup on text."""
273 | output = []
274 | for char in text:
275 | cp = ord(char)
276 | if cp == 0 or cp == 0xfffd or _is_control(char):
277 | continue
278 | if _is_whitespace(char):
279 | output.append(" ")
280 | else:
281 | output.append(char)
282 | return "".join(output)
283 |
284 |
285 | class WordpieceTokenizer(object):
286 | """Runs WordPiece tokenization."""
287 |
288 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
289 | self.vocab = vocab
290 | self.unk_token = unk_token
291 | self.max_input_chars_per_word = max_input_chars_per_word
292 |
293 | def tokenize(self, text):
294 | """Tokenizes a piece of text into its word pieces.
295 |
296 | This uses a greedy longest-match-first algorithm to perform tokenization
297 | using the given vocabulary.
298 |
299 | For example:
300 | input = "unaffable"
301 | output = ["un", "##aff", "##able"]
302 |
303 | Args:
304 | text: A single token or whitespace separated tokens. This should have
305 | already been passed through `BasicTokenizer`.
306 |
307 | Returns:
308 | A list of wordpiece tokens.
309 | """
310 |
311 | output_tokens = []
312 | for token in whitespace_tokenize(text):
313 | chars = list(token)
314 | if len(chars) > self.max_input_chars_per_word:
315 | output_tokens.append(self.unk_token)
316 | continue
317 |
318 | is_bad = False
319 | start = 0
320 | sub_tokens = []
321 | while start < len(chars):
322 | end = len(chars)
323 | cur_substr = None
324 | while start < end:
325 | substr = "".join(chars[start:end])
326 | if start > 0:
327 | substr = "##" + substr
328 | if substr in self.vocab:
329 | cur_substr = substr
330 | break
331 | end -= 1
332 | if cur_substr is None:
333 | is_bad = True
334 | break
335 | sub_tokens.append(cur_substr)
336 | start = end
337 |
338 | if is_bad:
339 | output_tokens.append(self.unk_token)
340 | else:
341 | output_tokens.extend(sub_tokens)
342 | return output_tokens
343 |
344 |
345 | def _is_whitespace(char):
346 | """Checks whether `chars` is a whitespace character."""
347 | # \t, \n, and \r are technically contorl characters but we treat them
348 | # as whitespace since they are generally considered as such.
349 | if char == " " or char == "\t" or char == "\n" or char == "\r":
350 | return True
351 | cat = unicodedata.category(char)
352 | if cat == "Zs":
353 | return True
354 | return False
355 |
356 |
357 | def _is_control(char):
358 | """Checks whether `chars` is a control character."""
359 | # These are technically control characters but we count them as whitespace
360 | # characters.
361 | if char == "\t" or char == "\n" or char == "\r":
362 | return False
363 | cat = unicodedata.category(char)
364 | if cat.startswith("C"):
365 | return True
366 | return False
367 |
368 |
369 | def _is_punctuation(char):
370 | """Checks whether `chars` is a punctuation character."""
371 | cp = ord(char)
372 | # We treat all non-letter/number ASCII as punctuation.
373 | # Characters such as "^", "$", and "`" are not in the Unicode
374 | # Punctuation class but we treat them as punctuation anyways, for
375 | # consistency.
376 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
377 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
378 | return True
379 | cat = unicodedata.category(char)
380 | if cat.startswith("P"):
381 | return True
382 | return False
383 |
--------------------------------------------------------------------------------
/src/others/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import shutil
4 | import time
5 |
6 | from others import pyrouge
7 |
8 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",
9 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'}
10 |
11 |
12 | def clean(x):
13 | return re.sub(
14 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''",
15 | lambda m: REMAP.get(m.group()), x)
16 |
17 |
18 | def process(params):
19 | temp_dir, data = params
20 | candidates, references, pool_id = data
21 | cnt = len(candidates)
22 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
23 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}-{}".format(current_time, pool_id))
24 | if not os.path.isdir(tmp_dir):
25 | os.mkdir(tmp_dir)
26 | os.mkdir(tmp_dir + "/candidate")
27 | os.mkdir(tmp_dir + "/reference")
28 | try:
29 |
30 | for i in range(cnt):
31 | if len(references[i]) < 1:
32 | continue
33 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w",
34 | encoding="utf-8") as f:
35 | f.write(candidates[i])
36 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w",
37 | encoding="utf-8") as f:
38 | f.write(references[i])
39 | r = pyrouge.Rouge155(temp_dir=temp_dir)
40 | r.model_dir = tmp_dir + "/reference/"
41 | r.system_dir = tmp_dir + "/candidate/"
42 | r.model_filename_pattern = 'ref.#ID#.txt'
43 | r.system_filename_pattern = r'cand.(\d+).txt'
44 | rouge_results = r.convert_and_evaluate()
45 | print(rouge_results)
46 | results_dict = r.output_to_dict(rouge_results)
47 | finally:
48 | pass
49 | if os.path.isdir(tmp_dir):
50 | shutil.rmtree(tmp_dir)
51 | return results_dict
52 |
53 |
54 | def test_rouge(temp_dir, cand, ref):
55 | candidates = [line.strip() for line in open(cand, encoding='utf-8')]
56 | references = [line.strip() for line in open(ref, encoding='utf-8')]
57 | print(len(candidates))
58 | print(len(references))
59 | assert len(candidates) == len(references)
60 |
61 | cnt = len(candidates)
62 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
63 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time))
64 | if not os.path.isdir(tmp_dir):
65 | os.mkdir(tmp_dir)
66 | os.mkdir(tmp_dir + "/candidate")
67 | os.mkdir(tmp_dir + "/reference")
68 | try:
69 |
70 | for i in range(cnt):
71 | if len(references[i]) < 1:
72 | continue
73 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w",
74 | encoding="utf-8") as f:
75 | f.write(candidates[i])
76 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w",
77 | encoding="utf-8") as f:
78 | f.write(references[i])
79 | r = pyrouge.Rouge155(temp_dir=temp_dir)
80 | r.model_dir = tmp_dir + "/reference/"
81 | r.system_dir = tmp_dir + "/candidate/"
82 | r.model_filename_pattern = 'ref.#ID#.txt'
83 | r.system_filename_pattern = r'cand.(\d+).txt'
84 | rouge_results = r.convert_and_evaluate()
85 | print(rouge_results)
86 | results_dict = r.output_to_dict(rouge_results)
87 | finally:
88 | pass
89 | if os.path.isdir(tmp_dir):
90 | shutil.rmtree(tmp_dir)
91 | return results_dict
92 |
93 |
94 | def tile(x, count, dim=0):
95 | """
96 | Tiles x on dimension dim count times.
97 | """
98 | perm = list(range(len(x.size())))
99 | if dim != 0:
100 | perm[0], perm[dim] = perm[dim], perm[0]
101 | x = x.permute(perm).contiguous()
102 | out_size = list(x.size())
103 | out_size[0] *= count
104 | batch = x.size(0)
105 | x = x.view(batch, -1) \
106 | .transpose(0, 1) \
107 | .repeat(count, 1) \
108 | .transpose(0, 1) \
109 | .contiguous() \
110 | .view(*out_size)
111 | if dim != 0:
112 | x = x.permute(perm).contiguous()
113 | return x
114 |
115 | def rouge_results_to_str(results_dict):
116 | return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
117 | results_dict["rouge_1_f_score"] * 100,
118 | results_dict["rouge_2_f_score"] * 100,
119 | # results_dict["rouge_3_f_score"] * 100,
120 | results_dict["rouge_l_f_score"] * 100,
121 | results_dict["rouge_1_recall"] * 100,
122 | results_dict["rouge_2_recall"] * 100,
123 | # results_dict["rouge_3_f_score"] * 100,
124 | results_dict["rouge_l_recall"] * 100
125 |
126 | # ,results_dict["rouge_su*_f_score"] * 100
127 | )
128 |
--------------------------------------------------------------------------------
/src/post_stats.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path
3 | from functools import reduce
4 | import re
5 |
6 | def str2bool(v):
7 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
8 | return True
9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
10 | return False
11 | else:
12 | raise argparse.ArgumentTypeError('Boolean value expected.')
13 |
14 |
15 |
16 | def n_grams(tokens, n):
17 | l = len(tokens)
18 | return [tuple(tokens[i:i + n]) for i in range(l) if i + n < l]
19 |
20 | def has_repeat(elements):
21 | d = set(elements)
22 | return len(d) < len(elements)
23 |
24 | def cal_self_repeat(summary):
25 | ngram_repeats = {2: 0, 4: 0, 8: 0}
26 | sents = summary.split('')
27 | for n in ngram_repeats.keys():
28 | # Respect sentence boundary
29 | grams = reduce(lambda x, y: x + y, [n_grams(sent.split(), n) for sent in sents], [])
30 | ngram_repeats[n] += has_repeat(grams)
31 | return ngram_repeats
32 |
33 | def cal_novel(summary, gold, source, summary_ngram_novel, gold_ngram_novel):
34 | summary = summary.replace('',' ')
35 | summary = re.sub(r' +', ' ', summary).strip()
36 | gold = gold.replace('',' ')
37 | gold = re.sub(r' +', ' ', gold).strip()
38 | source = source.replace(' ##','')
39 | source = source.replace('[CLS]',' ').replace('[SEP]',' ').replace('[PAD]',' ')
40 | source = re.sub(r' +', ' ', source).strip()
41 |
42 |
43 | for n in summary_ngram_novel.keys():
44 | summary_grams = set(n_grams(summary.split(), n))
45 | gold_grams = set(n_grams(gold.split(), n))
46 | source_grams = set(n_grams(source.split(), n))
47 | joint = summary_grams.intersection(source_grams)
48 | novel = summary_grams - joint
49 | summary_ngram_novel[n][0] += 1.0*len(novel)
50 | summary_ngram_novel[n][1] += len(summary_grams)
51 | summary_ngram_novel[n][2] += 1.0 * len(novel) / (len(summary.split()) + 1e-6)
52 | joint = gold_grams.intersection(source_grams)
53 | novel = gold_grams - joint
54 | gold_ngram_novel[n][0] += 1.0*len(novel)
55 | gold_ngram_novel[n][1] += len(gold_grams)
56 | gold_ngram_novel[n][2] += 1.0 * len(novel) / (len(gold.split()) + 1e-6)
57 |
58 |
59 | def cal_repeat(args):
60 | candidate_lines = open(args.result_path+'.candidate').read().strip().split('\n')
61 | gold_lines = open(args.result_path+'.gold').read().strip().split('\n')
62 | src_lines = open(args.result_path+'.raw_src').read().strip().split('\n')
63 | lines = zip(candidate_lines,gold_lines,src_lines)
64 |
65 | summary_ngram_novel = {1: [0, 0, 0], 2: [0, 0, 0], 4: [0, 0, 0]}
66 | gold_ngram_novel = {1: [0, 0, 0], 2: [0, 0, 0], 4: [0, 0, 0]}
67 |
68 | for c,g,s in lines:
69 | # self_repeats = cal_self_repeat(c)
70 | cal_novel(c, g, s,summary_ngram_novel, gold_ngram_novel)
71 | print(summary_ngram_novel, gold_ngram_novel)
72 |
73 | for n in summary_ngram_novel.keys():
74 | # summary_ngram_novel[n] = summary_ngram_novel[n][2]/len(src_lines)
75 | # gold_ngram_novel[n] = gold_ngram_novel[n][2]/len(src_lines)
76 | summary_ngram_novel[n] = summary_ngram_novel[n][0]/summary_ngram_novel[n][1]
77 | gold_ngram_novel[n] = gold_ngram_novel[n][0]/gold_ngram_novel[n][1]
78 | print(summary_ngram_novel, gold_ngram_novel)
79 |
80 |
81 | if __name__ == '__main__':
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument("-mode", default='', type=str)
84 | parser.add_argument("-result_path", default='../../results/cnndm.0')
85 |
86 |
87 | args = parser.parse_args()
88 | eval(args.mode + '(args)')
89 |
--------------------------------------------------------------------------------
/src/prepro/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/prepro/__init__.py
--------------------------------------------------------------------------------
/src/prepro/data_builder.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import glob
3 | import hashlib
4 | import itertools
5 | import json
6 | import os
7 | import random
8 | import re
9 | import subprocess
10 | from collections import Counter
11 | from os.path import join as pjoin
12 |
13 | import torch
14 | from multiprocess import Pool
15 |
16 | from others.logging import logger
17 | from others.tokenization import BertTokenizer
18 | from pytorch_transformers import XLNetTokenizer
19 |
20 | from others.utils import clean
21 | from prepro.utils import _get_word_ngrams
22 |
23 | import xml.etree.ElementTree as ET
24 |
25 | nyt_remove_words = ["photo", "graph", "chart", "map", "table", "drawing"]
26 |
27 |
28 | def recover_from_corenlp(s):
29 | s = re.sub(r' \'{\w}', '\'\g<1>', s)
30 | s = re.sub(r'\'\' {\w}', '\'\'\g<1>', s)
31 |
32 |
33 |
34 | def load_json(p, lower):
35 | source = []
36 | tgt = []
37 | flag = False
38 | for sent in json.load(open(p))['sentences']:
39 | tokens = [t['word'] for t in sent['tokens']]
40 | if (lower):
41 | tokens = [t.lower() for t in tokens]
42 | if (tokens[0] == '@highlight'):
43 | flag = True
44 | tgt.append([])
45 | continue
46 | if (flag):
47 | tgt[-1].extend(tokens)
48 | else:
49 | source.append(tokens)
50 |
51 | source = [clean(' '.join(sent)).split() for sent in source]
52 | tgt = [clean(' '.join(sent)).split() for sent in tgt]
53 | return source, tgt
54 |
55 |
56 |
57 | def load_xml(p):
58 | tree = ET.parse(p)
59 | root = tree.getroot()
60 | title, byline, abs, paras = [], [], [], []
61 | title_node = list(root.iter('hedline'))
62 | if (len(title_node) > 0):
63 | try:
64 | title = [p.text.lower().split() for p in list(title_node[0].iter('hl1'))][0]
65 | except:
66 | print(p)
67 |
68 | else:
69 | return None, None
70 | byline_node = list(root.iter('byline'))
71 | byline_node = [n for n in byline_node if n.attrib['class'] == 'normalized_byline']
72 | if (len(byline_node) > 0):
73 | byline = byline_node[0].text.lower().split()
74 | abs_node = list(root.iter('abstract'))
75 | if (len(abs_node) > 0):
76 | try:
77 | abs = [p.text.lower().split() for p in list(abs_node[0].iter('p'))][0]
78 | except:
79 | print(p)
80 |
81 | else:
82 | return None, None
83 | abs = ' '.join(abs).split(';')
84 | abs[-1] = abs[-1].replace('(m)', '')
85 | abs[-1] = abs[-1].replace('(s)', '')
86 |
87 | for ww in nyt_remove_words:
88 | abs[-1] = abs[-1].replace('(' + ww + ')', '')
89 | abs = [p.split() for p in abs]
90 | abs = [p for p in abs if len(p) > 2]
91 |
92 | for doc_node in root.iter('block'):
93 | att = doc_node.get('class')
94 | # if(att == 'abstract'):
95 | # abs = [p.text for p in list(f.iter('p'))]
96 | if (att == 'full_text'):
97 | paras = [p.text.lower().split() for p in list(doc_node.iter('p'))]
98 | break
99 | if (len(paras) > 0):
100 | if (len(byline) > 0):
101 | paras = [title + ['[unused3]'] + byline + ['[unused4]']] + paras
102 | else:
103 | paras = [title + ['[unused3]']] + paras
104 |
105 | return paras, abs
106 | else:
107 | return None, None
108 |
109 |
110 | def tokenize(args):
111 | stories_dir = os.path.abspath(args.raw_path)
112 | tokenized_stories_dir = os.path.abspath(args.save_path)
113 |
114 | print("Preparing to tokenize %s to %s..." % (stories_dir, tokenized_stories_dir))
115 | stories = os.listdir(stories_dir)
116 | # make IO list file
117 | print("Making list of files to tokenize...")
118 | with open("mapping_for_corenlp.txt", "w") as f:
119 | for s in stories:
120 | if (not s.endswith('story')):
121 | continue
122 | f.write("%s\n" % (os.path.join(stories_dir, s)))
123 | command = ['java', 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 'tokenize,ssplit',
124 | '-ssplit.newlineIsSentenceBreak', 'always', '-filelist', 'mapping_for_corenlp.txt', '-outputFormat',
125 | 'json', '-outputDirectory', tokenized_stories_dir]
126 | print("Tokenizing %i files in %s and saving in %s..." % (len(stories), stories_dir, tokenized_stories_dir))
127 | subprocess.call(command)
128 | print("Stanford CoreNLP Tokenizer has finished.")
129 | os.remove("mapping_for_corenlp.txt")
130 |
131 | # Check that the tokenized stories directory contains the same number of files as the original directory
132 | num_orig = len(os.listdir(stories_dir))
133 | num_tokenized = len(os.listdir(tokenized_stories_dir))
134 | if num_orig != num_tokenized:
135 | raise Exception(
136 | "The tokenized stories directory %s contains %i files, but it should contain the same number as %s (which has %i files). Was there an error during tokenization?" % (
137 | tokenized_stories_dir, num_tokenized, stories_dir, num_orig))
138 | print("Successfully finished tokenizing %s to %s.\n" % (stories_dir, tokenized_stories_dir))
139 |
140 | def cal_rouge(evaluated_ngrams, reference_ngrams):
141 | reference_count = len(reference_ngrams)
142 | evaluated_count = len(evaluated_ngrams)
143 |
144 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
145 | overlapping_count = len(overlapping_ngrams)
146 |
147 | if evaluated_count == 0:
148 | precision = 0.0
149 | else:
150 | precision = overlapping_count / evaluated_count
151 |
152 | if reference_count == 0:
153 | recall = 0.0
154 | else:
155 | recall = overlapping_count / reference_count
156 |
157 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
158 | return {"f": f1_score, "p": precision, "r": recall}
159 |
160 |
161 | def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
162 | def _rouge_clean(s):
163 | return re.sub(r'[^a-zA-Z0-9 ]', '', s)
164 |
165 | max_rouge = 0.0
166 | abstract = sum(abstract_sent_list, [])
167 | abstract = _rouge_clean(' '.join(abstract)).split()
168 | sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
169 | evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
170 | reference_1grams = _get_word_ngrams(1, [abstract])
171 | evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
172 | reference_2grams = _get_word_ngrams(2, [abstract])
173 |
174 | selected = []
175 | for s in range(summary_size):
176 | cur_max_rouge = max_rouge
177 | cur_id = -1
178 | for i in range(len(sents)):
179 | if (i in selected):
180 | continue
181 | c = selected + [i]
182 | candidates_1 = [evaluated_1grams[idx] for idx in c]
183 | candidates_1 = set.union(*map(set, candidates_1))
184 | candidates_2 = [evaluated_2grams[idx] for idx in c]
185 | candidates_2 = set.union(*map(set, candidates_2))
186 | rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
187 | rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
188 | rouge_score = rouge_1 + rouge_2
189 | if rouge_score > cur_max_rouge:
190 | cur_max_rouge = rouge_score
191 | cur_id = i
192 | if (cur_id == -1):
193 | return selected
194 | selected.append(cur_id)
195 | max_rouge = cur_max_rouge
196 |
197 | return sorted(selected)
198 |
199 |
200 | def hashhex(s):
201 | """Returns a heximal formated SHA1 hash of the input string."""
202 | h = hashlib.sha1()
203 | h.update(s.encode('utf-8'))
204 | return h.hexdigest()
205 |
206 |
207 | class BertData():
208 | def __init__(self, args):
209 | self.args = args
210 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
211 |
212 | self.sep_token = '[SEP]'
213 | self.cls_token = '[CLS]'
214 | self.pad_token = '[PAD]'
215 | self.tgt_bos = '[unused0]'
216 | self.tgt_eos = '[unused1]'
217 | self.tgt_sent_split = '[unused2]'
218 | self.sep_vid = self.tokenizer.vocab[self.sep_token]
219 | self.cls_vid = self.tokenizer.vocab[self.cls_token]
220 | self.pad_vid = self.tokenizer.vocab[self.pad_token]
221 |
222 | def preprocess(self, src, tgt, sent_labels, use_bert_basic_tokenizer=False, is_test=False):
223 |
224 | if ((not is_test) and len(src) == 0):
225 | return None
226 |
227 | original_src_txt = [' '.join(s) for s in src]
228 |
229 | idxs = [i for i, s in enumerate(src) if (len(s) > self.args.min_src_ntokens_per_sent)]
230 |
231 | _sent_labels = [0] * len(src)
232 | for l in sent_labels:
233 | _sent_labels[l] = 1
234 |
235 | src = [src[i][:self.args.max_src_ntokens_per_sent] for i in idxs]
236 | sent_labels = [_sent_labels[i] for i in idxs]
237 | src = src[:self.args.max_src_nsents]
238 | sent_labels = sent_labels[:self.args.max_src_nsents]
239 |
240 | if ((not is_test) and len(src) < self.args.min_src_nsents):
241 | return None
242 |
243 | src_txt = [' '.join(sent) for sent in src]
244 | text = ' {} {} '.format(self.sep_token, self.cls_token).join(src_txt)
245 |
246 | src_subtokens = self.tokenizer.tokenize(text)
247 |
248 | src_subtokens = [self.cls_token] + src_subtokens + [self.sep_token]
249 | src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens)
250 | _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid]
251 | segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
252 | segments_ids = []
253 | for i, s in enumerate(segs):
254 | if (i % 2 == 0):
255 | segments_ids += s * [0]
256 | else:
257 | segments_ids += s * [1]
258 | cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid]
259 | sent_labels = sent_labels[:len(cls_ids)]
260 |
261 | tgt_subtokens_str = '[unused0] ' + ' [unused2] '.join(
262 | [' '.join(self.tokenizer.tokenize(' '.join(tt), use_bert_basic_tokenizer=use_bert_basic_tokenizer)) for tt in tgt]) + ' [unused1]'
263 | tgt_subtoken = tgt_subtokens_str.split()[:self.args.max_tgt_ntokens]
264 | if ((not is_test) and len(tgt_subtoken) < self.args.min_tgt_ntokens):
265 | return None
266 |
267 | tgt_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(tgt_subtoken)
268 |
269 | tgt_txt = ''.join([' '.join(tt) for tt in tgt])
270 | src_txt = [original_src_txt[i] for i in idxs]
271 |
272 | return src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt
273 |
274 |
275 | def format_to_bert(args):
276 | if (args.dataset != ''):
277 | datasets = [args.dataset]
278 | else:
279 | datasets = ['train', 'valid', 'test']
280 | for corpus_type in datasets:
281 | a_lst = []
282 | for json_f in glob.glob(pjoin(args.raw_path, '*' + corpus_type + '.*.json')):
283 | real_name = json_f.split('/')[-1]
284 | a_lst.append((corpus_type, json_f, args, pjoin(args.save_path, real_name.replace('json', 'bert.pt'))))
285 | print(a_lst)
286 | pool = Pool(args.n_cpus)
287 | for d in pool.imap(_format_to_bert, a_lst):
288 | pass
289 |
290 | pool.close()
291 | pool.join()
292 |
293 |
294 | def _format_to_bert(params):
295 | corpus_type, json_file, args, save_file = params
296 | is_test = corpus_type == 'test'
297 | if (os.path.exists(save_file)):
298 | logger.info('Ignore %s' % save_file)
299 | return
300 |
301 | bert = BertData(args)
302 |
303 | logger.info('Processing %s' % json_file)
304 | jobs = json.load(open(json_file))
305 | datasets = []
306 | for d in jobs:
307 | source, tgt = d['src'], d['tgt']
308 |
309 | sent_labels = greedy_selection(source[:args.max_src_nsents], tgt, 3)
310 | if (args.lower):
311 | source = [' '.join(s).lower().split() for s in source]
312 | tgt = [' '.join(s).lower().split() for s in tgt]
313 | b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args.use_bert_basic_tokenizer,
314 | is_test=is_test)
315 | # b_data = bert.preprocess(source, tgt, sent_labels, use_bert_basic_tokenizer=args.use_bert_basic_tokenizer)
316 |
317 | if (b_data is None):
318 | continue
319 | src_subtoken_idxs, sent_labels, tgt_subtoken_idxs, segments_ids, cls_ids, src_txt, tgt_txt = b_data
320 | b_data_dict = {"src": src_subtoken_idxs, "tgt": tgt_subtoken_idxs,
321 | "src_sent_labels": sent_labels, "segs": segments_ids, 'clss': cls_ids,
322 | 'src_txt': src_txt, "tgt_txt": tgt_txt}
323 | datasets.append(b_data_dict)
324 | logger.info('Processed instances %d' % len(datasets))
325 | logger.info('Saving to %s' % save_file)
326 | torch.save(datasets, save_file)
327 | datasets = []
328 | gc.collect()
329 |
330 |
331 | def format_to_lines(args):
332 | corpus_mapping = {}
333 | for corpus_type in ['valid', 'test', 'train']:
334 | temp = []
335 | for line in open(pjoin(args.map_path, 'mapping_' + corpus_type + '.txt')):
336 | temp.append(hashhex(line.strip()))
337 | corpus_mapping[corpus_type] = {key.strip(): 1 for key in temp}
338 | train_files, valid_files, test_files = [], [], []
339 | for f in glob.glob(pjoin(args.raw_path, '*.json')):
340 | real_name = f.split('/')[-1].split('.')[0]
341 | if (real_name in corpus_mapping['valid']):
342 | valid_files.append(f)
343 | elif (real_name in corpus_mapping['test']):
344 | test_files.append(f)
345 | elif (real_name in corpus_mapping['train']):
346 | train_files.append(f)
347 | # else:
348 | # train_files.append(f)
349 |
350 | corpora = {'train': train_files, 'valid': valid_files, 'test': test_files}
351 | for corpus_type in ['train', 'valid', 'test']:
352 | a_lst = [(f, args) for f in corpora[corpus_type]]
353 | pool = Pool(args.n_cpus)
354 | dataset = []
355 | p_ct = 0
356 | for d in pool.imap_unordered(_format_to_lines, a_lst):
357 | dataset.append(d)
358 | if (len(dataset) > args.shard_size):
359 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct)
360 | with open(pt_file, 'w') as save:
361 | # save.write('\n'.join(dataset))
362 | save.write(json.dumps(dataset))
363 | p_ct += 1
364 | dataset = []
365 |
366 | pool.close()
367 | pool.join()
368 | if (len(dataset) > 0):
369 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct)
370 | with open(pt_file, 'w') as save:
371 | # save.write('\n'.join(dataset))
372 | save.write(json.dumps(dataset))
373 | p_ct += 1
374 | dataset = []
375 |
376 |
377 | def _format_to_lines(params):
378 | f, args = params
379 | print(f)
380 | source, tgt = load_json(f, args.lower)
381 | return {'src': source, 'tgt': tgt}
382 |
383 |
384 |
385 |
386 | def format_xsum_to_lines(args):
387 | if (args.dataset != ''):
388 | datasets = [args.dataset]
389 | else:
390 | datasets = ['train', 'test', 'valid']
391 |
392 | corpus_mapping = json.load(open(pjoin(args.raw_path, 'XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json')))
393 |
394 | for corpus_type in datasets:
395 | mapped_fnames = corpus_mapping[corpus_type]
396 | root_src = pjoin(args.raw_path, 'restbody')
397 | root_tgt = pjoin(args.raw_path, 'firstsentence')
398 | # realnames = [fname.split('.')[0] for fname in os.listdir(root_src)]
399 | realnames = mapped_fnames
400 |
401 | a_lst = [(root_src, root_tgt, n) for n in realnames]
402 | pool = Pool(args.n_cpus)
403 | dataset = []
404 | p_ct = 0
405 | for d in pool.imap_unordered(_format_xsum_to_lines, a_lst):
406 | if (d is None):
407 | continue
408 | dataset.append(d)
409 | if (len(dataset) > args.shard_size):
410 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct)
411 | with open(pt_file, 'w') as save:
412 | save.write(json.dumps(dataset))
413 | p_ct += 1
414 | dataset = []
415 |
416 | pool.close()
417 | pool.join()
418 | if (len(dataset) > 0):
419 | pt_file = "{:s}.{:s}.{:d}.json".format(args.save_path, corpus_type, p_ct)
420 | with open(pt_file, 'w') as save:
421 | save.write(json.dumps(dataset))
422 | p_ct += 1
423 | dataset = []
424 |
425 |
426 | def _format_xsum_to_lines(params):
427 | src_path, root_tgt, name = params
428 | f_src = pjoin(src_path, name + '.restbody')
429 | f_tgt = pjoin(root_tgt, name + '.fs')
430 | if (os.path.exists(f_src) and os.path.exists(f_tgt)):
431 | print(name)
432 | source = []
433 | for sent in open(f_src):
434 | source.append(sent.split())
435 | tgt = []
436 | for sent in open(f_tgt):
437 | tgt.append(sent.split())
438 | return {'src': source, 'tgt': tgt}
439 | return None
440 |
--------------------------------------------------------------------------------
/src/prepro/smart_common_words.txt:
--------------------------------------------------------------------------------
1 | rrb
2 | llb
3 | lsb
4 | rsb
5 | reuters
6 | ap
7 | jan
8 | feb
9 | mar
10 | apr
11 | may
12 | jun
13 | jul
14 | aug
15 | sep
16 | oct
17 | nov
18 | dec
19 | tech
20 | news
21 | index
22 | mon
23 | tue
24 | wed
25 | thu
26 | fri
27 | sat
28 | 's
29 | a
30 | a's
31 | able
32 | about
33 | above
34 | according
35 | accordingly
36 | across
37 | actually
38 | after
39 | afterwards
40 | again
41 | against
42 | ain't
43 | all
44 | allow
45 | allows
46 | almost
47 | alone
48 | along
49 | already
50 | also
51 | although
52 | always
53 | am
54 | amid
55 | among
56 | amongst
57 | an
58 | and
59 | another
60 | any
61 | anybody
62 | anyhow
63 | anyone
64 | anything
65 | anyway
66 | anyways
67 | anywhere
68 | apart
69 | appear
70 | appreciate
71 | appropriate
72 | are
73 | aren't
74 | around
75 | as
76 | aside
77 | ask
78 | asking
79 | associated
80 | at
81 | available
82 | away
83 | awfully
84 | b
85 | be
86 | became
87 | because
88 | become
89 | becomes
90 | becoming
91 | been
92 | before
93 | beforehand
94 | behind
95 | being
96 | believe
97 | below
98 | beside
99 | besides
100 | best
101 | better
102 | between
103 | beyond
104 | both
105 | brief
106 | but
107 | by
108 | c
109 | c'mon
110 | c's
111 | came
112 | can
113 | can't
114 | cannot
115 | cant
116 | cause
117 | causes
118 | certain
119 | certainly
120 | changes
121 | clearly
122 | co
123 | com
124 | come
125 | comes
126 | concerning
127 | consequently
128 | consider
129 | considering
130 | contain
131 | containing
132 | contains
133 | corresponding
134 | could
135 | couldn't
136 | course
137 | currently
138 | d
139 | definitely
140 | described
141 | despite
142 | did
143 | didn't
144 | different
145 | do
146 | does
147 | doesn't
148 | doing
149 | don't
150 | done
151 | down
152 | downwards
153 | during
154 | e
155 | each
156 | edu
157 | eg
158 | e.g.
159 | eight
160 | either
161 | else
162 | elsewhere
163 | enough
164 | entirely
165 | especially
166 | et
167 | etc
168 | etc.
169 | even
170 | ever
171 | every
172 | everybody
173 | everyone
174 | everything
175 | everywhere
176 | ex
177 | exactly
178 | example
179 | except
180 | f
181 | far
182 | few
183 | fifth
184 | five
185 | followed
186 | following
187 | follows
188 | for
189 | former
190 | formerly
191 | forth
192 | four
193 | from
194 | further
195 | furthermore
196 | g
197 | get
198 | gets
199 | getting
200 | given
201 | gives
202 | go
203 | goes
204 | going
205 | gone
206 | got
207 | gotten
208 | greetings
209 | h
210 | had
211 | hadn't
212 | happens
213 | hardly
214 | has
215 | hasn't
216 | have
217 | haven't
218 | having
219 | he
220 | he's
221 | hello
222 | help
223 | hence
224 | her
225 | here
226 | here's
227 | hereafter
228 | hereby
229 | herein
230 | hereupon
231 | hers
232 | herself
233 | hi
234 | him
235 | himself
236 | his
237 | hither
238 | hopefully
239 | how
240 | howbeit
241 | however
242 | i
243 | i'd
244 | i'll
245 | i'm
246 | i've
247 | ie
248 | i.e.
249 | if
250 | ignored
251 | immediate
252 | in
253 | inasmuch
254 | inc
255 | indeed
256 | indicate
257 | indicated
258 | indicates
259 | inner
260 | insofar
261 | instead
262 | into
263 | inward
264 | is
265 | isn't
266 | it
267 | it'd
268 | it'll
269 | it's
270 | its
271 | itself
272 | j
273 | just
274 | k
275 | keep
276 | keeps
277 | kept
278 | know
279 | knows
280 | known
281 | l
282 | lately
283 | later
284 | latter
285 | latterly
286 | least
287 | less
288 | lest
289 | let
290 | let's
291 | like
292 | liked
293 | likely
294 | little
295 | look
296 | looking
297 | looks
298 | ltd
299 | m
300 | mainly
301 | many
302 | may
303 | maybe
304 | me
305 | mean
306 | meanwhile
307 | merely
308 | might
309 | more
310 | moreover
311 | most
312 | mostly
313 | mr.
314 | ms.
315 | much
316 | must
317 | my
318 | myself
319 | n
320 | namely
321 | nd
322 | near
323 | nearly
324 | necessary
325 | need
326 | needs
327 | neither
328 | never
329 | nevertheless
330 | new
331 | next
332 | nine
333 | no
334 | nobody
335 | non
336 | none
337 | noone
338 | nor
339 | normally
340 | not
341 | nothing
342 | novel
343 | now
344 | nowhere
345 | o
346 | obviously
347 | of
348 | off
349 | often
350 | oh
351 | ok
352 | okay
353 | old
354 | on
355 | once
356 | one
357 | ones
358 | only
359 | onto
360 | or
361 | other
362 | others
363 | otherwise
364 | ought
365 | our
366 | ours
367 | ourselves
368 | out
369 | outside
370 | over
371 | overall
372 | own
373 | p
374 | particular
375 | particularly
376 | per
377 | perhaps
378 | placed
379 | please
380 | plus
381 | possible
382 | presumably
383 | probably
384 | provides
385 | q
386 | que
387 | quite
388 | qv
389 | r
390 | rather
391 | rd
392 | re
393 | really
394 | reasonably
395 | regarding
396 | regardless
397 | regards
398 | relatively
399 | respectively
400 | right
401 | s
402 | said
403 | same
404 | saw
405 | say
406 | saying
407 | says
408 | second
409 | secondly
410 | see
411 | seeing
412 | seem
413 | seemed
414 | seeming
415 | seems
416 | seen
417 | self
418 | selves
419 | sensible
420 | sent
421 | serious
422 | seriously
423 | seven
424 | several
425 | shall
426 | she
427 | should
428 | shouldn't
429 | since
430 | six
431 | so
432 | some
433 | somebody
434 | somehow
435 | someone
436 | something
437 | sometime
438 | sometimes
439 | somewhat
440 | somewhere
441 | soon
442 | sorry
443 | specified
444 | specify
445 | specifying
446 | still
447 | sub
448 | such
449 | sup
450 | sure
451 | t
452 | t's
453 | take
454 | taken
455 | tell
456 | tends
457 | th
458 | than
459 | thank
460 | thanks
461 | thanx
462 | that
463 | that's
464 | thats
465 | the
466 | their
467 | theirs
468 | them
469 | themselves
470 | then
471 | thence
472 | there
473 | there's
474 | thereafter
475 | thereby
476 | therefore
477 | therein
478 | theres
479 | thereupon
480 | these
481 | they
482 | they'd
483 | they'll
484 | they're
485 | they've
486 | think
487 | third
488 | this
489 | thorough
490 | thoroughly
491 | those
492 | though
493 | three
494 | through
495 | throughout
496 | thru
497 | thus
498 | to
499 | together
500 | too
501 | took
502 | toward
503 | towards
504 | tried
505 | tries
506 | truly
507 | try
508 | trying
509 | twice
510 | two
511 | u
512 | un
513 | under
514 | unfortunately
515 | unless
516 | unlikely
517 | until
518 | unto
519 | up
520 | upon
521 | us
522 | use
523 | used
524 | useful
525 | uses
526 | using
527 | usually
528 | uucp
529 | v
530 | value
531 | various
532 | very
533 | via
534 | viz
535 | vs
536 | w
537 | want
538 | wants
539 | was
540 | wasn't
541 | way
542 | we
543 | we'd
544 | we'll
545 | we're
546 | we've
547 | welcome
548 | well
549 | went
550 | were
551 | weren't
552 | what
553 | what's
554 | whatever
555 | when
556 | whence
557 | whenever
558 | where
559 | where's
560 | whereafter
561 | whereas
562 | whereby
563 | wherein
564 | whereupon
565 | wherever
566 | whether
567 | which
568 | while
569 | whither
570 | who
571 | who's
572 | whoever
573 | whole
574 | whom
575 | whose
576 | why
577 | will
578 | willing
579 | wish
580 | with
581 | within
582 | without
583 | won't
584 | wonder
585 | would
586 | would
587 | wouldn't
588 | x
589 | y
590 | yes
591 | yet
592 | you
593 | you'd
594 | you'll
595 | you're
596 | you've
597 | your
598 | yours
599 | yourself
600 | yourselves
601 | z
602 | zero
603 |
--------------------------------------------------------------------------------
/src/prepro/utils.py:
--------------------------------------------------------------------------------
1 | # stopwords = pkgutil.get_data(__package__, 'smart_common_words.txt')
2 | # stopwords = stopwords.decode('ascii').split('\n')
3 | # stopwords = {key.strip(): 1 for key in stopwords}
4 |
5 |
6 | def _get_ngrams(n, text):
7 | """Calcualtes n-grams.
8 |
9 | Args:
10 | n: which n-grams to calculate
11 | text: An array of tokens
12 |
13 | Returns:
14 | A set of n-grams
15 | """
16 | ngram_set = set()
17 | text_length = len(text)
18 | max_index_ngram_start = text_length - n
19 | for i in range(max_index_ngram_start + 1):
20 | ngram_set.add(tuple(text[i:i + n]))
21 | return ngram_set
22 |
23 |
24 | def _get_word_ngrams(n, sentences):
25 | """Calculates word n-grams for multiple sentences.
26 | """
27 | assert len(sentences) > 0
28 | assert n > 0
29 |
30 | # words = _split_into_words(sentences)
31 |
32 | words = sum(sentences, [])
33 | # words = [w for w in words if w not in stopwords]
34 | return _get_ngrams(n, words)
35 |
--------------------------------------------------------------------------------
/src/preprocess.py:
--------------------------------------------------------------------------------
1 | #encoding=utf-8
2 |
3 |
4 | import argparse
5 | import time
6 |
7 | from others.logging import init_logger
8 | from prepro import data_builder
9 |
10 |
11 | def do_format_to_lines(args):
12 | print(time.clock())
13 | data_builder.format_to_lines(args)
14 | print(time.clock())
15 |
16 | def do_format_to_bert(args):
17 | print(time.clock())
18 | data_builder.format_to_bert(args)
19 | print(time.clock())
20 |
21 |
22 |
23 | def do_format_xsum_to_lines(args):
24 | print(time.clock())
25 | data_builder.format_xsum_to_lines(args)
26 | print(time.clock())
27 |
28 | def do_tokenize(args):
29 | print(time.clock())
30 | data_builder.tokenize(args)
31 | print(time.clock())
32 |
33 |
34 | def str2bool(v):
35 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
36 | return True
37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
38 | return False
39 | else:
40 | raise argparse.ArgumentTypeError('Boolean value expected.')
41 |
42 |
43 | if __name__ == '__main__':
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("-pretrained_model", default='bert', type=str)
46 |
47 | parser.add_argument("-mode", default='', type=str)
48 | parser.add_argument("-select_mode", default='greedy', type=str)
49 | parser.add_argument("-map_path", default='../../data/')
50 | parser.add_argument("-raw_path", default='../../line_data')
51 | parser.add_argument("-save_path", default='../../data/')
52 |
53 | parser.add_argument("-shard_size", default=2000, type=int)
54 | parser.add_argument('-min_src_nsents', default=3, type=int)
55 | parser.add_argument('-max_src_nsents', default=100, type=int)
56 | parser.add_argument('-min_src_ntokens_per_sent', default=5, type=int)
57 | parser.add_argument('-max_src_ntokens_per_sent', default=200, type=int)
58 | parser.add_argument('-min_tgt_ntokens', default=5, type=int)
59 | parser.add_argument('-max_tgt_ntokens', default=500, type=int)
60 |
61 | parser.add_argument("-lower", type=str2bool, nargs='?',const=True,default=True)
62 | parser.add_argument("-use_bert_basic_tokenizer", type=str2bool, nargs='?',const=True,default=False)
63 |
64 | parser.add_argument('-log_file', default='../../logs/cnndm.log')
65 |
66 | parser.add_argument('-dataset', default='')
67 |
68 | parser.add_argument('-n_cpus', default=2, type=int)
69 |
70 |
71 | args = parser.parse_args()
72 | init_logger(args.log_file)
73 | eval('data_builder.'+args.mode + '(args)')
74 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Main training workflow
4 | """
5 | from __future__ import division
6 |
7 | import argparse
8 | import os
9 | from others.logging import init_logger
10 | from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs
11 | from train_extractive import train_ext, validate_ext, test_ext
12 |
13 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size',
14 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval']
15 |
16 |
17 | def str2bool(v):
18 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
19 | return True
20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
21 | return False
22 | else:
23 | raise argparse.ArgumentTypeError('Boolean value expected.')
24 |
25 |
26 |
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("-task", default='ext', type=str, choices=['ext', 'abs'])
31 | parser.add_argument("-encoder", default='bert', type=str, choices=['bert', 'baseline'])
32 | parser.add_argument("-mode", default='train', type=str, choices=['train', 'validate', 'test'])
33 | parser.add_argument("-bert_data_path", default='../bert_data_new/cnndm')
34 | parser.add_argument("-model_path", default='../models/')
35 | parser.add_argument("-result_path", default='../results/cnndm')
36 | parser.add_argument("-temp_dir", default='../temp')
37 |
38 | parser.add_argument("-batch_size", default=140, type=int)
39 | parser.add_argument("-test_batch_size", default=200, type=int)
40 |
41 | parser.add_argument("-max_pos", default=512, type=int)
42 | parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True)
43 | parser.add_argument("-large", type=str2bool, nargs='?',const=True,default=False)
44 | parser.add_argument("-load_from_extractive", default='', type=str)
45 |
46 | parser.add_argument("-sep_optim", type=str2bool, nargs='?',const=True,default=False)
47 | parser.add_argument("-lr_bert", default=2e-3, type=float)
48 | parser.add_argument("-lr_dec", default=2e-3, type=float)
49 | parser.add_argument("-use_bert_emb", type=str2bool, nargs='?',const=True,default=False)
50 |
51 | parser.add_argument("-share_emb", type=str2bool, nargs='?', const=True, default=False)
52 | parser.add_argument("-finetune_bert", type=str2bool, nargs='?', const=True, default=True)
53 | parser.add_argument("-dec_dropout", default=0.2, type=float)
54 | parser.add_argument("-dec_layers", default=6, type=int)
55 | parser.add_argument("-dec_hidden_size", default=768, type=int)
56 | parser.add_argument("-dec_heads", default=8, type=int)
57 | parser.add_argument("-dec_ff_size", default=2048, type=int)
58 | parser.add_argument("-enc_hidden_size", default=512, type=int)
59 | parser.add_argument("-enc_ff_size", default=512, type=int)
60 | parser.add_argument("-enc_dropout", default=0.2, type=float)
61 | parser.add_argument("-enc_layers", default=6, type=int)
62 |
63 | # params for EXT
64 | parser.add_argument("-ext_dropout", default=0.2, type=float)
65 | parser.add_argument("-ext_layers", default=2, type=int)
66 | parser.add_argument("-ext_hidden_size", default=768, type=int)
67 | parser.add_argument("-ext_heads", default=8, type=int)
68 | parser.add_argument("-ext_ff_size", default=2048, type=int)
69 |
70 | parser.add_argument("-label_smoothing", default=0.1, type=float)
71 | parser.add_argument("-generator_shard_size", default=32, type=int)
72 | parser.add_argument("-alpha", default=0.6, type=float)
73 | parser.add_argument("-beam_size", default=5, type=int)
74 | parser.add_argument("-min_length", default=15, type=int)
75 | parser.add_argument("-max_length", default=150, type=int)
76 | parser.add_argument("-max_tgt_len", default=140, type=int)
77 |
78 |
79 |
80 | parser.add_argument("-param_init", default=0, type=float)
81 | parser.add_argument("-param_init_glorot", type=str2bool, nargs='?',const=True,default=True)
82 | parser.add_argument("-optim", default='adam', type=str)
83 | parser.add_argument("-lr", default=1, type=float)
84 | parser.add_argument("-beta1", default= 0.9, type=float)
85 | parser.add_argument("-beta2", default=0.999, type=float)
86 | parser.add_argument("-warmup_steps", default=8000, type=int)
87 | parser.add_argument("-warmup_steps_bert", default=8000, type=int)
88 | parser.add_argument("-warmup_steps_dec", default=8000, type=int)
89 | parser.add_argument("-max_grad_norm", default=0, type=float)
90 |
91 | parser.add_argument("-save_checkpoint_steps", default=5, type=int)
92 | parser.add_argument("-accum_count", default=1, type=int)
93 | parser.add_argument("-report_every", default=1, type=int)
94 | parser.add_argument("-train_steps", default=1000, type=int)
95 | parser.add_argument("-recall_eval", type=str2bool, nargs='?',const=True,default=False)
96 |
97 |
98 | parser.add_argument('-visible_gpus', default='-1', type=str)
99 | parser.add_argument('-gpu_ranks', default='0', type=str)
100 | parser.add_argument('-log_file', default='../logs/cnndm.log')
101 | parser.add_argument('-seed', default=666, type=int)
102 |
103 | parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False)
104 | parser.add_argument("-test_from", default='')
105 | parser.add_argument("-test_start_from", default=-1, type=int)
106 |
107 | parser.add_argument("-train_from", default='')
108 | parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=True)
109 | parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True)
110 |
111 | args = parser.parse_args()
112 | args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))]
113 | args.world_size = len(args.gpu_ranks)
114 | os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus
115 |
116 | init_logger(args.log_file)
117 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
118 | device_id = 0 if device == "cuda" else -1
119 |
120 | if (args.task == 'abs'):
121 | if (args.mode == 'train'):
122 | train_abs(args, device_id)
123 | elif (args.mode == 'validate'):
124 | validate_abs(args, device_id)
125 | elif (args.mode == 'lead'):
126 | baseline(args, cal_lead=True)
127 | elif (args.mode == 'oracle'):
128 | baseline(args, cal_oracle=True)
129 | if (args.mode == 'test'):
130 | cp = args.test_from
131 | try:
132 | step = int(cp.split('.')[-2].split('_')[-1])
133 | except:
134 | step = 0
135 | test_abs(args, device_id, cp, step)
136 | elif (args.mode == 'test_text'):
137 | cp = args.test_from
138 | try:
139 | step = int(cp.split('.')[-2].split('_')[-1])
140 | except:
141 | step = 0
142 | test_text_abs(args, device_id, cp, step)
143 |
144 | elif (args.task == 'ext'):
145 | if (args.mode == 'train'):
146 | train_ext(args, device_id)
147 | elif (args.mode == 'validate'):
148 | validate_ext(args, device_id)
149 | if (args.mode == 'test'):
150 | cp = args.test_from
151 | try:
152 | step = int(cp.split('.')[-2].split('_')[-1])
153 | except:
154 | step = 0
155 | test_ext(args, device_id, cp, step)
156 | elif (args.mode == 'test_text'):
157 | cp = args.test_from
158 | try:
159 | step = int(cp.split('.')[-2].split('_')[-1])
160 | except:
161 | step = 0
162 | test_text_abs(args, device_id, cp, step)
163 |
--------------------------------------------------------------------------------
/src/train_abstractive.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Main training workflow
4 | """
5 | from __future__ import division
6 |
7 | import argparse
8 | import glob
9 | import os
10 | import random
11 | import signal
12 | import time
13 |
14 | import torch
15 | from pytorch_transformers import BertTokenizer
16 |
17 | import distributed
18 | from models import data_loader, model_builder
19 | from models.data_loader import load_dataset
20 | from models.loss import abs_loss
21 | from models.model_builder import AbsSummarizer
22 | from models.predictor import build_predictor
23 | from models.trainer import build_trainer
24 | from others.logging import logger, init_logger
25 |
26 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size',
27 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval']
28 |
29 |
30 | def str2bool(v):
31 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
32 | return True
33 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
34 | return False
35 | else:
36 | raise argparse.ArgumentTypeError('Boolean value expected.')
37 |
38 |
39 | def train_abs_multi(args):
40 | """ Spawns 1 process per GPU """
41 | init_logger()
42 |
43 | nb_gpu = args.world_size
44 | mp = torch.multiprocessing.get_context('spawn')
45 |
46 | # Create a thread to listen for errors in the child processes.
47 | error_queue = mp.SimpleQueue()
48 | error_handler = ErrorHandler(error_queue)
49 |
50 | # Train with multiprocessing.
51 | procs = []
52 | for i in range(nb_gpu):
53 | device_id = i
54 | procs.append(mp.Process(target=run, args=(args,
55 | device_id, error_queue,), daemon=True))
56 | procs[i].start()
57 | logger.info(" Starting process pid: %d " % procs[i].pid)
58 | error_handler.add_child(procs[i].pid)
59 | for p in procs:
60 | p.join()
61 |
62 |
63 | def run(args, device_id, error_queue):
64 | """ run process """
65 |
66 | setattr(args, 'gpu_ranks', [int(i) for i in args.gpu_ranks])
67 |
68 | try:
69 | gpu_rank = distributed.multi_init(device_id, args.world_size, args.gpu_ranks)
70 | print('gpu_rank %d' % gpu_rank)
71 | if gpu_rank != args.gpu_ranks[device_id]:
72 | raise AssertionError("An error occurred in \
73 | Distributed initialization")
74 |
75 | train_abs_single(args, device_id)
76 | except KeyboardInterrupt:
77 | pass # killed by parent, do nothing
78 | except Exception:
79 | # propagate exception to parent process, keeping original traceback
80 | import traceback
81 | error_queue.put((args.gpu_ranks[device_id], traceback.format_exc()))
82 |
83 |
84 | class ErrorHandler(object):
85 | """A class that listens for exceptions in children processes and propagates
86 | the tracebacks to the parent process."""
87 |
88 | def __init__(self, error_queue):
89 | """ init error handler """
90 | import signal
91 | import threading
92 | self.error_queue = error_queue
93 | self.children_pids = []
94 | self.error_thread = threading.Thread(
95 | target=self.error_listener, daemon=True)
96 | self.error_thread.start()
97 | signal.signal(signal.SIGUSR1, self.signal_handler)
98 |
99 | def add_child(self, pid):
100 | """ error handler """
101 | self.children_pids.append(pid)
102 |
103 | def error_listener(self):
104 | """ error listener """
105 | (rank, original_trace) = self.error_queue.get()
106 | self.error_queue.put((rank, original_trace))
107 | os.kill(os.getpid(), signal.SIGUSR1)
108 |
109 | def signal_handler(self, signalnum, stackframe):
110 | """ signal handler """
111 | for pid in self.children_pids:
112 | os.kill(pid, signal.SIGINT) # kill children processes
113 | (rank, original_trace) = self.error_queue.get()
114 | msg = """\n\n-- Tracebacks above this line can probably
115 | be ignored --\n\n"""
116 | msg += original_trace
117 | raise Exception(msg)
118 |
119 |
120 | def validate_abs(args, device_id):
121 | timestep = 0
122 | if (args.test_all):
123 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
124 | cp_files.sort(key=os.path.getmtime)
125 | xent_lst = []
126 | for i, cp in enumerate(cp_files):
127 | step = int(cp.split('.')[-2].split('_')[-1])
128 | if (args.test_start_from != -1 and step < args.test_start_from):
129 | xent_lst.append((1e6, cp))
130 | continue
131 | xent = validate(args, device_id, cp, step)
132 | xent_lst.append((xent, cp))
133 | max_step = xent_lst.index(min(xent_lst))
134 | if (i - max_step > 10):
135 | break
136 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:5]
137 | logger.info('PPL %s' % str(xent_lst))
138 | for xent, cp in xent_lst:
139 | step = int(cp.split('.')[-2].split('_')[-1])
140 | test_abs(args, device_id, cp, step)
141 | else:
142 | while (True):
143 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
144 | cp_files.sort(key=os.path.getmtime)
145 | if (cp_files):
146 | cp = cp_files[-1]
147 | time_of_cp = os.path.getmtime(cp)
148 | if (not os.path.getsize(cp) > 0):
149 | time.sleep(60)
150 | continue
151 | if (time_of_cp > timestep):
152 | timestep = time_of_cp
153 | step = int(cp.split('.')[-2].split('_')[-1])
154 | validate(args, device_id, cp, step)
155 | test_abs(args, device_id, cp, step)
156 |
157 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
158 | cp_files.sort(key=os.path.getmtime)
159 | if (cp_files):
160 | cp = cp_files[-1]
161 | time_of_cp = os.path.getmtime(cp)
162 | if (time_of_cp > timestep):
163 | continue
164 | else:
165 | time.sleep(300)
166 |
167 |
168 | def validate(args, device_id, pt, step):
169 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
170 | if (pt != ''):
171 | test_from = pt
172 | else:
173 | test_from = args.test_from
174 | logger.info('Loading checkpoint from %s' % test_from)
175 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
176 | opt = vars(checkpoint['opt'])
177 | for k in opt.keys():
178 | if (k in model_flags):
179 | setattr(args, k, opt[k])
180 | print(args)
181 |
182 | model = AbsSummarizer(args, device, checkpoint)
183 | model.eval()
184 |
185 | valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False),
186 | args.batch_size, device,
187 | shuffle=False, is_test=False)
188 |
189 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
190 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
191 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}
192 |
193 | valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device)
194 |
195 | trainer = build_trainer(args, device_id, model, None, valid_loss)
196 | stats = trainer.validate(valid_iter, step)
197 | return stats.xent()
198 |
199 |
200 | def test_abs(args, device_id, pt, step):
201 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
202 | if (pt != ''):
203 | test_from = pt
204 | else:
205 | test_from = args.test_from
206 | logger.info('Loading checkpoint from %s' % test_from)
207 |
208 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
209 | opt = vars(checkpoint['opt'])
210 | for k in opt.keys():
211 | if (k in model_flags):
212 | setattr(args, k, opt[k])
213 | print(args)
214 |
215 | model = AbsSummarizer(args, device, checkpoint)
216 | model.eval()
217 |
218 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
219 | args.test_batch_size, device,
220 | shuffle=False, is_test=True)
221 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
222 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
223 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}
224 | predictor = build_predictor(args, tokenizer, symbols, model, logger)
225 | predictor.translate(test_iter, step)
226 |
227 |
228 | def test_text_abs(args, device_id, pt, step):
229 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
230 | if (pt != ''):
231 | test_from = pt
232 | else:
233 | test_from = args.test_from
234 | logger.info('Loading checkpoint from %s' % test_from)
235 |
236 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
237 | opt = vars(checkpoint['opt'])
238 | for k in opt.keys():
239 | if (k in model_flags):
240 | setattr(args, k, opt[k])
241 | print(args)
242 |
243 | model = AbsSummarizer(args, device, checkpoint)
244 | model.eval()
245 |
246 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
247 | args.test_batch_size, device,
248 | shuffle=False, is_test=True)
249 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
250 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
251 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}
252 | predictor = build_predictor(args, tokenizer, symbols, model, logger)
253 | predictor.translate(test_iter, step)
254 |
255 |
256 | def baseline(args, cal_lead=False, cal_oracle=False):
257 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
258 | args.batch_size, 'cpu',
259 | shuffle=False, is_test=True)
260 |
261 | trainer = build_trainer(args, '-1', None, None, None)
262 | #
263 | if (cal_lead):
264 | trainer.test(test_iter, 0, cal_lead=True)
265 | elif (cal_oracle):
266 | trainer.test(test_iter, 0, cal_oracle=True)
267 |
268 |
269 | def train_abs(args, device_id):
270 | if (args.world_size > 1):
271 | train_abs_multi(args)
272 | else:
273 | train_abs_single(args, device_id)
274 |
275 |
276 | def train_abs_single(args, device_id):
277 | init_logger(args.log_file)
278 | logger.info(str(args))
279 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
280 | logger.info('Device ID %d' % device_id)
281 | logger.info('Device %s' % device)
282 | torch.manual_seed(args.seed)
283 | random.seed(args.seed)
284 | torch.backends.cudnn.deterministic = True
285 |
286 | if device_id >= 0:
287 | torch.cuda.set_device(device_id)
288 | torch.cuda.manual_seed(args.seed)
289 |
290 | if args.train_from != '':
291 | logger.info('Loading checkpoint from %s' % args.train_from)
292 | checkpoint = torch.load(args.train_from,
293 | map_location=lambda storage, loc: storage)
294 | opt = vars(checkpoint['opt'])
295 | for k in opt.keys():
296 | if (k in model_flags):
297 | setattr(args, k, opt[k])
298 | else:
299 | checkpoint = None
300 |
301 | if (args.load_from_extractive != ''):
302 | logger.info('Loading bert from extractive model %s' % args.load_from_extractive)
303 | bert_from_extractive = torch.load(args.load_from_extractive, map_location=lambda storage, loc: storage)
304 | bert_from_extractive = bert_from_extractive['model']
305 | else:
306 | bert_from_extractive = None
307 | torch.manual_seed(args.seed)
308 | random.seed(args.seed)
309 | torch.backends.cudnn.deterministic = True
310 |
311 | def train_iter_fct():
312 | return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
313 | shuffle=True, is_test=False)
314 |
315 | model = AbsSummarizer(args, device, checkpoint, bert_from_extractive)
316 | if (args.sep_optim):
317 | optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
318 | optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
319 | optim = [optim_bert, optim_dec]
320 | else:
321 | optim = [model_builder.build_optim(args, model, checkpoint)]
322 |
323 | logger.info(model)
324 |
325 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
326 | symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
327 | 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}
328 |
329 | train_loss = abs_loss(model.generator, symbols, model.vocab_size, device, train=True,
330 | label_smoothing=args.label_smoothing)
331 |
332 | trainer = build_trainer(args, device_id, model, optim, train_loss)
333 |
334 | trainer.train(train_iter_fct, args.train_steps)
335 |
--------------------------------------------------------------------------------
/src/train_extractive.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Main training workflow
4 | """
5 | from __future__ import division
6 |
7 | import argparse
8 | import glob
9 | import os
10 | import random
11 | import signal
12 | import time
13 |
14 | import torch
15 |
16 | import distributed
17 | from models import data_loader, model_builder
18 | from models.data_loader import load_dataset
19 | from models.model_builder import ExtSummarizer
20 | from models.trainer_ext import build_trainer
21 | from others.logging import logger, init_logger
22 |
23 | model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', 'rnn_size']
24 |
25 |
26 | def train_multi_ext(args):
27 | """ Spawns 1 process per GPU """
28 | init_logger()
29 |
30 | nb_gpu = args.world_size
31 | mp = torch.multiprocessing.get_context('spawn')
32 |
33 | # Create a thread to listen for errors in the child processes.
34 | error_queue = mp.SimpleQueue()
35 | error_handler = ErrorHandler(error_queue)
36 |
37 | # Train with multiprocessing.
38 | procs = []
39 | for i in range(nb_gpu):
40 | device_id = i
41 | procs.append(mp.Process(target=run, args=(args,
42 | device_id, error_queue,), daemon=True))
43 | procs[i].start()
44 | logger.info(" Starting process pid: %d " % procs[i].pid)
45 | error_handler.add_child(procs[i].pid)
46 | for p in procs:
47 | p.join()
48 |
49 |
50 | def run(args, device_id, error_queue):
51 | """ run process """
52 | setattr(args, 'gpu_ranks', [int(i) for i in args.gpu_ranks])
53 |
54 | try:
55 | gpu_rank = distributed.multi_init(device_id, args.world_size, args.gpu_ranks)
56 | print('gpu_rank %d' % gpu_rank)
57 | if gpu_rank != args.gpu_ranks[device_id]:
58 | raise AssertionError("An error occurred in \
59 | Distributed initialization")
60 |
61 | train_single_ext(args, device_id)
62 | except KeyboardInterrupt:
63 | pass # killed by parent, do nothing
64 | except Exception:
65 | # propagate exception to parent process, keeping original traceback
66 | import traceback
67 | error_queue.put((args.gpu_ranks[device_id], traceback.format_exc()))
68 |
69 |
70 | class ErrorHandler(object):
71 | """A class that listens for exceptions in children processes and propagates
72 | the tracebacks to the parent process."""
73 |
74 | def __init__(self, error_queue):
75 | """ init error handler """
76 | import signal
77 | import threading
78 | self.error_queue = error_queue
79 | self.children_pids = []
80 | self.error_thread = threading.Thread(
81 | target=self.error_listener, daemon=True)
82 | self.error_thread.start()
83 | signal.signal(signal.SIGUSR1, self.signal_handler)
84 |
85 | def add_child(self, pid):
86 | """ error handler """
87 | self.children_pids.append(pid)
88 |
89 | def error_listener(self):
90 | """ error listener """
91 | (rank, original_trace) = self.error_queue.get()
92 | self.error_queue.put((rank, original_trace))
93 | os.kill(os.getpid(), signal.SIGUSR1)
94 |
95 | def signal_handler(self, signalnum, stackframe):
96 | """ signal handler """
97 | for pid in self.children_pids:
98 | os.kill(pid, signal.SIGINT) # kill children processes
99 | (rank, original_trace) = self.error_queue.get()
100 | msg = """\n\n-- Tracebacks above this line can probably
101 | be ignored --\n\n"""
102 | msg += original_trace
103 | raise Exception(msg)
104 |
105 |
106 | def validate_ext(args, device_id):
107 | timestep = 0
108 | if (args.test_all):
109 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
110 | cp_files.sort(key=os.path.getmtime)
111 | xent_lst = []
112 | for i, cp in enumerate(cp_files):
113 | step = int(cp.split('.')[-2].split('_')[-1])
114 | xent = validate(args, device_id, cp, step)
115 | xent_lst.append((xent, cp))
116 | max_step = xent_lst.index(min(xent_lst))
117 | if (i - max_step > 10):
118 | break
119 | xent_lst = sorted(xent_lst, key=lambda x: x[0])[:3]
120 | logger.info('PPL %s' % str(xent_lst))
121 | for xent, cp in xent_lst:
122 | step = int(cp.split('.')[-2].split('_')[-1])
123 | test_ext(args, device_id, cp, step)
124 | else:
125 | while (True):
126 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
127 | cp_files.sort(key=os.path.getmtime)
128 | if (cp_files):
129 | cp = cp_files[-1]
130 | time_of_cp = os.path.getmtime(cp)
131 | if (not os.path.getsize(cp) > 0):
132 | time.sleep(60)
133 | continue
134 | if (time_of_cp > timestep):
135 | timestep = time_of_cp
136 | step = int(cp.split('.')[-2].split('_')[-1])
137 | validate(args, device_id, cp, step)
138 | test_ext(args, device_id, cp, step)
139 |
140 | cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
141 | cp_files.sort(key=os.path.getmtime)
142 | if (cp_files):
143 | cp = cp_files[-1]
144 | time_of_cp = os.path.getmtime(cp)
145 | if (time_of_cp > timestep):
146 | continue
147 | else:
148 | time.sleep(300)
149 |
150 |
151 | def validate(args, device_id, pt, step):
152 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
153 | if (pt != ''):
154 | test_from = pt
155 | else:
156 | test_from = args.test_from
157 | logger.info('Loading checkpoint from %s' % test_from)
158 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
159 | opt = vars(checkpoint['opt'])
160 | for k in opt.keys():
161 | if (k in model_flags):
162 | setattr(args, k, opt[k])
163 | print(args)
164 |
165 | model = ExtSummarizer(args, device, checkpoint)
166 | model.eval()
167 |
168 | valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False),
169 | args.batch_size, device,
170 | shuffle=False, is_test=False)
171 | trainer = build_trainer(args, device_id, model, None)
172 | stats = trainer.validate(valid_iter, step)
173 | return stats.xent()
174 |
175 |
176 | def test_ext(args, device_id, pt, step):
177 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
178 | if (pt != ''):
179 | test_from = pt
180 | else:
181 | test_from = args.test_from
182 | logger.info('Loading checkpoint from %s' % test_from)
183 | checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
184 | opt = vars(checkpoint['opt'])
185 | for k in opt.keys():
186 | if (k in model_flags):
187 | setattr(args, k, opt[k])
188 | print(args)
189 |
190 | model = ExtSummarizer(args, device, checkpoint)
191 | model.eval()
192 |
193 | test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
194 | args.test_batch_size, device,
195 | shuffle=False, is_test=True)
196 | trainer = build_trainer(args, device_id, model, None)
197 | trainer.test(test_iter, step)
198 |
199 | def train_ext(args, device_id):
200 | if (args.world_size > 1):
201 | train_multi_ext(args)
202 | else:
203 | train_single_ext(args, device_id)
204 |
205 |
206 | def train_single_ext(args, device_id):
207 | init_logger(args.log_file)
208 |
209 | device = "cpu" if args.visible_gpus == '-1' else "cuda"
210 | logger.info('Device ID %d' % device_id)
211 | logger.info('Device %s' % device)
212 | torch.manual_seed(args.seed)
213 | random.seed(args.seed)
214 | torch.backends.cudnn.deterministic = True
215 |
216 | if device_id >= 0:
217 | torch.cuda.set_device(device_id)
218 | torch.cuda.manual_seed(args.seed)
219 |
220 | torch.manual_seed(args.seed)
221 | random.seed(args.seed)
222 | torch.backends.cudnn.deterministic = True
223 |
224 | if args.train_from != '':
225 | logger.info('Loading checkpoint from %s' % args.train_from)
226 | checkpoint = torch.load(args.train_from,
227 | map_location=lambda storage, loc: storage)
228 | opt = vars(checkpoint['opt'])
229 | for k in opt.keys():
230 | if (k in model_flags):
231 | setattr(args, k, opt[k])
232 | else:
233 | checkpoint = None
234 |
235 | def train_iter_fct():
236 | return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
237 | shuffle=True, is_test=False)
238 |
239 | model = ExtSummarizer(args, device, checkpoint)
240 | optim = model_builder.build_optim(args, model, checkpoint)
241 |
242 | logger.info(model)
243 |
244 | trainer = build_trainer(args, device_id, model, optim)
245 | trainer.train(train_iter_fct, args.train_steps)
246 |
--------------------------------------------------------------------------------
/src/translate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nlpyang/PreSumm/70b810e0f06d179022958dd35c1a3385fe87f28c/src/translate/__init__.py
--------------------------------------------------------------------------------
/src/translate/beam.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | from translate import penalties
4 |
5 |
6 | class Beam(object):
7 | """
8 | Class for managing the internals of the beam search process.
9 |
10 | Takes care of beams, back pointers, and scores.
11 |
12 | Args:
13 | size (int): beam size
14 | pad, bos, eos (int): indices of padding, beginning, and ending.
15 | n_best (int): nbest size to use
16 | cuda (bool): use gpu
17 | global_scorer (:obj:`GlobalScorer`)
18 | """
19 |
20 | def __init__(self, size, pad, bos, eos,
21 | n_best=1, cuda=False,
22 | global_scorer=None,
23 | min_length=0,
24 | stepwise_penalty=False,
25 | block_ngram_repeat=0,
26 | exclusion_tokens=set()):
27 |
28 | self.size = size
29 | self.tt = torch.cuda if cuda else torch
30 |
31 | # The score for each translation on the beam.
32 | self.scores = self.tt.FloatTensor(size).zero_()
33 | self.all_scores = []
34 |
35 | # The backpointers at each time-step.
36 | self.prev_ks = []
37 |
38 | # The outputs at each time-step.
39 | self.next_ys = [self.tt.LongTensor(size)
40 | .fill_(pad)]
41 | self.next_ys[0][0] = bos
42 |
43 | # Has EOS topped the beam yet.
44 | self._eos = eos
45 | self.eos_top = False
46 |
47 | # The attentions (matrix) for each time.
48 | self.attn = []
49 |
50 | # Time and k pair for finished.
51 | self.finished = []
52 | self.n_best = n_best
53 |
54 | # Information for global scoring.
55 | self.global_scorer = global_scorer
56 | self.global_state = {}
57 |
58 | # Minimum prediction length
59 | self.min_length = min_length
60 |
61 | # Apply Penalty at every step
62 | self.stepwise_penalty = stepwise_penalty
63 | self.block_ngram_repeat = block_ngram_repeat
64 | self.exclusion_tokens = exclusion_tokens
65 |
66 | def get_current_state(self):
67 | "Get the outputs for the current timestep."
68 | return self.next_ys[-1]
69 |
70 | def get_current_origin(self):
71 | "Get the backpointers for the current timestep."
72 | return self.prev_ks[-1]
73 |
74 | def advance(self, word_probs, attn_out):
75 | """
76 | Given prob over words for every last beam `wordLk` and attention
77 | `attn_out`: Compute and update the beam search.
78 |
79 | Parameters:
80 |
81 | * `word_probs`- probs of advancing from the last step (K x words)
82 | * `attn_out`- attention at the last step
83 |
84 | Returns: True if beam search is complete.
85 | """
86 | num_words = word_probs.size(1)
87 | if self.stepwise_penalty:
88 | self.global_scorer.update_score(self, attn_out)
89 | # force the output to be longer than self.min_length
90 | cur_len = len(self.next_ys)
91 | if cur_len < self.min_length:
92 | for k in range(len(word_probs)):
93 | word_probs[k][self._eos] = -1e20
94 | # Sum the previous scores.
95 | if len(self.prev_ks) > 0:
96 | beam_scores = word_probs + \
97 | self.scores.unsqueeze(1).expand_as(word_probs)
98 | # Don't let EOS have children.
99 | for i in range(self.next_ys[-1].size(0)):
100 | if self.next_ys[-1][i] == self._eos:
101 | beam_scores[i] = -1e20
102 |
103 | # Block ngram repeats
104 | if self.block_ngram_repeat > 0:
105 | ngrams = []
106 | le = len(self.next_ys)
107 | for j in range(self.next_ys[-1].size(0)):
108 | hyp, _ = self.get_hyp(le - 1, j)
109 | ngrams = set()
110 | fail = False
111 | gram = []
112 | for i in range(le - 1):
113 | # Last n tokens, n = block_ngram_repeat
114 | gram = (gram +
115 | [hyp[i].item()])[-self.block_ngram_repeat:]
116 | # Skip the blocking if it is in the exclusion list
117 | if set(gram) & self.exclusion_tokens:
118 | continue
119 | if tuple(gram) in ngrams:
120 | fail = True
121 | ngrams.add(tuple(gram))
122 | if fail:
123 | beam_scores[j] = -10e20
124 | else:
125 | beam_scores = word_probs[0]
126 | flat_beam_scores = beam_scores.view(-1)
127 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0,
128 | True, True)
129 |
130 | self.all_scores.append(self.scores)
131 | self.scores = best_scores
132 |
133 | # best_scores_id is flattened beam x word array, so calculate which
134 | # word and beam each score came from
135 | prev_k = best_scores_id / num_words
136 | self.prev_ks.append(prev_k)
137 | self.next_ys.append((best_scores_id - prev_k * num_words))
138 | self.attn.append(attn_out.index_select(0, prev_k))
139 | self.global_scorer.update_global_state(self)
140 |
141 | for i in range(self.next_ys[-1].size(0)):
142 | if self.next_ys[-1][i] == self._eos:
143 | global_scores = self.global_scorer.score(self, self.scores)
144 | s = global_scores[i]
145 | self.finished.append((s, len(self.next_ys) - 1, i))
146 |
147 | # End condition is when top-of-beam is EOS and no global score.
148 | if self.next_ys[-1][0] == self._eos:
149 | self.all_scores.append(self.scores)
150 | self.eos_top = True
151 |
152 | def done(self):
153 | return self.eos_top and len(self.finished) >= self.n_best
154 |
155 | def sort_finished(self, minimum=None):
156 | if minimum is not None:
157 | i = 0
158 | # Add from beam until we have minimum outputs.
159 | while len(self.finished) < minimum:
160 | global_scores = self.global_scorer.score(self, self.scores)
161 | s = global_scores[i]
162 | self.finished.append((s, len(self.next_ys) - 1, i))
163 | i += 1
164 |
165 | self.finished.sort(key=lambda a: -a[0])
166 | scores = [sc for sc, _, _ in self.finished]
167 | ks = [(t, k) for _, t, k in self.finished]
168 | return scores, ks
169 |
170 | def get_hyp(self, timestep, k):
171 | """
172 | Walk back to construct the full hypothesis.
173 | """
174 | hyp, attn = [], []
175 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
176 | hyp.append(self.next_ys[j + 1][k])
177 | attn.append(self.attn[j][k])
178 | k = self.prev_ks[j][k]
179 | return hyp[::-1], torch.stack(attn[::-1])
180 |
181 |
182 | class GNMTGlobalScorer(object):
183 | """
184 | NMT re-ranking score from
185 | "Google's Neural Machine Translation System" :cite:`wu2016google`
186 |
187 | Args:
188 | alpha (float): length parameter
189 | beta (float): coverage parameter
190 | """
191 |
192 | def __init__(self, alpha, length_penalty):
193 | self.alpha = alpha
194 | penalty_builder = penalties.PenaltyBuilder(length_penalty)
195 | # Term will be subtracted from probability
196 | # Probability will be divided by this
197 | self.length_penalty = penalty_builder.length_penalty()
198 |
199 | def score(self, beam, logprobs):
200 | """
201 | Rescores a prediction based on penalty functions
202 | """
203 | normalized_probs = self.length_penalty(beam,
204 | logprobs,
205 | self.alpha)
206 |
207 | return normalized_probs
208 |
209 |
--------------------------------------------------------------------------------
/src/translate/penalties.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 |
4 |
5 | class PenaltyBuilder(object):
6 | """
7 | Returns the Length and Coverage Penalty function for Beam Search.
8 |
9 | Args:
10 | length_pen (str): option name of length pen
11 | cov_pen (str): option name of cov pen
12 | """
13 |
14 | def __init__(self, length_pen):
15 | self.length_pen = length_pen
16 |
17 | def length_penalty(self):
18 | if self.length_pen == "wu":
19 | return self.length_wu
20 | elif self.length_pen == "avg":
21 | return self.length_average
22 | else:
23 | return self.length_none
24 |
25 | """
26 | Below are all the different penalty terms implemented so far
27 | """
28 |
29 |
30 | def length_wu(self, beam, logprobs, alpha=0.):
31 | """
32 | NMT length re-ranking score from
33 | "Google's Neural Machine Translation System" :cite:`wu2016google`.
34 | """
35 |
36 | modifier = (((5 + len(beam.next_ys)) ** alpha) /
37 | ((5 + 1) ** alpha))
38 | return (logprobs / modifier)
39 |
40 | def length_average(self, beam, logprobs, alpha=0.):
41 | """
42 | Returns the average probability of tokens in a sequence.
43 | """
44 | return logprobs / len(beam.next_ys)
45 |
46 | def length_none(self, beam, logprobs, alpha=0., beta=0.):
47 | """
48 | Returns unmodified scores.
49 | """
50 | return logprobs
--------------------------------------------------------------------------------