├── data_util ├── config.py └── __pycache__ │ └── config.cpython-36.pyc ├── .gitignore ├── transformer ├── Constants.py ├── __init__.py ├── Modules.py ├── Optim.py ├── Layers.py ├── SubLayers.py ├── Beam.py ├── Translator.py └── Models.py ├── .idea └── vcs.xml ├── init.py ├── templates ├── summarization_form.html └── base.html ├── .github └── workflows │ └── pythonapp.yml ├── requirements.txt ├── models.py ├── summarize.py ├── README.md ├── preprocess.py ├── dataloader.py ├── server.py ├── train.py └── tokenizer.py /data_util/config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | max_src_size = 512 4 | max_tgt_size = 512 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /__pycache__/ 3 | /.idea/ 4 | /transformer/__pycache__/ 5 | -------------------------------------------------------------------------------- /data_util/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IwasakiYuuki/Bert-abstractive-text-summarization/HEAD/data_util/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '[PAD]' 8 | UNK_WORD = '[UNK]' 9 | BOS_WORD = '[CLS]' 10 | EOS_WORD = '[SEP]' 11 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Victor Huang 2 | # Released under the MIT license 3 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/LICENSE 4 | 5 | 6 | __copyright__ = 'Copyright (c) 2017 Victor Huang' 7 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import transformer.Constants 2 | import transformer.Modules 3 | import transformer.Layers 4 | import transformer.SubLayers 5 | import transformer.Models 6 | import transformer.Translator 7 | import transformer.Beam 8 | import transformer.Optim 9 | 10 | __all__ = [ 11 | transformer.Constants, transformer.Modules, transformer.Layers, 12 | transformer.SubLayers, transformer.Models, transformer.Optim, 13 | transformer.Translator, transformer.Beam] 14 | -------------------------------------------------------------------------------- /templates/summarization_form.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 |
4 |
5 |
6 |
7 | 8 | 10 | {% if summary %} 11 | 14 |
15 | {{ summary }} 16 |
17 | {% endif %} 18 | 19 |
20 |
21 |
22 |
23 | {% endblock %} -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | 18 | attn = torch.bmm(q, k.transpose(1, 2)) 19 | attn = attn / self.temperature 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask, -np.inf) 23 | 24 | attn = self.softmax(attn) 25 | attn = self.dropout(attn) 26 | output = torch.bmm(attn, v) 27 | 28 | return output, attn 29 | -------------------------------------------------------------------------------- /.github/workflows/pythonapp.yml: -------------------------------------------------------------------------------- 1 | name: Python application 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.7 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install -r requirements.txt 20 | - name: Lint with flake8 21 | run: | 22 | pip install flake8 23 | # stop the build if there are Python syntax errors or undefined names 24 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 25 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 26 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 27 | - name: Test with pytest 28 | run: | 29 | pip install pytest 30 | pytest 31 | -------------------------------------------------------------------------------- /templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {{ title }} 5 | 6 | 7 | 8 | 9 | {% block content %} 10 | {% endblock %} 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /transformer/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | class ScheduledOptim(): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps): 8 | self._optimizer = optimizer 9 | self.n_warmup_steps = n_warmup_steps 10 | self.n_current_steps = 0 11 | # self.init_lr = np.power(d_model, -0.5) 12 | self.init_lr = 0.0001 13 | 14 | def step_and_update_lr(self): 15 | "Step with the inner optimizer" 16 | self._update_learning_rate() 17 | self._optimizer.step() 18 | 19 | def zero_grad(self): 20 | "Zero out the gradients by the inner optimizer" 21 | self._optimizer.zero_grad() 22 | 23 | def _get_lr_scale(self): 24 | return np.min([ 25 | np.power((self.n_current_steps / 32), -0.5), 26 | # np.power(self.n_warmup_steps, -1.5) * (self.n_current_steps / 32)]) 27 | np.power(self.n_warmup_steps, -1.5) * (self.n_current_steps / 32)]) / (self.n_warmup_steps ** -0.5) 28 | 29 | def _update_learning_rate(self): 30 | ''' Learning rate scheduling per step ''' 31 | 32 | self.n_current_steps += 1 33 | lr = self.init_lr * self._get_lr_scale() 34 | 35 | for param_group in self._optimizer.param_groups: 36 | param_group['lr'] = lr 37 | 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | absl-py==0.8.0 3 | asn1crypto==0.24.0 4 | astor==0.8.0 5 | backcall==0.1.0 6 | beautifulsoup4==4.6.0 7 | boto3==1.9.241 8 | botocore==1.12.241 9 | certifi==2018.4.16 10 | cffi==1.11.5 11 | chardet==3.0.4 12 | conda==4.5.8 13 | conda-build==3.12.0 14 | cryptography==2.3 15 | Cython==0.29.13 16 | decorator==4.3.0 17 | docutils==0.15.2 18 | filelock==3.0.4 19 | gast==0.2.2 20 | glob2==0.6 21 | google-pasta==0.1.7 22 | grpcio==1.24.1 23 | h5py==2.10.0 24 | idna==2.6 25 | ipython==6.4.0 26 | ipython-genutils==0.2.0 27 | jedi==0.12.1 28 | Jinja2==2.10 29 | jmespath==0.9.4 30 | Keras-Applications==1.0.8 31 | Keras-Preprocessing==1.1.0 32 | Markdown==3.1.1 33 | MarkupSafe==1.0 34 | mecab-python3==0.996.2 35 | mkl-fft==1.0.4 36 | mkl-random==1.0.1 37 | numpy==1.17.2 38 | olefile==0.45.1 39 | opt-einsum==3.1.0 40 | pandas==0.25.1 41 | parso==0.3.1 42 | pexpect==4.6.0 43 | pickleshare==0.7.4 44 | Pillow==6.2.0 45 | pkginfo==1.4.2 46 | prompt-toolkit==1.0.15 47 | protobuf==3.9.2 48 | psutil==5.4.6 49 | ptyprocess==0.6.0 50 | pycosat==0.6.3 51 | pycparser==2.18 52 | Pygments==2.2.0 53 | pyknp==0.4.1 54 | pyOpenSSL==18.0.0 55 | PySocks==1.6.8 56 | python-dateutil==2.8.0 57 | pytorch-pretrained-bert==0.6.2 58 | pytz==2019.2 59 | PyYAML==5.1 60 | regex==2019.8.19 61 | requests==2.20.0 62 | ruamel-yaml==0.15.37 63 | s3transfer==0.2.1 64 | scipy==1.1.0 65 | simplegeneric==0.8.1 66 | six==1.11.0 67 | tensorboard==2.0.0 68 | tensorboardX==1.8 69 | tensorflow==2.0.0 70 | tensorflow-estimator==2.0.0 71 | termcolor==1.1.0 72 | torch==0.4.1 73 | torchvision==0.2.1 74 | tqdm==4.36.1 75 | traitlets==4.3.2 76 | urllib3==1.24.2 77 | wcwidth==0.1.7 78 | Werkzeug==0.16.0 79 | wrapt==1.11.2 80 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn 50 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig 4 | from transformer.Models import Decoder 5 | 6 | 7 | class AbstractiveTextSummarizationUsingBert(nn.Module): 8 | def __init__(self, bert_model_path, n_tgt_vocab, len_max_seq, d_word_vec=768, d_model=768, d_inner=3072, 9 | n_layers=12, n_head=12, d_k=64, d_v=64, dropout=0.1): 10 | 11 | super().__init__() 12 | 13 | self.encoder = BertModel.from_pretrained(bert_model_path) 14 | self.config = BertConfig(bert_model_path+'bert_config.json') 15 | self.decoder = Decoder( 16 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 17 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 18 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 19 | dropout=dropout) 20 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 21 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 22 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 23 | self.x_logit_scale = (d_model ** -0.5) 24 | self.o_l = nn.Linear(d_model, 512, bias=False) 25 | self.h_l = nn.Linear(512, 1, bias=True) 26 | nn.init.xavier_normal_(self.o_l.weight) 27 | nn.init.xavier_normal_(self.h_l.weight) 28 | self.a_l_1 = nn.Linear(d_model, 512, bias=False) 29 | self.a_l_2 = nn.Linear(d_model, 512, bias=False) 30 | nn.init.xavier_normal_(self.a_l_1.weight) 31 | nn.init.xavier_normal_(self.a_l_2.weight) 32 | 33 | def forward(self, src_seq, src_sen, tgt_seq, tgt_pos): 34 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 35 | 36 | enc_output, _ = self.encoder(src_seq, src_sen, output_all_encoded_layers=False) 37 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 38 | 39 | # o = self.o_l(dec_output) 40 | # p_gen = torch.sigmoid(self.h_l(o).view(-1, 1)) 41 | 42 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 43 | # a = self.a_l_1(dec_output) 44 | # a = torch.bmm(a, enc_output) 45 | # a = self.a_l_2(a) 46 | 47 | return seq_logit.view(-1, seq_logit.size(2)) 48 | -------------------------------------------------------------------------------- /summarize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | from dataloader import src_collate_fn, TextSummarizationDataset 7 | from transformer.Translator import Summarizer 8 | 9 | 10 | def main(): 11 | 12 | parser = argparse.ArgumentParser(description='translate.py') 13 | 14 | parser.add_argument('-model', required=True, 15 | help='Path to model .pt file') 16 | parser.add_argument('-src', required=True, 17 | help='Source sequence to decode (one line per sequence)') 18 | parser.add_argument('-vocab', required=True, 19 | help='Source sequence to decode (one line per sequence)') 20 | parser.add_argument('-output', default='pred.txt', 21 | help="""Path to output the predictions (each line will 22 | be the decoded sequence""") 23 | parser.add_argument('-beam_size', type=int, default=5, 24 | help='Beam size') 25 | parser.add_argument('-batch_size', type=int, default=30, 26 | help='Batch size') 27 | parser.add_argument('-n_best', type=int, default=1, 28 | help="""If verbose is set, will output the n_best 29 | decoded sentences""") 30 | parser.add_argument('-no_cuda', action='store_true') 31 | 32 | opt = parser.parse_args() 33 | opt.cuda = not opt.no_cuda 34 | 35 | # Prepare DataLoader 36 | data = torch.load(opt.src) 37 | data['settings'].cuda = opt.cuda 38 | 39 | test_loader = torch.utils.data.DataLoader( 40 | TextSummarizationDataset( 41 | # src_word2idx=preprocess_data['dict']['src'], 42 | src_word2idx=data['dict']['src'], 43 | # tgt_word2idx=preprocess_data['dict']['tgt'], 44 | tgt_word2idx=data['dict']['tgt'], 45 | # src_insts=test_src_insts), 46 | src_insts=data['valid']['src'][0:10]), 47 | num_workers=2, 48 | batch_size=opt.batch_size, 49 | collate_fn=src_collate_fn) 50 | translator = Summarizer(opt) 51 | with open(opt.output, 'w') as f: 52 | for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): 53 | all_hyp, all_scores = translator.translate_batch(*batch) 54 | for idx_seqs in all_hyp: 55 | for idx_seq in idx_seqs: 56 | print(len(idx_seq)) 57 | pred_line = ''.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) 58 | f.write('[@]' + pred_line + '\n\n') 59 | print('[Info] Finished.') 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformer.Modules import ScaledDotProductAttention 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | class MultiHeadAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | 35 | def forward(self, q, k, v, mask=None): 36 | 37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 38 | 39 | sz_b, len_q, _ = q.size() 40 | sz_b, len_k, _ = k.size() 41 | sz_b, len_v, _ = v.size() 42 | 43 | residual = q 44 | 45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 48 | 49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 52 | 53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 54 | output, attn = self.attention(q, k, v, mask=mask) 55 | 56 | output = output.view(n_head, sz_b, len_q, d_v) 57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 58 | 59 | output = self.dropout(self.fc(output)) 60 | output = self.layer_norm(output + residual) 61 | 62 | return output, attn 63 | 64 | class PositionwiseFeedForward(nn.Module): 65 | ''' A two-feed-forward-layer module ''' 66 | 67 | def __init__(self, d_in, d_hid, dropout=0.1): 68 | super().__init__() 69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 71 | self.layer_norm = nn.LayerNorm(d_in) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x): 75 | residual = x 76 | output = x.transpose(1, 2) 77 | output = self.w_2(F.relu(self.w_1(output))) 78 | output = output.transpose(1, 2) 79 | output = self.dropout(output) 80 | output = self.layer_norm(output + residual) 81 | return output 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Abstractive text summarization using BERT 2 | This is the models using BERT (refer the paper [Pretraining-Based Natural Language Generation for Text Summarization 3 | ](https://arxiv.org/abs/1902.09243) ) for one of the NLP(Natural Language Processing) task, abstractive text summarization. 4 | 5 | ## Requirements 6 | - Python 3.6.5+ 7 | - Pytorch 0.4.1+ 8 | - Tensorflow 9 | - Pandas 10 | - tqdm 11 | - Numpy 12 | - MeCab 13 | - Tensorboard X and others... 14 | 15 | All packages used here can be installed by pip as follow: 16 | 17 | ~~~ 18 | pip install -r requirement.txt 19 | ~~~ 20 | 21 | ## Docker 22 | If you train the model with GPU, it is easy to use [Pytorch docker images](https://hub.docker.com/r/pytorch/pytorch) in DockerHub. 23 | 24 | In this study, pytorch/pytorch:0.4.1-cuda9-cudnn7-devel(2.62GB) has been used. 25 | 26 | ## Before using 27 | When you use this, please follow the steps below. 28 | 1. Make a repository named "/data/checkpoint" under root. 29 | And put bert_model, vocabulary file and config file for bert. 30 | These files can be download [here](http://nlp.ist.i.kyoto-u.ac.jp/index.php?BERT%E6%97%A5%E6%9C%AC%E8%AA%9EPretrained%E3%83%A2%E3%83%87%E3%83%AB). 31 | 32 | 2. Put data file for training and validate under /workspace/data/. The format is as follow: 33 | 34 | ```preprocess.py 35 | data = { 36 | 'settings': opt, 37 | 'dict': { 38 | 'src': text2token, 39 | 'tgt': text2token}, 40 | 'train': { 41 | 'src': content[:100000], 42 | 'tgt': summary[:100000]}, 43 | 'valid': { 44 | 'src': content[100000:], 45 | 'tgt': summary[100000:]}} 46 | torch.save(data, opt.save_data) 47 | ``` 48 | 49 | overall directory structure is as follow: 50 | ``` 51 | `-- data # under workspace 52 | |-- checkpoint 53 | | |-- bert_config.json # BERT config file 54 | | |-- pytorch_model.bin # BERT model file 55 | | `-- vocab.txt # vocabulary file 56 | `-- preprocessed_data.data # train and valid data file 57 | ``` 58 | ## Setting 59 | |Name |Value | 60 | |---|---| 61 | |Encoder |BERT | 62 | |Decoder |Transformer (Only Decoder) | 63 | |Embed dimension |768 | 64 | |Hidden dimension |3072 | 65 | |Encoder layers |12 | 66 | |Decoder layers |8 | 67 | |Optimizer |Adam | 68 | |Learning rate |init=0.0001 | 69 | |Wormup step |4000 | 70 | |Input max length |512 | 71 | |Batch size |4 | 72 | 73 | ## Usage 74 | ### Train the model 75 | ``` 76 | python train.py -data data/preprocessed_data.data -bert_path data/checkpoint/ -proj_share_weight -label_smoothing -batch_size 4 -epoch 10 -save_model trained -save_mode best 77 | ``` 78 | ### Generate summarization with trained model 79 | ``` 80 | python summarize.py -model data/checkpoint/trained/trained.chkpt -src data/preprocessed_data.data -vocab data/checkpoint/vocab.txt -output pred.txt 81 | ``` 82 | 83 | ## Resut 84 | ### Tensorboard X image 85 | ![image](https://user-images.githubusercontent.com/24263438/66286505-cd044800-e90c-11e9-8bb8-659173def48d.png) 86 | 87 | 88 | ## TODO 89 | - Eval the model with score such as ROUGE-N 90 | - Make some examples 91 | 92 | ## Acknowledge 93 | - This repository structure and many codes are borrowed from [jadore801120/attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch). -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import io, sys 2 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 3 | import sqlite3 4 | import argparse 5 | from tqdm import tqdm 6 | import pandas.io.sql as psql 7 | import transformer.Constants as Constants 8 | import torch 9 | #import MeCab 10 | from tokenizer import FullTokenizer 11 | #m = MeCab.Tagger('-Owakati') 12 | 13 | def build_vocab(path_to_file): 14 | with open(path_to_file, encoding='utf-8') as f: 15 | vocab = f.readlines() 16 | token2text = {k: v.rstrip() for k, v in enumerate(vocab)} 17 | text2token = {v: k for k, v in token2text.items()} 18 | 19 | return text2token, token2text 20 | 21 | 22 | def get_content_summary_from_df(df): 23 | content = df['content'].values.tolist() 24 | summary = df['summary'].values.tolist() 25 | 26 | return content, summary 27 | 28 | 29 | def convert_text_to_token(content, summary, text2token, max_len, tokenizer): 30 | tokens_content = [] 31 | tokens_summary = [] 32 | for d_content, d_summary in tqdm(zip(content, summary), ascii=True, total=len(content)): 33 | tokens_content.append(convert_text_to_token_seq(d_content, text2token, max_len, tokenizer)) 34 | tokens_summary.append(convert_text_to_token_seq(d_summary, text2token, max_len, tokenizer)) 35 | 36 | return tokens_content, tokens_summary 37 | 38 | 39 | def convert_text_to_token_seq(text, text2token, max_len, tokenizer): 40 | if len(text) > 2000: 41 | text = text[:2000] 42 | splited_text = tokenizer.tokenize(convert_num_half_to_full(text.replace('。\n', '\n').replace('\n', '。\n'))) 43 | if len(splited_text) > (max_len - 2): 44 | splited_text = splited_text[:max_len-2] 45 | splited_text = [Constants.BOS] + \ 46 | [text2token.get(i, Constants.UNK) for i in splited_text] + \ 47 | [Constants.EOS] 48 | return splited_text 49 | 50 | 51 | def convert_num_half_to_full(text): 52 | table = str.maketrans({ 53 | '0': '0', 54 | '1': '1', 55 | '2': '2', 56 | '3': '3', 57 | '4': '4', 58 | '5': '5', 59 | '6': '6', 60 | '7': '7', 61 | '8': '8', 62 | '9': '9', 63 | }) 64 | return text.translate(table) 65 | 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('-vocab', required=True) 70 | parser.add_argument('-data', required=True) 71 | parser.add_argument('-save_data', required=True) 72 | parser.add_argument('--max_word_seq_len', required=True, type=int) 73 | opt = parser.parse_args() 74 | opt.max_token_seq_len = opt.max_word_seq_len 75 | 76 | tokenizer = FullTokenizer(opt.vocab) 77 | connection = sqlite3.connect(opt.data) 78 | cursor = connection.cursor() 79 | df = psql.read_sql("SELECT * FROM Article;", connection) 80 | print('Finished reading db file.') 81 | text2token, token2text = build_vocab(opt.vocab) 82 | print('Finished building vocab.') 83 | content, summary = get_content_summary_from_df(df) 84 | content, summary = convert_text_to_token(content, summary, text2token, opt.max_word_seq_len, tokenizer) 85 | data = { 86 | 'settings': opt, 87 | 'dict': { 88 | 'src': text2token, 89 | 'tgt': text2token}, 90 | 'train': { 91 | 'src': content[:100000], 92 | 'tgt': summary[:100000]}, 93 | 'valid': { 94 | 'src': content[100000:], 95 | 'tgt': summary[100000:]}} 96 | torch.save(data, opt.save_data) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | 5 | from transformer import Constants 6 | from data_util import config 7 | 8 | 9 | def paired_collate_fn(insts): 10 | src_insts, tgt_insts = list(zip(*insts)) 11 | src_insts = src_collate_fn(src_insts) 12 | tgt_insts = tgt_collate_fn(tgt_insts) 13 | return (*src_insts, *tgt_insts) 14 | 15 | 16 | def src_collate_fn(insts): 17 | ''' Pad the instance to the max seq length in batch ''' 18 | 19 | max_len = config.max_src_size 20 | 21 | batch_seq = np.array([ 22 | inst + [Constants.PAD] * (max_len - len(inst)) 23 | for inst in insts]) 24 | 25 | batch_pos = [] 26 | for inst in batch_seq: 27 | cnt = 0 28 | buf = [] 29 | for i in inst: 30 | buf.append(cnt) 31 | if i == 7: 32 | cnt = cnt ^ 1 33 | batch_pos.append(buf) 34 | batch_pos = np.array(batch_pos) 35 | 36 | batch_seq = torch.LongTensor(batch_seq) 37 | batch_pos = torch.LongTensor(batch_pos) 38 | 39 | return batch_seq, batch_pos 40 | 41 | 42 | def tgt_collate_fn(insts): 43 | ''' Pad the instance to the max seq length in batch ''' 44 | 45 | max_len = max(len(inst) for inst in insts) 46 | if max_len > config.max_tgt_size: 47 | max_len = config.max_tgt_size 48 | 49 | batch_seq = np.array([ 50 | inst + [Constants.PAD] * (max_len - len(inst)) 51 | for inst in insts]) 52 | 53 | batch_pos = np.array([ 54 | [pos_i+1 if w_i != Constants.PAD else 0 55 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq]) 56 | 57 | batch_seq = torch.LongTensor(batch_seq) 58 | batch_pos = torch.LongTensor(batch_pos) 59 | 60 | return batch_seq, batch_pos 61 | 62 | 63 | class TextSummarizationDataset(torch.utils.data.Dataset): 64 | def __init__( 65 | self, src_word2idx, tgt_word2idx, 66 | src_insts=None, tgt_insts=None): 67 | 68 | assert src_insts 69 | assert not tgt_insts or (len(src_insts) == len(tgt_insts)) 70 | 71 | src_idx2word = {idx:word for word, idx in src_word2idx.items()} 72 | self._src_word2idx = src_word2idx 73 | self._src_idx2word = src_idx2word 74 | self._src_insts = src_insts 75 | 76 | tgt_idx2word = {idx:word for word, idx in tgt_word2idx.items()} 77 | self._tgt_word2idx = tgt_word2idx 78 | self._tgt_idx2word = tgt_idx2word 79 | self._tgt_insts = tgt_insts 80 | 81 | @property 82 | def n_insts(self): 83 | """ Property for dataset size """ 84 | return len(self._src_insts) 85 | 86 | @property 87 | def src_vocab_size(self): 88 | """ Property for vocab size """ 89 | return len(self._src_word2idx) 90 | 91 | @property 92 | def tgt_vocab_size(self): 93 | """ Property for vocab size """ 94 | return len(self._tgt_word2idx) 95 | 96 | @property 97 | def src_word2idx(self): 98 | """ Property for word dictionary """ 99 | return self._src_word2idx 100 | 101 | @property 102 | def tgt_word2idx(self): 103 | """ Property for word dictionary """ 104 | return self._tgt_word2idx 105 | 106 | @property 107 | def src_idx2word(self): 108 | """ Property for index dictionary """ 109 | return self._src_idx2word 110 | 111 | @property 112 | def tgt_idx2word(self): 113 | """ Property for index dictionary """ 114 | return self._tgt_idx2word 115 | 116 | def __len__(self): 117 | return self.n_insts 118 | 119 | def __getitem__(self, idx): 120 | if self._tgt_insts: 121 | return self._src_insts[idx], self._tgt_insts[idx] 122 | return self._src_insts[idx] 123 | 124 | -------------------------------------------------------------------------------- /transformer/Beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link: 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | import transformer.Constants as Constants 10 | 11 | 12 | class Beam(object): 13 | """ Beam search """ 14 | 15 | def __init__(self, size, block_ngram_repeat=3, exclusion_tokens=set(), device=False): 16 | 17 | self.size = size 18 | self._done = False 19 | self.block_ngram_repeat = block_ngram_repeat 20 | self.exclusion_tokens = exclusion_tokens 21 | 22 | # The score for each translation on the beam. 23 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 24 | self.all_scores = [] 25 | 26 | # The backpointers at each time-step. 27 | self.prev_ks = [] 28 | 29 | # The outputs at each time-step. 30 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] 31 | self.next_ys[0][0] = Constants.BOS 32 | 33 | def get_current_state(self): 34 | """Get the outputs for the current timestep.""" 35 | return self.get_tentative_hypothesis() 36 | 37 | def get_current_origin(self): 38 | """Get the backpointers for the current timestep.""" 39 | return self.prev_ks[-1] 40 | 41 | @property 42 | def done(self): 43 | return self._done 44 | 45 | def advance(self, word_prob): 46 | """Update beam status and check if finished or not.""" 47 | num_words = word_prob.size(1) 48 | 49 | # Sum the previous scores. 50 | if len(self.prev_ks) > 0: 51 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 52 | # Block ngram repeats 53 | if self.block_ngram_repeat > 0: 54 | le = len(self.next_ys) 55 | for j in range(self.next_ys[-1].size(0)): 56 | hyp = self.get_hyp(le - 1, j) 57 | ngrams = set() 58 | fail = False 59 | gram = [] 60 | for i in range(le - 1): 61 | # Last n tokens, n = block_ngram_repeat 62 | gram = (gram + 63 | [hyp[i].item()])[-self.block_ngram_repeat:] 64 | # Skip the blocking if it is in the exclusion list 65 | if set(gram) & self.exclusion_tokens: 66 | continue 67 | if tuple(gram) in ngrams: 68 | fail = True 69 | ngrams.add(tuple(gram)) 70 | if fail: 71 | beam_lk[j] = -10e20 72 | else: 73 | beam_lk = word_prob[0] 74 | 75 | flat_beam_lk = beam_lk.view(-1) 76 | 77 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 78 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 79 | 80 | self.all_scores.append(self.scores) 81 | self.scores = best_scores 82 | 83 | # bestScoresId is flattened as a (beam x word) array, 84 | # so we need to calculate which word and beam each score came from 85 | prev_k = best_scores_id / num_words 86 | self.prev_ks.append(prev_k) 87 | self.next_ys.append(best_scores_id - prev_k * num_words) 88 | 89 | # End condition is when top-of-beam is EOS. 90 | if self.next_ys[-1][0].item() == Constants.EOS: 91 | self._done = True 92 | self.all_scores.append(self.scores) 93 | 94 | return self._done 95 | 96 | def sort_scores(self): 97 | """Sort the scores.""" 98 | return torch.sort(self.scores, 0, True) 99 | 100 | def get_the_best_score_and_idx(self): 101 | """Get the score of the best in the beam.""" 102 | scores, ids = self.sort_scores() 103 | return scores[1], ids[1] 104 | 105 | def get_tentative_hypothesis(self): 106 | """Get the decoded sequence for the current timestep.""" 107 | 108 | if len(self.next_ys) == 1: 109 | dec_seq = self.next_ys[0].unsqueeze(1) 110 | else: 111 | _, keys = self.sort_scores() 112 | hyps = [self.get_hypothesis(k) for k in keys] 113 | hyps = [[Constants.BOS] + h for h in hyps] 114 | dec_seq = torch.LongTensor(hyps) 115 | 116 | return dec_seq 117 | 118 | def get_hyp(self, timestep, k): 119 | """Walk back to construct the full hypothesis.""" 120 | hyp = [] 121 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 122 | hyp.append(self.next_ys[j + 1][k]) 123 | k = self.prev_ks[j][k] 124 | return hyp[::-1] 125 | 126 | def get_hypothesis(self, k): 127 | """ Walk back to construct the full hypothesis. """ 128 | hyp = [] 129 | for j in range(len(self.prev_ks) - 1, -1, -1): 130 | hyp.append(self.next_ys[j+1][k]) 131 | k = self.prev_ks[j][k] 132 | 133 | return list(map(lambda x: x.item(), hyp[::-1])) 134 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | from dataloader import src_collate_fn, TextSummarizationDataset 7 | from transformer.Translator import Summarizer 8 | from flask import Flask, jsonify, request 9 | from tokenizer import FullTokenizer 10 | import transformer.Constants as Constants 11 | 12 | app = Flask(__name__) 13 | app.config['JSON_AS_ASCII'] = False 14 | 15 | parser = argparse.ArgumentParser(description='translate.py') 16 | parser.add_argument('-model', default='/workspace/Bert-abstractive-text-summarization/data/checkpoint/trained' 17 | '/trained_20191004.chkpt', 18 | help='Path to model .pt file') 19 | parser.add_argument('-src', default='/workspace/Bert-abstractive-text-summarization/data/preprocessed_data.data', 20 | help='Source sequence to decode (one line per sequence)') 21 | parser.add_argument('-vocab', default='/workspace/Bert-abstractive-text-summarization/data/checkpoint/vocab.txt', 22 | help='Source sequence to decode (one line per sequence)') 23 | parser.add_argument('-output', default='pred.txt', 24 | help="""Path to output the predictions (each line will be the decoded sequence""") 25 | parser.add_argument('-beam_size', type=int, default=5, 26 | help='Beam size') 27 | parser.add_argument('-batch_size', type=int, default=30, 28 | help='Batch size') 29 | parser.add_argument('-n_best', type=int, default=1, 30 | help="""If verbose is set, will output the n_best decoded sentences""") 31 | parser.add_argument('-no_cuda', action='store_true') 32 | opt = parser.parse_args() 33 | opt.cuda = not opt.no_cuda 34 | 35 | # Prepare DataLoader 36 | data = torch.load(opt.src) 37 | data['settings'].cuda = opt.cuda 38 | 39 | # Create Translator Model 40 | translator = Summarizer(opt) 41 | 42 | # Create Tokenizer 43 | tokenizer = FullTokenizer(opt.vocab) 44 | 45 | 46 | @app.route('/', methods=['POST']) 47 | def summarization(): 48 | json_data = request.get_json() 49 | 50 | data_loader = preprocess(json_data) 51 | summaries = summarize(data_loader) 52 | summaries = remove_symbol(summaries) 53 | 54 | return jsonify({ 55 | 'summaries': summaries, 56 | }) 57 | 58 | 59 | def preprocess(json_data): 60 | """ 61 | Preprocess input data. 62 | 1. Extract text list from JSON data. 63 | 2. Divide text list into the words. 64 | 3. Convert text list to vocabulary id. 65 | 4. Convert id list to DataLoader. 66 | 67 | Args: 68 | json_data (object): received JSON data. 69 | 70 | Returns: 71 | data_loader (object): This data has been processed with tokenize, token2id, and toDataloader. 72 | 73 | """ 74 | # 1. Extract text list from JSON data. 75 | texts = get_text_from_json(json_data) 76 | 77 | # 2. Divide text list into the words. 78 | tokenized_texts = tokenize(texts) 79 | 80 | # 3. Convert text list to vocabulary id. 81 | tokens = text2token(tokenized_texts) 82 | 83 | # 4. Convert id list to DataLoader. 84 | data_loader = toDataLoader(tokens) 85 | 86 | return data_loader 87 | 88 | 89 | def get_text_from_json(json_data): 90 | """ 91 | Extract text list from JSON data. 92 | 93 | Args: 94 | json_data (object): input data. 95 | 96 | Returns: 97 | texts (list): text list 98 | 99 | """ 100 | return json_data['source_texts'] 101 | 102 | 103 | def tokenize(texts): 104 | """ 105 | Tokenize row text list. We use MeCab to tokenize sentences. 106 | 107 | Args: 108 | texts (list): The text list to tokenize. 109 | 110 | Returns: 111 | tokenized_texts (list): The tokenized text list. 112 | 113 | """ 114 | max_len = 512 115 | splited_texts = [] 116 | 117 | for text in texts: 118 | splited_text = tokenizer.tokenize(_convert_num_half_to_full(text.replace('。\n', '\n').replace('\n', '。\n'))) 119 | if len(splited_text) > (max_len - 2): 120 | splited_text = splited_text[:max_len-2] 121 | splited_texts.append(splited_text) 122 | 123 | return splited_texts 124 | 125 | 126 | def _convert_num_half_to_full(text): 127 | table = str.maketrans({ 128 | '0': '0', 129 | '1': '1', 130 | '2': '2', 131 | '3': '3', 132 | '4': '4', 133 | '5': '5', 134 | '6': '6', 135 | '7': '7', 136 | '8': '8', 137 | '9': '9', 138 | }) 139 | return text.translate(table) 140 | 141 | 142 | def text2token(texts): 143 | """ 144 | Convert input text list to vocabulary id (token) list. 145 | 146 | Args: 147 | texts (list): input text list. 148 | 149 | Returns: 150 | tokens (list): vocabulary id list. 151 | 152 | """ 153 | tokens = [] 154 | 155 | for text in texts: 156 | token = [Constants.BOS] + \ 157 | [data['dict']['src'].get(i, Constants.UNK) for i in text] + \ 158 | [Constants.EOS] 159 | tokens.append(token) 160 | 161 | return tokens 162 | 163 | 164 | def toDataLoader(tokens): 165 | """ 166 | Create DataLoader object from input vocabulary id list. 167 | 168 | Args: 169 | tokens (list): vocabulary id list. 170 | 171 | Returns: 172 | data_loader (object): DataLoader created from ids. 173 | 174 | """ 175 | return torch.utils.data.DataLoader( 176 | TextSummarizationDataset( 177 | src_word2idx=data['dict']['src'], 178 | tgt_word2idx=data['dict']['tgt'], 179 | src_insts=tokens), 180 | num_workers=2, 181 | batch_size=opt.batch_size, 182 | collate_fn=src_collate_fn) 183 | 184 | 185 | def summarize(data_loader): 186 | """ 187 | Summarize text in DataLoader with trained Deep Learning Model. 188 | 189 | Args: 190 | data_loader (DataLoader): inputted DataLoader. 191 | 192 | Returns: 193 | summarized_texts (list): summarized test list. 194 | 195 | """ 196 | pred_lines = [] 197 | 198 | # Prediction 199 | with open(opt.output, 'w') as f: 200 | for batch in tqdm(data_loader, mininterval=2, desc=' - (Test)', leave=False): 201 | all_hyp, all_scores = translator.translate_batch(*batch) 202 | for idx_seqs in all_hyp: 203 | for idx_seq in idx_seqs: 204 | pred_line = ''.join([data_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) 205 | pred_lines.append(pred_line) 206 | 207 | return pred_lines 208 | 209 | 210 | def remove_symbol(texts): 211 | """ 212 | Remove symbol "##", "[SEP]". 213 | 214 | Args: 215 | texts(list) : input text list 216 | 217 | Returns: 218 | removed_text(list): text list that removed symbol "##", "[SEP]" 219 | 220 | """ 221 | removed_texts = [] 222 | for text in texts: 223 | removed_texts.append(text.replace('##', '').replace('[SEP]', '')) 224 | 225 | return removed_texts 226 | 227 | 228 | if __name__ == '__main__': 229 | app.run(host='0.0.0.0', port=6006) 230 | -------------------------------------------------------------------------------- /transformer/Translator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models import AbstractiveTextSummarizationUsingBert 8 | from transformer.Beam import Beam 9 | 10 | 11 | class Summarizer(object): 12 | def __init__(self, opt): 13 | self.opt = opt 14 | self.device = torch.device('cuda' if opt.cuda else 'cpu') 15 | 16 | checkpoint = torch.load(opt.model) 17 | model_opt = checkpoint['settings'] 18 | self.model_opt = model_opt 19 | 20 | model = AbstractiveTextSummarizationUsingBert( 21 | # model_opt.bert_path, 22 | 'data/checkpoint/', 23 | model_opt.tgt_vocab_size, 24 | model_opt.max_token_seq_len, 25 | d_k=model_opt.d_k, 26 | d_v=model_opt.d_v, 27 | d_model=model_opt.d_model, 28 | d_word_vec=model_opt.d_word_vec, 29 | d_inner=model_opt.d_inner_hid, 30 | n_layers=model_opt.n_layers, 31 | n_head=model_opt.n_head, 32 | dropout=model_opt.dropout) 33 | 34 | model.load_state_dict(checkpoint['model']) 35 | print('[Info] Trained model state loaded.') 36 | 37 | model.word_prob_prj = nn.LogSoftmax(dim=1) 38 | 39 | model = model.to(self.device) 40 | 41 | self.model = model 42 | self.model.eval() 43 | 44 | def translate_batch(self, src_seq, src_pos): 45 | 46 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 47 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 48 | 49 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 50 | _, *d_hs = beamed_tensor.size() 51 | n_curr_active_inst = len(curr_active_inst_idx) 52 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 53 | 54 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 55 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 56 | beamed_tensor = beamed_tensor.view(*new_shape) 57 | 58 | return beamed_tensor 59 | 60 | def collate_active_info( 61 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 62 | # Sentences which are still active are collected, 63 | # so the decoder will not run on completed sentences. 64 | n_prev_active_inst = len(inst_idx_to_position_map) 65 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 66 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 67 | 68 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 69 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 70 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 71 | 72 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 73 | 74 | def beam_decode_step( 75 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 76 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 77 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 78 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 79 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 80 | return dec_partial_seq 81 | 82 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 83 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 84 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 85 | return dec_partial_pos 86 | 87 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): 88 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 89 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 90 | dec_output = self.model.tgt_word_prj(dec_output) 91 | dec_output[:, 0] = -float('inf') 92 | dec_output[:, 1] = -float('inf') 93 | word_prob = F.log_softmax(dec_output, dim=1) 94 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 95 | 96 | return word_prob 97 | 98 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 99 | active_inst_idx_list = [] 100 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 101 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 102 | if not is_inst_complete: 103 | active_inst_idx_list += [inst_idx] 104 | 105 | return active_inst_idx_list 106 | 107 | n_active_inst = len(inst_idx_to_position_map) 108 | 109 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 110 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 111 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) 112 | 113 | # Update the beam with predicted word prob information and collect incomplete instances 114 | active_inst_idx_list = collect_active_inst_idx_list( 115 | inst_dec_beams, word_prob, inst_idx_to_position_map) 116 | 117 | return active_inst_idx_list 118 | 119 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 120 | all_hyp, all_scores = [], [] 121 | for inst_idx in range(len(inst_dec_beams)): 122 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 123 | all_scores += [scores[:n_best]] 124 | 125 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 126 | all_hyp += [hyps] 127 | return all_hyp, all_scores 128 | 129 | with torch.no_grad(): 130 | #-- Encode 131 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 132 | src_enc, _ = self.model.encoder(src_seq, output_all_encoded_layers=False) 133 | 134 | #-- Repeat data for beam search 135 | n_bm = self.opt.beam_size 136 | n_inst, len_s, d_h = src_enc.size() 137 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 138 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 139 | 140 | #-- Prepare beams 141 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 142 | 143 | #-- Bookkeeping for active or not 144 | active_inst_idx_list = list(range(n_inst)) 145 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 146 | 147 | #-- Decode 148 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): 149 | 150 | active_inst_idx_list = beam_decode_step( 151 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 152 | 153 | if not active_inst_idx_list: 154 | break # all instances have finished their path to 155 | 156 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 157 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 158 | 159 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) 160 | 161 | return batch_hyp, batch_scores 162 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import transformer.Constants as Constants 6 | from transformer.Layers import EncoderLayer, DecoderLayer 7 | 8 | __author__ = "Yu-Hsiang Huang" 9 | 10 | def get_non_pad_mask(seq): 11 | assert seq.dim() == 2 12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | def get_attn_key_pad_mask(seq_k, seq_q): 35 | ''' For masking out the padding part of key sequence. ''' 36 | 37 | # Expand to fit the shape of key query attention matrix. 38 | len_q = seq_q.size(1) 39 | padding_mask = seq_k.eq(Constants.PAD) 40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 41 | 42 | return padding_mask 43 | 44 | def get_subsequent_mask(seq): 45 | ''' For masking out the subsequent info. ''' 46 | 47 | sz_b, len_s = seq.size() 48 | subsequent_mask = torch.triu( 49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 51 | 52 | return subsequent_mask 53 | 54 | class Encoder(nn.Module): 55 | ''' A encoder model with self attention mechanism. ''' 56 | 57 | def __init__( 58 | self, 59 | n_src_vocab, len_max_seq, d_word_vec, 60 | n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, dropout=0.1): 62 | 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.src_word_emb = nn.Embedding( 68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 69 | 70 | self.position_enc = nn.Embedding.from_pretrained( 71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 72 | freeze=True) 73 | 74 | self.layer_stack = nn.ModuleList([ 75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 76 | for _ in range(n_layers)]) 77 | 78 | def forward(self, src_seq, src_pos, return_attns=False): 79 | 80 | enc_slf_attn_list = [] 81 | 82 | # -- Prepare masks 83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 84 | non_pad_mask = get_non_pad_mask(src_seq) 85 | 86 | # -- Forward 87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 88 | 89 | for enc_layer in self.layer_stack: 90 | enc_output, enc_slf_attn = enc_layer( 91 | enc_output, 92 | non_pad_mask=non_pad_mask, 93 | slf_attn_mask=slf_attn_mask) 94 | if return_attns: 95 | enc_slf_attn_list += [enc_slf_attn] 96 | 97 | if return_attns: 98 | return enc_output, enc_slf_attn_list 99 | return enc_output, 100 | 101 | class Decoder(nn.Module): 102 | ''' A decoder model with self attention mechanism. ''' 103 | 104 | def __init__( 105 | self, 106 | n_tgt_vocab, len_max_seq, d_word_vec, 107 | n_layers, n_head, d_k, d_v, 108 | d_model, d_inner, dropout=0.1): 109 | 110 | super().__init__() 111 | n_position = len_max_seq + 1 112 | 113 | self.tgt_word_emb = nn.Embedding( 114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) 115 | 116 | self.position_enc = nn.Embedding.from_pretrained( 117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 118 | freeze=True) 119 | 120 | self.layer_stack = nn.ModuleList([ 121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 122 | for _ in range(n_layers)]) 123 | 124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 125 | 126 | dec_slf_attn_list, dec_enc_attn_list = [], [] 127 | 128 | # -- Prepare masks 129 | non_pad_mask = get_non_pad_mask(tgt_seq) 130 | 131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 133 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 134 | 135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 136 | 137 | # -- Forward 138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 139 | 140 | for dec_layer in self.layer_stack: 141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 142 | dec_output, enc_output, 143 | non_pad_mask=non_pad_mask, 144 | slf_attn_mask=slf_attn_mask, 145 | dec_enc_attn_mask=dec_enc_attn_mask) 146 | 147 | if return_attns: 148 | dec_slf_attn_list += [dec_slf_attn] 149 | dec_enc_attn_list += [dec_enc_attn] 150 | 151 | if return_attns: 152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 153 | return dec_output, 154 | 155 | class Transformer(nn.Module): 156 | ''' A sequence to sequence model with attention mechanism. ''' 157 | 158 | def __init__( 159 | self, 160 | n_src_vocab, n_tgt_vocab, len_max_seq, 161 | d_word_vec=512, d_model=512, d_inner=2048, 162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 163 | tgt_emb_prj_weight_sharing=True, 164 | emb_src_tgt_weight_sharing=True): 165 | 166 | super().__init__() 167 | 168 | self.encoder = Encoder( 169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 172 | dropout=dropout) 173 | 174 | self.decoder = Decoder( 175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 178 | dropout=dropout) 179 | 180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 181 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 182 | 183 | assert d_model == d_word_vec, \ 184 | 'To facilitate the residual connections, \ 185 | the dimensions of all module outputs shall be the same.' 186 | 187 | if tgt_emb_prj_weight_sharing: 188 | # Share the weight matrix between target word embedding & the final logit dense layer 189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 190 | self.x_logit_scale = (d_model ** -0.5) 191 | else: 192 | self.x_logit_scale = 1. 193 | 194 | if emb_src_tgt_weight_sharing: 195 | # Share the weight matrix between source & target word embeddings 196 | assert n_src_vocab == n_tgt_vocab, \ 197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 199 | 200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 201 | 202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 203 | 204 | enc_output, *_ = self.encoder(src_seq, src_pos) 205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 207 | 208 | return seq_logit.view(-1, seq_logit.size(2)) 209 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script handling the training process. 3 | """ 4 | 5 | import argparse 6 | import math 7 | import time 8 | 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.utils.data 14 | import transformer.Constants as Constants 15 | from dataloader import TextSummarizationDataset, paired_collate_fn 16 | from transformer.Optim import ScheduledOptim 17 | import tensorboardX as tbx 18 | from models import AbstractiveTextSummarizationUsingBert 19 | 20 | 21 | def cal_performance(pred, x, gold, smoothing=False): 22 | ''' Apply label smoothing if needed ''' 23 | 24 | loss = cal_loss(pred, x, gold, smoothing) 25 | 26 | pred = pred.max(1)[1] 27 | gold = gold.contiguous().view(-1) 28 | non_pad_mask = gold.ne(Constants.PAD) 29 | n_correct = pred.eq(gold) 30 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 31 | 32 | return loss, n_correct 33 | 34 | 35 | def cal_loss(pred, x, gold, smoothing): 36 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 37 | 38 | gold = gold.contiguous().view(-1) 39 | 40 | if smoothing: 41 | eps = 0.1 42 | n_class = pred.size(1) 43 | 44 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 45 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 46 | # x = x.repeat(1, int(a.size(0)/x.size(0))).view(-1, a.size(1)) 47 | # prb = (1-p_gen)*F.softmax(pred, dim=1) 48 | # a = p_gen*F.softmax(a, dim=1) 49 | prb = F.softmax(pred, dim=1) 50 | # prb = prb.scatter_add(1, x, a) 51 | log_prb = torch.log(prb) 52 | non_pad_mask = gold.ne(Constants.PAD) 53 | loss = -(one_hot * log_prb).sum(dim=1) 54 | loss = loss.masked_select(non_pad_mask).sum() # average later 55 | else: 56 | loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum') 57 | 58 | return loss 59 | 60 | 61 | def train_epoch(model, training_data, optimizer, device, smoothing, step): 62 | ''' Epoch operation in training phase''' 63 | 64 | model.train() 65 | 66 | total_loss = 0 67 | n_word_total = 0 68 | n_word_correct = 0 69 | 70 | for i, batch in enumerate(tqdm( 71 | training_data, mininterval=2, 72 | desc=' - (Training) ', leave=False, ascii=True)): 73 | 74 | # prepare data 75 | src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) 76 | gold = tgt_seq[:, 1:] 77 | 78 | # forward 79 | optimizer.zero_grad() 80 | # pred, a, p_gen = model(src_seq, src_pos, tgt_seq, tgt_pos) 81 | pred = model(src_seq, src_pos, tgt_seq, tgt_pos) 82 | 83 | # backward 84 | loss, n_correct = cal_performance(pred, src_seq, gold, smoothing=smoothing) 85 | loss.backward() 86 | 87 | # update parameters 88 | optimizer.step_and_update_lr() 89 | 90 | # note keeping 91 | total_loss += loss.item() 92 | 93 | non_pad_mask = gold.ne(Constants.PAD) 94 | n_word = non_pad_mask.sum().item() 95 | n_word_total += n_word 96 | n_word_correct += n_correct 97 | writer.add_scalars('data/train_loss', {'train_loss_each_batch': loss.item() / n_word}, i+step) 98 | writer.add_scalars('data/train_accu', {'train_accu_each_batch': n_correct / n_word}, i+step) 99 | 100 | loss_per_word = total_loss/n_word_total 101 | accuracy = n_word_correct/n_word_total 102 | return loss_per_word, accuracy, step + i 103 | 104 | 105 | def eval_epoch(model, validation_data, device, step): 106 | ''' Epoch operation in evaluation phase ''' 107 | 108 | model.eval() 109 | 110 | total_loss = 0 111 | n_word_total = 0 112 | n_word_correct = 0 113 | 114 | with torch.no_grad(): 115 | for i, batch in enumerate(tqdm( 116 | validation_data, mininterval=2, 117 | desc=' - (Validation) ', leave=False, ascii=True)): 118 | 119 | # prepare data 120 | src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) 121 | gold = tgt_seq[:, 1:] 122 | 123 | # forward 124 | pred = model(src_seq, src_pos, tgt_seq, tgt_pos) 125 | loss, n_correct = cal_performance(pred, src_seq, gold, smoothing=False) 126 | 127 | # note keeping 128 | total_loss += loss.item() 129 | 130 | non_pad_mask = gold.ne(Constants.PAD) 131 | n_word = non_pad_mask.sum().item() 132 | n_word_total += n_word 133 | n_word_correct += n_correct 134 | writer.add_scalars('data/valid_loss', {'valid_loss_each_batch': loss.item() / n_word}, i+step) 135 | writer.add_scalars('data/valid_accu', {'valid_accu_each_batch': n_correct / n_word}, i+step) 136 | 137 | loss_per_word = total_loss/n_word_total 138 | accuracy = n_word_correct/n_word_total 139 | return loss_per_word, accuracy, step + i 140 | 141 | 142 | def train(model, training_data, validation_data, optimizer, device, opt): 143 | ''' Start training ''' 144 | 145 | log_train_file = None 146 | log_valid_file = None 147 | 148 | if opt.log: 149 | log_train_file = opt.log + '.train.log' 150 | log_valid_file = opt.log + '.valid.log' 151 | 152 | print('[Info] Training performance will be written to file: {} and {}'.format( 153 | log_train_file, log_valid_file)) 154 | 155 | with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf: 156 | log_tf.write('epoch,loss,ppl,accuracy\n') 157 | log_vf.write('epoch,loss,ppl,accuracy\n') 158 | 159 | valid_accus = [] 160 | train_step = 0 161 | valid_step = 0 162 | for epoch_i in range(opt.epoch): 163 | print('[ Epoch', epoch_i, ']') 164 | 165 | start = time.time() 166 | train_loss, train_accu, train_step = train_epoch( 167 | model, training_data, optimizer, device, opt.label_smoothing, train_step) 168 | print(' - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, ' \ 169 | 'elapse: {elapse:3.3f} min'.format( 170 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu, 171 | elapse=(time.time()-start)/60)) 172 | 173 | start = time.time() 174 | valid_loss, valid_accu, valid_step = eval_epoch(model, validation_data, device, valid_step) 175 | print(' - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, ' \ 176 | 'elapse: {elapse:3.3f} min'.format( 177 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu, 178 | elapse=(time.time()-start)/60)) 179 | 180 | valid_accus += [valid_accu] 181 | 182 | model_state_dict = model.state_dict() 183 | checkpoint = { 184 | 'model': model_state_dict, 185 | 'settings': opt, 186 | 'epoch': epoch_i} 187 | 188 | if opt.save_model: 189 | if opt.save_mode == 'all': 190 | model_name = 'data/checkpoint/trained/' + opt.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu) 191 | torch.save(checkpoint, model_name) 192 | elif opt.save_mode == 'best': 193 | model_name = 'data/checkpoint/trained/' + opt.save_model + '.chkpt' 194 | if valid_accu >= max(valid_accus): 195 | torch.save(checkpoint, model_name) 196 | print(' - [Info] The checkpoint file has been updated.') 197 | 198 | if log_train_file and log_valid_file: 199 | with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf: 200 | log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( 201 | epoch=epoch_i, loss=train_loss, 202 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu)) 203 | log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( 204 | epoch=epoch_i, loss=valid_loss, 205 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu)) 206 | 207 | def main(): 208 | ''' Main function ''' 209 | parser = argparse.ArgumentParser() 210 | 211 | parser.add_argument('-data', required=True) 212 | parser.add_argument('-bert_path', required=True) 213 | 214 | parser.add_argument('-epoch', type=int, default=10) 215 | parser.add_argument('-batch_size', type=int, default=64) 216 | 217 | parser.add_argument('-d_model', type=int, default=768) 218 | parser.add_argument('-d_inner_hid', type=int, default=3072) 219 | parser.add_argument('-d_k', type=int, default=64) 220 | parser.add_argument('-d_v', type=int, default=64) 221 | 222 | parser.add_argument('-n_head', type=int, default=12) 223 | parser.add_argument('-n_layers', type=int, default=8) 224 | parser.add_argument('-n_warmup_steps', type=int, default=4000) 225 | 226 | parser.add_argument('-dropout', type=float, default=0.1) 227 | parser.add_argument('-embs_share_weight', action='store_true') 228 | parser.add_argument('-proj_share_weight', action='store_true') 229 | 230 | parser.add_argument('-log', default=None) 231 | parser.add_argument('-save_model', default=None) 232 | parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') 233 | 234 | parser.add_argument('-no_cuda', action='store_true') 235 | parser.add_argument('-label_smoothing', action='store_true') 236 | 237 | opt = parser.parse_args() 238 | opt.cuda = not opt.no_cuda 239 | opt.d_word_vec = opt.d_model 240 | 241 | #========= Loading Dataset =========# 242 | data = torch.load(opt.data) 243 | opt.max_token_seq_len = data['settings'].max_token_seq_len 244 | 245 | training_data, validation_data = prepare_dataloaders(data, opt) 246 | 247 | opt.src_vocab_size = training_data.dataset.src_vocab_size 248 | opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size 249 | 250 | #========= Preparing Model =========# 251 | if opt.embs_share_weight: 252 | assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \ 253 | 'The src/tgt word2idx table are different but asked to share word embedding.' 254 | 255 | print(opt) 256 | 257 | device = torch.device('cuda' if opt.cuda else 'cpu') 258 | model = AbstractiveTextSummarizationUsingBert( 259 | opt.bert_path, 260 | opt.tgt_vocab_size, 261 | opt.max_token_seq_len, 262 | d_k=opt.d_k, 263 | d_v=opt.d_v, 264 | d_model=opt.d_model, 265 | d_word_vec=opt.d_word_vec, 266 | d_inner=opt.d_inner_hid, 267 | n_layers=opt.n_layers, 268 | n_head=opt.n_head, 269 | dropout=opt.dropout).to(device) 270 | 271 | optimizer = ScheduledOptim( 272 | optim.Adam( 273 | filter(lambda x: x.requires_grad, model.parameters()), 274 | betas=(0.9, 0.999), eps=1e-09), 275 | opt.d_model, opt.n_warmup_steps) 276 | 277 | train(model, training_data, validation_data, optimizer, device ,opt) 278 | 279 | 280 | def prepare_dataloaders(data, opt): 281 | # ========= Preparing DataLoader =========# 282 | train_loader = torch.utils.data.DataLoader( 283 | TextSummarizationDataset( 284 | src_word2idx=data['dict']['src'], 285 | tgt_word2idx=data['dict']['tgt'], 286 | src_insts=data['train']['src'], 287 | tgt_insts=data['train']['tgt']), 288 | num_workers=2, 289 | batch_size=opt.batch_size, 290 | collate_fn=paired_collate_fn, 291 | shuffle=True) 292 | 293 | valid_loader = torch.utils.data.DataLoader( 294 | TextSummarizationDataset( 295 | src_word2idx=data['dict']['src'], 296 | tgt_word2idx=data['dict']['tgt'], 297 | src_insts=data['valid']['src'], 298 | tgt_insts=data['valid']['tgt']), 299 | num_workers=2, 300 | batch_size=opt.batch_size, 301 | collate_fn=paired_collate_fn) 302 | return train_loader, valid_loader 303 | 304 | 305 | if __name__ == '__main__': 306 | writer = tbx.SummaryWriter(log_dir='data/tensorboardx/runs') 307 | main() 308 | writer.export_scalars_to_json("data/all_scalars.json") 309 | writer.close() 310 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | #self.tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | #self.tokenizer = JumanTokenizer() 169 | self.tokenizer = MeCabTokenizer() 170 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 171 | 172 | def tokenize(self, text): 173 | split_tokens = [] 174 | for token in self.tokenizer.tokenize(text): 175 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 176 | split_tokens.append(sub_token) 177 | 178 | return split_tokens 179 | 180 | def convert_tokens_to_ids(self, tokens): 181 | return convert_by_vocab(self.vocab, tokens) 182 | 183 | def convert_ids_to_tokens(self, ids): 184 | return convert_by_vocab(self.inv_vocab, ids) 185 | 186 | 187 | class BasicTokenizer(object): 188 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 189 | 190 | def __init__(self, do_lower_case=True): 191 | """Constructs a BasicTokenizer. 192 | 193 | Args: 194 | do_lower_case: Whether to lower case the input. 195 | """ 196 | self.do_lower_case = do_lower_case 197 | 198 | def tokenize(self, text): 199 | """Tokenizes a piece of text.""" 200 | text = convert_to_unicode(text) 201 | text = self._clean_text(text) 202 | 203 | # This was added on November 1st, 2018 for the multilingual and Chinese 204 | # models. This is also applied to the English models now, but it doesn't 205 | # matter since the English models were not trained on any Chinese data 206 | # and generally don't have any Chinese data in them (there are Chinese 207 | # characters in the vocabulary because Wikipedia does have some Chinese 208 | # words in the English Wikipedia.). 209 | text = self._tokenize_chinese_chars(text) 210 | 211 | orig_tokens = whitespace_tokenize(text) 212 | split_tokens = [] 213 | for token in orig_tokens: 214 | if self.do_lower_case: 215 | token = token.lower() 216 | token = self._run_strip_accents(token) 217 | split_tokens.extend(self._run_split_on_punc(token)) 218 | 219 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 220 | return output_tokens 221 | 222 | def _run_strip_accents(self, text): 223 | """Strips accents from a piece of text.""" 224 | text = unicodedata.normalize("NFD", text) 225 | output = [] 226 | for char in text: 227 | cat = unicodedata.category(char) 228 | if cat == "Mn": 229 | continue 230 | output.append(char) 231 | return "".join(output) 232 | 233 | def _run_split_on_punc(self, text): 234 | """Splits punctuation on a piece of text.""" 235 | chars = list(text) 236 | i = 0 237 | start_new_word = True 238 | output = [] 239 | while i < len(chars): 240 | char = chars[i] 241 | if _is_punctuation(char): 242 | output.append([char]) 243 | start_new_word = True 244 | else: 245 | if start_new_word: 246 | output.append([]) 247 | start_new_word = False 248 | output[-1].append(char) 249 | i += 1 250 | 251 | return ["".join(x) for x in output] 252 | 253 | def _tokenize_chinese_chars(self, text): 254 | """Adds whitespace around any CJK character.""" 255 | output = [] 256 | for char in text: 257 | cp = ord(char) 258 | if self._is_chinese_char(cp): 259 | output.append(" ") 260 | output.append(char) 261 | output.append(" ") 262 | else: 263 | output.append(char) 264 | return "".join(output) 265 | 266 | def _is_chinese_char(self, cp): 267 | """Checks whether CP is the codepoint of a CJK character.""" 268 | # This defines a "chinese character" as anything in the CJK Unicode block: 269 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 270 | # 271 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 272 | # despite its name. The modern Korean Hangul alphabet is a different block, 273 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 274 | # space-separated words, so they are not treated specially and handled 275 | # like the all of the other languages. 276 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 277 | (cp >= 0x3400 and cp <= 0x4DBF) or # 278 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 279 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 280 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 281 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 282 | (cp >= 0xF900 and cp <= 0xFAFF) or # 283 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 284 | return True 285 | 286 | return False 287 | 288 | def _clean_text(self, text): 289 | """Performs invalid character removal and whitespace cleanup on text.""" 290 | output = [] 291 | for char in text: 292 | cp = ord(char) 293 | if cp == 0 or cp == 0xfffd or _is_control(char): 294 | continue 295 | if _is_whitespace(char): 296 | output.append(" ") 297 | else: 298 | output.append(char) 299 | return "".join(output) 300 | 301 | 302 | class JumanTokenizer(BasicTokenizer): 303 | def __init__(self): 304 | from pyknp import Juman 305 | 306 | self.do_lower_case = False 307 | self._jumanpp = Juman() 308 | 309 | def tokenize(self, text): 310 | """Tokenizes a piece of text with Juman.""" 311 | 312 | text = convert_to_unicode(text) 313 | text = self._clean_text(text) 314 | 315 | 316 | juman_result = self._jumanpp.analysis(text.replace(' ', '')) 317 | split_tokens = [] 318 | for mrph in juman_result.mrph_list(): 319 | split_tokens.extend(self._run_split_on_punc(mrph.midasi)) 320 | 321 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 322 | return output_tokens 323 | 324 | 325 | class MeCabTokenizer(BasicTokenizer): 326 | def __init__(self): 327 | import MeCab 328 | 329 | self.do_lower_case = False 330 | self._mecab = MeCab.Tagger('-Owakati') 331 | 332 | def tokenize(self, text): 333 | """Tokenizes a piece of text with Juman.""" 334 | 335 | text = convert_to_unicode(text) 336 | text = self._clean_text(text) 337 | 338 | mecab_result = self._mecab.parse(text) 339 | output_tokens = mecab_result.split(' ') 340 | return output_tokens 341 | 342 | 343 | 344 | class WordpieceTokenizer(object): 345 | """Runs WordPiece tokenziation.""" 346 | 347 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 348 | self.vocab = vocab 349 | self.unk_token = unk_token 350 | self.max_input_chars_per_word = max_input_chars_per_word 351 | 352 | def tokenize(self, text): 353 | """Tokenizes a piece of text into its word pieces. 354 | 355 | This uses a greedy longest-match-first algorithm to perform tokenization 356 | using the given vocabulary. 357 | 358 | For example: 359 | input = "unaffable" 360 | output = ["un", "##aff", "##able"] 361 | 362 | Args: 363 | text: A single token or whitespace separated tokens. This should have 364 | already been passed through `BasicTokenizer. 365 | 366 | Returns: 367 | A list of wordpiece tokens. 368 | """ 369 | 370 | text = convert_to_unicode(text) 371 | 372 | output_tokens = [] 373 | for token in whitespace_tokenize(text): 374 | chars = list(token) 375 | if len(chars) > self.max_input_chars_per_word: 376 | output_tokens.append(self.unk_token) 377 | continue 378 | 379 | is_bad = False 380 | start = 0 381 | sub_tokens = [] 382 | while start < len(chars): 383 | end = len(chars) 384 | cur_substr = None 385 | while start < end: 386 | substr = "".join(chars[start:end]) 387 | if start > 0: 388 | substr = "##" + substr 389 | if substr in self.vocab: 390 | cur_substr = substr 391 | break 392 | end -= 1 393 | if cur_substr is None: 394 | is_bad = True 395 | break 396 | sub_tokens.append(cur_substr) 397 | start = end 398 | 399 | if is_bad: 400 | output_tokens.append(self.unk_token) 401 | else: 402 | output_tokens.extend(sub_tokens) 403 | return output_tokens 404 | 405 | 406 | def _is_whitespace(char): 407 | """Checks whether `chars` is a whitespace character.""" 408 | # \t, \n, and \r are technically contorl characters but we treat them 409 | # as whitespace since they are generally considered as such. 410 | if char == " " or char == "\t" or char == "\n" or char == "\r": 411 | return True 412 | cat = unicodedata.category(char) 413 | if cat == "Zs": 414 | return True 415 | return False 416 | 417 | 418 | def _is_control(char): 419 | """Checks whether `chars` is a control character.""" 420 | # These are technically control characters but we count them as whitespace 421 | # characters. 422 | if char == "\t" or char == "\n" or char == "\r": 423 | return False 424 | cat = unicodedata.category(char) 425 | if cat in ("Cc", "Cf"): 426 | return True 427 | return False 428 | 429 | 430 | def _is_punctuation(char): 431 | """Checks whether `chars` is a punctuation character.""" 432 | cp = ord(char) 433 | # We treat all non-letter/number ASCII as punctuation. 434 | # Characters such as "^", "$", and "`" are not in the Unicode 435 | # Punctuation class but we treat them as punctuation anyways, for 436 | # consistency. 437 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 438 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 439 | return True 440 | cat = unicodedata.category(char) 441 | if cat.startswith("P"): 442 | return True 443 | return False 444 | --------------------------------------------------------------------------------