├── data ├── __init__.py ├── data.py └── batcher.py ├── model ├── __init__.py ├── attention.py ├── TransformerEncoder.py ├── util.py ├── rnn.py ├── DeepLSTM.py └── extract.py ├── transformer ├── Layers.pyc ├── Constants.pyc ├── Modules.pyc ├── SubLayers.pyc ├── __init__.pyc ├── Constants.py ├── __init__.py ├── Modules.py ├── Optim.py ├── Layers.py ├── SubLayers.py ├── Beam.py ├── Translator.py └── Models.py ├── evaluate.py ├── utils.py ├── decoding.py ├── metric.py ├── README.md ├── training.py └── main.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformer/Layers.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/Effective_Extractive_Summarization/HEAD/transformer/Layers.pyc -------------------------------------------------------------------------------- /transformer/Constants.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/Effective_Extractive_Summarization/HEAD/transformer/Constants.pyc -------------------------------------------------------------------------------- /transformer/Modules.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/Effective_Extractive_Summarization/HEAD/transformer/Modules.pyc -------------------------------------------------------------------------------- /transformer/SubLayers.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/Effective_Extractive_Summarization/HEAD/transformer/SubLayers.pyc -------------------------------------------------------------------------------- /transformer/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/Effective_Extractive_Summarization/HEAD/transformer/__init__.pyc -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | 7 | PAD_WORD = '' 8 | UNK_WORD = '' 9 | BOS_WORD = '' 10 | EOS_WORD = '' 11 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import transformer.Constants 2 | import transformer.Modules 3 | import transformer.Layers 4 | import transformer.SubLayers 5 | import transformer.Models 6 | import transformer.Translator 7 | import transformer.Beam 8 | import transformer.Optim 9 | 10 | __all__ = [ 11 | transformer.Constants, transformer.Modules, transformer.Layers, 12 | transformer.SubLayers, transformer.Models, transformer.Optim, 13 | transformer.Translator, transformer.Beam] 14 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | 18 | attn = torch.bmm(q, k.transpose(1, 2)) 19 | attn = attn / self.temperature 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask, -np.inf) 23 | 24 | attn = self.softmax(attn) 25 | attn = self.dropout(attn) 26 | output = torch.bmm(attn, v) 27 | 28 | return output, attn 29 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | """ CNN/DM dataset""" 2 | import json 3 | import re 4 | import os 5 | from os.path import join 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class CnnDmDataset(Dataset): 11 | def __init__(self, split: str, path: str) -> None: 12 | assert split in ['train', 'val', 'test'] 13 | self._data_path = join(path, split) 14 | self._n_data = _count_data(self._data_path) 15 | 16 | def __len__(self) -> int: 17 | return self._n_data 18 | 19 | def __getitem__(self, i: int): 20 | with open(join(self._data_path, '{}.json'.format(i))) as f: 21 | js = json.loads(f.read()) 22 | return js 23 | 24 | 25 | def _count_data(path): 26 | """ count number of data in the given path""" 27 | matcher = re.compile(r'[0-9]+\.json') 28 | match = lambda name: bool(matcher.match(name)) 29 | names = os.listdir(path) 30 | n_data = len(list(filter(match, names))) 31 | return n_data 32 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | """ attention functions """ 2 | from torch.nn import functional as F 3 | 4 | 5 | def dot_attention_score(key, query): 6 | """[B, Tk, D], [(Bs), B, Tq, D] -> [(Bs), B, Tq, Tk]""" 7 | return query.matmul(key.transpose(1, 2)) 8 | 9 | def prob_normalize(score, mask): 10 | """ [(...), T] 11 | user should handle mask shape""" 12 | score = score.masked_fill(mask == 0, -1e18) 13 | norm_score = F.softmax(score, dim=-1) 14 | return norm_score 15 | 16 | def attention_aggregate(value, score): 17 | """[B, Tv, D], [(Bs), B, Tq, Tv] -> [(Bs), B, Tq, D]""" 18 | output = score.matmul(value) 19 | return output 20 | 21 | 22 | def step_attention(query, key, value, mem_mask=None): 23 | """ query[(Bs), B, D], key[B, T, D], value[B, T, D]""" 24 | score = dot_attention_score(key, query.unsqueeze(-2)) 25 | if mem_mask is None: 26 | norm_score = F.softmax(score, dim=-1) 27 | else: 28 | norm_score = prob_normalize(score, mem_mask) 29 | output = attention_aggregate(value, norm_score) 30 | return output.squeeze(-2), norm_score.squeeze(-2) 31 | -------------------------------------------------------------------------------- /transformer/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | class ScheduledOptim(): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps): 8 | self._optimizer = optimizer 9 | self.n_warmup_steps = n_warmup_steps 10 | self.n_current_steps = 0 11 | self.init_lr = np.power(d_model, -0.5) 12 | 13 | def step_and_update_lr(self): 14 | "Step with the inner optimizer" 15 | self._update_learning_rate() 16 | self._optimizer.step() 17 | 18 | def zero_grad(self): 19 | "Zero out the gradients by the inner optimizer" 20 | self._optimizer.zero_grad() 21 | 22 | def _get_lr_scale(self): 23 | return np.min([ 24 | np.power(self.n_current_steps, -0.5), 25 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 26 | 27 | def _update_learning_rate(self): 28 | ''' Learning rate scheduling per step ''' 29 | 30 | self.n_current_steps += 1 31 | lr = self.init_lr * self._get_lr_scale() 32 | 33 | for param_group in self._optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ evaluation scripts""" 2 | import re 3 | import os 4 | from os.path import join 5 | import logging 6 | import tempfile 7 | import subprocess as sp 8 | 9 | from cytoolz import curry 10 | 11 | from pyrouge import Rouge155 12 | from pyrouge.utils import log 13 | 14 | _ROUGE_PATH = '/path/to/RELEASE-1.5.5' 15 | 16 | def eval_rouge(dec_dir, ref_dir): 17 | """ evaluate by original Perl implementation""" 18 | # silence pyrouge logging 19 | assert _ROUGE_PATH is not None 20 | log.get_global_console_logger().setLevel(logging.WARNING) 21 | dec_pattern = '(\d+).dec' 22 | ref_pattern = '#ID#.ref' 23 | cmd = '-c 95 -r 1000 -n 2 -m' 24 | with tempfile.TemporaryDirectory() as tmp_dir: 25 | Rouge155.convert_summaries_to_rouge_format( 26 | dec_dir, join(tmp_dir, 'dec')) 27 | Rouge155.convert_summaries_to_rouge_format( 28 | ref_dir, join(tmp_dir, 'ref')) 29 | Rouge155.write_config_static( 30 | join(tmp_dir, 'dec'), dec_pattern, 31 | join(tmp_dir, 'ref'), ref_pattern, 32 | join(tmp_dir, 'settings.xml'), system_id=1 33 | ) 34 | cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl') 35 | + ' -e {} '.format(join(_ROUGE_PATH, 'data')) 36 | + cmd 37 | + ' -a {}'.format(join(tmp_dir, 'settings.xml'))) 38 | output = sp.check_output(cmd.split(' '), universal_newlines=True) 39 | return output 40 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ utility functions""" 2 | import re 3 | import os 4 | from os.path import basename 5 | 6 | import gensim 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def count_data(path): 12 | """ count number of data in the given path""" 13 | matcher = re.compile(r'[0-9]+\.json') 14 | match = lambda name: bool(matcher.match(name)) 15 | names = os.listdir(path) 16 | n_data = len(list(filter(match, names))) 17 | return n_data 18 | 19 | 20 | PAD = 0 21 | UNK = 1 22 | START = 2 23 | END = 3 24 | def make_vocab(wc, vocab_size): 25 | word2id, id2word = {}, {} 26 | word2id[''] = PAD 27 | word2id[''] = UNK 28 | word2id[''] = START 29 | word2id[''] = END 30 | for i, (w, _) in enumerate(wc.most_common(vocab_size), 4): 31 | word2id[w] = i 32 | return word2id 33 | 34 | 35 | def make_embedding(id2word, w2v_file, initializer=None): 36 | attrs = basename(w2v_file).split('.') #word2vec.{dim}d.{vsize}k.bin 37 | w2v = gensim.models.Word2Vec.load(w2v_file).wv 38 | vocab_size = len(id2word) 39 | emb_dim = int(attrs[-3][:-1]) 40 | embedding = nn.Embedding(vocab_size, emb_dim).weight 41 | if initializer is not None: 42 | initializer(embedding) 43 | 44 | oovs = [] 45 | with torch.no_grad(): 46 | for i in range(len(id2word)): 47 | # NOTE: id2word can be list or dict 48 | if i == START: 49 | embedding[i, :] = torch.Tensor(w2v['']) 50 | elif i == END: 51 | embedding[i, :] = torch.Tensor(w2v[r'<\s>']) 52 | elif id2word[i] in w2v: 53 | embedding[i, :] = torch.Tensor(w2v[id2word[i]]) 54 | else: 55 | oovs.append(i) 56 | return embedding, oovs 57 | -------------------------------------------------------------------------------- /model/TransformerEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformer.Layers import EncoderLayer 5 | 6 | 7 | class TransformerEncoder(nn.Module): 8 | 9 | def __init__(self, input_size, hidden_size, n_layer, decoder, ff_size=2048, n_head=8, dropout=0.1): 10 | 11 | super(TransformerEncoder, self).__init__() 12 | 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.d_inner = ff_size 16 | self.n_head = n_head 17 | self.dropout = dropout 18 | self.decoder = decoder 19 | if decoder == 'SL': 20 | self.n_head = 16 21 | self.dropout = 0.2 22 | self.n_layer = n_layer 23 | self.d_k = self.d_v = int(self.hidden_size / self.n_head) 24 | 25 | self.layer_stack = nn.ModuleList([ 26 | EncoderLayer(self.hidden_size, self.d_inner, self.n_head, self.d_k, self.d_v, dropout=self.dropout) 27 | for _ in range(self.n_layer)]) 28 | 29 | def forward(self, inputs, non_pad_mask, attn_mask): 30 | 31 | bs, seq_len = inputs.size(0), inputs.size(1) 32 | 33 | assert inputs.size() == (bs, seq_len, self.hidden_size) 34 | assert non_pad_mask.size() == (bs, seq_len, 1) 35 | assert attn_mask.size() == (bs, seq_len, seq_len) 36 | 37 | output = [] 38 | for layer in self.layer_stack: 39 | inputs, enc_slf_atten = layer(inputs, non_pad_mask=non_pad_mask, slf_attn_mask=attn_mask) 40 | output.append(inputs) 41 | if self.decoder == 'SL': 42 | output = output[-1] + output[-2] + output[-3] + output[-4] 43 | else: 44 | output = output[-1] 45 | 46 | assert output.size() == (bs, seq_len, self.hidden_size) 47 | 48 | return output 49 | 50 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | __author__ = "Yu-Hsiang Huang" 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | ''' Compose with two layers ''' 10 | 11 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 12 | super(EncoderLayer, self).__init__() 13 | self.slf_attn = MultiHeadAttention( 14 | n_head, d_model, d_k, d_v, dropout=dropout) 15 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 16 | 17 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 18 | enc_output, enc_slf_attn = self.slf_attn( 19 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 20 | enc_output *= non_pad_mask 21 | 22 | enc_output = self.pos_ffn(enc_output) 23 | enc_output *= non_pad_mask 24 | 25 | return enc_output, enc_slf_attn 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | ''' Compose with three layers ''' 30 | 31 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 32 | super(DecoderLayer, self).__init__() 33 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 34 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 36 | 37 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 38 | dec_output, dec_slf_attn = self.slf_attn( 39 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 40 | dec_output *= non_pad_mask 41 | 42 | dec_output, dec_enc_attn = self.enc_attn( 43 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 44 | dec_output *= non_pad_mask 45 | 46 | dec_output = self.pos_ffn(dec_output) 47 | dec_output *= non_pad_mask 48 | 49 | return dec_output, dec_slf_attn, dec_enc_attn 50 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformer.Modules import ScaledDotProductAttention 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | class MultiHeadAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.xavier_normal_(self.w_qs.weight) 23 | nn.init.xavier_normal_(self.w_ks.weight) 24 | nn.init.xavier_normal_(self.w_vs.weight) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | 35 | def forward(self, q, k, v, mask=None): 36 | 37 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 38 | 39 | sz_b, len_q, _ = q.size() 40 | sz_b, len_k, _ = k.size() 41 | sz_b, len_v, _ = v.size() 42 | 43 | residual = q 44 | 45 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 46 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 47 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 48 | 49 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 50 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 51 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 52 | 53 | if mask is not None: 54 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 55 | output, attn = self.attention(q, k, v, mask=mask) 56 | 57 | output = output.view(n_head, sz_b, len_q, d_v) 58 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 59 | 60 | output = self.dropout(self.fc(output)) 61 | output = self.layer_norm(output + residual) 62 | 63 | return output, attn 64 | 65 | class PositionwiseFeedForward(nn.Module): 66 | ''' A two-feed-forward-layer module ''' 67 | 68 | def __init__(self, d_in, d_hid, dropout=0.1): 69 | super().__init__() 70 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 71 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 72 | self.layer_norm = nn.LayerNorm(d_in) 73 | self.dropout = nn.Dropout(dropout) 74 | 75 | def forward(self, x): 76 | residual = x 77 | output = x.transpose(1, 2) 78 | output = self.w_2(F.relu(self.w_1(output))) 79 | output = output.transpose(1, 2) 80 | output = self.dropout(output) 81 | output = self.layer_norm(output + residual) 82 | return output 83 | -------------------------------------------------------------------------------- /transformer/Beam.py: -------------------------------------------------------------------------------- 1 | """ Manage beam search info structure. 2 | 3 | Heavily borrowed from OpenNMT-py. 4 | For code in OpenNMT-py, please check the following link: 5 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import transformer.Constants as Constants 11 | 12 | class Beam(): 13 | ''' Beam search ''' 14 | 15 | def __init__(self, size, device=False): 16 | 17 | self.size = size 18 | self._done = False 19 | 20 | # The score for each translation on the beam. 21 | self.scores = torch.zeros((size,), dtype=torch.float, device=device) 22 | self.all_scores = [] 23 | 24 | # The backpointers at each time-step. 25 | self.prev_ks = [] 26 | 27 | # The outputs at each time-step. 28 | self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] 29 | self.next_ys[0][0] = Constants.BOS 30 | 31 | def get_current_state(self): 32 | "Get the outputs for the current timestep." 33 | return self.get_tentative_hypothesis() 34 | 35 | def get_current_origin(self): 36 | "Get the backpointers for the current timestep." 37 | return self.prev_ks[-1] 38 | 39 | @property 40 | def done(self): 41 | return self._done 42 | 43 | def advance(self, word_prob): 44 | "Update beam status and check if finished or not." 45 | num_words = word_prob.size(1) 46 | 47 | # Sum the previous scores. 48 | if len(self.prev_ks) > 0: 49 | beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) 50 | else: 51 | beam_lk = word_prob[0] 52 | 53 | flat_beam_lk = beam_lk.view(-1) 54 | 55 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 56 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 57 | 58 | self.all_scores.append(self.scores) 59 | self.scores = best_scores 60 | 61 | # bestScoresId is flattened as a (beam x word) array, 62 | # so we need to calculate which word and beam each score came from 63 | prev_k = best_scores_id / num_words 64 | self.prev_ks.append(prev_k) 65 | self.next_ys.append(best_scores_id - prev_k * num_words) 66 | 67 | # End condition is when top-of-beam is EOS. 68 | if self.next_ys[-1][0].item() == Constants.EOS: 69 | self._done = True 70 | self.all_scores.append(self.scores) 71 | 72 | return self._done 73 | 74 | def sort_scores(self): 75 | "Sort the scores." 76 | return torch.sort(self.scores, 0, True) 77 | 78 | def get_the_best_score_and_idx(self): 79 | "Get the score of the best in the beam." 80 | scores, ids = self.sort_scores() 81 | return scores[1], ids[1] 82 | 83 | def get_tentative_hypothesis(self): 84 | "Get the decoded sequence for the current timestep." 85 | 86 | if len(self.next_ys) == 1: 87 | dec_seq = self.next_ys[0].unsqueeze(1) 88 | else: 89 | _, keys = self.sort_scores() 90 | hyps = [self.get_hypothesis(k) for k in keys] 91 | hyps = [[Constants.BOS] + h for h in hyps] 92 | dec_seq = torch.LongTensor(hyps) 93 | 94 | return dec_seq 95 | 96 | def get_hypothesis(self, k): 97 | """ Walk back to construct the full hypothesis. """ 98 | hyp = [] 99 | for j in range(len(self.prev_ks) - 1, -1, -1): 100 | hyp.append(self.next_ys[j+1][k]) 101 | k = self.prev_ks[j][k] 102 | 103 | return list(map(lambda x: x.item(), hyp[::-1])) 104 | -------------------------------------------------------------------------------- /decoding.py: -------------------------------------------------------------------------------- 1 | """ decoding utilities""" 2 | import json 3 | import re 4 | import os 5 | import torch 6 | from os.path import join 7 | import pickle as pkl 8 | from itertools import starmap 9 | 10 | from cytoolz import curry 11 | from pytorch_pretrained_bert.tokenization import BertTokenizer 12 | 13 | from utils import PAD, UNK, START, END 14 | from model.extract import Summarizer 15 | from model.rl import ActorCritic 16 | from data.batcher import conver2id, pad_batch_tensorize 17 | from data.data import CnnDmDataset 18 | 19 | 20 | DATASET_DIR = './CNNDM' 21 | 22 | class DecodeDataset(CnnDmDataset): 23 | """ get the article sentences only (for decoding use)""" 24 | def __init__(self, split): 25 | assert split in ['val', 'test'] 26 | super().__init__(split, DATASET_DIR) 27 | 28 | def __getitem__(self, i): 29 | js_data = super().__getitem__(i) 30 | art_sents = js_data['article'] 31 | return art_sents 32 | 33 | 34 | def make_html_safe(s): 35 | """Rouge use html, has to make output html safe""" 36 | return s.replace("<", "<").replace(">", ">") 37 | 38 | 39 | def sort_ckpt(model_dir, reverse=False): 40 | """ reverse=False->loss, reverse=True->reward """ 41 | ckpts = os.listdir(join(model_dir, 'ckpt')) 42 | ckpt_matcher = re.compile('^ckpt-.*-[0-9]*') 43 | ckpts = sorted([c for c in ckpts if ckpt_matcher.match(c)], 44 | key=lambda c: float(c.split('-')[1]), reverse=reverse) 45 | return ckpts 46 | 47 | def get_n_ext(split, idx): 48 | path = join(DATASET_DIR, '{}/{}.json'.format(split, idx)) 49 | with open(path) as f: 50 | data = json.loads(f.read()) 51 | if data['source'] == 'CNN': 52 | return 2 53 | else: 54 | return 3 55 | 56 | class Extractor(object): 57 | def __init__(self, ext_dir, ext_ckpt, emb_type, max_ext=6, cuda=True): 58 | ext_meta = json.load(open(join(ext_dir, 'meta.json'))) 59 | ext_args = ext_meta['model_args'] 60 | extractor = Summarizer(**ext_args) 61 | extractor.load_state_dict(ext_ckpt) 62 | if emb_type == 'W2V': 63 | word2id = pkl.load(open(join(ext_dir, 'vocab.pkl'), 'rb')) 64 | self._word2id = word2id 65 | else: 66 | self._tokenizer = BertTokenizer.from_pretrained( 67 | '/path/to/uncased_L-24_H-1024_A-16/vocab.txt') 68 | self._emb_type = emb_type 69 | self._device = torch.device('cuda' if cuda else 'cpu') 70 | self._net = extractor.to(self._device) 71 | ## self._id2word = {i: w for w, i in word2id.items()} 72 | self._max_ext = max_ext 73 | 74 | def __call__(self, raw_article_sents): 75 | self._net.eval() 76 | n_art = len(raw_article_sents) 77 | if self._emb_type == 'W2V': 78 | articles = conver2id(UNK, self._word2id, raw_article_sents) 79 | else: 80 | articles = [self._tokenizer.convert_tokens_to_ids(sentence) 81 | for sentence in raw_article_sents] 82 | article = pad_batch_tensorize(articles, PAD, cuda=False 83 | ).to(self._device) 84 | indices = self._net.extract([article], k=min(n_art, self._max_ext)) 85 | return indices 86 | 87 | 88 | class ArticleBatcher(object): 89 | def __init__(self, word2id, cuda=True): 90 | self._device = torch.device('cuda' if cuda else 'cpu') 91 | self._word2id = word2id 92 | self._device = torch.device('cuda' if cuda else 'cpu') 93 | 94 | def __call__(self, raw_article_sents): 95 | articles = conver2id(UNK, self._word2id, raw_article_sents) 96 | article = pad_batch_tensorize(articles, PAD, cuda=False 97 | ).to(self._device) 98 | return article 99 | 100 | -------------------------------------------------------------------------------- /model/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | #################### general sequence helper ######################### 8 | def len_mask(lens, device): 9 | """ users are resposible for shaping 10 | Return: tensor_type [B, T] 11 | """ 12 | max_len = max(lens) 13 | batch_size = len(lens) 14 | mask = torch.ByteTensor(batch_size, max_len).to(device) 15 | mask.fill_(0) 16 | for i, l in enumerate(lens): 17 | mask[i, :l].fill_(1) 18 | return mask 19 | 20 | def sequence_mean(sequence, seq_lens, dim=1): 21 | if seq_lens: 22 | assert sequence.size(0) == len(seq_lens) # batch_size 23 | sum_ = torch.sum(sequence, dim=dim, keepdim=False) 24 | mean = torch.stack([s/l for s, l in zip(sum_, seq_lens)], dim=0) 25 | else: 26 | mean = torch.mean(sequence, dim=dim, keepdim=False) 27 | return mean 28 | 29 | def sequence_loss(logits, targets, sent_num, decoder, xent_fn=None, pad_idx=0): 30 | """ functional interface of SequenceLoss""" 31 | 32 | if decoder == 'PN': 33 | 34 | assert logits.size()[:-1] == targets.size() 35 | 36 | mask = targets != pad_idx 37 | target = targets.masked_select(mask) 38 | logit = logits.masked_select( 39 | mask.unsqueeze(2).expand_as(logits) 40 | ).contiguous().view(-1, logits.size(-1)) 41 | 42 | else: 43 | bs, seq_len, d = logits.size() 44 | target_len = targets.size(-1) 45 | assert d == 2 46 | 47 | sent_mask = len_mask(sent_num, logits.get_device()) 48 | target = torch.LongTensor(bs, seq_len).to(logits.device) 49 | target.fill_(0) 50 | 51 | for i in range(bs): 52 | for j in range(target_len): 53 | if targets[i][j] != -1: 54 | target[i][targets[i][j]] = 1 55 | sent_mask = sent_mask.contiguous().view(-1) 56 | target = target.contiguous().view(-1) 57 | logit = logits.contiguous().view(-1, logits.size(-1)) 58 | 59 | target = target.masked_select(sent_mask).contiguous().view(-1) 60 | logit = logit.masked_select(sent_mask.unsqueeze(-1).expand_as(logit)).contiguous().view(-1, logit.size(-1)) 61 | 62 | assert logit.size(-1) == 2 63 | assert logit.size()[:-1] == target.size() 64 | 65 | if xent_fn: 66 | loss = xent_fn(logit, target) 67 | else: 68 | loss = F.cross_entropy(logit, target) 69 | assert (not math.isnan(loss.mean().item()) 70 | and not math.isinf(loss.mean().item())) 71 | return loss 72 | 73 | 74 | #################### LSTM helper ######################### 75 | 76 | def reorder_sequence(sequence_emb, order, batch_first=False): 77 | """ 78 | sequence_emb: [T, B, D] if not batch_first 79 | order: list of sequence length 80 | """ 81 | batch_dim = 0 if batch_first else 1 82 | assert len(order) == sequence_emb.size()[batch_dim] 83 | 84 | order = torch.LongTensor(order).to(sequence_emb.get_device()) 85 | sorted_ = sequence_emb.index_select(index=order, dim=batch_dim) 86 | 87 | return sorted_ 88 | 89 | def reorder_lstm_states(lstm_states, order): 90 | """ 91 | lstm_states: (H, C) of tensor [layer, batch, hidden] 92 | order: list of sequence length 93 | """ 94 | assert isinstance(lstm_states, tuple) 95 | assert len(lstm_states) == 2 96 | assert lstm_states[0].size() == lstm_states[1].size() 97 | assert len(order) == lstm_states[0].size()[1] 98 | 99 | order = torch.LongTensor(order).to(lstm_states[0].get_device()) 100 | sorted_states = (lstm_states[0].index_select(index=order, dim=1), 101 | lstm_states[1].index_select(index=order, dim=1)) 102 | 103 | return sorted_states 104 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | """ ROUGE utils""" 2 | import os 3 | import threading 4 | import subprocess as sp 5 | from collections import Counter, deque 6 | 7 | from cytoolz import concat, curry 8 | 9 | 10 | def make_n_grams(seq, n): 11 | """ return iterator """ 12 | ngrams = (tuple(seq[i:i+n]) for i in range(len(seq)-n+1)) 13 | return ngrams 14 | 15 | def _n_gram_match(summ, ref, n): 16 | summ_grams = Counter(make_n_grams(summ, n)) 17 | ref_grams = Counter(make_n_grams(ref, n)) 18 | grams = min(summ_grams, ref_grams, key=len) 19 | count = sum(min(summ_grams[g], ref_grams[g]) for g in grams) 20 | return count 21 | 22 | @curry 23 | def compute_rouge_n(output, reference, n=1, mode='f'): 24 | """ compute ROUGE-N for a single pair of summary and reference""" 25 | assert mode in list('fpr') # F-1, precision, recall 26 | match = _n_gram_match(reference, output, n) 27 | if match == 0: 28 | score = 0.0 29 | else: 30 | precision = match / len(output) 31 | recall = match / len(reference) 32 | f_score = 2 * (precision * recall) / (precision + recall) 33 | if mode == 'p': 34 | score = precision 35 | if mode == 'r': 36 | score = recall 37 | else: 38 | score = f_score 39 | return score 40 | 41 | 42 | def _lcs_dp(a, b): 43 | """ compute the len dp of lcs""" 44 | dp = [[0 for _ in range(0, len(b)+1)] 45 | for _ in range(0, len(a)+1)] 46 | # dp[i][j]: lcs_len(a[:i], b[:j]) 47 | for i in range(1, len(a)+1): 48 | for j in range(1, len(b)+1): 49 | if a[i-1] == b[j-1]: 50 | dp[i][j] = dp[i-1][j-1] + 1 51 | else: 52 | dp[i][j] = max(dp[i-1][j], dp[i][j-1]) 53 | return dp 54 | 55 | def _lcs_len(a, b): 56 | """ compute the length of longest common subsequence between a and b""" 57 | dp = _lcs_dp(a, b) 58 | return dp[-1][-1] 59 | 60 | @curry 61 | def compute_rouge_l(output, reference, mode='f'): 62 | """ compute ROUGE-L for a single pair of summary and reference 63 | output, reference are list of words 64 | """ 65 | assert mode in list('fpr') # F-1, precision, recall 66 | lcs = _lcs_len(output, reference) 67 | if lcs == 0: 68 | score = 0.0 69 | else: 70 | precision = lcs / len(output) 71 | recall = lcs / len(reference) 72 | f_score = 2 * (precision * recall) / (precision + recall) 73 | if mode == 'p': 74 | score = precision 75 | if mode == 'r': 76 | score = recall 77 | else: 78 | score = f_score 79 | return score 80 | 81 | 82 | def _lcs(a, b): 83 | """ compute the longest common subsequence between a and b""" 84 | dp = _lcs_dp(a, b) 85 | i = len(a) 86 | j = len(b) 87 | lcs = deque() 88 | while (i > 0 and j > 0): 89 | if a[i-1] == b[j-1]: 90 | lcs.appendleft(a[i-1]) 91 | i -= 1 92 | j -= 1 93 | elif dp[i-1][j] >= dp[i][j-1]: 94 | i -= 1 95 | else: 96 | j -= 1 97 | assert len(lcs) == dp[-1][-1] 98 | return lcs 99 | 100 | def compute_rouge_l_summ(summs, refs, mode='f'): 101 | """ summary level ROUGE-L""" 102 | assert mode in list('fpr') # F-1, precision, recall 103 | tot_hit = 0 104 | ref_cnt = Counter(concat(refs)) 105 | summ_cnt = Counter(concat(summs)) 106 | for ref in refs: 107 | for summ in summs: 108 | lcs = _lcs(summ, ref) 109 | for gram in lcs: 110 | if ref_cnt[gram] > 0 and summ_cnt[gram] > 0: 111 | tot_hit += 1 112 | ref_cnt[gram] -= 1 113 | summ_cnt[gram] -= 1 114 | if tot_hit == 0: 115 | score = 0.0 116 | else: 117 | precision = tot_hit / sum((len(s) for s in summs)) 118 | recall = tot_hit / sum((len(r) for r in refs)) 119 | f_score = 2 * (precision * recall) / (precision + recall) 120 | if mode == 'p': 121 | score = precision 122 | if mode == 'r': 123 | score = recall 124 | else: 125 | score = f_score 126 | return score 127 | 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Effective_Extractive_Summarization 2 | Code for ACL 2019 paper (oral): 3 | *[Searching for Effective Neural Extractive Summarization: What Works and What's Next](https://arxiv.org/abs/1907.03491)* 4 | 5 | If you use our code or data, please cite our paper: 6 | ``` 7 | @inproceedings{zhong2019searching, 8 | title={Searching for Effective Neural Extractive Summarization: What Works and What’s Next}, 9 | author={Zhong, Ming and Liu, Pengfei and Wang, Danqing and Qiu, Xipeng and Huang, Xuan-Jing}, 10 | booktitle={Proceedings of the 57th Conference of the Association for Computational Linguistics}, 11 | pages={1049--1058}, 12 | year={2019} 13 | } 14 | 15 | ``` 16 | 17 | ## Dependencies 18 | - Python 3.7 19 | - [PyTorch](https://github.com/pytorch/pytorch) 1.1.0 20 | - [gensim](https://github.com/RaRe-Technologies/gensim) 21 | - [cytoolz](https://github.com/pytoolz/cytoolz) 22 | - [tensorboardX](https://github.com/lanpa/tensorboard-pytorch) 23 | - [pyrouge](https://github.com/bheinzerling/pyrouge) 24 | - [pytorch-pretrained-bert](https://github.com/huggingface/pytorch-transformers) 0.6.1 25 | - now is pytorch-transformers, you can use *pip* to install pytorch-pretrained-bert (0.6.1) 26 | - You should download the BERT model(bert-large-uncased) and convert it to a pytorch version, get a folder called `uncased_L-24_H-1024_A-16` 27 | 28 | All code only supports running on *Linux*. 29 | 30 | ## Data 31 | 32 | We have already processed CNN/DailyMail dataset, you can download it through [this link](https://drive.google.com/open?id=1QB9hVPF_YkJslaX4INnUZGS9OVL1Pr3O), unzip and store it in the current path (contrains train, val, test, refs, word2vec folders and vocab_cnt.pkl, you should put them in `./CNNDM`) 33 | 34 | ## Path 35 | 36 | You should fill in the three paths in the files before running the code. 37 | 1. path to RELEASE-1.5.5 (evaluate.py line 14), example: `/home/ROUGE/RELEASE-1.5.5` 38 | 2. path to vocab.txt (decoding.py line 67 and data/batcher.py line 13), example: `/home/pretrain_model/uncased_L-24_H-1024_A-16/vocab.txt` 39 | 3. path to BERT model (model/extract.py line 255), example: `/home/pretrain_model/uncased_L-24_H-1024_A-16` 40 | 41 | ## Train 42 | 43 | We currently provide a variety of options to combine into a model. For the encoder, we provide **BiLSTM/Transformer/DeepLSTM**. For the decoder, we provide **Sequence Labeling/Pointer Network**. For the type of word embedding, we provide **Word2Vec/BERT**. 44 | We only tested the code on the GPU, and we strongly recommend using the GPU to train your model because of the long training time. 45 | 46 | To run BiLSTM + Pointer Network + Word2Vec model, run 47 | 48 | ``` 49 | CUDA_VISIBLE_DEVICES=0 python main.py --mode=train --encoder=BiLSTM --decoder=PN --emb_type=W2V 50 | ``` 51 | 52 | To run Transformer + Sequence Labeling + Word2Vec model, run 53 | 54 | ``` 55 | CUDA_VISIBLE_DEVICES=0 python main.py --mode=train --encoder=Transformer --decoder=SL --emb_type=W2V 56 | ``` 57 | 58 | To run DeepLSTM + Pointer Network + BERT model (models with BERT have a long training time), run 59 | 60 | ``` 61 | CUDA_VISIBLE_DEVICES=0 python main.py --mode=train --encoder=DeepLSTM --decoder=PN --emb_type=BERT 62 | ``` 63 | 64 | You can try any other combination to train your own model. 65 | 66 | ## Test 67 | 68 | After completing the training process, you can test the five best models and obtain ROUGE score by the following instructions. 69 | You only need to switch mode to test, leaving other commands unchanged. 70 | 71 | For example, when you test BiLSTM + Pointer Network + Word2Vec model, run 72 | 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0 python main.py --mode=test --encoder=BiLSTM --decoder=PN --emb_type=W2V 75 | ``` 76 | The results will be printed on the screen and saved in the `BiLSTM_PN_W2V` folder. 77 | 78 | 79 | ## Output 80 | 81 | You can find the outputs produced by our different models in this paper on CNN/DailyMail through [this link](https://drive.google.com/open?id=1VuZp2Wq3TH4kaXzuID34_e6VkBxa5Gkf). 82 | 83 | ## Note 84 | 1. Part of our code uses the the implementation of [fast_abs_rl](https://github.com/ChenRocks/fast_abs_rl) and [Transformer](https://github.com/jadore801120/attention-is-all-you-need-pytorch). Thanks for their work! 85 | 2. For the code of reinforcement learning, we use the implementation in [fast_abs_rl](https://github.com/ChenRocks/fast_abs_rl). The only difference is that we changed the parameter *reward* in their file *train_full_rl.py* to ROUGE-1-precision and the parameter *stop* to 8.5. If you are interested in the implementation of this part, please directly refer to their code. 86 | -------------------------------------------------------------------------------- /model/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | 6 | from .util import reorder_sequence, reorder_lstm_states 7 | 8 | 9 | def lstm_encoder(sequence, lstm, 10 | seq_lens=None, init_states=None, embedding=None): 11 | """ functional LSTM encoder (sequence is [b, t]/[b, t, d], 12 | lstm should be rolled lstm)""" 13 | batch_size = sequence.size(0) 14 | if not lstm.batch_first: 15 | sequence = sequence.transpose(0, 1) 16 | emb_sequence = (embedding(sequence) if embedding is not None 17 | else sequence) 18 | if seq_lens: 19 | assert batch_size == len(seq_lens) 20 | sort_ind = sorted(range(len(seq_lens)), 21 | key=lambda i: seq_lens[i], reverse=True) 22 | seq_lens = [seq_lens[i] for i in sort_ind] 23 | emb_sequence = reorder_sequence(emb_sequence, sort_ind, 24 | lstm.batch_first) 25 | 26 | if init_states is None: 27 | device = sequence.get_device() 28 | init_states = init_lstm_states(lstm, batch_size, device) 29 | else: 30 | init_states = (init_states[0].contiguous(), 31 | init_states[1].contiguous()) 32 | 33 | if seq_lens: 34 | packed_seq = nn.utils.rnn.pack_padded_sequence(emb_sequence, 35 | seq_lens) 36 | packed_out, final_states = lstm(packed_seq, init_states) 37 | lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out) 38 | 39 | back_map = {ind: i for i, ind in enumerate(sort_ind)} 40 | reorder_ind = [back_map[i] for i in range(len(seq_lens))] 41 | lstm_out = reorder_sequence(lstm_out, reorder_ind, lstm.batch_first) 42 | final_states = reorder_lstm_states(final_states, reorder_ind) 43 | else: 44 | lstm_out, final_states = lstm(emb_sequence, init_states) 45 | 46 | return lstm_out, final_states 47 | 48 | 49 | def init_lstm_states(lstm, batch_size, device): 50 | n_layer = lstm.num_layers*(2 if lstm.bidirectional else 1) 51 | n_hidden = lstm.hidden_size 52 | 53 | states = (torch.zeros(n_layer, batch_size, n_hidden).to(device), 54 | torch.zeros(n_layer, batch_size, n_hidden).to(device)) 55 | return states 56 | 57 | 58 | class StackedLSTMCells(nn.Module): 59 | """ stack multiple LSTM Cells""" 60 | def __init__(self, cells, dropout=0.0): 61 | super().__init__() 62 | self._cells = nn.ModuleList(cells) 63 | self._dropout = dropout 64 | 65 | def forward(self, input_, state): 66 | """ 67 | Arguments: 68 | input_: FloatTensor (batch, input_size) 69 | states: tuple of the H, C LSTM states 70 | FloatTensor (num_layers, batch, hidden_size) 71 | Returns: 72 | LSTM states 73 | new_h: (num_layers, batch, hidden_size) 74 | new_c: (num_layers, batch, hidden_size) 75 | """ 76 | hs = [] 77 | cs = [] 78 | for i, cell in enumerate(self._cells): 79 | s = (state[0][i, :, :], state[1][i, :, :]) 80 | h, c = cell(input_, s) 81 | hs.append(h) 82 | cs.append(c) 83 | input_ = F.dropout(h, p=self._dropout, training=self.training) 84 | 85 | new_h = torch.stack(hs, dim=0) 86 | new_c = torch.stack(cs, dim=0) 87 | 88 | return new_h, new_c 89 | 90 | @property 91 | def hidden_size(self): 92 | return self._cells[0].hidden_size 93 | 94 | @property 95 | def input_size(self): 96 | return self._cells[0].input_size 97 | 98 | @property 99 | def num_layers(self): 100 | return len(self._cells) 101 | 102 | @property 103 | def bidirectional(self): 104 | return self._cells[0].bidirectional 105 | 106 | 107 | class MultiLayerLSTMCells(StackedLSTMCells): 108 | """ 109 | This class is a one-step version of the cudnn LSTM 110 | , or multi-layer version of LSTMCell 111 | """ 112 | def __init__(self, input_size, hidden_size, num_layers, 113 | bias=True, dropout=0.0): 114 | """ same as nn.LSTM but without (bidirectional)""" 115 | cells = [] 116 | cells.append(nn.LSTMCell(input_size, hidden_size, bias)) 117 | for _ in range(num_layers-1): 118 | cells.append(nn.LSTMCell(hidden_size, hidden_size, bias)) 119 | super().__init__(cells, dropout) 120 | 121 | @property 122 | def bidirectional(self): 123 | return False 124 | 125 | def reset_parameters(self): 126 | for cell in self._cells: 127 | # xavier initilization 128 | gate_size = self.hidden_size / 4 129 | for weight in [cell.weight_ih, cell.weight_hh]: 130 | for w in torch.chunk(weight, 4, dim=0): 131 | init.xavier_normal_(w) 132 | #forget bias = 1 133 | for bias in [cell.bias_ih, cell.bias_hh]: 134 | torch.chunk(bias, 4, dim=0)[1].data.fill_(1) 135 | 136 | @staticmethod 137 | def convert(lstm): 138 | """ convert from a cudnn LSTM""" 139 | lstm_cell = MultiLayerLSTMCells( 140 | lstm.input_size, lstm.hidden_size, 141 | lstm.num_layers, dropout=lstm.dropout) 142 | for i, cell in enumerate(lstm_cell._cells): 143 | cell.weight_ih.data.copy_(getattr(lstm, 'weight_ih_l{}'.format(i))) 144 | cell.weight_hh.data.copy_(getattr(lstm, 'weight_hh_l{}'.format(i))) 145 | cell.bias_ih.data.copy_(getattr(lstm, 'bias_ih_l{}'.format(i))) 146 | cell.bias_hh.data.copy_(getattr(lstm, 'bias_hh_l{}'.format(i))) 147 | return lstm_cell 148 | -------------------------------------------------------------------------------- /data/batcher.py: -------------------------------------------------------------------------------- 1 | """ batching """ 2 | import random 3 | from collections import defaultdict 4 | 5 | from toolz.sandbox import unzip 6 | from cytoolz import curry, concat, compose 7 | from cytoolz import curried 8 | 9 | import torch 10 | import torch.multiprocessing as mp 11 | from pytorch_pretrained_bert.tokenization import BertTokenizer 12 | tokenizer = BertTokenizer.from_pretrained( 13 | '/path/to/uncased_L-24_H-1024_A-16/vocab.txt') 14 | MAX_ARTICLE_LEN = 512 15 | 16 | # Batching functions 17 | def coll_fn_extract(data): 18 | def is_good_data(d): 19 | """ make sure data is not empty""" 20 | source_sents, extracts = d 21 | return source_sents and extracts 22 | batch = list(filter(is_good_data, data)) 23 | assert all(map(is_good_data, batch)) 24 | return batch 25 | 26 | @curry 27 | def tokenize(max_len, emb_type, texts): 28 | if emb_type == 'W2V': 29 | return [t.lower().split()[:max_len] for t in texts] 30 | else: 31 | truncated_article = [] 32 | left = MAX_ARTICLE_LEN - 2 33 | for sentence in texts: 34 | tokens = tokenizer.tokenize(sentence) 35 | tokens = tokens[:max_len] 36 | if left >= len(tokens): 37 | truncated_article.append(tokens) 38 | left -= len(tokens) 39 | else: 40 | truncated_article.append(tokens[0:left]) 41 | break 42 | return truncated_article 43 | 44 | def conver2id(unk, word2id, words_list): 45 | word2id = defaultdict(lambda: unk, word2id) 46 | return [[word2id[w] for w in words] for words in words_list] 47 | 48 | @curry 49 | def prepro_fn_extract(max_src_len, max_src_num, emb_type, batch): 50 | def prepro_one(sample): 51 | source_sents, extracts = sample 52 | tokenized_sents = tokenize(max_src_len, emb_type, source_sents)[:max_src_num] 53 | cleaned_extracts = list(filter(lambda e: e < len(tokenized_sents), 54 | extracts)) 55 | return tokenized_sents, cleaned_extracts 56 | batch = list(map(prepro_one, batch)) 57 | return batch 58 | 59 | @curry 60 | def convert_batch_extract_ptr(unk, word2id, emb_type, batch): 61 | def convert_one(sample): 62 | source_sents, extracts = sample 63 | if emb_type == 'W2V': 64 | id_sents = conver2id(unk, word2id, source_sents) 65 | else: 66 | id_sents = [tokenizer.convert_tokens_to_ids(sentence) 67 | for sentence in source_sents] 68 | return id_sents, extracts 69 | batch = list(map(convert_one, batch)) 70 | return batch 71 | 72 | @curry 73 | def pad_batch_tensorize(inputs, pad, cuda=True): 74 | """pad_batch_tensorize 75 | 76 | :param inputs: List of size B containing torch tensors of shape [T, ...] 77 | :type inputs: List[np.ndarray] 78 | :rtype: TorchTensor of size (B, T, ...) 79 | """ 80 | tensor_type = torch.cuda.LongTensor if cuda else torch.LongTensor 81 | batch_size = len(inputs) 82 | max_len = max(len(ids) for ids in inputs) 83 | tensor_shape = (batch_size, max_len) 84 | tensor = tensor_type(*tensor_shape) 85 | tensor.fill_(pad) 86 | for i, ids in enumerate(inputs): 87 | tensor[i, :len(ids)] = tensor_type(ids) 88 | return tensor 89 | 90 | @curry 91 | def batchify_fn_extract_ptr(pad, data, cuda=True): 92 | source_lists, targets = tuple(map(list, unzip(data))) 93 | 94 | src_nums = list(map(len, source_lists)) 95 | sources = list(map(pad_batch_tensorize(pad=pad, cuda=cuda), source_lists)) 96 | 97 | # PAD is -1 (dummy extraction index) for using sequence loss 98 | target = pad_batch_tensorize(targets, pad=-1, cuda=cuda) 99 | remove_last = lambda tgt: tgt[:-1] 100 | 101 | tar_in = pad_batch_tensorize( 102 | list(map(remove_last, targets)), 103 | pad=-0, cuda=cuda # use 0 here for feeding first conv sentence repr. 104 | ) 105 | 106 | fw_args = (sources, src_nums, tar_in) 107 | loss_args = (target, src_nums) 108 | return fw_args, loss_args 109 | 110 | def _batch2q(loader, prepro, q, single_run=True): 111 | epoch = 0 112 | while True: 113 | for batch in loader: 114 | q.put(prepro(batch)) 115 | if single_run: 116 | break 117 | epoch += 1 118 | q.put(epoch) 119 | q.put(None) 120 | 121 | class BucketedGenerater(object): 122 | def __init__(self, loader, prepro, 123 | sort_key, batchify, 124 | single_run=True, queue_size=8, fork=True): 125 | self._loader = loader 126 | self._prepro = prepro 127 | self._sort_key = sort_key 128 | self._batchify = batchify 129 | self._single_run = single_run 130 | if fork: 131 | ctx = mp.get_context('forkserver') 132 | self._queue = ctx.Queue(queue_size) 133 | else: 134 | # for easier debugging 135 | self._queue = None 136 | self._process = None 137 | 138 | def __call__(self, batch_size: int): 139 | def get_batches(hyper_batch): 140 | indexes = list(range(0, len(hyper_batch), batch_size)) 141 | if not self._single_run: 142 | # random shuffle for training batches 143 | random.shuffle(hyper_batch) 144 | random.shuffle(indexes) 145 | hyper_batch.sort(key=self._sort_key) 146 | for i in indexes: 147 | batch = self._batchify(hyper_batch[i:i+batch_size]) 148 | yield batch 149 | 150 | if self._queue is not None: 151 | ctx = mp.get_context('forkserver') 152 | self._process = ctx.Process( 153 | target=_batch2q, 154 | args=(self._loader, self._prepro, 155 | self._queue, self._single_run) 156 | ) 157 | self._process.start() 158 | while True: 159 | d = self._queue.get() 160 | if d is None: 161 | break 162 | if isinstance(d, int): 163 | print('\nepoch {} done'.format(d)) 164 | continue 165 | yield from get_batches(d) 166 | self._process.join() 167 | else: 168 | i = 0 169 | while True: 170 | for batch in self._loader: 171 | yield from get_batches(self._prepro(batch)) 172 | if self._single_run: 173 | break 174 | i += 1 175 | print('\nepoch {} done'.format(i)) 176 | 177 | def terminate(self): 178 | if self._process is not None: 179 | self._process.terminate() 180 | self._process.join() 181 | -------------------------------------------------------------------------------- /model/DeepLSTM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from torch.distributions import Bernoulli 10 | 11 | class DeepLSTM(nn.Module): 12 | def __init__(self, input_size, hidden_size, num_layers, recurrent_dropout, use_orthnormal_init=True, fix_mask=True, use_cuda=True): 13 | super(DeepLSTM, self).__init__() 14 | 15 | self.fix_mask = fix_mask 16 | self.use_cuda = use_cuda 17 | self.input_size = input_size 18 | self.num_layers = num_layers 19 | self.hidden_size = hidden_size 20 | self.recurrent_dropout = recurrent_dropout 21 | 22 | self.lstms = nn.ModuleList([None] * self.num_layers) 23 | self.highway_gate_input = nn.ModuleList([None] * self.num_layers) 24 | self.highway_gate_state = nn.ModuleList([nn.Linear(hidden_size, hidden_size)] * self.num_layers) 25 | self.highway_linear_input = nn.ModuleList([None] * self.num_layers) 26 | 27 | # self._input_w = nn.Parameter(torch.Tensor(input_size, hidden_size)) 28 | # init.xavier_normal_(self._input_w) 29 | 30 | for l in range(self.num_layers): 31 | input_dim = input_size if l == 0 else hidden_size 32 | 33 | self.lstms[l] = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_size) 34 | self.highway_gate_input[l] = nn.Linear(input_dim, hidden_size) 35 | self.highway_linear_input[l] = nn.Linear(input_dim, hidden_size, bias=False) 36 | 37 | # initing W for LSTM 38 | for l in range(self.num_layers): 39 | if use_orthnormal_init: 40 | # initing W using orthnormal init 41 | init.orthogonal_(self.lstms[l].weight_ih) 42 | init.orthogonal_(self.lstms[l].weight_hh) 43 | init.orthogonal_(self.highway_gate_input[l].weight.data) 44 | init.orthogonal_(self.highway_gate_state[l].weight.data) 45 | init.orthogonal_(self.highway_linear_input[l].weight.data) 46 | else: 47 | # initing W using xavier_normal 48 | init_weight_value = 6.0 49 | init.xavier_normal_(self.lstms[l].weight_ih, gain=np.sqrt(init_weight_value)) 50 | init.xavier_normal_(self.lstms[l].weight_hh, gain=np.sqrt(init_weight_value)) 51 | init.xavier_normal_(self.highway_gate_input[l].weight.data, gain=np.sqrt(init_weight_value)) 52 | init.xavier_normal_(self.highway_gate_state[l].weight.data, gain=np.sqrt(init_weight_value)) 53 | init.xavier_normal_(self.highway_linear_input[l].weight.data, gain=np.sqrt(init_weight_value)) 54 | 55 | def init_hidden(self, batch_size, hidden_size): 56 | # the first is the hidden h 57 | # the second is the cell c 58 | if self.use_cuda: 59 | return (torch.zeros(batch_size, hidden_size).cuda(), 60 | torch.zeros(batch_size, hidden_size).cuda()) 61 | else: 62 | return (torch.zeros(batch_size, hidden_size), 63 | torch.zeros(batch_size, hidden_size)) 64 | 65 | def forward(self, inputs, input_masks, Train): 66 | 67 | ''' 68 | inputs: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) 69 | input_masks: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) 70 | ''' 71 | 72 | batch_size, seq_len = inputs[0].size(1), inputs[0].size(0) 73 | 74 | self.inputs = inputs 75 | self.input_masks = input_masks 76 | 77 | if self.fix_mask: 78 | self.output_dropout_layers = [None] * self.num_layers 79 | for l in range(self.num_layers): 80 | binary_mask = torch.rand((batch_size, self.hidden_size)) > self.recurrent_dropout 81 | # This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same. 82 | # from allennlp.nn.util.py 83 | self.output_dropout_layers[l] = binary_mask.float().div(1.0 - self.recurrent_dropout) 84 | if self.use_cuda: 85 | self.output_dropout_layers[l] = self.output_dropout_layers[l].cuda() 86 | 87 | for l in range(self.num_layers): 88 | h, c = self.init_hidden(batch_size, self.hidden_size) 89 | outputs_list = [] 90 | for t in range(len(self.inputs[l])): 91 | x = self.inputs[l][t] 92 | m = self.input_masks[l][t].float() 93 | h_temp, c_temp = self.lstms[l].forward(x, (h, c)) # [batch, hidden_size] 94 | r = torch.sigmoid(self.highway_gate_input[l](x) + self.highway_gate_state[l](h)) 95 | lx = self.highway_linear_input[l](x) # [batch, hidden_size] 96 | h_temp = r * h_temp + (1 - r) * lx 97 | 98 | if Train: 99 | if self.fix_mask: 100 | h_temp = self.output_dropout_layers[l] * h_temp 101 | else: 102 | h_temp = F.dropout(h_temp, p=self.recurrent_dropout) 103 | 104 | h = m * h_temp + (1 - m) * h 105 | c = m * c_temp + (1 - m) * c 106 | outputs_list.append(h) 107 | outputs = torch.stack(outputs_list, 0) # [seq_len, batch, hidden_size] 108 | self.inputs[l + 1] = DeepLSTM.flip(outputs, 0) # reverse [seq_len, batch, hidden_size] 109 | self.input_masks[l + 1] = DeepLSTM.flip(self.input_masks[l], 0) 110 | 111 | self.output_state = self.inputs # num_layers * [seq_len, batch, hidden_size] 112 | 113 | # flip -2 layer 114 | # self.output_state[-2] = DeepLSTM.flip(self.output_state[-2], 0) 115 | 116 | # concat last two layer 117 | # self.output_state = torch.cat([self.output_state[-1], self.output_state[-2]], dim=-1).transpose(0, 1) 118 | 119 | self.output_state = self.output_state[-1].transpose(0, 1) 120 | 121 | assert self.output_state.size() == (batch_size, seq_len, self.hidden_size) 122 | 123 | return self.output_state 124 | 125 | @staticmethod 126 | def flip(x, dim): 127 | xsize = x.size() 128 | dim = x.dim() + dim if dim < 0 else dim 129 | x = x.contiguous() 130 | x = x.view(-1, *xsize[dim:]).contiguous() 131 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, 132 | -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 133 | return x.view(xsize) 134 | -------------------------------------------------------------------------------- /transformer/Translator.py: -------------------------------------------------------------------------------- 1 | ''' This module will handle the text generation with beam search. ''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from transformer.Models import Transformer 8 | from transformer.Beam import Beam 9 | 10 | class Translator(object): 11 | ''' Load with trained model and handle the beam search ''' 12 | 13 | def __init__(self, opt): 14 | self.opt = opt 15 | self.device = torch.device('cuda' if opt.cuda else 'cpu') 16 | 17 | checkpoint = torch.load(opt.model) 18 | model_opt = checkpoint['settings'] 19 | self.model_opt = model_opt 20 | 21 | model = Transformer( 22 | model_opt.src_vocab_size, 23 | model_opt.tgt_vocab_size, 24 | model_opt.max_token_seq_len, 25 | tgt_emb_prj_weight_sharing=model_opt.proj_share_weight, 26 | emb_src_tgt_weight_sharing=model_opt.embs_share_weight, 27 | d_k=model_opt.d_k, 28 | d_v=model_opt.d_v, 29 | d_model=model_opt.d_model, 30 | d_word_vec=model_opt.d_word_vec, 31 | d_inner=model_opt.d_inner_hid, 32 | n_layers=model_opt.n_layers, 33 | n_head=model_opt.n_head, 34 | dropout=model_opt.dropout) 35 | 36 | model.load_state_dict(checkpoint['model']) 37 | print('[Info] Trained model state loaded.') 38 | 39 | model.word_prob_prj = nn.LogSoftmax(dim=1) 40 | 41 | model = model.to(self.device) 42 | 43 | self.model = model 44 | self.model.eval() 45 | 46 | def translate_batch(self, src_seq, src_pos): 47 | ''' Translation work in one batch ''' 48 | 49 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 50 | ''' Indicate the position of an instance in a tensor. ''' 51 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 52 | 53 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): 54 | ''' Collect tensor parts associated to active instances. ''' 55 | 56 | _, *d_hs = beamed_tensor.size() 57 | n_curr_active_inst = len(curr_active_inst_idx) 58 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 59 | 60 | beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) 61 | beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) 62 | beamed_tensor = beamed_tensor.view(*new_shape) 63 | 64 | return beamed_tensor 65 | 66 | def collate_active_info( 67 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): 68 | # Sentences which are still active are collected, 69 | # so the decoder will not run on completed sentences. 70 | n_prev_active_inst = len(inst_idx_to_position_map) 71 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 72 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 73 | 74 | active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) 75 | active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) 76 | active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 77 | 78 | return active_src_seq, active_src_enc, active_inst_idx_to_position_map 79 | 80 | def beam_decode_step( 81 | inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm): 82 | ''' Decode and update beam status, and then return active beam idx ''' 83 | 84 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 85 | dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] 86 | dec_partial_seq = torch.stack(dec_partial_seq).to(self.device) 87 | dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) 88 | return dec_partial_seq 89 | 90 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 91 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 92 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 93 | return dec_partial_pos 94 | 95 | def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): 96 | dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output) 97 | dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h 98 | word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1) 99 | word_prob = word_prob.view(n_active_inst, n_bm, -1) 100 | 101 | return word_prob 102 | 103 | def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): 104 | active_inst_idx_list = [] 105 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 106 | is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) 107 | if not is_inst_complete: 108 | active_inst_idx_list += [inst_idx] 109 | 110 | return active_inst_idx_list 111 | 112 | n_active_inst = len(inst_idx_to_position_map) 113 | 114 | dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) 115 | dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm) 116 | word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) 117 | 118 | # Update the beam with predicted word prob information and collect incomplete instances 119 | active_inst_idx_list = collect_active_inst_idx_list( 120 | inst_dec_beams, word_prob, inst_idx_to_position_map) 121 | 122 | return active_inst_idx_list 123 | 124 | def collect_hypothesis_and_scores(inst_dec_beams, n_best): 125 | all_hyp, all_scores = [], [] 126 | for inst_idx in range(len(inst_dec_beams)): 127 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 128 | all_scores += [scores[:n_best]] 129 | 130 | hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] 131 | all_hyp += [hyps] 132 | return all_hyp, all_scores 133 | 134 | with torch.no_grad(): 135 | #-- Encode 136 | src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) 137 | src_enc, *_ = self.model.encoder(src_seq, src_pos) 138 | 139 | #-- Repeat data for beam search 140 | n_bm = self.opt.beam_size 141 | n_inst, len_s, d_h = src_enc.size() 142 | src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 143 | src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 144 | 145 | #-- Prepare beams 146 | inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)] 147 | 148 | #-- Bookkeeping for active or not 149 | active_inst_idx_list = list(range(n_inst)) 150 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 151 | 152 | #-- Decode 153 | for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1): 154 | 155 | active_inst_idx_list = beam_decode_step( 156 | inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm) 157 | 158 | if not active_inst_idx_list: 159 | break # all instances have finished their path to 160 | 161 | src_seq, src_enc, inst_idx_to_position_map = collate_active_info( 162 | src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list) 163 | 164 | batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best) 165 | 166 | return batch_hyp, batch_scores 167 | -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import transformer.Constants as Constants 6 | from transformer.Layers import EncoderLayer, DecoderLayer 7 | 8 | __author__ = "Yu-Hsiang Huang" 9 | 10 | def get_non_pad_mask(seq): 11 | assert seq.dim() == 2 12 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 15 | ''' Sinusoid position encoding table ''' 16 | 17 | def cal_angle(position, hid_idx): 18 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 19 | 20 | def get_posi_angle_vec(position): 21 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | 25 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 26 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 27 | 28 | if padding_idx is not None: 29 | # zero vector for padding dimension 30 | sinusoid_table[padding_idx] = 0. 31 | 32 | return torch.FloatTensor(sinusoid_table) 33 | 34 | def get_attn_key_pad_mask(seq_k, seq_q): 35 | ''' For masking out the padding part of key sequence. ''' 36 | 37 | # Expand to fit the shape of key query attention matrix. 38 | len_q = seq_q.size(1) 39 | padding_mask = seq_k.eq(Constants.PAD) 40 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 41 | 42 | return padding_mask 43 | 44 | def get_subsequent_mask(seq): 45 | ''' For masking out the subsequent info. ''' 46 | 47 | sz_b, len_s = seq.size() 48 | subsequent_mask = torch.triu( 49 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 50 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 51 | 52 | return subsequent_mask 53 | 54 | class Encoder(nn.Module): 55 | ''' A encoder model with self attention mechanism. ''' 56 | 57 | def __init__( 58 | self, 59 | n_src_vocab, len_max_seq, d_word_vec, 60 | n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, dropout=0.1): 62 | 63 | super().__init__() 64 | 65 | n_position = len_max_seq + 1 66 | 67 | self.src_word_emb = nn.Embedding( 68 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD) 69 | 70 | self.position_enc = nn.Embedding.from_pretrained( 71 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 72 | freeze=True) 73 | 74 | self.layer_stack = nn.ModuleList([ 75 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 76 | for _ in range(n_layers)]) 77 | 78 | def forward(self, src_seq, src_pos, return_attns=False): 79 | 80 | enc_slf_attn_list = [] 81 | 82 | # -- Prepare masks 83 | slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) 84 | non_pad_mask = get_non_pad_mask(src_seq) 85 | 86 | # -- Forward 87 | enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) 88 | 89 | for enc_layer in self.layer_stack: 90 | enc_output, enc_slf_attn = enc_layer( 91 | enc_output, 92 | non_pad_mask=non_pad_mask, 93 | slf_attn_mask=slf_attn_mask) 94 | if return_attns: 95 | enc_slf_attn_list += [enc_slf_attn] 96 | 97 | if return_attns: 98 | return enc_output, enc_slf_attn_list 99 | return enc_output, 100 | 101 | class Decoder(nn.Module): 102 | ''' A decoder model with self attention mechanism. ''' 103 | 104 | def __init__( 105 | self, 106 | n_tgt_vocab, len_max_seq, d_word_vec, 107 | n_layers, n_head, d_k, d_v, 108 | d_model, d_inner, dropout=0.1): 109 | 110 | super().__init__() 111 | n_position = len_max_seq + 1 112 | 113 | self.tgt_word_emb = nn.Embedding( 114 | n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) 115 | 116 | self.position_enc = nn.Embedding.from_pretrained( 117 | get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), 118 | freeze=True) 119 | 120 | self.layer_stack = nn.ModuleList([ 121 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 122 | for _ in range(n_layers)]) 123 | 124 | def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): 125 | 126 | dec_slf_attn_list, dec_enc_attn_list = [], [] 127 | 128 | # -- Prepare masks 129 | non_pad_mask = get_non_pad_mask(tgt_seq) 130 | 131 | slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) 132 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) 133 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 134 | 135 | dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) 136 | 137 | # -- Forward 138 | dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) 139 | 140 | for dec_layer in self.layer_stack: 141 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 142 | dec_output, enc_output, 143 | non_pad_mask=non_pad_mask, 144 | slf_attn_mask=slf_attn_mask, 145 | dec_enc_attn_mask=dec_enc_attn_mask) 146 | 147 | if return_attns: 148 | dec_slf_attn_list += [dec_slf_attn] 149 | dec_enc_attn_list += [dec_enc_attn] 150 | 151 | if return_attns: 152 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 153 | return dec_output, 154 | 155 | class Transformer(nn.Module): 156 | ''' A sequence to sequence model with attention mechanism. ''' 157 | 158 | def __init__( 159 | self, 160 | n_src_vocab, n_tgt_vocab, len_max_seq, 161 | d_word_vec=512, d_model=512, d_inner=2048, 162 | n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, 163 | tgt_emb_prj_weight_sharing=True, 164 | emb_src_tgt_weight_sharing=True): 165 | 166 | super().__init__() 167 | 168 | self.encoder = Encoder( 169 | n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, 170 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 171 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 172 | dropout=dropout) 173 | 174 | self.decoder = Decoder( 175 | n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, 176 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 177 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 178 | dropout=dropout) 179 | 180 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 181 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 182 | 183 | assert d_model == d_word_vec, \ 184 | 'To facilitate the residual connections, \ 185 | the dimensions of all module outputs shall be the same.' 186 | 187 | if tgt_emb_prj_weight_sharing: 188 | # Share the weight matrix between target word embedding & the final logit dense layer 189 | self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight 190 | self.x_logit_scale = (d_model ** -0.5) 191 | else: 192 | self.x_logit_scale = 1. 193 | 194 | if emb_src_tgt_weight_sharing: 195 | # Share the weight matrix between source & target word embeddings 196 | assert n_src_vocab == n_tgt_vocab, \ 197 | "To share word embedding table, the vocabulary size of src/tgt shall be the same." 198 | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight 199 | 200 | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): 201 | 202 | tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] 203 | 204 | enc_output, *_ = self.encoder(src_seq, src_pos) 205 | dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) 206 | seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale 207 | 208 | return seq_logit.view(-1, seq_logit.size(2)) 209 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ module providing basic training utilities""" 2 | import os 3 | from os.path import join 4 | from time import time 5 | from datetime import timedelta 6 | from itertools import starmap 7 | 8 | from cytoolz import curry, reduce 9 | 10 | import torch 11 | from torch.nn.utils import clip_grad_norm_ 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | import tensorboardX 14 | 15 | 16 | def get_basic_grad_fn(net, clip_grad, max_grad=10): 17 | def f(): 18 | grad_norm = clip_grad_norm_( 19 | [p for p in net.parameters() if p.requires_grad], clip_grad) 20 | if max_grad is not None and grad_norm >= max_grad: 21 | # print('WARNING: Exploding Gradients {:.2f}'.format(grad_norm)) 22 | grad_norm = max_grad 23 | grad_log = {} 24 | grad_log['grad_norm'] = grad_norm 25 | return grad_log 26 | return f 27 | 28 | @curry 29 | def compute_loss(net, criterion, decoder, fw_args, loss_args): 30 | loss = criterion(*((net(*fw_args),) + loss_args + (decoder, ))) 31 | return loss 32 | 33 | @curry 34 | def val_step(loss_step, fw_args, loss_args): 35 | loss = loss_step(fw_args, loss_args) 36 | return loss.size(0), loss.sum().item() 37 | 38 | @curry 39 | def basic_validate(net, criterion, decoder, val_batches): 40 | print('running validation ... ', end='') 41 | net.eval() 42 | net._isTrain = False # for DeepLSTM 43 | start = time() 44 | with torch.no_grad(): 45 | validate_fn = val_step(compute_loss(net, criterion, decoder)) 46 | n_data, tot_loss = reduce( 47 | lambda a, b: (a[0]+b[0], a[1]+b[1]), 48 | starmap(validate_fn, val_batches), 49 | (0, 0) 50 | ) 51 | val_loss = tot_loss / n_data 52 | print( 53 | 'validation finished in {} '.format( 54 | timedelta(seconds=int(time()-start))) 55 | ) 56 | print('validation loss: {:.4f} ... '.format(val_loss)) 57 | return {'loss': val_loss} 58 | 59 | 60 | class BasicPipeline(object): 61 | def __init__(self, net, decoder, 62 | train_batcher, val_batcher, batch_size, 63 | val_fn, criterion, optim, grad_fn=None): 64 | self._net = net 65 | self._decoder = decoder 66 | self._train_batcher = train_batcher 67 | self._val_batcher = val_batcher 68 | self._criterion = criterion 69 | self._opt = optim 70 | # grad_fn is calleble without input args that modifyies gradient 71 | # it should return a dictionary of logging values 72 | self._grad_fn = grad_fn 73 | self._val_fn = val_fn 74 | 75 | self._n_epoch = 0 # epoch not very useful? 76 | self._batch_size = batch_size 77 | self._batches = self.batches() 78 | 79 | def batches(self): 80 | while True: 81 | for fw_args, bw_args in self._train_batcher(self._batch_size): 82 | yield fw_args, bw_args 83 | self._n_epoch += 1 84 | 85 | def get_loss_args(self, net_out, bw_args): 86 | if isinstance(net_out, tuple): 87 | loss_args = net_out + bw_args + (self._decoder, ) 88 | else: 89 | loss_args = (net_out, ) + bw_args + (self._decoder, ) 90 | return loss_args 91 | 92 | def train_step(self): 93 | # forward pass of model 94 | self._net.train() 95 | self._net._isTrain = True # for DeepLSTM 96 | fw_args, bw_args = next(self._batches) 97 | net_out = self._net(*fw_args) 98 | 99 | # get logs and output for logging, backward 100 | log_dict = {} 101 | loss_args = self.get_loss_args(net_out, bw_args) 102 | 103 | # backward and update ( and optional gradient monitoring ) 104 | loss = self._criterion(*loss_args).mean() 105 | loss.backward() 106 | log_dict['loss'] = loss.item() 107 | if self._grad_fn is not None: 108 | log_dict.update(self._grad_fn()) 109 | self._opt.step() 110 | self._net.zero_grad() 111 | 112 | return log_dict 113 | 114 | def validate(self): 115 | return self._val_fn(self._val_batcher(self._batch_size)) 116 | 117 | def checkpoint(self, save_path, step, val_metric=None): 118 | save_dict = {} 119 | if val_metric is not None: 120 | name = 'ckpt-{:6f}-{}'.format(val_metric, step) 121 | save_dict['val_metric'] = val_metric 122 | else: 123 | name = 'ckpt-{}'.format(step) 124 | 125 | save_dict['state_dict'] = self._net.state_dict() 126 | save_dict['optimizer'] = self._opt.state_dict() 127 | torch.save(save_dict, join(save_path, name)) 128 | 129 | def terminate(self): 130 | self._train_batcher.terminate() 131 | self._val_batcher.terminate() 132 | 133 | 134 | class BasicTrainer(object): 135 | """ Basic trainer with minimal function and early stopping""" 136 | def __init__(self, pipeline, save_dir, ckpt_freq, patience, 137 | scheduler=None, val_mode='loss'): 138 | assert isinstance(pipeline, BasicPipeline) 139 | assert val_mode in ['loss', 'score'] 140 | self._pipeline = pipeline 141 | self._save_dir = save_dir 142 | self._logger = tensorboardX.SummaryWriter(join(save_dir, 'log')) 143 | os.makedirs(join(save_dir, 'ckpt')) 144 | 145 | self._ckpt_freq = ckpt_freq 146 | self._patience = patience 147 | self._sched = scheduler 148 | self._val_mode = val_mode 149 | 150 | self._step = 0 151 | self._running_loss = None 152 | # state vars for early stopping 153 | self._current_p = 0 154 | self._best_val = None 155 | 156 | def log(self, log_dict): 157 | loss = log_dict['loss'] if 'loss' in log_dict else log_dict['reward'] 158 | if self._running_loss is not None: 159 | self._running_loss = 0.99*self._running_loss + 0.01*loss 160 | else: 161 | self._running_loss = loss 162 | print('train step: {}, {}: {:.4f}\r'.format( 163 | self._step, 164 | 'loss' if 'loss' in log_dict else 'reward', 165 | self._running_loss), end='') 166 | for key, value in log_dict.items(): 167 | self._logger.add_scalar( 168 | '{}_'.format(key), value, self._step) 169 | 170 | def validate(self): 171 | print() 172 | val_log = self._pipeline.validate() 173 | for key, value in val_log.items(): 174 | self._logger.add_scalar( 175 | 'val_{}'.format(key), 176 | value, self._step 177 | ) 178 | if 'reward' in val_log: 179 | val_metric = val_log['reward'] 180 | else: 181 | val_metric = (val_log['loss'] if self._val_mode == 'loss' 182 | else val_log['score']) 183 | return val_metric 184 | 185 | def checkpoint(self): 186 | val_metric = self.validate() 187 | self._pipeline.checkpoint( 188 | join(self._save_dir, 'ckpt'), self._step, val_metric) 189 | if isinstance(self._sched, ReduceLROnPlateau): 190 | self._sched.step(val_metric) 191 | else: 192 | self._sched.step() 193 | stop = self.check_stop(val_metric) 194 | return stop 195 | 196 | def check_stop(self, val_metric): 197 | if self._best_val is None: 198 | self._best_val = val_metric 199 | elif ((val_metric < self._best_val and self._val_mode == 'loss') 200 | or (val_metric > self._best_val and self._val_mode == 'score')): 201 | self._current_p = 0 202 | self._best_val = val_metric 203 | else: 204 | self._current_p += 1 205 | print('Current minimum loss = {}'.format(self._best_val)) 206 | return self._current_p >= self._patience 207 | 208 | def train(self): 209 | try: 210 | start = time() 211 | print('Start training') 212 | while True: 213 | log_dict = self._pipeline.train_step() 214 | self._step += 1 215 | self.log(log_dict) 216 | 217 | if self._step % self._ckpt_freq == 0: 218 | stop = self.checkpoint() 219 | if stop: 220 | break 221 | print('Training finised in ', timedelta(seconds=time()-start)) 222 | finally: 223 | self._pipeline.terminate() 224 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from os.path import join, exists 5 | import pickle as pkl 6 | import random # 7 | from time import time 8 | from datetime import timedelta 9 | 10 | from cytoolz import compose 11 | 12 | import torch 13 | from torch import optim 14 | from torch.nn import functional as F 15 | from torch.optim.lr_scheduler import ReduceLROnPlateau 16 | from torch.utils.data import DataLoader 17 | 18 | from model.extract import Summarizer 19 | from model.util import sequence_loss 20 | from training import get_basic_grad_fn, basic_validate 21 | from training import BasicPipeline, BasicTrainer 22 | from decoding import Extractor, DecodeDataset 23 | from decoding import sort_ckpt, get_n_ext 24 | from evaluate import eval_rouge 25 | 26 | from utils import PAD, UNK 27 | from utils import make_vocab, make_embedding 28 | 29 | from data.data import CnnDmDataset 30 | from data.batcher import tokenize 31 | from data.batcher import coll_fn_extract 32 | from data.batcher import prepro_fn_extract 33 | from data.batcher import convert_batch_extract_ptr 34 | from data.batcher import batchify_fn_extract_ptr 35 | from data.batcher import BucketedGenerater 36 | 37 | BUCKET_SIZE = 6400 38 | 39 | DATA_DIR = './CNNDM' 40 | 41 | class ExtractDataset(CnnDmDataset): 42 | """ article sentences -> extraction indices 43 | (dataset created by greedily matching ROUGE) 44 | """ 45 | def __init__(self, split): 46 | super().__init__(split, DATA_DIR) 47 | 48 | def __getitem__(self, i): 49 | js_data = super().__getitem__(i) 50 | art_sents, extracts = js_data['article'], js_data['label'] 51 | return art_sents, extracts 52 | 53 | def set_parameters(args): 54 | if args.encoder == 'BiLSTM': 55 | args.lr = 1e-4 if args.emb_type == 'W2V' else 5e-5 56 | args.clip = 2.0 57 | args.encoder_layer = 1 # Actually 2 layers (bidirectional) 58 | args.encoder_hidden = 512 59 | elif args.encoder == 'Transformer' and args.decoder == 'PN': 60 | args.encoder_layer = 4 61 | args.encoder_hidden = 512 62 | elif args.encoder == 'Transformer' and args.decoder == 'SL': 63 | args.lr = 5e-5 64 | args.decay = 0.8 65 | args.patience = 10 66 | args.encoder_layer = 12 67 | args.encoder_hidden = 512 68 | elif args.encoder == 'DeepLSTM' and args.decoder == 'PN': 69 | args.batch = 16 70 | args.ckpt_freq = 6000 71 | args.encoder_layer = 4 72 | args.encoder_hidden = 2048 73 | elif args.encoder == 'DeepLSTM' and args.decoder == 'SL': 74 | args.lr = 5e-5 75 | args.batch = 16 76 | args.ckpt_freq = 6000 77 | args.encoder_layer = 8 78 | args.encoder_hidden = 2048 79 | return args 80 | 81 | def build_batchers(decoder, emb_type, word2id, cuda, debug): 82 | prepro = prepro_fn_extract(args.max_word, args.max_sent, emb_type) 83 | def sort_key(sample): 84 | src_sents, _ = sample 85 | return len(src_sents) 86 | batchify_fn = batchify_fn_extract_ptr 87 | convert_batch = convert_batch_extract_ptr 88 | batchify = compose(batchify_fn(PAD, cuda=cuda), 89 | convert_batch(UNK, word2id, emb_type)) 90 | 91 | train_loader = DataLoader( 92 | ExtractDataset('train'), batch_size=BUCKET_SIZE, 93 | shuffle=not debug, 94 | num_workers=4 if cuda and not debug else 0, 95 | collate_fn=coll_fn_extract 96 | ) 97 | train_batcher = BucketedGenerater(train_loader, prepro, sort_key, batchify, 98 | single_run=False, fork=not debug) 99 | 100 | val_loader = DataLoader( 101 | ExtractDataset('val'), batch_size=BUCKET_SIZE, 102 | shuffle=False, num_workers=4 if cuda and not debug else 0, 103 | collate_fn=coll_fn_extract 104 | ) 105 | val_batcher = BucketedGenerater(val_loader, prepro, sort_key, batchify, 106 | single_run=True, fork=not debug) 107 | return train_batcher, val_batcher 108 | 109 | 110 | def configure_net(encoder, decoder, emb_type, vocab_size, emb_dim, 111 | conv_hidden, encoder_hidden, encoder_layer): 112 | model_args = {} 113 | model_args['encoder'] = encoder 114 | model_args['decoder'] = decoder 115 | model_args['emb_type'] = emb_type 116 | model_args['vocab_size'] = vocab_size 117 | model_args['emb_dim'] = emb_dim 118 | model_args['conv_hidden'] = conv_hidden 119 | model_args['encoder_hidden'] = encoder_hidden 120 | model_args['encoder_layer'] = encoder_layer 121 | 122 | model = Summarizer(**model_args) 123 | return model, model_args 124 | 125 | 126 | def configure_training(decoder, opt, lr, clip_grad, lr_decay, batch_size): 127 | """ supports Adam optimizer only""" 128 | assert opt in ['adam'] 129 | 130 | opt_kwargs = {} 131 | opt_kwargs['lr'] = lr 132 | 133 | train_params = {} 134 | train_params['optimizer'] = (opt, opt_kwargs) 135 | train_params['clip_grad_norm'] = clip_grad 136 | train_params['batch_size'] = batch_size 137 | train_params['lr_decay'] = lr_decay 138 | 139 | ce = lambda logit, target: F.cross_entropy(logit, target, reduction='none') 140 | def criterion(logits, targets, sent_num, decoder): 141 | return sequence_loss(logits, targets, sent_num, decoder, ce, pad_idx=-1) 142 | 143 | return criterion, train_params 144 | 145 | 146 | def train(args): 147 | 148 | assert args.encoder in ['BiLSTM', 'DeepLSTM', 'Transformer'] 149 | assert args.decoder in ['SL', 'PN'] 150 | assert args.emb_type in ['W2V', 'BERT'] 151 | 152 | # create data batcher, vocabulary 153 | # batcher 154 | with open(join(DATA_DIR, 'vocab_cnt.pkl'), 'rb') as f: 155 | wc = pkl.load(f) 156 | word2id = make_vocab(wc, args.vsize) 157 | train_batcher, val_batcher = build_batchers(args.decoder, args.emb_type, 158 | word2id, args.cuda, args.debug) 159 | 160 | # make model 161 | model, model_args = configure_net(args.encoder, args.decoder, args.emb_type, len(word2id), 162 | args.emb_dim, args.conv_hidden, args.encoder_hidden, 163 | args.encoder_layer) 164 | 165 | if args.emb_type == 'W2V': 166 | # NOTE: the pretrained embedding having the same dimension 167 | # as args.emb_dim should already be trained 168 | w2v_path='./CNNDM/word2vec/word2vec.128d.226k.bin' 169 | embedding, _ = make_embedding( 170 | {i: w for w, i in word2id.items()}, w2v_path) 171 | model.set_embedding(embedding) 172 | 173 | # configure training setting 174 | criterion, train_params = configure_training( 175 | args.decoder, 'adam', args.lr, args.clip, args.decay, args.batch 176 | ) 177 | 178 | # save experiment setting 179 | if not exists(args.path): 180 | os.makedirs(args.path) 181 | with open(join(args.path, 'vocab.pkl'), 'wb') as f: 182 | pkl.dump(word2id, f, pkl.HIGHEST_PROTOCOL) 183 | meta = {} 184 | meta['model_args'] = model_args 185 | meta['traing_params'] = train_params 186 | with open(join(args.path, 'meta.json'), 'w') as f: 187 | json.dump(meta, f, indent=4) 188 | 189 | # prepare trainer 190 | val_fn = basic_validate(model, criterion, args.decoder) 191 | grad_fn = get_basic_grad_fn(model, args.clip) 192 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), **train_params['optimizer'][1]) 193 | scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, 194 | factor=args.decay, min_lr=2e-5, 195 | patience=args.lr_p) 196 | 197 | if args.cuda: 198 | model = model.cuda() 199 | pipeline = BasicPipeline(model, args.decoder, 200 | train_batcher, val_batcher, args.batch, val_fn, 201 | criterion, optimizer, grad_fn) 202 | trainer = BasicTrainer(pipeline, args.path, 203 | args.ckpt_freq, args.patience, scheduler) 204 | 205 | # for name, para in net.named_parameters(): 206 | # if para.requires_grad: 207 | # print(name) 208 | 209 | print('Start training with the following hyper-parameters:') 210 | print(meta) 211 | trainer.train() 212 | 213 | def test(args, split): 214 | ext_dir = args.path 215 | ckpts = sort_ckpt(ext_dir) 216 | 217 | # setup loader 218 | def coll(batch): 219 | articles = list(filter(bool, batch)) 220 | return articles 221 | dataset = DecodeDataset(split) 222 | 223 | n_data = len(dataset) 224 | loader = DataLoader( 225 | dataset, batch_size=args.batch, shuffle=False, num_workers=4, 226 | collate_fn=coll 227 | ) 228 | 229 | # decode and evaluate top 5 models 230 | os.mkdir(join(args.path, 'decode')) 231 | os.mkdir(join(args.path, 'ROUGE')) 232 | for i in range(min(5, len(ckpts))): 233 | print('Start loading checkpoint {} !'.format(ckpts[i])) 234 | cur_ckpt = torch.load( 235 | join(ext_dir, 'ckpt/{}'.format(ckpts[i])) 236 | )['state_dict'] 237 | extractor = Extractor(ext_dir, cur_ckpt, args.emb_type, cuda=args.cuda) 238 | save_path = join(args.path, 'decode/{}'.format(ckpts[i])) 239 | os.mkdir(save_path) 240 | 241 | # decoding 242 | ext_list = [] 243 | cur_idx = 0 244 | start = time() 245 | with torch.no_grad(): 246 | for raw_article_batch in loader: 247 | tokenized_article_batch = map(tokenize(None, args.emb_type), raw_article_batch) 248 | for raw_art_sents in tokenized_article_batch: 249 | ext_idx = extractor(raw_art_sents) 250 | ext_list.append(ext_idx) 251 | cur_idx += 1 252 | print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( 253 | cur_idx, n_data, cur_idx/n_data*100, timedelta(seconds=int(time()-start)) 254 | ), end='') 255 | print() 256 | 257 | # write files 258 | for file_idx, ext_ids in enumerate(ext_list): 259 | dec = [] 260 | data_path = join(DATA_DIR, '{}/{}.json'.format(split, file_idx)) 261 | with open(data_path) as f: 262 | data = json.loads(f.read()) 263 | n_ext = 2 if data['source'] == 'CNN' else 3 264 | n_ext = min(n_ext, len(data['article'])) 265 | for j in range(n_ext): 266 | sent_idx = ext_ids[j] 267 | dec.append(data['article'][sent_idx]) 268 | with open(join(save_path, '{}.dec'.format(file_idx)), 'w') as f: 269 | for sent in dec: 270 | print(sent, file=f) 271 | 272 | # evaluate current model 273 | print('Starting evaluating ROUGE !') 274 | dec_path = save_path 275 | ref_path = join(DATA_DIR, 'refs/{}'.format(split)) 276 | ROUGE = eval_rouge(dec_path, ref_path) 277 | print(ROUGE) 278 | with open(join(args.path, 'ROUGE/{}.txt'.format(ckpts[i])), 'w') as f: 279 | print(ROUGE, file=f) 280 | 281 | if __name__ == '__main__': 282 | parser = argparse.ArgumentParser( 283 | description='training of the different encoder/decoder/emb_type' 284 | ) 285 | parser.add_argument('--path', default='./result', help='root of the model') 286 | parser.add_argument('--mode', type=str, required=True, help='train/test') 287 | 288 | # model options 289 | parser.add_argument('--encoder', type=str, required=True, 290 | help='BiLSTM/DeepLSTM/Transformer') 291 | parser.add_argument('--decoder', type=str, required=True, 292 | help='SL(Sequence Labeling)/PN(Pointer Network)') 293 | parser.add_argument('--emb_type', type=str, required=True, 294 | help='W2V(word2vec)/BERT') 295 | parser.add_argument('--vsize', type=int, action='store', default=30000, 296 | help='vocabulary size') 297 | parser.add_argument('--emb_dim', type=int, action='store', default=128, 298 | help='the dimension of word embedding') 299 | parser.add_argument('--vocab_path', action='store', 300 | help='use pretrained word2vec embedding') 301 | parser.add_argument('--conv_hidden', type=int, action='store', default=100, 302 | help='the number of hidden units of Conv') 303 | parser.add_argument('--encoder_hidden', type=int, action='store', default=512, 304 | help='the number of hidden units of encoder') 305 | parser.add_argument('--encoder_layer', type=int, action='store', default=1, 306 | help='the number of layers of encoder') 307 | 308 | # length limit 309 | parser.add_argument('--max_word', type=int, action='store', default=100, 310 | help='maximun words in a single article sentence') 311 | parser.add_argument('--max_sent', type=int, action='store', default=60, 312 | help='maximun sentences in an article article') 313 | # training options 314 | parser.add_argument('--lr', type=float, action='store', default=2e-5, 315 | help='learning rate') 316 | parser.add_argument('--decay', type=float, action='store', default=0.5, 317 | help='learning rate decay ratio') 318 | parser.add_argument('--lr_p', type=int, action='store', default=0, 319 | help='patience for learning rate decay') 320 | parser.add_argument('--clip', type=float, action='store', default=0.5, 321 | help='gradient clipping') 322 | parser.add_argument('--batch', type=int, action='store', default=32, 323 | help='the training batch size') 324 | parser.add_argument( 325 | '--ckpt_freq', type=int, action='store', default=3000, 326 | help='number of update steps for checkpoint and validation' 327 | ) 328 | parser.add_argument('--patience', type=int, action='store', default=5, 329 | help='patience for early stopping') 330 | 331 | parser.add_argument('--debug', action='store_true', 332 | help='run in debugging mode') 333 | parser.add_argument('--no-cuda', action='store_true', 334 | help='disable GPU training') 335 | 336 | args = parser.parse_args() 337 | args.cuda = torch.cuda.is_available() and not args.no_cuda 338 | args = set_parameters(args) 339 | args.path = './{}_{}_{}'.format(args.encoder, args.decoder, args.emb_type) 340 | assert args.mode in ['train', 'test'] 341 | 342 | if args.mode == 'train': 343 | train(args) 344 | else: 345 | test(args, 'test') 346 | -------------------------------------------------------------------------------- /model/extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from math import sqrt 5 | from torch import nn 6 | from torch.nn import init 7 | from torch.nn import functional as F 8 | 9 | from .rnn import MultiLayerLSTMCells 10 | from .rnn import lstm_encoder 11 | from .util import sequence_mean, len_mask 12 | from .attention import prob_normalize 13 | from .DeepLSTM import DeepLSTM 14 | from .TransformerEncoder import TransformerEncoder 15 | from transformer.Models import get_sinusoid_encoding_table 16 | 17 | from pytorch_pretrained_bert.modeling import BertModel 18 | 19 | INI = 1e-2 20 | MAX_ARTICLE_LEN = 512 21 | 22 | class ConvSentEncoder(nn.Module): 23 | """ 24 | Convolutional word-level sentence encoder 25 | w/ max-over-time pooling, [3, 4, 5] kernel sizes, ReLU activation 26 | """ 27 | def __init__(self, vocab_size, emb_dim, n_hidden, dropout, emb_type): 28 | super().__init__() 29 | self._emb_type = emb_type 30 | if emb_type == 'W2V': 31 | self._embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0) 32 | self._convs = nn.ModuleList([nn.Conv1d(emb_dim, n_hidden, i) 33 | for i in range(3, 6)]) 34 | self._dropout = dropout 35 | self._grad_handle = None 36 | 37 | def forward(self, input_): 38 | if self._emb_type == 'W2V': 39 | emb_input = self._embedding(input_) 40 | else: 41 | emb_input = input_ 42 | conv_in = F.dropout(emb_input.transpose(1, 2), 43 | self._dropout, training=self.training) 44 | output = torch.cat([F.relu(conv(conv_in)).max(dim=2)[0] 45 | for conv in self._convs], dim=1) 46 | return output 47 | 48 | def set_embedding(self, embedding): 49 | """embedding is the weight matrix""" 50 | assert self._embedding.weight.size() == embedding.size() 51 | self._embedding.weight.data.copy_(embedding) 52 | 53 | 54 | class LSTMEncoder(nn.Module): 55 | def __init__(self, input_dim, n_hidden, n_layer, dropout, bidirectional): 56 | super().__init__() 57 | self._init_h = nn.Parameter( 58 | torch.Tensor(n_layer*(2 if bidirectional else 1), n_hidden)) 59 | self._init_c = nn.Parameter( 60 | torch.Tensor(n_layer*(2 if bidirectional else 1), n_hidden)) 61 | init.uniform_(self._init_h, -INI, INI) 62 | init.uniform_(self._init_c, -INI, INI) 63 | self._lstm = nn.LSTM(input_dim, n_hidden, n_layer, 64 | dropout=dropout, bidirectional=bidirectional) 65 | 66 | def forward(self, input_, in_lens=None): 67 | """ [batch_size, max_num_sent, input_dim] Tensor""" 68 | size = (self._init_h.size(0), input_.size(0), self._init_h.size(1)) 69 | init_states = (self._init_h.unsqueeze(1).expand(*size), 70 | self._init_c.unsqueeze(1).expand(*size)) 71 | lstm_out, _ = lstm_encoder( 72 | input_, self._lstm, in_lens, init_states) 73 | return lstm_out.transpose(0, 1) 74 | 75 | @property 76 | def input_size(self): 77 | return self._lstm.input_size 78 | 79 | @property 80 | def hidden_size(self): 81 | return self._lstm.hidden_size 82 | 83 | @property 84 | def num_layers(self): 85 | return self._lstm.num_layers 86 | 87 | @property 88 | def bidirectional(self): 89 | return self._lstm.bidirectional 90 | 91 | class LSTMPointerNet(nn.Module): 92 | """Pointer network as in Vinyals et al """ 93 | def __init__(self, input_dim, n_hidden, n_layer, 94 | dropout, n_hop, bidirectional=True): 95 | super().__init__() 96 | self._init_h = nn.Parameter(torch.Tensor(n_layer, n_hidden)) 97 | self._init_c = nn.Parameter(torch.Tensor(n_layer, n_hidden)) 98 | self._init_i = torch.Tensor(input_dim) 99 | init.uniform_(self._init_h, -INI, INI) 100 | init.uniform_(self._init_c, -INI, INI) 101 | init.uniform_(self._init_i, -0.1, 0.1) 102 | self._lstm = nn.LSTM( 103 | input_dim, n_hidden, n_layer, 104 | bidirectional=False, dropout=dropout 105 | ) 106 | self._lstm_cell = None 107 | 108 | # attention parameters 109 | self._attn_wm = nn.Parameter(torch.Tensor(input_dim, n_hidden)) 110 | self._attn_wq = nn.Parameter(torch.Tensor(n_hidden, n_hidden)) 111 | self._attn_v = nn.Parameter(torch.Tensor(n_hidden)) 112 | init.xavier_normal_(self._attn_wm) 113 | init.xavier_normal_(self._attn_wq) 114 | init.uniform_(self._attn_v, -INI, INI) 115 | 116 | # hop parameters 117 | self._hop_wm = nn.Parameter(torch.Tensor(input_dim, n_hidden)) 118 | self._hop_wq = nn.Parameter(torch.Tensor(n_hidden, n_hidden)) 119 | self._hop_v = nn.Parameter(torch.Tensor(n_hidden)) 120 | init.xavier_normal_(self._hop_wm) 121 | init.xavier_normal_(self._hop_wq) 122 | init.uniform_(self._hop_v, -INI, INI) 123 | self._n_hop = n_hop 124 | 125 | def forward(self, attn_mem, mem_sizes, lstm_in): 126 | """atten_mem: Tensor of size [batch_size, max_sent_num, input_dim]""" 127 | attn_feat, hop_feat, lstm_states, init_i = self._prepare(attn_mem, mem_sizes) 128 | lstm_in = torch.cat([init_i, lstm_in], dim=1).transpose(0, 1) 129 | query, final_states = self._lstm(lstm_in, lstm_states) 130 | query = query.transpose(0, 1) 131 | for _ in range(self._n_hop): 132 | query = LSTMPointerNet.attention( 133 | hop_feat, query, self._hop_v, self._hop_wq, mem_sizes) 134 | output = LSTMPointerNet.attention_score( 135 | attn_feat, query, self._attn_v, self._attn_wq) 136 | return output # unormalized extraction logit 137 | 138 | def extract(self, attn_mem, mem_sizes, k): 139 | """extract k sentences, decode only, batch_size==1""" 140 | attn_feat, hop_feat, lstm_states, lstm_in = self._prepare(attn_mem, mem_sizes) 141 | lstm_in = lstm_in.squeeze(1) 142 | if self._lstm_cell is None: 143 | self._lstm_cell = MultiLayerLSTMCells.convert( 144 | self._lstm).to(attn_mem.get_device()) 145 | extracts = [] 146 | for _ in range(k): 147 | h, c = self._lstm_cell(lstm_in, lstm_states) 148 | query = h[-1] 149 | for _ in range(self._n_hop): 150 | query = LSTMPointerNet.attention( 151 | hop_feat, query, self._hop_v, self._hop_wq, mem_sizes) 152 | score = LSTMPointerNet.attention_score( 153 | attn_feat, query, self._attn_v, self._attn_wq) 154 | score = score.squeeze() 155 | for e in extracts: 156 | score[e] = -1e6 157 | ext = score.max(dim=0)[1].item() 158 | extracts.append(ext) 159 | lstm_states = (h, c) 160 | lstm_in = attn_mem[:, ext, :] 161 | return extracts 162 | 163 | def _prepare(self, attn_mem, mem_sizes): 164 | attn_feat = torch.matmul(attn_mem, self._attn_wm.unsqueeze(0)) 165 | hop_feat = torch.matmul(attn_mem, self._hop_wm.unsqueeze(0)) 166 | bs = attn_mem.size(0) 167 | n_l, d = self._init_h.size() 168 | size = (n_l, bs, d) 169 | lstm_states = (self._init_h.unsqueeze(1).expand(*size).contiguous(), 170 | self._init_c.unsqueeze(1).expand(*size).contiguous()) 171 | d = self._init_i.size(0) 172 | 173 | # random 174 | # init_i = self._init_i.unsqueeze(0).unsqueeze(1).expand(bs, 1, d) 175 | 176 | # last state 177 | # init_i = attn_mem[:, -1, :].unsqueeze(1).expand(bs, 1, d) 178 | 179 | # max pooling 180 | init_i = LSTMPointerNet.max_pooling(attn_mem, mem_sizes) 181 | 182 | # mean pooling 183 | # init_i = LSTMPointerNet.mean_pooling(attn_mem, mem_sizes) 184 | 185 | return attn_feat, hop_feat, lstm_states, init_i 186 | 187 | @staticmethod 188 | def attention_score(attention, query, v, w): 189 | """ unnormalized attention score""" 190 | sum_ = attention.unsqueeze(1) + torch.matmul( 191 | query, w.unsqueeze(0) 192 | ).unsqueeze(2) # [B, Nq, Ns, D] 193 | 194 | score = torch.matmul( 195 | torch.tanh(sum_), v.unsqueeze(0).unsqueeze(1).unsqueeze(3) 196 | ).squeeze(3) # [B, Nq, Ns] 197 | 198 | return score 199 | 200 | @staticmethod 201 | def attention(attention, query, v, w, mem_sizes): 202 | """ attention context vector""" 203 | score = LSTMPointerNet.attention_score(attention, query, v, w) 204 | 205 | if mem_sizes is None: # decode 206 | norm_score = F.softmax(score, dim=-1) 207 | else: 208 | mask = len_mask(mem_sizes, score.get_device()).unsqueeze(-2) 209 | norm_score = prob_normalize(score, mask) 210 | 211 | output = torch.matmul(norm_score, attention) 212 | return output 213 | 214 | @staticmethod 215 | def mean_pooling(attn_mem, mem_sizes): 216 | if mem_sizes is None: # decode 217 | lens = torch.Tensor([attn_mem.size(1)]).cuda() 218 | else: 219 | lens = torch.Tensor(mem_sizes).unsqueeze(1).cuda() 220 | init_i = torch.sum(attn_mem, dim=1) / lens 221 | init_i = init_i.unsqueeze(1) 222 | return init_i 223 | 224 | @staticmethod 225 | def max_pooling(attn_mem, mem_sizes): 226 | if mem_sizes is not None: 227 | # not in decode 228 | B, Ns = attn_mem.size(0), attn_mem.size(1) 229 | mask = torch.ByteTensor(B, Ns).cuda() 230 | mask.fill_(0) 231 | for i, l in enumerate(mem_sizes): 232 | mask[i, :l].fill_(1) 233 | mask = mask.unsqueeze(-1) 234 | attn_mem = attn_mem.masked_fill(mask == 0, -1e18) 235 | init_i = attn_mem.max(dim=1, keepdim=True)[0] 236 | return init_i 237 | 238 | 239 | class Summarizer(nn.Module): 240 | """ Different encoder/decoder/embedding type """ 241 | def __init__(self, encoder, decoder, emb_type, emb_dim, vocab_size, 242 | conv_hidden, encoder_hidden, encoder_layer, 243 | isTrain=True, n_hop=1, dropout=0.0): 244 | super().__init__() 245 | self._encoder = encoder 246 | self._decoder = decoder 247 | self._emb_type = emb_type 248 | 249 | self._sent_enc = ConvSentEncoder( 250 | vocab_size, emb_dim, conv_hidden, dropout, emb_type) 251 | 252 | # BERT 253 | if emb_type == 'BERT': 254 | self._bert = BertModel.from_pretrained( 255 | '/path/to/uncased_L-24_H-1024_A-16') 256 | self._bert.eval() 257 | for p in self._bert.parameters(): 258 | p.requires_grad = False 259 | self._bert_w = nn.Linear(1024*4, emb_dim) 260 | 261 | # Sentence Encoder 262 | if encoder == 'BiLSTM': 263 | enc_out_dim = encoder_hidden * 2 # bidirectional 264 | self._art_enc = LSTMEncoder( 265 | 3*conv_hidden, encoder_hidden, encoder_layer, 266 | dropout=dropout, bidirectional=True 267 | ) 268 | elif encoder == 'Transformer': 269 | enc_out_dim = encoder_hidden 270 | self._art_enc = TransformerEncoder( 271 | 3*conv_hidden, encoder_hidden, encoder_layer, decoder) 272 | 273 | self._emb_w = nn.Linear(3*conv_hidden, encoder_hidden) 274 | self.sent_pos_embed = nn.Embedding.from_pretrained( 275 | get_sinusoid_encoding_table(1000, enc_out_dim, padding_idx=0), freeze=True) 276 | elif encoder == 'DeepLSTM': 277 | enc_out_dim = encoder_hidden 278 | self._isTrain = isTrain 279 | self._art_enc = DeepLSTM( 280 | 3*conv_hidden, encoder_hidden, encoder_layer, 0.1) 281 | 282 | # Decoder 283 | decoder_hidden = encoder_hidden 284 | decoder_layer = encoder_layer 285 | if decoder == 'PN': 286 | self._extractor = LSTMPointerNet( 287 | enc_out_dim, decoder_hidden, decoder_layer, 288 | dropout, n_hop 289 | ) 290 | else: 291 | self._ws = nn.Linear(enc_out_dim, 2) 292 | 293 | 294 | def forward(self, article_sents, sent_nums, target): 295 | enc_out = self._encode(article_sents, sent_nums) 296 | 297 | if self._decoder == 'PN': 298 | bs, nt = target.size() 299 | d = enc_out.size(2) 300 | ptr_in = torch.gather( 301 | enc_out, dim=1, index=target.unsqueeze(2).expand(bs, nt, d) 302 | ) 303 | output = self._extractor(enc_out, sent_nums, ptr_in) 304 | 305 | else: 306 | bs, seq_len, d = enc_out.size() 307 | output = self._ws(enc_out) 308 | assert output.size() == (bs, seq_len, 2) 309 | 310 | return output 311 | 312 | def extract(self, article_sents, sent_nums=None, k=4): 313 | enc_out = self._encode(article_sents, sent_nums) 314 | 315 | if self._decoder == 'PN': 316 | extract = self._extractor.extract(enc_out, sent_nums, k) 317 | else: 318 | seq_len = enc_out.size(1) 319 | output = self._ws(enc_out) 320 | assert output.size() == (1, seq_len, 2) 321 | _, indices = output[:, :, 1].sort(descending=True) 322 | extract = [] 323 | for i in range(k): 324 | extract.append(indices[0][i].item()) 325 | 326 | return extract 327 | 328 | def _encode(self, article_sents, sent_nums): 329 | 330 | hidden_size = self._art_enc.input_size 331 | 332 | if sent_nums is None: # test-time excode only 333 | if self._emb_type == 'W2V': 334 | enc_sent = self._sent_enc(article_sents[0]).unsqueeze(0) 335 | else: 336 | enc_sent = self._article_encode(article=article_sents[0], 337 | device=article_sents[0].device).unsqueeze(0) 338 | else: 339 | max_n = max(sent_nums) 340 | if self._emb_type == 'W2V': 341 | enc_sents = [self._sent_enc(art_sent) 342 | for art_sent in article_sents] 343 | else: 344 | enc_sents = [self._article_encode(article=article, device=article.device) 345 | for article in article_sents] 346 | def zero(n, device): 347 | z = torch.zeros(n, hidden_size).to(device) 348 | return z 349 | enc_sent = torch.stack( 350 | [torch.cat([s, zero(max_n-n, s.get_device())], dim=0) 351 | if n != max_n 352 | else s 353 | for s, n in zip(enc_sents, sent_nums)], 354 | dim=0 355 | ) 356 | 357 | # Input for different encoder 358 | if self._encoder == 'BiLSTM': 359 | output = self._art_enc(enc_sent, sent_nums) 360 | 361 | elif self._encoder == 'Transformer': 362 | batch_size, seq_len = enc_sent.size(0), enc_sent.size(1) 363 | 364 | # prepare mask 365 | if sent_nums != None: 366 | input_len = len_mask(sent_nums, enc_sent.get_device()).float() # [batch_size, seq_len] 367 | else: 368 | input_len = torch.ones(batch_size, seq_len).float().cuda() 369 | 370 | attn_mask = input_len.eq(0.0).unsqueeze(1).expand(batch_size, 371 | seq_len, seq_len).cuda() # [batch_size, seq_len, seq_len] 372 | non_pad_mask = input_len.unsqueeze(-1).cuda() # [batch, seq_len, 1] 373 | 374 | # add postional embedding 375 | if sent_nums != None: 376 | sent_pos = torch.LongTensor([np.hstack((np.arange(1, doclen + 1), 377 | np.zeros(seq_len - doclen))) for doclen in sent_nums]).cuda() 378 | else: 379 | sent_pos = torch.LongTensor([np.arange(1, seq_len + 1)]).cuda() 380 | 381 | inputs = self._emb_w(enc_sent) + self.sent_pos_embed(sent_pos) 382 | 383 | assert attn_mask.size() == (batch_size, seq_len, seq_len) 384 | assert non_pad_mask.size() == (batch_size, seq_len, 1) 385 | 386 | output = self._art_enc(inputs, non_pad_mask, attn_mask) 387 | 388 | elif self._encoder == 'DeepLSTM': 389 | batch_size, seq_len = enc_sent.size(0), enc_sent.size(1) 390 | inputs = [enc_sent.transpose(0, 1)] 391 | 392 | # prepare mask 393 | if sent_nums != None: 394 | inputs_mask = [len_mask(sent_nums, enc_sent.get_device()).transpose(0, 1).unsqueeze(-1)] 395 | else: 396 | inputs_mask = [torch.ones(seq_len, batch_size, 1).cuda()] 397 | 398 | for _ in range(self._art_enc.num_layers): 399 | inputs.append([None]) 400 | inputs_mask.append([None]) 401 | 402 | assert inputs[0].size() == (seq_len, batch_size, hidden_size) 403 | assert inputs_mask[0].size() == (seq_len, batch_size, 1) 404 | 405 | output = self._art_enc(inputs, inputs_mask, self._isTrain) 406 | 407 | return output 408 | 409 | def _article_encode(self, article, device, pad_idx=0): 410 | sent_num, sent_len = article.size() 411 | tokens_id = [101] # [CLS] 412 | for i in range(sent_num): 413 | for j in range(sent_len): 414 | if article[i][j] != pad_idx: 415 | tokens_id.append(article[i][j]) 416 | else: 417 | break 418 | tokens_id.append(102) # [SEP] 419 | input_mask = [1] * len(tokens_id) 420 | total_len = len(tokens_id) - 2 421 | while len(tokens_id) < MAX_ARTICLE_LEN: 422 | tokens_id.append(0) 423 | input_mask.append(0) 424 | 425 | assert len(tokens_id) == MAX_ARTICLE_LEN 426 | assert len(input_mask) == MAX_ARTICLE_LEN 427 | 428 | input_ids = torch.LongTensor(tokens_id).unsqueeze(0).to(device) 429 | input_mask = torch.LongTensor(input_mask).unsqueeze(0).to(device) 430 | 431 | # concat last 4 layers 432 | out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask) 433 | out = torch.cat([out[-1], out[-2], out[-3], out[-4]], dim=-1) 434 | 435 | assert out.size() == (1, MAX_ARTICLE_LEN, 4096) 436 | 437 | emb_out = self._bert_w(out).squeeze(0) 438 | emb_dim = emb_out.size(-1) 439 | 440 | emb_input = torch.zeros(sent_num, sent_len, emb_dim).to(device) 441 | cur_idx = 1 # after [CLS] 442 | for i in range(sent_num): 443 | for j in range(sent_len): 444 | if article[i][j] != pad_idx: 445 | emb_input[i][j] = emb_out[cur_idx] 446 | cur_idx += 1 447 | else: 448 | break 449 | assert cur_idx - 1 == total_len 450 | 451 | cnn_out = self._sent_enc(emb_input) 452 | assert cnn_out.size() == (sent_num, 300) # 300 = 3 * conv_hidden 453 | 454 | return cnn_out 455 | 456 | def set_embedding(self, embedding): 457 | self._sent_enc.set_embedding(embedding) 458 | 459 | --------------------------------------------------------------------------------