├── data_util
├── config.py
└── __pycache__
│ └── config.cpython-36.pyc
├── .gitignore
├── transformer
├── Constants.py
├── __init__.py
├── Modules.py
├── Optim.py
├── Layers.py
├── SubLayers.py
├── Beam.py
├── Translator.py
└── Models.py
├── .idea
└── vcs.xml
├── init.py
├── templates
├── summarization_form.html
└── base.html
├── .github
└── workflows
│ └── pythonapp.yml
├── requirements.txt
├── models.py
├── summarize.py
├── README.md
├── preprocess.py
├── dataloader.py
├── server.py
├── train.py
└── tokenizer.py
/data_util/config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | max_src_size = 512
4 | max_tgt_size = 512
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
2 | /__pycache__/
3 | /.idea/
4 | /transformer/__pycache__/
5 |
--------------------------------------------------------------------------------
/data_util/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IwasakiYuuki/Bert-abstractive-text-summarization/HEAD/data_util/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/transformer/Constants.py:
--------------------------------------------------------------------------------
1 |
2 | PAD = 0
3 | UNK = 1
4 | BOS = 2
5 | EOS = 3
6 |
7 | PAD_WORD = '[PAD]'
8 | UNK_WORD = '[UNK]'
9 | BOS_WORD = '[CLS]'
10 | EOS_WORD = '[SEP]'
11 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Victor Huang
2 | # Released under the MIT license
3 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/LICENSE
4 |
5 |
6 | __copyright__ = 'Copyright (c) 2017 Victor Huang'
7 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import transformer.Constants
2 | import transformer.Modules
3 | import transformer.Layers
4 | import transformer.SubLayers
5 | import transformer.Models
6 | import transformer.Translator
7 | import transformer.Beam
8 | import transformer.Optim
9 |
10 | __all__ = [
11 | transformer.Constants, transformer.Modules, transformer.Layers,
12 | transformer.SubLayers, transformer.Models, transformer.Optim,
13 | transformer.Translator, transformer.Beam]
14 |
--------------------------------------------------------------------------------
/templates/summarization_form.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block content %}
3 |
23 | {% endblock %}
--------------------------------------------------------------------------------
/transformer/Modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | __author__ = "Yu-Hsiang Huang"
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 | ''' Scaled Dot-Product Attention '''
9 |
10 | def __init__(self, temperature, attn_dropout=0.1):
11 | super().__init__()
12 | self.temperature = temperature
13 | self.dropout = nn.Dropout(attn_dropout)
14 | self.softmax = nn.Softmax(dim=2)
15 |
16 | def forward(self, q, k, v, mask=None):
17 |
18 | attn = torch.bmm(q, k.transpose(1, 2))
19 | attn = attn / self.temperature
20 |
21 | if mask is not None:
22 | attn = attn.masked_fill(mask, -np.inf)
23 |
24 | attn = self.softmax(attn)
25 | attn = self.dropout(attn)
26 | output = torch.bmm(attn, v)
27 |
28 | return output, attn
29 |
--------------------------------------------------------------------------------
/.github/workflows/pythonapp.yml:
--------------------------------------------------------------------------------
1 | name: Python application
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v1
12 | - name: Set up Python 3.7
13 | uses: actions/setup-python@v1
14 | with:
15 | python-version: 3.7
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install -r requirements.txt
20 | - name: Lint with flake8
21 | run: |
22 | pip install flake8
23 | # stop the build if there are Python syntax errors or undefined names
24 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
25 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
26 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
27 | - name: Test with pytest
28 | run: |
29 | pip install pytest
30 | pytest
31 |
--------------------------------------------------------------------------------
/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {{ title }}
5 |
6 |
7 |
8 |
9 | {% block content %}
10 | {% endblock %}
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/transformer/Optim.py:
--------------------------------------------------------------------------------
1 | '''A wrapper class for optimizer '''
2 | import numpy as np
3 |
4 | class ScheduledOptim():
5 | '''A simple wrapper class for learning rate scheduling'''
6 |
7 | def __init__(self, optimizer, d_model, n_warmup_steps):
8 | self._optimizer = optimizer
9 | self.n_warmup_steps = n_warmup_steps
10 | self.n_current_steps = 0
11 | # self.init_lr = np.power(d_model, -0.5)
12 | self.init_lr = 0.0001
13 |
14 | def step_and_update_lr(self):
15 | "Step with the inner optimizer"
16 | self._update_learning_rate()
17 | self._optimizer.step()
18 |
19 | def zero_grad(self):
20 | "Zero out the gradients by the inner optimizer"
21 | self._optimizer.zero_grad()
22 |
23 | def _get_lr_scale(self):
24 | return np.min([
25 | np.power((self.n_current_steps / 32), -0.5),
26 | # np.power(self.n_warmup_steps, -1.5) * (self.n_current_steps / 32)])
27 | np.power(self.n_warmup_steps, -1.5) * (self.n_current_steps / 32)]) / (self.n_warmup_steps ** -0.5)
28 |
29 | def _update_learning_rate(self):
30 | ''' Learning rate scheduling per step '''
31 |
32 | self.n_current_steps += 1
33 | lr = self.init_lr * self._get_lr_scale()
34 |
35 | for param_group in self._optimizer.param_groups:
36 | param_group['lr'] = lr
37 |
38 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flask
2 | absl-py==0.8.0
3 | asn1crypto==0.24.0
4 | astor==0.8.0
5 | backcall==0.1.0
6 | beautifulsoup4==4.6.0
7 | boto3==1.9.241
8 | botocore==1.12.241
9 | certifi==2018.4.16
10 | cffi==1.11.5
11 | chardet==3.0.4
12 | conda==4.5.8
13 | conda-build==3.12.0
14 | cryptography==2.3
15 | Cython==0.29.13
16 | decorator==4.3.0
17 | docutils==0.15.2
18 | filelock==3.0.4
19 | gast==0.2.2
20 | glob2==0.6
21 | google-pasta==0.1.7
22 | grpcio==1.24.1
23 | h5py==2.10.0
24 | idna==2.6
25 | ipython==6.4.0
26 | ipython-genutils==0.2.0
27 | jedi==0.12.1
28 | Jinja2==2.10
29 | jmespath==0.9.4
30 | Keras-Applications==1.0.8
31 | Keras-Preprocessing==1.1.0
32 | Markdown==3.1.1
33 | MarkupSafe==1.0
34 | mecab-python3==0.996.2
35 | mkl-fft==1.0.4
36 | mkl-random==1.0.1
37 | numpy==1.17.2
38 | olefile==0.45.1
39 | opt-einsum==3.1.0
40 | pandas==0.25.1
41 | parso==0.3.1
42 | pexpect==4.6.0
43 | pickleshare==0.7.4
44 | Pillow==6.2.0
45 | pkginfo==1.4.2
46 | prompt-toolkit==1.0.15
47 | protobuf==3.9.2
48 | psutil==5.4.6
49 | ptyprocess==0.6.0
50 | pycosat==0.6.3
51 | pycparser==2.18
52 | Pygments==2.2.0
53 | pyknp==0.4.1
54 | pyOpenSSL==18.0.0
55 | PySocks==1.6.8
56 | python-dateutil==2.8.0
57 | pytorch-pretrained-bert==0.6.2
58 | pytz==2019.2
59 | PyYAML==5.1
60 | regex==2019.8.19
61 | requests==2.20.0
62 | ruamel-yaml==0.15.37
63 | s3transfer==0.2.1
64 | scipy==1.1.0
65 | simplegeneric==0.8.1
66 | six==1.11.0
67 | tensorboard==2.0.0
68 | tensorboardX==1.8
69 | tensorflow==2.0.0
70 | tensorflow-estimator==2.0.0
71 | termcolor==1.1.0
72 | torch==0.4.1
73 | torchvision==0.2.1
74 | tqdm==4.36.1
75 | traitlets==4.3.2
76 | urllib3==1.24.2
77 | wcwidth==0.1.7
78 | Werkzeug==0.16.0
79 | wrapt==1.11.2
80 |
--------------------------------------------------------------------------------
/transformer/Layers.py:
--------------------------------------------------------------------------------
1 | ''' Define the Layers '''
2 | import torch.nn as nn
3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
4 |
5 | __author__ = "Yu-Hsiang Huang"
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | ''' Compose with two layers '''
10 |
11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
12 | super(EncoderLayer, self).__init__()
13 | self.slf_attn = MultiHeadAttention(
14 | n_head, d_model, d_k, d_v, dropout=dropout)
15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
16 |
17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
18 | enc_output, enc_slf_attn = self.slf_attn(
19 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
20 | enc_output *= non_pad_mask
21 |
22 | enc_output = self.pos_ffn(enc_output)
23 | enc_output *= non_pad_mask
24 |
25 | return enc_output, enc_slf_attn
26 |
27 |
28 | class DecoderLayer(nn.Module):
29 | ''' Compose with three layers '''
30 |
31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
32 | super(DecoderLayer, self).__init__()
33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
36 |
37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
38 | dec_output, dec_slf_attn = self.slf_attn(
39 | dec_input, dec_input, dec_input, mask=slf_attn_mask)
40 | dec_output *= non_pad_mask
41 |
42 | dec_output, dec_enc_attn = self.enc_attn(
43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
44 | dec_output *= non_pad_mask
45 |
46 | dec_output = self.pos_ffn(dec_output)
47 | dec_output *= non_pad_mask
48 |
49 | return dec_output, dec_slf_attn, dec_enc_attn
50 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_pretrained_bert.modeling import BertModel, BertConfig
4 | from transformer.Models import Decoder
5 |
6 |
7 | class AbstractiveTextSummarizationUsingBert(nn.Module):
8 | def __init__(self, bert_model_path, n_tgt_vocab, len_max_seq, d_word_vec=768, d_model=768, d_inner=3072,
9 | n_layers=12, n_head=12, d_k=64, d_v=64, dropout=0.1):
10 |
11 | super().__init__()
12 |
13 | self.encoder = BertModel.from_pretrained(bert_model_path)
14 | self.config = BertConfig(bert_model_path+'bert_config.json')
15 | self.decoder = Decoder(
16 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
17 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
18 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
19 | dropout=dropout)
20 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
21 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
22 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
23 | self.x_logit_scale = (d_model ** -0.5)
24 | self.o_l = nn.Linear(d_model, 512, bias=False)
25 | self.h_l = nn.Linear(512, 1, bias=True)
26 | nn.init.xavier_normal_(self.o_l.weight)
27 | nn.init.xavier_normal_(self.h_l.weight)
28 | self.a_l_1 = nn.Linear(d_model, 512, bias=False)
29 | self.a_l_2 = nn.Linear(d_model, 512, bias=False)
30 | nn.init.xavier_normal_(self.a_l_1.weight)
31 | nn.init.xavier_normal_(self.a_l_2.weight)
32 |
33 | def forward(self, src_seq, src_sen, tgt_seq, tgt_pos):
34 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]
35 |
36 | enc_output, _ = self.encoder(src_seq, src_sen, output_all_encoded_layers=False)
37 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
38 |
39 | # o = self.o_l(dec_output)
40 | # p_gen = torch.sigmoid(self.h_l(o).view(-1, 1))
41 |
42 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
43 | # a = self.a_l_1(dec_output)
44 | # a = torch.bmm(a, enc_output)
45 | # a = self.a_l_2(a)
46 |
47 | return seq_logit.view(-1, seq_logit.size(2))
48 |
--------------------------------------------------------------------------------
/summarize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | from dataloader import src_collate_fn, TextSummarizationDataset
7 | from transformer.Translator import Summarizer
8 |
9 |
10 | def main():
11 |
12 | parser = argparse.ArgumentParser(description='translate.py')
13 |
14 | parser.add_argument('-model', required=True,
15 | help='Path to model .pt file')
16 | parser.add_argument('-src', required=True,
17 | help='Source sequence to decode (one line per sequence)')
18 | parser.add_argument('-vocab', required=True,
19 | help='Source sequence to decode (one line per sequence)')
20 | parser.add_argument('-output', default='pred.txt',
21 | help="""Path to output the predictions (each line will
22 | be the decoded sequence""")
23 | parser.add_argument('-beam_size', type=int, default=5,
24 | help='Beam size')
25 | parser.add_argument('-batch_size', type=int, default=30,
26 | help='Batch size')
27 | parser.add_argument('-n_best', type=int, default=1,
28 | help="""If verbose is set, will output the n_best
29 | decoded sentences""")
30 | parser.add_argument('-no_cuda', action='store_true')
31 |
32 | opt = parser.parse_args()
33 | opt.cuda = not opt.no_cuda
34 |
35 | # Prepare DataLoader
36 | data = torch.load(opt.src)
37 | data['settings'].cuda = opt.cuda
38 |
39 | test_loader = torch.utils.data.DataLoader(
40 | TextSummarizationDataset(
41 | # src_word2idx=preprocess_data['dict']['src'],
42 | src_word2idx=data['dict']['src'],
43 | # tgt_word2idx=preprocess_data['dict']['tgt'],
44 | tgt_word2idx=data['dict']['tgt'],
45 | # src_insts=test_src_insts),
46 | src_insts=data['valid']['src'][0:10]),
47 | num_workers=2,
48 | batch_size=opt.batch_size,
49 | collate_fn=src_collate_fn)
50 | translator = Summarizer(opt)
51 | with open(opt.output, 'w') as f:
52 | for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False):
53 | all_hyp, all_scores = translator.translate_batch(*batch)
54 | for idx_seqs in all_hyp:
55 | for idx_seq in idx_seqs:
56 | print(len(idx_seq))
57 | pred_line = ''.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
58 | f.write('[@]' + pred_line + '\n\n')
59 | print('[Info] Finished.')
60 |
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/transformer/SubLayers.py:
--------------------------------------------------------------------------------
1 | ''' Define the sublayers in encoder/decoder layer '''
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformer.Modules import ScaledDotProductAttention
6 |
7 | __author__ = "Yu-Hsiang Huang"
8 |
9 | class MultiHeadAttention(nn.Module):
10 | ''' Multi-Head Attention module '''
11 |
12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
13 | super().__init__()
14 |
15 | self.n_head = n_head
16 | self.d_k = d_k
17 | self.d_v = d_v
18 |
19 | self.w_qs = nn.Linear(d_model, n_head * d_k)
20 | self.w_ks = nn.Linear(d_model, n_head * d_k)
21 | self.w_vs = nn.Linear(d_model, n_head * d_v)
22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
25 |
26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
27 | self.layer_norm = nn.LayerNorm(d_model)
28 |
29 | self.fc = nn.Linear(n_head * d_v, d_model)
30 | nn.init.xavier_normal_(self.fc.weight)
31 |
32 | self.dropout = nn.Dropout(dropout)
33 |
34 |
35 | def forward(self, q, k, v, mask=None):
36 |
37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
38 |
39 | sz_b, len_q, _ = q.size()
40 | sz_b, len_k, _ = k.size()
41 | sz_b, len_v, _ = v.size()
42 |
43 | residual = q
44 |
45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
48 |
49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
52 |
53 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
54 | output, attn = self.attention(q, k, v, mask=mask)
55 |
56 | output = output.view(n_head, sz_b, len_q, d_v)
57 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
58 |
59 | output = self.dropout(self.fc(output))
60 | output = self.layer_norm(output + residual)
61 |
62 | return output, attn
63 |
64 | class PositionwiseFeedForward(nn.Module):
65 | ''' A two-feed-forward-layer module '''
66 |
67 | def __init__(self, d_in, d_hid, dropout=0.1):
68 | super().__init__()
69 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
70 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
71 | self.layer_norm = nn.LayerNorm(d_in)
72 | self.dropout = nn.Dropout(dropout)
73 |
74 | def forward(self, x):
75 | residual = x
76 | output = x.transpose(1, 2)
77 | output = self.w_2(F.relu(self.w_1(output)))
78 | output = output.transpose(1, 2)
79 | output = self.dropout(output)
80 | output = self.layer_norm(output + residual)
81 | return output
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Abstractive text summarization using BERT
2 | This is the models using BERT (refer the paper [Pretraining-Based Natural Language Generation for Text Summarization
3 | ](https://arxiv.org/abs/1902.09243) ) for one of the NLP(Natural Language Processing) task, abstractive text summarization.
4 |
5 | ## Requirements
6 | - Python 3.6.5+
7 | - Pytorch 0.4.1+
8 | - Tensorflow
9 | - Pandas
10 | - tqdm
11 | - Numpy
12 | - MeCab
13 | - Tensorboard X and others...
14 |
15 | All packages used here can be installed by pip as follow:
16 |
17 | ~~~
18 | pip install -r requirement.txt
19 | ~~~
20 |
21 | ## Docker
22 | If you train the model with GPU, it is easy to use [Pytorch docker images](https://hub.docker.com/r/pytorch/pytorch) in DockerHub.
23 |
24 | In this study, pytorch/pytorch:0.4.1-cuda9-cudnn7-devel(2.62GB) has been used.
25 |
26 | ## Before using
27 | When you use this, please follow the steps below.
28 | 1. Make a repository named "/data/checkpoint" under root.
29 | And put bert_model, vocabulary file and config file for bert.
30 | These files can be download [here](http://nlp.ist.i.kyoto-u.ac.jp/index.php?BERT%E6%97%A5%E6%9C%AC%E8%AA%9EPretrained%E3%83%A2%E3%83%87%E3%83%AB).
31 |
32 | 2. Put data file for training and validate under /workspace/data/. The format is as follow:
33 |
34 | ```preprocess.py
35 | data = {
36 | 'settings': opt,
37 | 'dict': {
38 | 'src': text2token,
39 | 'tgt': text2token},
40 | 'train': {
41 | 'src': content[:100000],
42 | 'tgt': summary[:100000]},
43 | 'valid': {
44 | 'src': content[100000:],
45 | 'tgt': summary[100000:]}}
46 | torch.save(data, opt.save_data)
47 | ```
48 |
49 | overall directory structure is as follow:
50 | ```
51 | `-- data # under workspace
52 | |-- checkpoint
53 | | |-- bert_config.json # BERT config file
54 | | |-- pytorch_model.bin # BERT model file
55 | | `-- vocab.txt # vocabulary file
56 | `-- preprocessed_data.data # train and valid data file
57 | ```
58 | ## Setting
59 | |Name |Value |
60 | |---|---|
61 | |Encoder |BERT |
62 | |Decoder |Transformer (Only Decoder) |
63 | |Embed dimension |768 |
64 | |Hidden dimension |3072 |
65 | |Encoder layers |12 |
66 | |Decoder layers |8 |
67 | |Optimizer |Adam |
68 | |Learning rate |init=0.0001 |
69 | |Wormup step |4000 |
70 | |Input max length |512 |
71 | |Batch size |4 |
72 |
73 | ## Usage
74 | ### Train the model
75 | ```
76 | python train.py -data data/preprocessed_data.data -bert_path data/checkpoint/ -proj_share_weight -label_smoothing -batch_size 4 -epoch 10 -save_model trained -save_mode best
77 | ```
78 | ### Generate summarization with trained model
79 | ```
80 | python summarize.py -model data/checkpoint/trained/trained.chkpt -src data/preprocessed_data.data -vocab data/checkpoint/vocab.txt -output pred.txt
81 | ```
82 |
83 | ## Resut
84 | ### Tensorboard X image
85 | 
86 |
87 |
88 | ## TODO
89 | - Eval the model with score such as ROUGE-N
90 | - Make some examples
91 |
92 | ## Acknowledge
93 | - This repository structure and many codes are borrowed from [jadore801120/attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch).
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import io, sys
2 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
3 | import sqlite3
4 | import argparse
5 | from tqdm import tqdm
6 | import pandas.io.sql as psql
7 | import transformer.Constants as Constants
8 | import torch
9 | #import MeCab
10 | from tokenizer import FullTokenizer
11 | #m = MeCab.Tagger('-Owakati')
12 |
13 | def build_vocab(path_to_file):
14 | with open(path_to_file, encoding='utf-8') as f:
15 | vocab = f.readlines()
16 | token2text = {k: v.rstrip() for k, v in enumerate(vocab)}
17 | text2token = {v: k for k, v in token2text.items()}
18 |
19 | return text2token, token2text
20 |
21 |
22 | def get_content_summary_from_df(df):
23 | content = df['content'].values.tolist()
24 | summary = df['summary'].values.tolist()
25 |
26 | return content, summary
27 |
28 |
29 | def convert_text_to_token(content, summary, text2token, max_len, tokenizer):
30 | tokens_content = []
31 | tokens_summary = []
32 | for d_content, d_summary in tqdm(zip(content, summary), ascii=True, total=len(content)):
33 | tokens_content.append(convert_text_to_token_seq(d_content, text2token, max_len, tokenizer))
34 | tokens_summary.append(convert_text_to_token_seq(d_summary, text2token, max_len, tokenizer))
35 |
36 | return tokens_content, tokens_summary
37 |
38 |
39 | def convert_text_to_token_seq(text, text2token, max_len, tokenizer):
40 | if len(text) > 2000:
41 | text = text[:2000]
42 | splited_text = tokenizer.tokenize(convert_num_half_to_full(text.replace('。\n', '\n').replace('\n', '。\n')))
43 | if len(splited_text) > (max_len - 2):
44 | splited_text = splited_text[:max_len-2]
45 | splited_text = [Constants.BOS] + \
46 | [text2token.get(i, Constants.UNK) for i in splited_text] + \
47 | [Constants.EOS]
48 | return splited_text
49 |
50 |
51 | def convert_num_half_to_full(text):
52 | table = str.maketrans({
53 | '0': '0',
54 | '1': '1',
55 | '2': '2',
56 | '3': '3',
57 | '4': '4',
58 | '5': '5',
59 | '6': '6',
60 | '7': '7',
61 | '8': '8',
62 | '9': '9',
63 | })
64 | return text.translate(table)
65 |
66 |
67 | def main():
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('-vocab', required=True)
70 | parser.add_argument('-data', required=True)
71 | parser.add_argument('-save_data', required=True)
72 | parser.add_argument('--max_word_seq_len', required=True, type=int)
73 | opt = parser.parse_args()
74 | opt.max_token_seq_len = opt.max_word_seq_len
75 |
76 | tokenizer = FullTokenizer(opt.vocab)
77 | connection = sqlite3.connect(opt.data)
78 | cursor = connection.cursor()
79 | df = psql.read_sql("SELECT * FROM Article;", connection)
80 | print('Finished reading db file.')
81 | text2token, token2text = build_vocab(opt.vocab)
82 | print('Finished building vocab.')
83 | content, summary = get_content_summary_from_df(df)
84 | content, summary = convert_text_to_token(content, summary, text2token, opt.max_word_seq_len, tokenizer)
85 | data = {
86 | 'settings': opt,
87 | 'dict': {
88 | 'src': text2token,
89 | 'tgt': text2token},
90 | 'train': {
91 | 'src': content[:100000],
92 | 'tgt': summary[:100000]},
93 | 'valid': {
94 | 'src': content[100000:],
95 | 'tgt': summary[100000:]}}
96 | torch.save(data, opt.save_data)
97 |
98 |
99 | if __name__ == '__main__':
100 | main()
101 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 |
5 | from transformer import Constants
6 | from data_util import config
7 |
8 |
9 | def paired_collate_fn(insts):
10 | src_insts, tgt_insts = list(zip(*insts))
11 | src_insts = src_collate_fn(src_insts)
12 | tgt_insts = tgt_collate_fn(tgt_insts)
13 | return (*src_insts, *tgt_insts)
14 |
15 |
16 | def src_collate_fn(insts):
17 | ''' Pad the instance to the max seq length in batch '''
18 |
19 | max_len = config.max_src_size
20 |
21 | batch_seq = np.array([
22 | inst + [Constants.PAD] * (max_len - len(inst))
23 | for inst in insts])
24 |
25 | batch_pos = []
26 | for inst in batch_seq:
27 | cnt = 0
28 | buf = []
29 | for i in inst:
30 | buf.append(cnt)
31 | if i == 7:
32 | cnt = cnt ^ 1
33 | batch_pos.append(buf)
34 | batch_pos = np.array(batch_pos)
35 |
36 | batch_seq = torch.LongTensor(batch_seq)
37 | batch_pos = torch.LongTensor(batch_pos)
38 |
39 | return batch_seq, batch_pos
40 |
41 |
42 | def tgt_collate_fn(insts):
43 | ''' Pad the instance to the max seq length in batch '''
44 |
45 | max_len = max(len(inst) for inst in insts)
46 | if max_len > config.max_tgt_size:
47 | max_len = config.max_tgt_size
48 |
49 | batch_seq = np.array([
50 | inst + [Constants.PAD] * (max_len - len(inst))
51 | for inst in insts])
52 |
53 | batch_pos = np.array([
54 | [pos_i+1 if w_i != Constants.PAD else 0
55 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq])
56 |
57 | batch_seq = torch.LongTensor(batch_seq)
58 | batch_pos = torch.LongTensor(batch_pos)
59 |
60 | return batch_seq, batch_pos
61 |
62 |
63 | class TextSummarizationDataset(torch.utils.data.Dataset):
64 | def __init__(
65 | self, src_word2idx, tgt_word2idx,
66 | src_insts=None, tgt_insts=None):
67 |
68 | assert src_insts
69 | assert not tgt_insts or (len(src_insts) == len(tgt_insts))
70 |
71 | src_idx2word = {idx:word for word, idx in src_word2idx.items()}
72 | self._src_word2idx = src_word2idx
73 | self._src_idx2word = src_idx2word
74 | self._src_insts = src_insts
75 |
76 | tgt_idx2word = {idx:word for word, idx in tgt_word2idx.items()}
77 | self._tgt_word2idx = tgt_word2idx
78 | self._tgt_idx2word = tgt_idx2word
79 | self._tgt_insts = tgt_insts
80 |
81 | @property
82 | def n_insts(self):
83 | """ Property for dataset size """
84 | return len(self._src_insts)
85 |
86 | @property
87 | def src_vocab_size(self):
88 | """ Property for vocab size """
89 | return len(self._src_word2idx)
90 |
91 | @property
92 | def tgt_vocab_size(self):
93 | """ Property for vocab size """
94 | return len(self._tgt_word2idx)
95 |
96 | @property
97 | def src_word2idx(self):
98 | """ Property for word dictionary """
99 | return self._src_word2idx
100 |
101 | @property
102 | def tgt_word2idx(self):
103 | """ Property for word dictionary """
104 | return self._tgt_word2idx
105 |
106 | @property
107 | def src_idx2word(self):
108 | """ Property for index dictionary """
109 | return self._src_idx2word
110 |
111 | @property
112 | def tgt_idx2word(self):
113 | """ Property for index dictionary """
114 | return self._tgt_idx2word
115 |
116 | def __len__(self):
117 | return self.n_insts
118 |
119 | def __getitem__(self, idx):
120 | if self._tgt_insts:
121 | return self._src_insts[idx], self._tgt_insts[idx]
122 | return self._src_insts[idx]
123 |
124 |
--------------------------------------------------------------------------------
/transformer/Beam.py:
--------------------------------------------------------------------------------
1 | """ Manage beam search info structure.
2 |
3 | Heavily borrowed from OpenNMT-py.
4 | For code in OpenNMT-py, please check the following link:
5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
6 | """
7 |
8 | import torch
9 | import transformer.Constants as Constants
10 |
11 |
12 | class Beam(object):
13 | """ Beam search """
14 |
15 | def __init__(self, size, block_ngram_repeat=3, exclusion_tokens=set(), device=False):
16 |
17 | self.size = size
18 | self._done = False
19 | self.block_ngram_repeat = block_ngram_repeat
20 | self.exclusion_tokens = exclusion_tokens
21 |
22 | # The score for each translation on the beam.
23 | self.scores = torch.zeros((size,), dtype=torch.float, device=device)
24 | self.all_scores = []
25 |
26 | # The backpointers at each time-step.
27 | self.prev_ks = []
28 |
29 | # The outputs at each time-step.
30 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)]
31 | self.next_ys[0][0] = Constants.BOS
32 |
33 | def get_current_state(self):
34 | """Get the outputs for the current timestep."""
35 | return self.get_tentative_hypothesis()
36 |
37 | def get_current_origin(self):
38 | """Get the backpointers for the current timestep."""
39 | return self.prev_ks[-1]
40 |
41 | @property
42 | def done(self):
43 | return self._done
44 |
45 | def advance(self, word_prob):
46 | """Update beam status and check if finished or not."""
47 | num_words = word_prob.size(1)
48 |
49 | # Sum the previous scores.
50 | if len(self.prev_ks) > 0:
51 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
52 | # Block ngram repeats
53 | if self.block_ngram_repeat > 0:
54 | le = len(self.next_ys)
55 | for j in range(self.next_ys[-1].size(0)):
56 | hyp = self.get_hyp(le - 1, j)
57 | ngrams = set()
58 | fail = False
59 | gram = []
60 | for i in range(le - 1):
61 | # Last n tokens, n = block_ngram_repeat
62 | gram = (gram +
63 | [hyp[i].item()])[-self.block_ngram_repeat:]
64 | # Skip the blocking if it is in the exclusion list
65 | if set(gram) & self.exclusion_tokens:
66 | continue
67 | if tuple(gram) in ngrams:
68 | fail = True
69 | ngrams.add(tuple(gram))
70 | if fail:
71 | beam_lk[j] = -10e20
72 | else:
73 | beam_lk = word_prob[0]
74 |
75 | flat_beam_lk = beam_lk.view(-1)
76 |
77 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
78 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort
79 |
80 | self.all_scores.append(self.scores)
81 | self.scores = best_scores
82 |
83 | # bestScoresId is flattened as a (beam x word) array,
84 | # so we need to calculate which word and beam each score came from
85 | prev_k = best_scores_id / num_words
86 | self.prev_ks.append(prev_k)
87 | self.next_ys.append(best_scores_id - prev_k * num_words)
88 |
89 | # End condition is when top-of-beam is EOS.
90 | if self.next_ys[-1][0].item() == Constants.EOS:
91 | self._done = True
92 | self.all_scores.append(self.scores)
93 |
94 | return self._done
95 |
96 | def sort_scores(self):
97 | """Sort the scores."""
98 | return torch.sort(self.scores, 0, True)
99 |
100 | def get_the_best_score_and_idx(self):
101 | """Get the score of the best in the beam."""
102 | scores, ids = self.sort_scores()
103 | return scores[1], ids[1]
104 |
105 | def get_tentative_hypothesis(self):
106 | """Get the decoded sequence for the current timestep."""
107 |
108 | if len(self.next_ys) == 1:
109 | dec_seq = self.next_ys[0].unsqueeze(1)
110 | else:
111 | _, keys = self.sort_scores()
112 | hyps = [self.get_hypothesis(k) for k in keys]
113 | hyps = [[Constants.BOS] + h for h in hyps]
114 | dec_seq = torch.LongTensor(hyps)
115 |
116 | return dec_seq
117 |
118 | def get_hyp(self, timestep, k):
119 | """Walk back to construct the full hypothesis."""
120 | hyp = []
121 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
122 | hyp.append(self.next_ys[j + 1][k])
123 | k = self.prev_ks[j][k]
124 | return hyp[::-1]
125 |
126 | def get_hypothesis(self, k):
127 | """ Walk back to construct the full hypothesis. """
128 | hyp = []
129 | for j in range(len(self.prev_ks) - 1, -1, -1):
130 | hyp.append(self.next_ys[j+1][k])
131 | k = self.prev_ks[j][k]
132 |
133 | return list(map(lambda x: x.item(), hyp[::-1]))
134 |
--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | from dataloader import src_collate_fn, TextSummarizationDataset
7 | from transformer.Translator import Summarizer
8 | from flask import Flask, jsonify, request
9 | from tokenizer import FullTokenizer
10 | import transformer.Constants as Constants
11 |
12 | app = Flask(__name__)
13 | app.config['JSON_AS_ASCII'] = False
14 |
15 | parser = argparse.ArgumentParser(description='translate.py')
16 | parser.add_argument('-model', default='/workspace/Bert-abstractive-text-summarization/data/checkpoint/trained'
17 | '/trained_20191004.chkpt',
18 | help='Path to model .pt file')
19 | parser.add_argument('-src', default='/workspace/Bert-abstractive-text-summarization/data/preprocessed_data.data',
20 | help='Source sequence to decode (one line per sequence)')
21 | parser.add_argument('-vocab', default='/workspace/Bert-abstractive-text-summarization/data/checkpoint/vocab.txt',
22 | help='Source sequence to decode (one line per sequence)')
23 | parser.add_argument('-output', default='pred.txt',
24 | help="""Path to output the predictions (each line will be the decoded sequence""")
25 | parser.add_argument('-beam_size', type=int, default=5,
26 | help='Beam size')
27 | parser.add_argument('-batch_size', type=int, default=30,
28 | help='Batch size')
29 | parser.add_argument('-n_best', type=int, default=1,
30 | help="""If verbose is set, will output the n_best decoded sentences""")
31 | parser.add_argument('-no_cuda', action='store_true')
32 | opt = parser.parse_args()
33 | opt.cuda = not opt.no_cuda
34 |
35 | # Prepare DataLoader
36 | data = torch.load(opt.src)
37 | data['settings'].cuda = opt.cuda
38 |
39 | # Create Translator Model
40 | translator = Summarizer(opt)
41 |
42 | # Create Tokenizer
43 | tokenizer = FullTokenizer(opt.vocab)
44 |
45 |
46 | @app.route('/', methods=['POST'])
47 | def summarization():
48 | json_data = request.get_json()
49 |
50 | data_loader = preprocess(json_data)
51 | summaries = summarize(data_loader)
52 | summaries = remove_symbol(summaries)
53 |
54 | return jsonify({
55 | 'summaries': summaries,
56 | })
57 |
58 |
59 | def preprocess(json_data):
60 | """
61 | Preprocess input data.
62 | 1. Extract text list from JSON data.
63 | 2. Divide text list into the words.
64 | 3. Convert text list to vocabulary id.
65 | 4. Convert id list to DataLoader.
66 |
67 | Args:
68 | json_data (object): received JSON data.
69 |
70 | Returns:
71 | data_loader (object): This data has been processed with tokenize, token2id, and toDataloader.
72 |
73 | """
74 | # 1. Extract text list from JSON data.
75 | texts = get_text_from_json(json_data)
76 |
77 | # 2. Divide text list into the words.
78 | tokenized_texts = tokenize(texts)
79 |
80 | # 3. Convert text list to vocabulary id.
81 | tokens = text2token(tokenized_texts)
82 |
83 | # 4. Convert id list to DataLoader.
84 | data_loader = toDataLoader(tokens)
85 |
86 | return data_loader
87 |
88 |
89 | def get_text_from_json(json_data):
90 | """
91 | Extract text list from JSON data.
92 |
93 | Args:
94 | json_data (object): input data.
95 |
96 | Returns:
97 | texts (list): text list
98 |
99 | """
100 | return json_data['source_texts']
101 |
102 |
103 | def tokenize(texts):
104 | """
105 | Tokenize row text list. We use MeCab to tokenize sentences.
106 |
107 | Args:
108 | texts (list): The text list to tokenize.
109 |
110 | Returns:
111 | tokenized_texts (list): The tokenized text list.
112 |
113 | """
114 | max_len = 512
115 | splited_texts = []
116 |
117 | for text in texts:
118 | splited_text = tokenizer.tokenize(_convert_num_half_to_full(text.replace('。\n', '\n').replace('\n', '。\n')))
119 | if len(splited_text) > (max_len - 2):
120 | splited_text = splited_text[:max_len-2]
121 | splited_texts.append(splited_text)
122 |
123 | return splited_texts
124 |
125 |
126 | def _convert_num_half_to_full(text):
127 | table = str.maketrans({
128 | '0': '0',
129 | '1': '1',
130 | '2': '2',
131 | '3': '3',
132 | '4': '4',
133 | '5': '5',
134 | '6': '6',
135 | '7': '7',
136 | '8': '8',
137 | '9': '9',
138 | })
139 | return text.translate(table)
140 |
141 |
142 | def text2token(texts):
143 | """
144 | Convert input text list to vocabulary id (token) list.
145 |
146 | Args:
147 | texts (list): input text list.
148 |
149 | Returns:
150 | tokens (list): vocabulary id list.
151 |
152 | """
153 | tokens = []
154 |
155 | for text in texts:
156 | token = [Constants.BOS] + \
157 | [data['dict']['src'].get(i, Constants.UNK) for i in text] + \
158 | [Constants.EOS]
159 | tokens.append(token)
160 |
161 | return tokens
162 |
163 |
164 | def toDataLoader(tokens):
165 | """
166 | Create DataLoader object from input vocabulary id list.
167 |
168 | Args:
169 | tokens (list): vocabulary id list.
170 |
171 | Returns:
172 | data_loader (object): DataLoader created from ids.
173 |
174 | """
175 | return torch.utils.data.DataLoader(
176 | TextSummarizationDataset(
177 | src_word2idx=data['dict']['src'],
178 | tgt_word2idx=data['dict']['tgt'],
179 | src_insts=tokens),
180 | num_workers=2,
181 | batch_size=opt.batch_size,
182 | collate_fn=src_collate_fn)
183 |
184 |
185 | def summarize(data_loader):
186 | """
187 | Summarize text in DataLoader with trained Deep Learning Model.
188 |
189 | Args:
190 | data_loader (DataLoader): inputted DataLoader.
191 |
192 | Returns:
193 | summarized_texts (list): summarized test list.
194 |
195 | """
196 | pred_lines = []
197 |
198 | # Prediction
199 | with open(opt.output, 'w') as f:
200 | for batch in tqdm(data_loader, mininterval=2, desc=' - (Test)', leave=False):
201 | all_hyp, all_scores = translator.translate_batch(*batch)
202 | for idx_seqs in all_hyp:
203 | for idx_seq in idx_seqs:
204 | pred_line = ''.join([data_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
205 | pred_lines.append(pred_line)
206 |
207 | return pred_lines
208 |
209 |
210 | def remove_symbol(texts):
211 | """
212 | Remove symbol "##", "[SEP]".
213 |
214 | Args:
215 | texts(list) : input text list
216 |
217 | Returns:
218 | removed_text(list): text list that removed symbol "##", "[SEP]"
219 |
220 | """
221 | removed_texts = []
222 | for text in texts:
223 | removed_texts.append(text.replace('##', '').replace('[SEP]', ''))
224 |
225 | return removed_texts
226 |
227 |
228 | if __name__ == '__main__':
229 | app.run(host='0.0.0.0', port=6006)
230 |
--------------------------------------------------------------------------------
/transformer/Translator.py:
--------------------------------------------------------------------------------
1 | ''' This module will handle the text generation with beam search. '''
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from models import AbstractiveTextSummarizationUsingBert
8 | from transformer.Beam import Beam
9 |
10 |
11 | class Summarizer(object):
12 | def __init__(self, opt):
13 | self.opt = opt
14 | self.device = torch.device('cuda' if opt.cuda else 'cpu')
15 |
16 | checkpoint = torch.load(opt.model)
17 | model_opt = checkpoint['settings']
18 | self.model_opt = model_opt
19 |
20 | model = AbstractiveTextSummarizationUsingBert(
21 | # model_opt.bert_path,
22 | 'data/checkpoint/',
23 | model_opt.tgt_vocab_size,
24 | model_opt.max_token_seq_len,
25 | d_k=model_opt.d_k,
26 | d_v=model_opt.d_v,
27 | d_model=model_opt.d_model,
28 | d_word_vec=model_opt.d_word_vec,
29 | d_inner=model_opt.d_inner_hid,
30 | n_layers=model_opt.n_layers,
31 | n_head=model_opt.n_head,
32 | dropout=model_opt.dropout)
33 |
34 | model.load_state_dict(checkpoint['model'])
35 | print('[Info] Trained model state loaded.')
36 |
37 | model.word_prob_prj = nn.LogSoftmax(dim=1)
38 |
39 | model = model.to(self.device)
40 |
41 | self.model = model
42 | self.model.eval()
43 |
44 | def translate_batch(self, src_seq, src_pos):
45 |
46 | def get_inst_idx_to_tensor_position_map(inst_idx_list):
47 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
48 |
49 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
50 | _, *d_hs = beamed_tensor.size()
51 | n_curr_active_inst = len(curr_active_inst_idx)
52 | new_shape = (n_curr_active_inst * n_bm, *d_hs)
53 |
54 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
55 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
56 | beamed_tensor = beamed_tensor.view(*new_shape)
57 |
58 | return beamed_tensor
59 |
60 | def collate_active_info(
61 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
62 | # Sentences which are still active are collected,
63 | # so the decoder will not run on completed sentences.
64 | n_prev_active_inst = len(inst_idx_to_position_map)
65 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
66 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)
67 |
68 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
69 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
70 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
71 |
72 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map
73 |
74 | def beam_decode_step(
75 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
76 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
77 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
78 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
79 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
80 | return dec_partial_seq
81 |
82 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
83 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
84 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
85 | return dec_partial_pos
86 |
87 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
88 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
89 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
90 | dec_output = self.model.tgt_word_prj(dec_output)
91 | dec_output[:, 0] = -float('inf')
92 | dec_output[:, 1] = -float('inf')
93 | word_prob = F.log_softmax(dec_output, dim=1)
94 | word_prob = word_prob.view(n_active_inst, n_bm, -1)
95 |
96 | return word_prob
97 |
98 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
99 | active_inst_idx_list = []
100 | for inst_idx, inst_position in inst_idx_to_position_map.items():
101 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
102 | if not is_inst_complete:
103 | active_inst_idx_list += [inst_idx]
104 |
105 | return active_inst_idx_list
106 |
107 | n_active_inst = len(inst_idx_to_position_map)
108 |
109 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
110 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
111 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)
112 |
113 | # Update the beam with predicted word prob information and collect incomplete instances
114 | active_inst_idx_list = collect_active_inst_idx_list(
115 | inst_dec_beams, word_prob, inst_idx_to_position_map)
116 |
117 | return active_inst_idx_list
118 |
119 | def collect_hypothesis_and_scores(inst_dec_beams, n_best):
120 | all_hyp, all_scores = [], []
121 | for inst_idx in range(len(inst_dec_beams)):
122 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
123 | all_scores += [scores[:n_best]]
124 |
125 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
126 | all_hyp += [hyps]
127 | return all_hyp, all_scores
128 |
129 | with torch.no_grad():
130 | #-- Encode
131 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
132 | src_enc, _ = self.model.encoder(src_seq, output_all_encoded_layers=False)
133 |
134 | #-- Repeat data for beam search
135 | n_bm = self.opt.beam_size
136 | n_inst, len_s, d_h = src_enc.size()
137 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
138 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
139 |
140 | #-- Prepare beams
141 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]
142 |
143 | #-- Bookkeeping for active or not
144 | active_inst_idx_list = list(range(n_inst))
145 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
146 |
147 | #-- Decode
148 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):
149 |
150 | active_inst_idx_list = beam_decode_step(
151 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)
152 |
153 | if not active_inst_idx_list:
154 | break # all instances have finished their path to
155 |
156 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
157 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)
158 |
159 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best)
160 |
161 | return batch_hyp, batch_scores
162 |
--------------------------------------------------------------------------------
/transformer/Models.py:
--------------------------------------------------------------------------------
1 | ''' Define the Transformer model '''
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import transformer.Constants as Constants
6 | from transformer.Layers import EncoderLayer, DecoderLayer
7 |
8 | __author__ = "Yu-Hsiang Huang"
9 |
10 | def get_non_pad_mask(seq):
11 | assert seq.dim() == 2
12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
13 |
14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
15 | ''' Sinusoid position encoding table '''
16 |
17 | def cal_angle(position, hid_idx):
18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
19 |
20 | def get_posi_angle_vec(position):
21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
22 |
23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
24 |
25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
27 |
28 | if padding_idx is not None:
29 | # zero vector for padding dimension
30 | sinusoid_table[padding_idx] = 0.
31 |
32 | return torch.FloatTensor(sinusoid_table)
33 |
34 | def get_attn_key_pad_mask(seq_k, seq_q):
35 | ''' For masking out the padding part of key sequence. '''
36 |
37 | # Expand to fit the shape of key query attention matrix.
38 | len_q = seq_q.size(1)
39 | padding_mask = seq_k.eq(Constants.PAD)
40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
41 |
42 | return padding_mask
43 |
44 | def get_subsequent_mask(seq):
45 | ''' For masking out the subsequent info. '''
46 |
47 | sz_b, len_s = seq.size()
48 | subsequent_mask = torch.triu(
49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
51 |
52 | return subsequent_mask
53 |
54 | class Encoder(nn.Module):
55 | ''' A encoder model with self attention mechanism. '''
56 |
57 | def __init__(
58 | self,
59 | n_src_vocab, len_max_seq, d_word_vec,
60 | n_layers, n_head, d_k, d_v,
61 | d_model, d_inner, dropout=0.1):
62 |
63 | super().__init__()
64 |
65 | n_position = len_max_seq + 1
66 |
67 | self.src_word_emb = nn.Embedding(
68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
69 |
70 | self.position_enc = nn.Embedding.from_pretrained(
71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
72 | freeze=True)
73 |
74 | self.layer_stack = nn.ModuleList([
75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
76 | for _ in range(n_layers)])
77 |
78 | def forward(self, src_seq, src_pos, return_attns=False):
79 |
80 | enc_slf_attn_list = []
81 |
82 | # -- Prepare masks
83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
84 | non_pad_mask = get_non_pad_mask(src_seq)
85 |
86 | # -- Forward
87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
88 |
89 | for enc_layer in self.layer_stack:
90 | enc_output, enc_slf_attn = enc_layer(
91 | enc_output,
92 | non_pad_mask=non_pad_mask,
93 | slf_attn_mask=slf_attn_mask)
94 | if return_attns:
95 | enc_slf_attn_list += [enc_slf_attn]
96 |
97 | if return_attns:
98 | return enc_output, enc_slf_attn_list
99 | return enc_output,
100 |
101 | class Decoder(nn.Module):
102 | ''' A decoder model with self attention mechanism. '''
103 |
104 | def __init__(
105 | self,
106 | n_tgt_vocab, len_max_seq, d_word_vec,
107 | n_layers, n_head, d_k, d_v,
108 | d_model, d_inner, dropout=0.1):
109 |
110 | super().__init__()
111 | n_position = len_max_seq + 1
112 |
113 | self.tgt_word_emb = nn.Embedding(
114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD)
115 |
116 | self.position_enc = nn.Embedding.from_pretrained(
117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
118 | freeze=True)
119 |
120 | self.layer_stack = nn.ModuleList([
121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
122 | for _ in range(n_layers)])
123 |
124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):
125 |
126 | dec_slf_attn_list, dec_enc_attn_list = [], []
127 |
128 | # -- Prepare masks
129 | non_pad_mask = get_non_pad_mask(tgt_seq)
130 |
131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
133 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
134 |
135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)
136 |
137 | # -- Forward
138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)
139 |
140 | for dec_layer in self.layer_stack:
141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
142 | dec_output, enc_output,
143 | non_pad_mask=non_pad_mask,
144 | slf_attn_mask=slf_attn_mask,
145 | dec_enc_attn_mask=dec_enc_attn_mask)
146 |
147 | if return_attns:
148 | dec_slf_attn_list += [dec_slf_attn]
149 | dec_enc_attn_list += [dec_enc_attn]
150 |
151 | if return_attns:
152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list
153 | return dec_output,
154 |
155 | class Transformer(nn.Module):
156 | ''' A sequence to sequence model with attention mechanism. '''
157 |
158 | def __init__(
159 | self,
160 | n_src_vocab, n_tgt_vocab, len_max_seq,
161 | d_word_vec=512, d_model=512, d_inner=2048,
162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
163 | tgt_emb_prj_weight_sharing=True,
164 | emb_src_tgt_weight_sharing=True):
165 |
166 | super().__init__()
167 |
168 | self.encoder = Encoder(
169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq,
170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
172 | dropout=dropout)
173 |
174 | self.decoder = Decoder(
175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
178 | dropout=dropout)
179 |
180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
181 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
182 |
183 | assert d_model == d_word_vec, \
184 | 'To facilitate the residual connections, \
185 | the dimensions of all module outputs shall be the same.'
186 |
187 | if tgt_emb_prj_weight_sharing:
188 | # Share the weight matrix between target word embedding & the final logit dense layer
189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
190 | self.x_logit_scale = (d_model ** -0.5)
191 | else:
192 | self.x_logit_scale = 1.
193 |
194 | if emb_src_tgt_weight_sharing:
195 | # Share the weight matrix between source & target word embeddings
196 | assert n_src_vocab == n_tgt_vocab, \
197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same."
198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
199 |
200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
201 |
202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]
203 |
204 | enc_output, *_ = self.encoder(src_seq, src_pos)
205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
207 |
208 | return seq_logit.view(-1, seq_logit.size(2))
209 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | This script handling the training process.
3 | """
4 |
5 | import argparse
6 | import math
7 | import time
8 |
9 | from tqdm import tqdm
10 | import torch
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | import torch.utils.data
14 | import transformer.Constants as Constants
15 | from dataloader import TextSummarizationDataset, paired_collate_fn
16 | from transformer.Optim import ScheduledOptim
17 | import tensorboardX as tbx
18 | from models import AbstractiveTextSummarizationUsingBert
19 |
20 |
21 | def cal_performance(pred, x, gold, smoothing=False):
22 | ''' Apply label smoothing if needed '''
23 |
24 | loss = cal_loss(pred, x, gold, smoothing)
25 |
26 | pred = pred.max(1)[1]
27 | gold = gold.contiguous().view(-1)
28 | non_pad_mask = gold.ne(Constants.PAD)
29 | n_correct = pred.eq(gold)
30 | n_correct = n_correct.masked_select(non_pad_mask).sum().item()
31 |
32 | return loss, n_correct
33 |
34 |
35 | def cal_loss(pred, x, gold, smoothing):
36 | ''' Calculate cross entropy loss, apply label smoothing if needed. '''
37 |
38 | gold = gold.contiguous().view(-1)
39 |
40 | if smoothing:
41 | eps = 0.1
42 | n_class = pred.size(1)
43 |
44 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
45 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
46 | # x = x.repeat(1, int(a.size(0)/x.size(0))).view(-1, a.size(1))
47 | # prb = (1-p_gen)*F.softmax(pred, dim=1)
48 | # a = p_gen*F.softmax(a, dim=1)
49 | prb = F.softmax(pred, dim=1)
50 | # prb = prb.scatter_add(1, x, a)
51 | log_prb = torch.log(prb)
52 | non_pad_mask = gold.ne(Constants.PAD)
53 | loss = -(one_hot * log_prb).sum(dim=1)
54 | loss = loss.masked_select(non_pad_mask).sum() # average later
55 | else:
56 | loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum')
57 |
58 | return loss
59 |
60 |
61 | def train_epoch(model, training_data, optimizer, device, smoothing, step):
62 | ''' Epoch operation in training phase'''
63 |
64 | model.train()
65 |
66 | total_loss = 0
67 | n_word_total = 0
68 | n_word_correct = 0
69 |
70 | for i, batch in enumerate(tqdm(
71 | training_data, mininterval=2,
72 | desc=' - (Training) ', leave=False, ascii=True)):
73 |
74 | # prepare data
75 | src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
76 | gold = tgt_seq[:, 1:]
77 |
78 | # forward
79 | optimizer.zero_grad()
80 | # pred, a, p_gen = model(src_seq, src_pos, tgt_seq, tgt_pos)
81 | pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
82 |
83 | # backward
84 | loss, n_correct = cal_performance(pred, src_seq, gold, smoothing=smoothing)
85 | loss.backward()
86 |
87 | # update parameters
88 | optimizer.step_and_update_lr()
89 |
90 | # note keeping
91 | total_loss += loss.item()
92 |
93 | non_pad_mask = gold.ne(Constants.PAD)
94 | n_word = non_pad_mask.sum().item()
95 | n_word_total += n_word
96 | n_word_correct += n_correct
97 | writer.add_scalars('data/train_loss', {'train_loss_each_batch': loss.item() / n_word}, i+step)
98 | writer.add_scalars('data/train_accu', {'train_accu_each_batch': n_correct / n_word}, i+step)
99 |
100 | loss_per_word = total_loss/n_word_total
101 | accuracy = n_word_correct/n_word_total
102 | return loss_per_word, accuracy, step + i
103 |
104 |
105 | def eval_epoch(model, validation_data, device, step):
106 | ''' Epoch operation in evaluation phase '''
107 |
108 | model.eval()
109 |
110 | total_loss = 0
111 | n_word_total = 0
112 | n_word_correct = 0
113 |
114 | with torch.no_grad():
115 | for i, batch in enumerate(tqdm(
116 | validation_data, mininterval=2,
117 | desc=' - (Validation) ', leave=False, ascii=True)):
118 |
119 | # prepare data
120 | src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
121 | gold = tgt_seq[:, 1:]
122 |
123 | # forward
124 | pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
125 | loss, n_correct = cal_performance(pred, src_seq, gold, smoothing=False)
126 |
127 | # note keeping
128 | total_loss += loss.item()
129 |
130 | non_pad_mask = gold.ne(Constants.PAD)
131 | n_word = non_pad_mask.sum().item()
132 | n_word_total += n_word
133 | n_word_correct += n_correct
134 | writer.add_scalars('data/valid_loss', {'valid_loss_each_batch': loss.item() / n_word}, i+step)
135 | writer.add_scalars('data/valid_accu', {'valid_accu_each_batch': n_correct / n_word}, i+step)
136 |
137 | loss_per_word = total_loss/n_word_total
138 | accuracy = n_word_correct/n_word_total
139 | return loss_per_word, accuracy, step + i
140 |
141 |
142 | def train(model, training_data, validation_data, optimizer, device, opt):
143 | ''' Start training '''
144 |
145 | log_train_file = None
146 | log_valid_file = None
147 |
148 | if opt.log:
149 | log_train_file = opt.log + '.train.log'
150 | log_valid_file = opt.log + '.valid.log'
151 |
152 | print('[Info] Training performance will be written to file: {} and {}'.format(
153 | log_train_file, log_valid_file))
154 |
155 | with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
156 | log_tf.write('epoch,loss,ppl,accuracy\n')
157 | log_vf.write('epoch,loss,ppl,accuracy\n')
158 |
159 | valid_accus = []
160 | train_step = 0
161 | valid_step = 0
162 | for epoch_i in range(opt.epoch):
163 | print('[ Epoch', epoch_i, ']')
164 |
165 | start = time.time()
166 | train_loss, train_accu, train_step = train_epoch(
167 | model, training_data, optimizer, device, opt.label_smoothing, train_step)
168 | print(' - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, ' \
169 | 'elapse: {elapse:3.3f} min'.format(
170 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu,
171 | elapse=(time.time()-start)/60))
172 |
173 | start = time.time()
174 | valid_loss, valid_accu, valid_step = eval_epoch(model, validation_data, device, valid_step)
175 | print(' - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, ' \
176 | 'elapse: {elapse:3.3f} min'.format(
177 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu,
178 | elapse=(time.time()-start)/60))
179 |
180 | valid_accus += [valid_accu]
181 |
182 | model_state_dict = model.state_dict()
183 | checkpoint = {
184 | 'model': model_state_dict,
185 | 'settings': opt,
186 | 'epoch': epoch_i}
187 |
188 | if opt.save_model:
189 | if opt.save_mode == 'all':
190 | model_name = 'data/checkpoint/trained/' + opt.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
191 | torch.save(checkpoint, model_name)
192 | elif opt.save_mode == 'best':
193 | model_name = 'data/checkpoint/trained/' + opt.save_model + '.chkpt'
194 | if valid_accu >= max(valid_accus):
195 | torch.save(checkpoint, model_name)
196 | print(' - [Info] The checkpoint file has been updated.')
197 |
198 | if log_train_file and log_valid_file:
199 | with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
200 | log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
201 | epoch=epoch_i, loss=train_loss,
202 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu))
203 | log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
204 | epoch=epoch_i, loss=valid_loss,
205 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu))
206 |
207 | def main():
208 | ''' Main function '''
209 | parser = argparse.ArgumentParser()
210 |
211 | parser.add_argument('-data', required=True)
212 | parser.add_argument('-bert_path', required=True)
213 |
214 | parser.add_argument('-epoch', type=int, default=10)
215 | parser.add_argument('-batch_size', type=int, default=64)
216 |
217 | parser.add_argument('-d_model', type=int, default=768)
218 | parser.add_argument('-d_inner_hid', type=int, default=3072)
219 | parser.add_argument('-d_k', type=int, default=64)
220 | parser.add_argument('-d_v', type=int, default=64)
221 |
222 | parser.add_argument('-n_head', type=int, default=12)
223 | parser.add_argument('-n_layers', type=int, default=8)
224 | parser.add_argument('-n_warmup_steps', type=int, default=4000)
225 |
226 | parser.add_argument('-dropout', type=float, default=0.1)
227 | parser.add_argument('-embs_share_weight', action='store_true')
228 | parser.add_argument('-proj_share_weight', action='store_true')
229 |
230 | parser.add_argument('-log', default=None)
231 | parser.add_argument('-save_model', default=None)
232 | parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best')
233 |
234 | parser.add_argument('-no_cuda', action='store_true')
235 | parser.add_argument('-label_smoothing', action='store_true')
236 |
237 | opt = parser.parse_args()
238 | opt.cuda = not opt.no_cuda
239 | opt.d_word_vec = opt.d_model
240 |
241 | #========= Loading Dataset =========#
242 | data = torch.load(opt.data)
243 | opt.max_token_seq_len = data['settings'].max_token_seq_len
244 |
245 | training_data, validation_data = prepare_dataloaders(data, opt)
246 |
247 | opt.src_vocab_size = training_data.dataset.src_vocab_size
248 | opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size
249 |
250 | #========= Preparing Model =========#
251 | if opt.embs_share_weight:
252 | assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
253 | 'The src/tgt word2idx table are different but asked to share word embedding.'
254 |
255 | print(opt)
256 |
257 | device = torch.device('cuda' if opt.cuda else 'cpu')
258 | model = AbstractiveTextSummarizationUsingBert(
259 | opt.bert_path,
260 | opt.tgt_vocab_size,
261 | opt.max_token_seq_len,
262 | d_k=opt.d_k,
263 | d_v=opt.d_v,
264 | d_model=opt.d_model,
265 | d_word_vec=opt.d_word_vec,
266 | d_inner=opt.d_inner_hid,
267 | n_layers=opt.n_layers,
268 | n_head=opt.n_head,
269 | dropout=opt.dropout).to(device)
270 |
271 | optimizer = ScheduledOptim(
272 | optim.Adam(
273 | filter(lambda x: x.requires_grad, model.parameters()),
274 | betas=(0.9, 0.999), eps=1e-09),
275 | opt.d_model, opt.n_warmup_steps)
276 |
277 | train(model, training_data, validation_data, optimizer, device ,opt)
278 |
279 |
280 | def prepare_dataloaders(data, opt):
281 | # ========= Preparing DataLoader =========#
282 | train_loader = torch.utils.data.DataLoader(
283 | TextSummarizationDataset(
284 | src_word2idx=data['dict']['src'],
285 | tgt_word2idx=data['dict']['tgt'],
286 | src_insts=data['train']['src'],
287 | tgt_insts=data['train']['tgt']),
288 | num_workers=2,
289 | batch_size=opt.batch_size,
290 | collate_fn=paired_collate_fn,
291 | shuffle=True)
292 |
293 | valid_loader = torch.utils.data.DataLoader(
294 | TextSummarizationDataset(
295 | src_word2idx=data['dict']['src'],
296 | tgt_word2idx=data['dict']['tgt'],
297 | src_insts=data['valid']['src'],
298 | tgt_insts=data['valid']['tgt']),
299 | num_workers=2,
300 | batch_size=opt.batch_size,
301 | collate_fn=paired_collate_fn)
302 | return train_loader, valid_loader
303 |
304 |
305 | if __name__ == '__main__':
306 | writer = tbx.SummaryWriter(log_dir='data/tensorboardx/runs')
307 | main()
308 | writer.export_scalars_to_json("data/all_scalars.json")
309 | writer.close()
310 |
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.io.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | #self.tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | #self.tokenizer = JumanTokenizer()
169 | self.tokenizer = MeCabTokenizer()
170 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
171 |
172 | def tokenize(self, text):
173 | split_tokens = []
174 | for token in self.tokenizer.tokenize(text):
175 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
176 | split_tokens.append(sub_token)
177 |
178 | return split_tokens
179 |
180 | def convert_tokens_to_ids(self, tokens):
181 | return convert_by_vocab(self.vocab, tokens)
182 |
183 | def convert_ids_to_tokens(self, ids):
184 | return convert_by_vocab(self.inv_vocab, ids)
185 |
186 |
187 | class BasicTokenizer(object):
188 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
189 |
190 | def __init__(self, do_lower_case=True):
191 | """Constructs a BasicTokenizer.
192 |
193 | Args:
194 | do_lower_case: Whether to lower case the input.
195 | """
196 | self.do_lower_case = do_lower_case
197 |
198 | def tokenize(self, text):
199 | """Tokenizes a piece of text."""
200 | text = convert_to_unicode(text)
201 | text = self._clean_text(text)
202 |
203 | # This was added on November 1st, 2018 for the multilingual and Chinese
204 | # models. This is also applied to the English models now, but it doesn't
205 | # matter since the English models were not trained on any Chinese data
206 | # and generally don't have any Chinese data in them (there are Chinese
207 | # characters in the vocabulary because Wikipedia does have some Chinese
208 | # words in the English Wikipedia.).
209 | text = self._tokenize_chinese_chars(text)
210 |
211 | orig_tokens = whitespace_tokenize(text)
212 | split_tokens = []
213 | for token in orig_tokens:
214 | if self.do_lower_case:
215 | token = token.lower()
216 | token = self._run_strip_accents(token)
217 | split_tokens.extend(self._run_split_on_punc(token))
218 |
219 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
220 | return output_tokens
221 |
222 | def _run_strip_accents(self, text):
223 | """Strips accents from a piece of text."""
224 | text = unicodedata.normalize("NFD", text)
225 | output = []
226 | for char in text:
227 | cat = unicodedata.category(char)
228 | if cat == "Mn":
229 | continue
230 | output.append(char)
231 | return "".join(output)
232 |
233 | def _run_split_on_punc(self, text):
234 | """Splits punctuation on a piece of text."""
235 | chars = list(text)
236 | i = 0
237 | start_new_word = True
238 | output = []
239 | while i < len(chars):
240 | char = chars[i]
241 | if _is_punctuation(char):
242 | output.append([char])
243 | start_new_word = True
244 | else:
245 | if start_new_word:
246 | output.append([])
247 | start_new_word = False
248 | output[-1].append(char)
249 | i += 1
250 |
251 | return ["".join(x) for x in output]
252 |
253 | def _tokenize_chinese_chars(self, text):
254 | """Adds whitespace around any CJK character."""
255 | output = []
256 | for char in text:
257 | cp = ord(char)
258 | if self._is_chinese_char(cp):
259 | output.append(" ")
260 | output.append(char)
261 | output.append(" ")
262 | else:
263 | output.append(char)
264 | return "".join(output)
265 |
266 | def _is_chinese_char(self, cp):
267 | """Checks whether CP is the codepoint of a CJK character."""
268 | # This defines a "chinese character" as anything in the CJK Unicode block:
269 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
270 | #
271 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
272 | # despite its name. The modern Korean Hangul alphabet is a different block,
273 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
274 | # space-separated words, so they are not treated specially and handled
275 | # like the all of the other languages.
276 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
277 | (cp >= 0x3400 and cp <= 0x4DBF) or #
278 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
279 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
280 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
281 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
282 | (cp >= 0xF900 and cp <= 0xFAFF) or #
283 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
284 | return True
285 |
286 | return False
287 |
288 | def _clean_text(self, text):
289 | """Performs invalid character removal and whitespace cleanup on text."""
290 | output = []
291 | for char in text:
292 | cp = ord(char)
293 | if cp == 0 or cp == 0xfffd or _is_control(char):
294 | continue
295 | if _is_whitespace(char):
296 | output.append(" ")
297 | else:
298 | output.append(char)
299 | return "".join(output)
300 |
301 |
302 | class JumanTokenizer(BasicTokenizer):
303 | def __init__(self):
304 | from pyknp import Juman
305 |
306 | self.do_lower_case = False
307 | self._jumanpp = Juman()
308 |
309 | def tokenize(self, text):
310 | """Tokenizes a piece of text with Juman."""
311 |
312 | text = convert_to_unicode(text)
313 | text = self._clean_text(text)
314 |
315 |
316 | juman_result = self._jumanpp.analysis(text.replace(' ', ''))
317 | split_tokens = []
318 | for mrph in juman_result.mrph_list():
319 | split_tokens.extend(self._run_split_on_punc(mrph.midasi))
320 |
321 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
322 | return output_tokens
323 |
324 |
325 | class MeCabTokenizer(BasicTokenizer):
326 | def __init__(self):
327 | import MeCab
328 |
329 | self.do_lower_case = False
330 | self._mecab = MeCab.Tagger('-Owakati')
331 |
332 | def tokenize(self, text):
333 | """Tokenizes a piece of text with Juman."""
334 |
335 | text = convert_to_unicode(text)
336 | text = self._clean_text(text)
337 |
338 | mecab_result = self._mecab.parse(text)
339 | output_tokens = mecab_result.split(' ')
340 | return output_tokens
341 |
342 |
343 |
344 | class WordpieceTokenizer(object):
345 | """Runs WordPiece tokenziation."""
346 |
347 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
348 | self.vocab = vocab
349 | self.unk_token = unk_token
350 | self.max_input_chars_per_word = max_input_chars_per_word
351 |
352 | def tokenize(self, text):
353 | """Tokenizes a piece of text into its word pieces.
354 |
355 | This uses a greedy longest-match-first algorithm to perform tokenization
356 | using the given vocabulary.
357 |
358 | For example:
359 | input = "unaffable"
360 | output = ["un", "##aff", "##able"]
361 |
362 | Args:
363 | text: A single token or whitespace separated tokens. This should have
364 | already been passed through `BasicTokenizer.
365 |
366 | Returns:
367 | A list of wordpiece tokens.
368 | """
369 |
370 | text = convert_to_unicode(text)
371 |
372 | output_tokens = []
373 | for token in whitespace_tokenize(text):
374 | chars = list(token)
375 | if len(chars) > self.max_input_chars_per_word:
376 | output_tokens.append(self.unk_token)
377 | continue
378 |
379 | is_bad = False
380 | start = 0
381 | sub_tokens = []
382 | while start < len(chars):
383 | end = len(chars)
384 | cur_substr = None
385 | while start < end:
386 | substr = "".join(chars[start:end])
387 | if start > 0:
388 | substr = "##" + substr
389 | if substr in self.vocab:
390 | cur_substr = substr
391 | break
392 | end -= 1
393 | if cur_substr is None:
394 | is_bad = True
395 | break
396 | sub_tokens.append(cur_substr)
397 | start = end
398 |
399 | if is_bad:
400 | output_tokens.append(self.unk_token)
401 | else:
402 | output_tokens.extend(sub_tokens)
403 | return output_tokens
404 |
405 |
406 | def _is_whitespace(char):
407 | """Checks whether `chars` is a whitespace character."""
408 | # \t, \n, and \r are technically contorl characters but we treat them
409 | # as whitespace since they are generally considered as such.
410 | if char == " " or char == "\t" or char == "\n" or char == "\r":
411 | return True
412 | cat = unicodedata.category(char)
413 | if cat == "Zs":
414 | return True
415 | return False
416 |
417 |
418 | def _is_control(char):
419 | """Checks whether `chars` is a control character."""
420 | # These are technically control characters but we count them as whitespace
421 | # characters.
422 | if char == "\t" or char == "\n" or char == "\r":
423 | return False
424 | cat = unicodedata.category(char)
425 | if cat in ("Cc", "Cf"):
426 | return True
427 | return False
428 |
429 |
430 | def _is_punctuation(char):
431 | """Checks whether `chars` is a punctuation character."""
432 | cp = ord(char)
433 | # We treat all non-letter/number ASCII as punctuation.
434 | # Characters such as "^", "$", and "`" are not in the Unicode
435 | # Punctuation class but we treat them as punctuation anyways, for
436 | # consistency.
437 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
438 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
439 | return True
440 | cat = unicodedata.category(char)
441 | if cat.startswith("P"):
442 | return True
443 | return False
444 |
--------------------------------------------------------------------------------