├── .gitignore ├── Readme.md ├── src ├── evaluate.py ├── model.py ├── predict.py ├── saved_model │ └── model_20.pt └── train.py └── utils ├── config.py ├── data_loader.py ├── multi_proc_utils.py ├── params_utils.py └── preprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /data/ 131 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ### 一、介绍 2 | 本项目实现了生成式自动文本摘要的模型训练及预测,项目用到的主要架构如下: 3 | 1. seq2seq:利用 Encoder-Decoder 架构,结合 Attention,对文本摘要生成任务进行建模和训练。Encoder 和 Decoder 中,分别使用双层和单层 GRU 结构来抽取文本特征。 4 | 2. PGN:使用PGN(Pointer Generator Network),可以直接 copy 原文中的重要单词和短语,缓解文本摘要生成过程中可能出现 OOV 问题。 5 | 3. Coverage 机制:在模型中,加入 Coverage Loss,对过往时刻已经生成的单词进行惩罚,缓解文本摘要生成过程中可能出现的重复问题。 6 | 7 | 关于本项目中的各项细节可以参考以下文章: 8 | 9 | [文本摘要(一):任务介绍](https://zhuanlan.zhihu.com/p/451808468) 10 | 11 | [文本摘要(二):TextRank](https://zhuanlan.zhihu.com/p/452359234) 12 | 13 | [文本摘要(三):数据处理](https://zhuanlan.zhihu.com/p/452429994) 14 | 15 | [文本摘要(四):seq2seq 介绍及实现](https://zhuanlan.zhihu.com/p/452475603) 16 | 17 | [文本摘要(五):seq2seq 训练及预测](https://zhuanlan.zhihu.com/p/452703432) 18 | 19 | [文本摘要(六):生成任务中的采样方法](https://zhuanlan.zhihu.com/p/453286395) 20 | 21 | [文本摘要(七):PGN 模型架构](https://zhuanlan.zhihu.com/p/453600830) 22 | 23 | ### 二、 框架 24 | ``` 25 | ├── data 26 | │ ├── sina-article-test.txt 27 | │ ├── sina-article-train.txt 28 | │ ├── train_label.txt 29 | │ └── train_text.txt 30 | ├── src 31 | │ ├── saved_model 32 | │ ├── evaluate.py 33 | │ ├── model.py 34 | │ ├── predict.py 35 | │ └── train.py 36 | └── utils 37 | ├── config.py 38 | ├── data_loader.py 39 | ├── multi_proc_utils.py 40 | ├── params_utils.py 41 | └── preprocess.py 42 | ``` 43 | 44 | ### 三、使用 45 | 1. 数据预处理,生成分词后的 'data/sina-article-train.txt' 及 'data/sina-article-test.txt' 文件. 46 | ```bash 47 | python utils/preprocess.py 48 | ``` 49 | 2. 模型训练,训练好的模型储存在 'src/saved_model' 文件夹中。 50 | ```bash 51 | python src/train.py 52 | ``` 53 | 3. 模型预测,修改 config 中的模型加载路径 model_load_path 54 | ```bash 55 | python src/predict.py 56 | ``` 57 | 58 | ### 四、TODO 59 | - [ ] ROUGE 指标评测 60 | - [ ] 参数调整 61 | - [ ] 模型部署 62 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(root_path) 5 | 6 | import numpy as np 7 | import torch 8 | from utils import config 9 | 10 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(root_path) 12 | 13 | 14 | # 模型验证过程 15 | def evaluate(model, val_dataloader, loss_fn): 16 | val_loss = [] 17 | model.eval() 18 | with torch.no_grad(): 19 | val_loss = [] 20 | for i, (text, text_len, title_in, title_out, oovs, len_oovs) in enumerate(val_dataloader): 21 | text = text.to(config.device) 22 | title_in = title_in.to(config.device) 23 | title_out = title_out.to(config.device) 24 | title_pred, _, _ = model(text, title_in, text_len, len_oovs) 25 | loss = loss_fn(title_pred.transpose(1, 2).to(config.device), title_out) 26 | val_loss.append(loss.item()) 27 | return np.mean(val_loss) 28 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(root_path) 5 | 6 | import random 7 | import torch 8 | from torch import nn 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | 11 | from utils import config 12 | from utils.data_loader import replace_oovs, load_data, Vocab 13 | 14 | 15 | class Encoder(nn.Module): 16 | def __init__(self, vocab): 17 | super(Encoder, self).__init__() 18 | self.vocab_size = len(vocab) 19 | # Embedding 层设定 padding_idx 值 20 | self.embedding = nn.Embedding(self.vocab_size, config.emb_size, padding_idx=vocab[config.PAD_TOKEN]) 21 | self.gru = nn.GRU(config.emb_size, config.hidden_size, num_layers=config.num_layers, batch_first=True, dropout=config.dropout) 22 | self.dropout = nn.Dropout(config.dropout) 23 | self.linear = nn.Linear(config.hidden_size, config.hidden_size) 24 | self.relu = nn.ReLU() 25 | 26 | def forward(self, enc_input, text_lengths): 27 | # enc_input: [batch_size, seq_len], 经过 padding 处理过的输入token_id 28 | # text_lengths: [batch_size], padding 之前,输入 tokens 的长度 29 | embedded = self.dropout(self.embedding(enc_input)) # [batch_size, seq_len, emb_size] 30 | # 输入 GRU 前,将 padded_sequence 打包,去掉 padding token,加快训练速度 31 | embedded = pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False) 32 | # GRU 的输出 分为两部分,output 对应每个 token 的最后一层隐状态,hidden 对应最后一个字符的所有层的隐状态 33 | # output: [batch_size, seq_len, hidden_size] 34 | # hidden: [num_layers, batch_size, hidden_size] 35 | output, hidden = self.gru(embedded) 36 | # GRU 训练完成后,再恢复 padding token 状态 37 | output, _ = pad_packed_sequence(output, batch_first=True) 38 | # 输出再经过一个 linear 层,增加模型复杂度 39 | output = self.relu(self.linear(output)) 40 | return output, hidden[-1].detach() 41 | 42 | 43 | class Decoder(nn.Module): 44 | def __init__(self, vocab, attention): 45 | super(Decoder, self).__init__() 46 | self.vocab_size = len(vocab) 47 | self.embedding = nn.Embedding(self.vocab_size, config.emb_size, padding_idx=vocab[config.PAD_TOKEN]) 48 | self.attention = attention 49 | self.gru = nn.GRU(config.emb_size + config.hidden_size, config.hidden_size, batch_first=True) 50 | self.linear = nn.Linear(config.emb_size + 2 * config.hidden_size, self.vocab_size) 51 | self.dropout = nn.Dropout(config.dropout) 52 | # 设置 PGN 网络架构的参数,用于计算 p_gen 53 | if config.pointer: 54 | self.w_gen = nn.Linear(config.hidden_size * 2 + config.emb_size, 1) 55 | 56 | def forward(self, dec_input, prev_hidden, enc_output, text_lengths, coverage_vector): 57 | # 与 Encoder 不同,Decoder 的计算是分步进行的,每次输入一个时间步的 dec_input,同时输出这个时间步的 dec_output 58 | # dec_input = [batch_size] 59 | # prev_hidden = [batch_size, hidden_size] 60 | # enc_output = [batch_size, src_len, hidden_size] 61 | dec_input = dec_input.unsqueeze(1) # [batch_size, 1] 62 | embedded = self.embedding(dec_input) # [batch_size, 1, dec_len] 63 | # 加入 coverage 机制后,attention 的计算公式参考 https://zhuanlan.zhihu.com/p/453600830 64 | attention_weights, coverage_vector = self.attention(embedded, enc_output, text_lengths, coverage_vector) 65 | attention_weights = attention_weights.unsqueeze(1) # [batch_size, 1, enc_len] 66 | # 根据 attention weights,计算 context vector 67 | c = torch.bmm(attention_weights, enc_output) # [batch_size, 1, hidden_size] 68 | # 将经过 embedding 处理过的 decoder 输入,和上下文向量一起送入到 GRU 网络中 69 | gru_input = torch.cat([embedded, c], dim=2) 70 | # dec_output: [batch_size, 1, hidden_size] 71 | # dec_hidden: [1, batch_size, hidden_size] 72 | # prev_hidden 是上个时间步的隐状态,作为 decoder 的参数传入进来 73 | dec_output, dec_hidden = self.gru(gru_input, prev_hidden.unsqueeze(0)) 74 | # 将输出映射到 vocab_size 维度,以便计算每个 vocab 的生成概率 75 | dec_output = self.linear(torch.cat((dec_output.squeeze(1), c.squeeze(1), embedded.squeeze(1)), dim=1)) # [batch_size, vocab_size] 76 | dec_hidden = dec_hidden.squeeze(0) 77 | p_gen = None 78 | # 计算 p_gen 79 | if config.pointer: 80 | x_gen = torch.cat([dec_hidden, c.squeeze(1), embedded.squeeze(1)], dim=1) 81 | p_gen = torch.sigmoid(self.w_gen(x_gen)) 82 | return dec_output, dec_hidden, attention_weights.squeeze(1), p_gen, coverage_vector 83 | 84 | 85 | class Attention(nn.Module): 86 | def __init__(self): 87 | super(Attention, self).__init__() 88 | self.linear = nn.Linear(config.hidden_size * 2 + config.emb_size, config.hidden_size) 89 | self.v = nn.Linear(config.hidden_size, 1) 90 | self.softmax = nn.Softmax(dim=-1) 91 | 92 | def forward(self, dec_input, enc_output, text_lengths, coverage_vector): 93 | # enc_output = [batch_size, seq_len, hidden_size] 94 | # dec_input = [batch_size, hidden_size] 95 | # text_lengths = [batch_size] 96 | # coverage_vector = [batch_size, seq_len] 97 | seq_len = enc_output.shape[1] 98 | hidden_size = enc_output.shape[-1] 99 | s = dec_input.repeat(1, seq_len, 1) # [batch_size, seq_len, hidden_size] 100 | coverage_vector_copy = coverage_vector.unsqueeze(2).repeat(1, 1, hidden_size) # [batch_size, seq_len, hidden_size] 101 | # enc_output, s, coverage_vector_copy 维度统一,用于计算 attention 102 | x = torch.tanh(self.linear(torch.cat([enc_output, s, coverage_vector_copy], dim=2))) 103 | attention = self.v(x).squeeze(-1) # [batch_size, seq_len] 104 | max_len = enc_output.shape[1] 105 | # mask = [batch_size, seq_len],遮蔽掉 Decoder 当前时间步之后的单词 106 | mask = torch.arange(max_len).expand(text_lengths.shape[0], max_len) >= text_lengths.unsqueeze(1) 107 | attention.masked_fill_(mask.to(config.device), float('-inf')) 108 | attention_weights = self.softmax(attention) 109 | # 更新 coverage_vector 110 | coverage_vector += attention_weights 111 | return attention_weights, coverage_vector # [batch, seq_len], [batch_size, seq_len] 112 | 113 | 114 | # seq2seq 模型架构 115 | class Seq2seq(nn.Module): 116 | def __init__(self, vocab): 117 | super(Seq2seq, self).__init__() 118 | attention = Attention() 119 | self.encoder = Encoder(vocab) 120 | self.decoder = Decoder(vocab, attention) 121 | 122 | def forward(self, src, tgt, src_lengths, teacher_forcing_ratio=0.5): 123 | # src = [batch_size, src_len] 124 | # tgt = [batch_size, tgt_len] 125 | batch_size = tgt.shape[0] 126 | tgt_len = tgt.shape[1] 127 | vocab_size = self.decoder.vocab_size 128 | enc_output, prev_hidden = self.encoder(src, src_lengths) 129 | dec_input = tgt[:, 0] 130 | dec_outputs = torch.zeros(batch_size, tgt_len, vocab_size) 131 | for t in range(tgt_len - 1): 132 | dec_output, prev_hidden, _, _ = self.decoder(dec_input, prev_hidden, enc_output, src_lengths) 133 | dec_outputs[:, t, :] = dec_output 134 | teacher_force = random.random() < teacher_forcing_ratio 135 | top1 = dec_output.argmax(1) 136 | dec_input = tgt[:, t] if teacher_force else top1 137 | return dec_outputs 138 | 139 | 140 | # PGN 模型架构 141 | class PGN(nn.Module): 142 | def __init__(self, vocab): 143 | super(PGN, self).__init__() 144 | self.vocab = vocab 145 | self.vocab_size = len(vocab) 146 | self.device = config.device 147 | 148 | attention = Attention() 149 | self.encoder = Encoder(vocab) 150 | self.decoder = Decoder(vocab, attention) 151 | 152 | def get_final_distribution(self, x, p_gen, p_vocab, attention_weights, max_oov): 153 | # 应用 PGN 公式,计算最终单词的概率分布。由于PGN网络会copy原文中的单词,因此需要考虑原文中 OOV 单词的影响 154 | if not config.pointer: 155 | return p_vocab 156 | batch_size = x.shape[0] 157 | p_gen = torch.clamp(p_gen, 0.001, 0.999) 158 | p_vocab_weighted = p_gen * p_vocab 159 | attention_weighted = (1 - p_gen) * attention_weights 160 | # 加入 max_oov 维度,将原文中的 OOV 单词考虑进来 161 | extension = torch.zeros((batch_size, max_oov), dtype=torch.float).to(self.device) 162 | p_vocab_extended = torch.cat([p_vocab_weighted, extension], dim=-1) 163 | # p_gen * p_vocab + (1 - p_gen) * attention_weights, 将 attention weights 中的每个位置 idx 映射成该位置的 token_id 164 | final_distribution = p_vocab_extended.scatter_add_(dim=1, index=x, src=attention_weighted) 165 | # 输出最终的 vocab distribution [batch_size, vocab_size + len(oov)] 166 | return final_distribution 167 | 168 | def forward(self, src, tgt, src_lengths, len_oovs, teacher_forcing_ratio=0.5): 169 | # src = [batch_size, src_len],Encoder 原文输入 170 | # tgt = [batch_size, tgt_len],Decoder 摘要输入 171 | # src_lengths = [batch_size], Encoder 原文长度 172 | # len_oovs = [batch_size, max_oovs], Encoder 原文中 oov 的长度 173 | 174 | # 将 oov 替换成 , 以便 Encoder 可以处理 175 | src_copy = replace_oovs(src, self.vocab) 176 | batch_size = tgt.shape[0] 177 | tgt_len = tgt.shape[1] 178 | vocab_size = self.vocab_size 179 | # encoder 过程 180 | enc_output, prev_hidden = self.encoder(src_copy, src_lengths) 181 | # decoder 的第一个输入 182 | dec_input = tgt[:, 0] 183 | dec_outputs = torch.zeros(batch_size, tgt_len, vocab_size + max(len_oovs)) 184 | coverage_vector = torch.zeros_like(src, dtype=torch.float32).to(config.device) 185 | # 依次处理每一个 decoder 时间步的输入 186 | for t in range(tgt_len - 1): 187 | # 将 oov 替换成 , 以便 Dncoder 可以处理 188 | dec_input = replace_oovs(dec_input, self.vocab) 189 | dec_output, prev_hidden, attention_weights, p_gen, coverage_vector = self.decoder(dec_input, prev_hidden, enc_output, src_lengths, coverage_vector) 190 | final_distribution = self.get_final_distribution(src, p_gen, dec_output, attention_weights, max(len_oovs)) 191 | # 随机使用 teacher forcing 训练,增加模型稳定性 192 | teacher_force = random.random() < teacher_forcing_ratio 193 | # 将这个时间步得到的每个单词的概率,赋值给 dec_outputs 194 | dec_outputs[:, t, :] = final_distribution 195 | top1 = final_distribution.argmax(1) 196 | dec_input = tgt[:, t] if teacher_force else top1 197 | return dec_outputs, attention_weights, coverage_vector 198 | 199 | 200 | if __name__ == '__main__': 201 | train_text, train_title = load_data(config.train_save_path) 202 | train_text = train_text[:1000] 203 | train_title = train_title[:1000] 204 | vocab = Vocab(train_text + train_title, reserved_tokens=config.reserved_tokens) 205 | model = PGN(vocab) 206 | print(model) 207 | -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import random 5 | from utils import config 6 | from src.model import Seq2seq, PGN 7 | from utils.data_loader import * 8 | 9 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | 12 | # 模型预测过程 13 | def predict(model, vocab, text, max_len=20): 14 | # 预测的长度大于 max_len 或者遇到 EOS_TOKEN 时停止 15 | model.eval() 16 | dec_words = [] 17 | with torch.no_grad(): 18 | # 处理输入 19 | src, oovs = vocab.convert_text_to_ids(text) 20 | src_lengths = torch.tensor([len(src)]) 21 | src = torch.tensor(src).reshape(1, -1) 22 | src_copy = replace_oovs(src, vocab) 23 | enc_output, prev_hidden = model.encoder(src_copy, src_lengths) 24 | # Decoder 的第一个输入为 EOS_TOKEN 25 | dec_input = torch.tensor([vocab[config.BOS_TOKEN]]).to(device) 26 | # 依次处理每个时间步的 decoder 过程 27 | for t in range(max_len): 28 | dec_output, prev_hidden, attention_weights, p_gen = model.decoder(dec_input, prev_hidden, enc_output, src_lengths) 29 | final_distribution = model.get_final_distribution(src,p_gen, dec_output, attention_weights, len(oovs)) 30 | dec_output = final_distribution.argmax(-1) 31 | token_id = dec_output.item() 32 | # 对 token_id 进行解码,转换成单词 33 | # 遇到 EOS_TOKEN 时停止 34 | if dec_output.item() == vocab[config.EOS_TOKEN]: 35 | dec_words.append(config.EOS_TOKEN) 36 | break 37 | # token_id 在 vocab 里面,直接输出 38 | elif token_id < len(vocab): 39 | dec_words.append(vocab.idx2token[token_id]) 40 | # token_id 在 oovs 里面,输入 oovs 对应的该单词。oovs 来源于原文输入。 41 | elif token_id < len(vocab) + len(oovs): 42 | dec_words.append(oovs[token_id - len(vocab)]) 43 | # 其他情况,输入 UNK_TOKEN 44 | else: 45 | dec_words.append(vocab.UNK_TOKEN) 46 | # 将 decoder output 作为下一个时刻的 decoder input,并将其中的 oovs 替换成 UNK_TOKEN 47 | dec_input = replace_oovs(dec_output, vocab) 48 | return dec_words 49 | 50 | 51 | if __name__ == '__main__': 52 | train_text, train_title = load_data(config.train_save_path) 53 | if config.train_sample > 0: 54 | train_text = train_text[:config.train_sample] 55 | train_title = train_title[:config.train_sample] 56 | vocab = Vocab(train_text + train_title, reserved_tokens=config.reserved_tokens) 57 | 58 | # 加载训练好的模型 59 | model = PGN(vocab) 60 | model.load_state_dict((torch.load(config.model_load_path))) 61 | 62 | # 随机打印预测的结果 63 | for i in range(10): 64 | idx = random.randint(0, config.train_sample) 65 | text = train_text[idx].split() 66 | title = train_title[idx].split() 67 | print('>', ''.join(text)) 68 | print('=', ''.join(title)) 69 | output_words = predict(model, vocab, text) 70 | print('<', ''.join(output_words)) 71 | print('') 72 | -------------------------------------------------------------------------------- /src/saved_model/model_20.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minlik/TextSummarization/fbc52fb63f0c06fa5ac155e7dc431caec08227fd/src/saved_model/model_20.pt -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torch.utils.data import DataLoader 6 | from torch.nn.utils import clip_grad_norm_ 7 | from matplotlib import pyplot as plt 8 | import torch 9 | from torch import nn 10 | import datetime 11 | 12 | from model import Seq2seq, PGN 13 | from src.evaluate import evaluate 14 | from utils.data_loader import Vocab, MyDataset, load_data 15 | from utils import config 16 | 17 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 18 | sys.path.append(root_path) 19 | 20 | 21 | def print_bar(): 22 | now_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 23 | print("=========="*8 + now_time) 24 | 25 | 26 | def train(model, train_dataloader, val_dataloader, loss_fn, optimizer, epochs): 27 | # 模型训练过程 28 | model = model.to(config.device) 29 | model.train() 30 | print_bar() 31 | print('Start Training...') 32 | train_loss = [] 33 | val_loss = [] 34 | for epoch in range(epochs): 35 | total_loss = 0 36 | for i, (text, text_len, title_in, title_out, oovs, len_oovs) in enumerate(train_dataloader): 37 | text = text.to(config.device) 38 | title_in = title_in.to(config.device) 39 | title_out = title_out.to(config.device) 40 | optimizer.zero_grad() 41 | title_pred, attention_weights, coverage_vector = model(text, title_in, text_len, len_oovs) 42 | # 计算 cross entropy loss 43 | ce_loss = loss_fn(title_pred.transpose(1, 2).to(config.device), title_out) 44 | if config.coverage: 45 | # 计算 coverage loss 46 | c_t = torch.min(attention_weights, coverage_vector) 47 | cov_loss = torch.mean(torch.sum(c_t, dim=1)) 48 | # 计算整体 loss 49 | loss = ce_loss + config.cov_lambda * cov_loss 50 | total_loss += loss.item() 51 | loss.backward() 52 | # 梯度截断 53 | clip_grad_norm_(model.parameters(), config.max_grad_norm) 54 | optimizer.step() 55 | 56 | avg_train_loss = total_loss / len(train_dataloader) 57 | # 每个 epoch 结束,验证模型的精度 58 | avg_val_loss = evaluate(model, val_dataloader, loss_fn) 59 | print_bar() 60 | print(f'epoch: {epoch+1}/{epochs}, training loss: {avg_train_loss:.4f}, validation loss: {avg_val_loss:.4f}') 61 | # if epoch == 0 or avg_val_loss < min_val_loss: 62 | if (epoch + 1) % 20 == 0: 63 | model_path = root_path + '/src/saved_model/' + 'model_' + str(epoch+1) + '.pt' 64 | torch.save(model.state_dict(), model_path) 65 | print(f'The model has been saved for epoch {epoch + 1}') 66 | # min_val_loss = avg_val_loss 67 | 68 | train_loss.append(avg_train_loss) 69 | val_loss.append(avg_val_loss) 70 | return train_loss, val_loss 71 | 72 | 73 | def collate_fn(batch): 74 | # 将 dataset 中的数据进行整理,得到 dataloader 需要的格式 75 | # 1. text 和 title 加入 padding 处理,统一每个 batch 中的句子长度 76 | # 2. 统计原文中的 oov 单词 77 | is_train = 'title_ids' in batch[0] 78 | text = [torch.tensor(example['text_ids']) for example in batch] 79 | text_len = torch.tensor([len(example['text_ids']) for example in batch]) 80 | padded_text = pad_sequence(text, batch_first=True, padding_value=vocab[config.PAD_TOKEN]) 81 | oovs = [example['oovs'] for example in batch] 82 | len_oovs = [example['len_oovs'] for example in batch] 83 | if is_train: 84 | title_in = [torch.tensor(example['title_ids'][:-1]) for example in batch] 85 | title_out = [torch.tensor(example['title_ids'][1:]) for example in batch] 86 | padded_title_in = pad_sequence(title_in, batch_first=True, padding_value=vocab[config.PAD_TOKEN]) 87 | padded_title_out = pad_sequence(title_out, batch_first=True, padding_value=vocab[config.PAD_TOKEN]) 88 | return padded_text, text_len, padded_title_in, padded_title_out, oovs, len_oovs 89 | return padded_text, text_len, oovs, len_oovs 90 | 91 | 92 | if __name__ == '__main__': 93 | train_text, train_title = load_data(config.train_save_path) 94 | if config.train_sample > 0: 95 | train_text = train_text[:config.train_sample] 96 | train_title = train_title[:config.train_sample] 97 | vocab = Vocab(train_text + train_title, reserved_tokens=config.reserved_tokens) 98 | train_dataset = MyDataset(vocab, train_text, train_title) 99 | train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=collate_fn, shuffle=True) 100 | 101 | val_text, val_title = load_data(config.val_save_path) 102 | if config.val_sample > 0: 103 | val_text = val_text[:config.val_sample] 104 | val_title = val_title[:config.val_sample] 105 | val_dataset = MyDataset(vocab, val_text, val_title) 106 | val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, collate_fn=collate_fn, shuffle=True) 107 | 108 | model = PGN(vocab) 109 | 110 | loss_fn = nn.CrossEntropyLoss(ignore_index=vocab[config.PAD_TOKEN]) 111 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 112 | 113 | train_loss, val_loss = train(model, train_dataloader, val_dataloader, loss_fn, optimizer, config.epochs) 114 | plt.plot(train_loss, label='training loss') 115 | plt.plot(val_loss, label='validation loss') 116 | plt.legend() 117 | plt.title('Training and validation loss') 118 | plt.xlabel('epoch') 119 | plt.ylabel('loss') 120 | plt.show() 121 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # 将路径加入到环境变量中 5 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | # 设置模型参数 8 | emb_size = 128 9 | hidden_size = 256 10 | num_layers = 2 11 | dropout = 0.5 12 | # 设置 PGN + coverage 13 | pointer = True 14 | coverage = True 15 | cov_lambda = 1 # 计算总体loss时,设置coverage loss的权重。 16 | 17 | # 设置训练参数 18 | batch_size = 16 19 | epochs = 10 20 | lr = 1e-3 21 | max_grad_norm = 2 # 梯度最大截断值,避免出现梯度爆炸 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | train_sample = 100 # 训练集采样,调试时使用。-1时使用全部训练集。 26 | val_sample = 100 # 验证集采样,训练时使用。-1时为全部验证集。 27 | 28 | # 设置文件路径 29 | content_path = os.path.join(root_path, 'data', 'train_text.txt') 30 | title_path = os.path.join(root_path, 'data', 'train_label.txt') 31 | train_save_path = os.path.join(root_path, 'data', 'sina-article-train.txt') 32 | val_save_path = os.path.join(root_path, 'data', 'sina-article-test.txt') 33 | model_load_path = os.path.join(root_path, 'src', 'saved_model', 'model.pt') 34 | 35 | # 定义词典中预留的 token 36 | PAD_TOKEN = "" 37 | UNK_TOKEN = "" 38 | BOS_TOKEN = "" 39 | EOS_TOKEN = "" 40 | reserved_tokens = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] 41 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from utils import config 5 | 6 | root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(root_path) 8 | 9 | from collections import defaultdict 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | from utils.config import * 13 | 14 | 15 | # 生成词典 16 | class Vocab: 17 | def __init__(self, sentences, min_freq=1, reserved_tokens=None): 18 | self.idx2token = list() 19 | self.token2idx = {} 20 | token_freqs = defaultdict(int) 21 | self.UNK_TOKEN = '' 22 | for sentence in sentences: 23 | for token in sentence.split(' '): 24 | token_freqs[token] += 1 25 | unique_tokens = reserved_tokens if reserved_tokens else [] 26 | unique_tokens += [token for token, freq in token_freqs.items() if freq >= min_freq] 27 | if self.UNK_TOKEN not in unique_tokens: 28 | unique_tokens = [self.UNK_TOKEN] + unique_tokens 29 | for token in unique_tokens: 30 | self.idx2token.append(token) 31 | self.token2idx[token] = len(self.idx2token) - 1 32 | self.unk = self.token2idx[self.UNK_TOKEN] 33 | 34 | def __len__(self): 35 | return len(self.idx2token) 36 | 37 | def __getitem__(self, token): 38 | return self.token2idx.get(token, self.unk) 39 | 40 | def convert_tokens_to_ids(self, tokens): 41 | return [self[token] for token in tokens] 42 | 43 | def convert_ids_to_tokens(self, ids): 44 | return [self.idx2token[idx] for idx in ids] 45 | 46 | # 将 source token 转化为 ids, 其中 unk_token 加入到 oovs 47 | def convert_text_to_ids(self, text_tokens): 48 | ids = [] 49 | oovs = [] 50 | unk_id = self.unk 51 | for token in text_tokens: 52 | i = self[token] 53 | if i == unk_id: 54 | if token not in oovs: 55 | oovs.append(token) 56 | oov_idx = oovs.index(token) 57 | ids.append(oov_idx + len(self)) 58 | else: 59 | ids.append(i) 60 | return ids, oovs 61 | 62 | # 将 title token 转化为 ids,考虑 source token 中出现的 oovs 63 | def convert_title_to_ids(self, title_tokens, oovs): 64 | ids = [] 65 | unk_id = self.unk 66 | for token in title_tokens: 67 | i = self[token] 68 | if i == unk_id: 69 | if token in oovs: 70 | token_idx = oovs.index(token) + len(self) 71 | ids.append(token_idx) 72 | else: 73 | ids.append(unk_id) 74 | else: 75 | ids.append(i) 76 | return ids 77 | 78 | 79 | class MyDataset(Dataset): 80 | def __init__(self, vocab, text, title=None): 81 | self.is_train = True if title is not None else False 82 | self.vocab = vocab 83 | self.text = text 84 | self.title = title 85 | 86 | def __getitem__(self, i): 87 | # 得到原文中的 token_id,以及 oovs 88 | text_ids, oovs = self.vocab.convert_text_to_ids(self.text[i].split()) 89 | if not self.is_train: 90 | return {'text_ids': text_ids, 91 | 'oovs': oovs, 92 | 'len_oovs': len(oovs)} 93 | else: 94 | # title 的首尾分别加入 BOS_TOKEN 和 EOS_TOKEN 95 | title_ids = [self.vocab[BOS_TOKEN]] + self.vocab.convert_title_to_ids(self.title[i].split(), oovs) + [self.vocab[EOS_TOKEN]] 96 | return {'text_ids': text_ids, 97 | 'oovs': oovs, 98 | 'len_oovs': len(oovs), 99 | 'title_ids': title_ids} 100 | 101 | def __len__(self): 102 | return len(self.text) 103 | 104 | 105 | def load_data(path): 106 | # 数据的加载 107 | with open(path, 'r') as f: 108 | lines = f.readlines() 109 | xs, ys = [], [] 110 | for line in lines: 111 | x, y = line.split('\t') 112 | xs.append(x.strip()) 113 | ys.append(y.strip()) 114 | return xs, ys 115 | 116 | 117 | def replace_oovs(in_tensor, vocab): 118 | # 将文本张量中所有OOV单词的id, 全部替换成 UNK_TOKEN 对应的 id,以便模型可以直接处理 119 | oov_token = torch.full(in_tensor.shape, vocab.unk, dtype=torch.long).to(config.device) 120 | out_tensor = torch.where(in_tensor > len(vocab) - 1, oov_token, in_tensor) 121 | return out_tensor 122 | 123 | 124 | -------------------------------------------------------------------------------- /utils/multi_proc_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from multiprocessing import cpu_count, Pool 4 | 5 | cores = cpu_count() 6 | 7 | 8 | def parallelize(df, func): 9 | data_split = np.array_split(df, cores) 10 | pool = Pool(cores) 11 | data = pd.concat(pool.map(func, data_split)) 12 | pool.close() 13 | pool.join() 14 | return data 15 | 16 | 17 | if __name__ == '__main__': 18 | print(cores) 19 | -------------------------------------------------------------------------------- /utils/params_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_params(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--batch_size", default=16, help='Batch size', type=int) 7 | parser.add_argument("--emb_size", default=128, help="Embedding size", type=int) 8 | parser.add_argument("--hidden_size", default=256, help="Hidden size", type=int) 9 | parser.add_argument("--epochs", default=10, help="Training epochs", type=int) 10 | parser.add_argument("--lr", default=1e-3, help="Learning rate", type=float) 11 | parser.add_argument("--num_layers", default=2, help="Num layers of encoder GRU module", type=int) 12 | parser.add_argument("--dropout", default=0.5, help="Dropout ratio", type=float) 13 | 14 | args = parser.parse_args() 15 | params = vars(args) 16 | return params 17 | 18 | 19 | if __name__ == '__main__': 20 | params = get_params() 21 | print(params) -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import partial 3 | from multiprocessing import Pool, cpu_count 4 | from tqdm import tqdm 5 | import random 6 | import jieba 7 | import config 8 | 9 | 10 | def clean_weibo_title(title: str): 11 | """ 12 | 对微博数据中的标题内容(待生成)进行清洗 13 | Args: 14 | title: 标题 15 | Returns: 16 | """ 17 | # 去除##符号(一般为微博数据的话题标记) 18 | title = re.sub(r"#", "", title) 19 | # 去除[]中间的文字(一般为微博数据中的表情) 20 | title = re.sub(r"(\[{1,2})(.*?)(\]{1,2})", "", title) 21 | # 合并标题中过多的空格 22 | title = re.sub(r"\s+", " ", title) 23 | return title 24 | 25 | 26 | def clean_weibo_content(content: str): 27 | """ 28 | 对微博数据中的文本内容进行清洗 29 | Args: 30 | content: 文本 31 | Returns: 32 | """ 33 | # 去除网址 34 | content = re.sub(r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b", "", content) 35 | # 合并正文中过多的空格 36 | content = re.sub(r"\s+", " ", content) 37 | # 去除\u200b字符 38 | content = content.replace("\u200b", "") 39 | return content 40 | 41 | 42 | def clean_data(sample): 43 | """ 44 | 整体清洗函数,为了方便多线程使用 45 | Args: 46 | sample: 一个元组,包含正文内容和标题内容 47 | Returns: 48 | """ 49 | (content, title) = sample 50 | sample = dict() 51 | # 清洗数据 52 | sample["title"] = clean_weibo_title(title.strip()) 53 | sample["content"] = clean_weibo_content(content.strip()) 54 | sample["content"] = ' '.join(jieba.cut(sample['content'], cut_all=False)) 55 | sample["title"] = ' '.join(jieba.cut(sample['title'], cut_all=False)) 56 | return sample 57 | 58 | 59 | def build_news_data(content_path, title_path, train_save_path, test_save_path): 60 | """ 61 | 对微博数据进行清洗,构建训练集和测试集 62 | Args: 63 | content_path: 正文内容文件路径 64 | title_path: 标题内容文件路径 65 | train_save_path: 训练集文件路径 66 | test_save_path: 测试集文件路径 67 | Returns: 68 | """ 69 | # 打开文件,并将其zip成一个文件 70 | content_data = open(content_path, "r", encoding="utf-8") 71 | title_data = open(title_path, "r", encoding="utf-8") 72 | data = zip(content_data.readlines(), title_data.readlines()) 73 | # 使用多进程处理数据 74 | threads = min(8, cpu_count()) 75 | with Pool(threads) as p: 76 | annoate_ = partial(clean_data) 77 | data = list(tqdm(p.imap(annoate_, data, chunksize=8), 78 | desc="build data" 79 | ) 80 | ) 81 | # 对数据进行过滤,去除重复数据、正文内容字长小于100的数据和标题内容字长小于100的数据 82 | data_set = set() 83 | data_new = [] 84 | for d in data: 85 | if d["content"] in data_set or len(d["content"]) < 100 or len(d["title"]) < 2: 86 | continue 87 | else: 88 | data_set.add(d["content"]) 89 | 90 | data_new.append(d['content'] + '\t' + d['title']) 91 | # 分割数据,构建训练集和测试集 92 | random.shuffle(data_new) 93 | train_data = data_new[:-3000] 94 | test_data = data_new[-3000:] 95 | print('writing train data to file ...') 96 | fin = open(train_save_path, "w", encoding="utf-8") 97 | fin.write('\n'.join(train_data)) 98 | fin.close() 99 | print('writing test data to file ...') 100 | fin = open(test_save_path, "w", encoding="utf-8") 101 | fin.write('\n'.join(test_data)) 102 | fin.close() 103 | 104 | 105 | if __name__ == '__main__': 106 | build_news_data(content_path=config.content_path, 107 | title_path=config.title_path, 108 | train_save_path=config.train_save_path, 109 | test_save_path=config.test_save_path) 110 | --------------------------------------------------------------------------------