├── CPM ├── config.py ├── cpm_pretrain │ ├── chinese_vocab.model │ ├── chinese_vocab.vocab │ ├── cpm-medium.json │ ├── cpm-small.json │ └── vocab.json ├── data_helper.py ├── inference.py └── run_train.py ├── GPT2 ├── config.py ├── data_helper.py ├── inference.py └── run_train.py ├── MT5 ├── config.py ├── data_helper.py ├── inference.py ├── run_train.py └── tokenizer.py ├── README.md ├── UniLM ├── bert_model.py ├── config.py ├── data_helper.py ├── inference.py ├── model.py ├── run_train.py └── tokenizer.py ├── seq2seq_rnn ├── config.py ├── data_helper.py ├── inference.py ├── model.py ├── run_train.py └── start.sh ├── seq2seq_rnn_pointer_network ├── config.py ├── data_helper.py ├── inference.py ├── model.py ├── run_train.py └── start.sh ├── seq2seq_transformer ├── config.py ├── data_helper.py ├── inference.py ├── model.py ├── run_train.py └── start.sh └── seq2seq_transformer_pointer_network ├── config.py ├── data_helper.py ├── inference.py ├── model.py ├── run_train.py └── start.sh /CPM/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--seed', type=int, default=1234, help='设置随机种子') 13 | parser.add_argument('--train_data_path', default='../data/article.json', type=str, required=False, help='经过预处理之后的数据存放路径') 14 | parser.add_argument('--train_data_path_processed', default='./data/train.json', type=str, required=False, help='经过预处理之后的数据存放路径') 15 | parser.add_argument('--cpm_model_config', default='./cpm_pretrain/cpm-small.json', type=str, help='模型配置') 16 | parser.add_argument('--cpm_model_vocab', default='./cpm_pretrain/chinese_vocab.model', type=str, help='cpm模型的词表') 17 | 18 | parser.add_argument('--max_len', default=200, type=int, required=False, help='训练数据最大长度') 19 | parser.add_argument('--train_batch_size', default=16, type=int, required=False, help='训练的batch size') 20 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, required=False, help='梯度积累的步数') 21 | parser.add_argument('--epochs', default=50, type=int, required=False, help='训练的最大轮次') 22 | parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') 23 | parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='AdamW优化器的衰减率') 24 | parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数') 25 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 26 | parser.add_argument('--output_dir', default='./output', type=str, required=False, help='输出路径') 27 | args = parser.parse_args() 28 | return args 29 | -------------------------------------------------------------------------------- /CPM/cpm_pretrain/chinese_vocab.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawroad/Text-Generation-Chinese-Pytorch/d776d75058fcbbd4a847934480cf49194f80298d/CPM/cpm_pretrain/chinese_vocab.model -------------------------------------------------------------------------------- /CPM/cpm_pretrain/cpm-medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 1, 8 | "embd_pdrop": 0.1, 9 | "eos_token_id": 2, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-05, 12 | "model_type": "gpt2", 13 | "n_ctx": 1024, 14 | "n_embd": 1024, 15 | "n_head": 16, 16 | "n_layer": 24, 17 | "n_positions": 1024, 18 | "n_special": 0, 19 | "predict_special_tokens": true, 20 | "resid_pdrop": 0.1, 21 | "summary_activation": null, 22 | "summary_first_dropout": 0.1, 23 | "summary_proj_to_labels": true, 24 | "summary_type": "cls_index", 25 | "summary_use_proj": true, 26 | "task_specific_params": { 27 | "text-generation": { 28 | "do_sample": true, 29 | "max_length": 50 30 | } 31 | }, 32 | "vocab_size": 30000 33 | } -------------------------------------------------------------------------------- /CPM/cpm_pretrain/cpm-small.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 50256, 8 | "embd_pdrop": 0.1, 9 | "eos_token_id": 50256, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-05, 12 | "model_type": "gpt2", 13 | "n_ctx": 1024, 14 | "n_embd": 768, 15 | "n_head": 12, 16 | "n_layer": 12, 17 | "n_positions": 1024, 18 | "resid_pdrop": 0.1, 19 | "summary_activation": null, 20 | "summary_first_dropout": 0.1, 21 | "summary_proj_to_labels": true, 22 | "summary_type": "cls_index", 23 | "summary_use_proj": true, 24 | "task_specific_params": { 25 | "text-generation": { 26 | "do_sample": true, 27 | "max_length": 50 28 | } 29 | }, 30 | "vocab_size": 30000 31 | } -------------------------------------------------------------------------------- /CPM/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | from torch.utils.data import Dataset 8 | import torch 9 | import json 10 | from tqdm import tqdm 11 | 12 | 13 | def load_data(path, tokenizer, save_path): 14 | # 加载数据 并转为对应的id序列 15 | bos_id = tokenizer.bos_token_id # 开始 16 | eos_id = tokenizer.eos_token_id # 结束 17 | sep_id = tokenizer.sep_token_id # 分割 18 | 19 | win_size = 200 # 窗口大小 20 | step_length = 128 # 步长 21 | train_list = [] 22 | with open(path, 'r', encoding='utf8') as f: 23 | lines = f.readlines() 24 | for line in tqdm(lines): 25 | line = json.loads(line) 26 | title = line['title'] 27 | article = line['article'] 28 | 29 | title_ids = tokenizer.encode(title, add_special_tokens=False) 30 | article_ids = tokenizer.encode(article, add_special_tokens=False) 31 | token_ids = [bos_id] + title_ids + [sep_id] + article_ids + [eos_id] 32 | 33 | # 如果数据过长 滑动窗口处理 34 | start_index = 0 35 | end_index = win_size 36 | data = token_ids[start_index:end_index] 37 | train_list.append(data) 38 | 39 | start_index += step_length 40 | end_index += step_length 41 | while end_index + 50 < len(token_ids): # 剩下的数据长度,大于或等于50,才加入训练数据集 42 | data = token_ids[start_index:end_index] 43 | train_list.append(data) 44 | start_index += step_length 45 | end_index += step_length 46 | json.dump(train_list, open(save_path, 'w', encoding='utf8')) 47 | return train_list 48 | 49 | 50 | class CPMDataset(Dataset): 51 | def __init__(self, input_list, max_len): 52 | self.input_list = input_list 53 | self.max_len = max_len 54 | 55 | def __getitem__(self, index): 56 | input_ids = self.input_list[index] 57 | input_ids = input_ids[:self.max_len] 58 | input_ids = torch.tensor(input_ids, dtype=torch.long) 59 | return input_ids 60 | 61 | def __len__(self): 62 | return len(self.input_list) 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /CPM/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import torch 8 | import os 9 | from config import set_args 10 | from transformers.models.gpt2 import GPT2LMHeadModel 11 | import warnings 12 | from transformers import CpmTokenizer 13 | warnings.filterwarnings("ignore") 14 | 15 | 16 | def greedy_decode(input_ids): 17 | max_length = len(input_ids) + 200 18 | input_ids = torch.tensor(input_ids, dtype=torch.long) 19 | if torch.cuda.is_available(): 20 | input_ids = input_ids.cuda() 21 | output_greedy = model.generate(input_ids, max_length=max_length, do_sample=False, eos_token_id=tokenizer.eos_token_id) 22 | res = tokenizer.decode(output_greedy[0]) 23 | vocab = res.split('[SEP]')[1:-1] 24 | vocab = [v.replace(' ', '') for v in vocab] 25 | return vocab 26 | 27 | 28 | def beamsearch_decode(input_ids): 29 | max_length = len(input_ids) + 200 30 | input_ids = torch.tensor(input_ids, dtype=torch.long) 31 | if torch.cuda.is_available(): 32 | input_ids = input_ids.cuda() 33 | output_beam = model.generate(input_ids, max_length=max_length, num_beams=3, do_sample=False, 34 | no_repeat_ngram_size=2, eos_token_id=tokenizer.eos_token_id) 35 | res = tokenizer.decode(output_beam[0]) 36 | vocab = res.split('[SEP]')[1:-1] 37 | vocab = [v.replace(' ', '') for v in vocab] 38 | return vocab 39 | 40 | 41 | def greedy_sample_decode(input_ids): 42 | max_length = len(input_ids) + 200 43 | repetition_penalty = 1.1 44 | temperature = 1 45 | topk = 5 46 | topp = 0.95 47 | input_ids = torch.tensor(input_ids, dtype=torch.long) 48 | if torch.cuda.is_available(): 49 | input_ids = input_ids.cuda() 50 | 51 | output_greedy_random = model.generate(input_ids, max_length=max_length, do_sample=True, 52 | temperature=temperature, top_k=topk, top_p=topp, 53 | repetition_penalty=repetition_penalty, 54 | eos_token_id=tokenizer.eos_token_id) 55 | res = tokenizer.decode(output_greedy_random[0]) 56 | vocab = res.split('[SEP]')[1:-1] 57 | vocab = [v.replace(' ', '') for v in vocab] 58 | return vocab 59 | 60 | 61 | def beamsearch_sample_decode(input_ids): 62 | max_length = len(input_ids) + 200 63 | repetition_penalty = 1.1 64 | num_beams = 3 65 | temperature = 1 66 | topk = 5 67 | topp = 0.95 68 | input_ids = torch.tensor(input_ids, dtype=torch.long) 69 | if torch.cuda.is_available(): 70 | input_ids = input_ids.cuda() 71 | 72 | output_beamsearch_random = model.generate(input_ids, max_length=max_length, do_sample=True, 73 | num_beams=num_beams, temperature=temperature, top_k=topk, 74 | top_p=topp, repetition_penalty=repetition_penalty, 75 | eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id) 76 | res = tokenizer.decode(output_beamsearch_random[0]) 77 | article = res.split('[SEP]')[-1] 78 | article = [v.replace(' ', '') for v in article] 79 | return article 80 | 81 | 82 | if __name__ == '__main__': 83 | args = set_args() 84 | tokenizer = CpmTokenizer(vocab_file="cpm_pretrain/chinese_vocab.model") 85 | 86 | # 加载模型 87 | model_path = os.path.join(args.output_dir, 'model_epoch_{}'.format(9)) 88 | model = GPT2LMHeadModel.from_pretrained(model_path) 89 | 90 | if torch.cuda.is_available(): 91 | model.cuda() 92 | # model.half() 93 | model.eval() 94 | 95 | title = '家乡的四季' 96 | prefix = '在我的家乡,春天是万物复苏的季节。' 97 | bos_id = tokenizer.bos_token_id # 开始 98 | eos_id = tokenizer.eos_token_id # 结束 99 | sep_id = tokenizer.sep_token_id # 分割 100 | 101 | title_ids = tokenizer.encode(title, add_special_tokens=False) 102 | article_ids = tokenizer.encode(prefix, add_special_tokens=False) 103 | input_ids = [bos_id] + title_ids + [sep_id] + article_ids 104 | 105 | gen_article_greedy = greedy_decode(input_ids) 106 | gen_article_beamsearch = beamsearch_decode(input_ids) 107 | gen_article_greedy_sample = greedy_sample_decode(input_ids) 108 | gen_article_beamsearch_sample = beamsearch_sample_decode(input_ids) 109 | print(gen_article_greedy) 110 | print(gen_article_beamsearch) 111 | print(gen_article_greedy_sample) 112 | print(gen_article_beamsearch_sample) 113 | -------------------------------------------------------------------------------- /CPM/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import os 8 | import torch 9 | import random 10 | import numpy as np 11 | import json 12 | import torch.nn.functional as F 13 | from config import set_args 14 | from torch.optim import AdamW 15 | from datetime import datetime 16 | from torch.nn.utils.rnn import pad_sequence 17 | from data_helper import CPMDataset, load_data 18 | from torch.utils.data import DataLoader 19 | from transformers import GPT2LMHeadModel, GPT2Config, CpmTokenizer 20 | from transformers import get_linear_schedule_with_warmup 21 | 22 | 23 | def caculate_loss(logit, target, pad_idx, smoothing=True): 24 | if smoothing: 25 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2)) 26 | target = target[..., 1:].contiguous().view(-1) 27 | 28 | eps = 0.1 29 | n_class = logit.size(-1) 30 | 31 | one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1) 32 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 33 | log_prb = F.log_softmax(logit, dim=1) 34 | 35 | non_pad_mask = target.ne(pad_idx) 36 | loss = -(one_hot * log_prb).sum(dim=1) 37 | loss = loss.masked_select(non_pad_mask).mean() # average later 38 | else: 39 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 40 | labels = target[..., 1:].contiguous().view(-1) 41 | loss = F.cross_entropy(logit, labels, ignore_index=pad_idx) 42 | return loss 43 | 44 | 45 | def calculate_acc(logit, labels, ignore_index): 46 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 47 | labels = labels[..., 1:].contiguous().view(-1) 48 | 49 | _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index 50 | # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 51 | non_pad_mask = labels.ne(ignore_index) 52 | n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() 53 | n_word = non_pad_mask.sum().item() 54 | acc = n_correct / n_word 55 | return acc 56 | 57 | 58 | def set_seed(): 59 | random.seed(args.seed) 60 | np.random.seed(args.seed) 61 | torch.manual_seed(args.seed) 62 | if torch.cuda.is_available(): 63 | torch.cuda.manual_seed_all(args.seed) 64 | 65 | 66 | def get_model(pretrain_model_path=None): 67 | # 有预训练模型 则加载 否则从头开始训练 68 | if pretrain_model_path is not None: 69 | model = GPT2LMHeadModel.from_pretrained(pretrain_model_path) # 加载预训练模型 70 | else: 71 | model_config = GPT2Config.from_json_file(args.cpm_model_config) 72 | model = GPT2LMHeadModel(config=model_config) 73 | if torch.cuda.is_available(): 74 | model.cuda() 75 | return model 76 | 77 | 78 | def collate_fn(batch): 79 | input_ids = pad_sequence(batch, batch_first=True, padding_value=tokenizer.pad_token_id) 80 | return input_ids 81 | 82 | 83 | if __name__ == '__main__': 84 | # 初始化参数 85 | args = set_args() 86 | # os.makedirs(args.output_dir, exist_ok=True) 87 | set_seed() 88 | # 初始化tokenizer 89 | tokenizer = CpmTokenizer(vocab_file=args.cpm_model_vocab) 90 | model = get_model() # 这里其实还是用gpt2 只是分词采用的是CPM 91 | assert model.config.vocab_size == tokenizer.vocab_size 92 | 93 | # 计算模型参数数量 94 | num_parameters = 0 95 | parameters = model.parameters() 96 | for parameter in parameters: 97 | num_parameters += parameter.numel() 98 | print('total parameters nums:', num_parameters) 99 | 100 | # 加载训练集 如果有预处理数据 直接加载 否则 直接预处理 101 | try: 102 | train_data_list = json.load(open(args.train_data_path_processed, 'r', encoding='utf8')) 103 | except: 104 | train_data_list = load_data(args.train_data_path, tokenizer, args.train_data_path_processed) 105 | 106 | train_dataset = CPMDataset(train_data_list, args.max_len) 107 | 108 | train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn) 109 | 110 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs 111 | optimizer = AdamW(model.parameters(), lr=args.lr, eps=args.eps) 112 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 113 | 114 | for epoch in range(args.epochs): 115 | model.train() 116 | epoch_loss, epoch_acc = 0, 0 117 | epoch_start_time = datetime.now() 118 | for step, input_ids in enumerate(train_dataloader): 119 | if torch.cuda.is_available(): 120 | input_ids = input_ids.cuda() 121 | outputs = model(input_ids=input_ids) 122 | logits = outputs.logits 123 | loss = caculate_loss(logits, input_ids, tokenizer.pad_token_id, smoothing=True) 124 | accuracy = calculate_acc(logits, input_ids, ignore_index=tokenizer.pad_token_id) 125 | 126 | if args.gradient_accumulation_steps > 1: 127 | loss = loss / args.gradient_accumulation_steps 128 | accuracy = accuracy / args.gradient_accumulation_steps 129 | loss.backward() 130 | print('epoch:{}, step:{}, loss:{:10f}, accuracy:{:10f}, lr:{:10f}'.format(epoch, step, loss, accuracy, scheduler.get_last_lr()[0])) 131 | 132 | # 梯度裁剪 一般梯度裁剪尽量别用 会影响效果 133 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 134 | epoch_loss += loss.item() 135 | epoch_acc += accuracy 136 | 137 | if (step + 1) % args.gradient_accumulation_steps == 0: 138 | optimizer.step() 139 | optimizer.zero_grad() 140 | scheduler.step() # 进行warm_up 141 | 142 | avg_loss = epoch_loss / len(train_dataloader) 143 | avg_acc = epoch_acc / len(train_dataloader) 144 | 145 | ss = 'epoch:{}, loss:{:10f}, accuracy:{:10f}'.format(epoch, avg_loss, avg_acc) 146 | loss_path = os.path.join(args.output_dir, 'logs.txt') 147 | with open(loss_path, 'a+', encoding='utf8') as f: 148 | f.write(ss + '\n') 149 | 150 | # 一个epoch跑完保存一下模型 151 | model_save_path = os.path.join(args.output_dir, 'model_epoch_{}'.format(epoch + 1)) 152 | if not os.path.exists(model_save_path): 153 | os.mkdir(model_save_path) 154 | model_to_save = model.module if hasattr(model, 'module') else model 155 | model_to_save.save_pretrained(model_save_path) 156 | tokenizer.save_pretrained(model_save_path) 157 | epoch_finish_time = datetime.now() 158 | print('per epoch cost time: {}'.format(epoch_finish_time - epoch_start_time)) 159 | -------------------------------------------------------------------------------- /GPT2/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser('--小作文生成') 12 | parser.add_argument('--pretrain_model_path', default='./gpt2_pretrain', type=str, help='预训练模型') 13 | parser.add_argument('--pretrain_model_config_path', default='./gpt2_pretrain/config.json', type=str, help='预训练模型') 14 | parser.add_argument('--train_data_path', default='../data/article.json', type=str, help='训练数据') 15 | parser.add_argument('--train_data_path_processed', default='./data/train.json', type=str, help='训练数据') 16 | parser.add_argument('--output_dir', default='./output', type=str, required=False, help='多少步汇报一次loss') 17 | 18 | parser.add_argument('--max_len', default=200, type=int, required=False, help='输入的最大长度') 19 | parser.add_argument('--batch_size', default=16, type=int, required=False, help='训练batch size') 20 | parser.add_argument('--epochs', default=50, type=int, required=False, help='训练的轮次') 21 | parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') 22 | parser.add_argument('--learning_rate', default=1.5e-4, type=float, required=False, help='学习率') 23 | parser.add_argument('--warmup_steps_rate', default=0.05, type=float, required=False, help='warm up步数占总步数的比例') 24 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 25 | return parser.parse_args() 26 | -------------------------------------------------------------------------------- /GPT2/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import re 8 | import torch 9 | import json 10 | import pandas as pd 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | 15 | def load_data(path, tokenizer, save_path): 16 | # 起始和结束和分割特殊字符 17 | bos_id = tokenizer.bos_token_id # 开始 18 | sep_id = tokenizer.sep_token_id # 标题和文章分割符 19 | eos_id = tokenizer.eos_token_id # 结束 20 | 21 | win_size = 200 # 窗口大小 22 | step_length = 128 # 步长 23 | train_list = [] 24 | with open(path, 'r', encoding='utf8') as f: 25 | lines = f.readlines() 26 | for line in tqdm(lines): 27 | line = json.loads(line) 28 | title = line['title'] 29 | article = line['article'] 30 | 31 | title_ids = tokenizer.encode(title, add_special_tokens=False) 32 | article_ids = tokenizer.encode(article, add_special_tokens=False) 33 | token_ids = [bos_id] + title_ids + [sep_id] + article_ids + [eos_id] 34 | 35 | # 如果数据过长 滑动窗口处理 36 | start_index = 0 37 | end_index = win_size 38 | data = token_ids[start_index:end_index] 39 | train_list.append(data) 40 | 41 | start_index += step_length 42 | end_index += step_length 43 | while end_index + 50 < len(token_ids): # 剩下的数据长度,大于或等于50,才加入训练数据集 44 | data = token_ids[start_index:end_index] 45 | train_list.append(data) 46 | start_index += step_length 47 | end_index += step_length 48 | json.dump(train_list, open(save_path, 'w', encoding='utf8')) 49 | return train_list 50 | 51 | 52 | class GPT2Dataset(Dataset): 53 | def __init__(self, input_list, max_len): 54 | self.input_list = input_list 55 | self.max_len = max_len 56 | 57 | def __getitem__(self, index): 58 | input_ids = self.input_list[index] 59 | input_ids = input_ids[:self.max_len] 60 | input_ids = torch.tensor(input_ids, dtype=torch.long) 61 | return input_ids 62 | 63 | def __len__(self): 64 | return len(self.input_list) 65 | -------------------------------------------------------------------------------- /GPT2/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import pandas as pd 8 | import torch 9 | import os 10 | import re 11 | import json 12 | import requests 13 | from config import set_args 14 | from transformers.models.gpt2 import GPT2LMHeadModel 15 | from transformers import BertTokenizer 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def greedy_decode(input_ids): 21 | max_length = len(input_ids) + 200 22 | input_ids = torch.tensor(input_ids, dtype=torch.long) 23 | if torch.cuda.is_available(): 24 | input_ids = input_ids.cuda() 25 | output_greedy = model.generate(input_ids, max_length=max_length, do_sample=False, eos_token_id=tokenizer.eos_token_id) 26 | res = tokenizer.decode(output_greedy[0]) 27 | vocab = res.split('[SEP]')[1:-1] 28 | vocab = [v.replace(' ', '') for v in vocab] 29 | return vocab 30 | 31 | 32 | def beamsearch_decode(input_ids): 33 | max_length = len(input_ids) + 200 34 | input_ids = torch.tensor(input_ids, dtype=torch.long) 35 | if torch.cuda.is_available(): 36 | input_ids = input_ids.cuda() 37 | output_beam = model.generate(input_ids, max_length=max_length, num_beams=3, do_sample=False, 38 | no_repeat_ngram_size=2, eos_token_id=tokenizer.eos_token_id) 39 | res = tokenizer.decode(output_beam[0]) 40 | vocab = res.split('[SEP]')[1:-1] 41 | vocab = [v.replace(' ', '') for v in vocab] 42 | return vocab 43 | 44 | 45 | def greedy_sample_decode(input_ids): 46 | max_length = len(input_ids) + 200 47 | repetition_penalty = 1.1 48 | temperature = 1 49 | topk = 5 50 | topp = 0.95 51 | input_ids = torch.tensor(input_ids, dtype=torch.long) 52 | if torch.cuda.is_available(): 53 | input_ids = input_ids.cuda() 54 | 55 | output_greedy_random = model.generate(input_ids, max_length=max_length, do_sample=True, 56 | temperature=temperature, top_k=topk, top_p=topp, 57 | repetition_penalty=repetition_penalty, 58 | eos_token_id=tokenizer.eos_token_id) 59 | res = tokenizer.decode(output_greedy_random[0]) 60 | vocab = res.split('[SEP]')[1:-1] 61 | vocab = [v.replace(' ', '') for v in vocab] 62 | return vocab 63 | 64 | 65 | def beamsearch_sample_decode(input_ids): 66 | max_length = len(input_ids) + 200 67 | repetition_penalty = 1.1 68 | num_beams = 3 69 | temperature = 1 70 | topk = 5 71 | topp = 0.95 72 | input_ids = torch.tensor(input_ids, dtype=torch.long) 73 | if torch.cuda.is_available(): 74 | input_ids = input_ids.cuda() 75 | 76 | output_beamsearch_random = model.generate(input_ids, max_length=max_length, do_sample=True, 77 | num_beams=num_beams, temperature=temperature, top_k=topk, 78 | top_p=topp, repetition_penalty=repetition_penalty, 79 | eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id) 80 | res = tokenizer.decode(output_beamsearch_random[0]) 81 | article = res.split('[SEP]')[-1] 82 | article = [v.replace(' ', '') for v in article] 83 | return article 84 | 85 | 86 | if __name__ == '__main__': 87 | args = set_args() 88 | tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_path) 89 | num_added_toks = tokenizer.add_special_tokens({'bos_token': '[BOS]', 'eos_token': '[EOS]'}) 90 | 91 | # 加载模型 92 | model_path = os.path.join(args.output_dir, 'model_epoch_{}'.format(9)) 93 | model = GPT2LMHeadModel.from_pretrained(model_path) 94 | 95 | if torch.cuda.is_available(): 96 | model.cuda() 97 | # model.half() 98 | model.eval() 99 | 100 | title = '家乡的四季' 101 | prefix = '在我的家乡,春天是万物复苏的季节。' 102 | bos_id = tokenizer.bos_token_id # 开始 103 | sep_id = tokenizer.sep_token_id # 标题和文章分割符 104 | eos_id = tokenizer.eos_token_id # 结束 105 | title_ids = tokenizer.encode(title, add_special_tokens=False) 106 | article_ids = tokenizer.encode(prefix, add_special_tokens=False) 107 | input_ids = [bos_id] + title_ids + [sep_id] + article_ids 108 | 109 | gen_article_greedy = greedy_decode(input_ids) 110 | gen_article_beamsearch = beamsearch_decode(input_ids) 111 | gen_article_greedy_sample = greedy_sample_decode(input_ids) 112 | gen_article_beamsearch_sample = beamsearch_sample_decode(input_ids) 113 | print(gen_article_greedy) 114 | print(gen_article_beamsearch) 115 | print(gen_article_greedy_sample) 116 | print(gen_article_beamsearch_sample) 117 | -------------------------------------------------------------------------------- /GPT2/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-03 6 | """ 7 | import os 8 | import torch 9 | import random 10 | import numpy as np 11 | import json 12 | import torch.nn.functional as F 13 | from datetime import datetime 14 | from config import set_args 15 | from torch.optim import AdamW 16 | from torch.utils.data import DataLoader 17 | from torch.nn.utils.rnn import pad_sequence 18 | from data_helper import GPT2Dataset, load_data 19 | from transformers.models.bert import BertTokenizer 20 | from transformers import get_linear_schedule_with_warmup 21 | from transformers.models.gpt2 import GPT2LMHeadModel, modeling_gpt2 22 | 23 | 24 | def set_seed(): 25 | random.seed(args.seed) 26 | np.random.seed(args.seed) 27 | torch.manual_seed(args.seed) 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(args.seed) 30 | 31 | 32 | def get_model(pretrain_model_path=None): 33 | # 有预训练模型 则加载 否则从头开始训练 34 | if pretrain_model_path is not None: 35 | model = GPT2LMHeadModel.from_pretrained(pretrain_model_path) # 加载预训练模型 36 | else: 37 | model_config = modeling_gpt2.GPT2Config.from_json_file(args.pretrain_model_config_path) 38 | model = GPT2LMHeadModel(config=model_config) 39 | if torch.cuda.is_available(): 40 | model.cuda() 41 | return model 42 | 43 | 44 | def collate_fn(batch): 45 | input_ids = pad_sequence(batch, batch_first=True, padding_value=tokenizer.pad_token_id) 46 | return input_ids 47 | 48 | 49 | def caculate_loss(logit, target, pad_idx, smoothing=True): 50 | if smoothing: 51 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2)) 52 | target = target[..., 1:].contiguous().view(-1) 53 | 54 | eps = 0.1 55 | n_class = logit.size(-1) 56 | 57 | one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1) 58 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 59 | log_prb = F.log_softmax(logit, dim=1) 60 | 61 | non_pad_mask = target.ne(pad_idx) 62 | loss = -(one_hot * log_prb).sum(dim=1) 63 | loss = loss.masked_select(non_pad_mask).mean() # average later 64 | else: 65 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 66 | labels = target[..., 1:].contiguous().view(-1) 67 | loss = F.cross_entropy(logit, labels, ignore_index=pad_idx) 68 | return loss 69 | 70 | 71 | def calculate_acc(logit, labels, ignore_index): 72 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 73 | labels = labels[..., 1:].contiguous().view(-1) 74 | _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index 75 | # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 76 | non_pad_mask = labels.ne(ignore_index) 77 | n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() 78 | n_word = non_pad_mask.sum().item() 79 | acc = n_correct / n_word 80 | return acc 81 | 82 | 83 | if __name__ == '__main__': 84 | args = set_args() 85 | os.makedirs(args.output_dir, exist_ok=True) 86 | tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_path) # 词表 87 | 88 | # 自定义加了一个起始和终止字符 89 | num_added_toks = tokenizer.add_special_tokens({'bos_token': '[BOS]', 'eos_token': '[EOS]'}) 90 | print("We have added", num_added_toks, "tokens") 91 | # 加载GPT2模型 92 | model = get_model(args.pretrain_model_path) 93 | model.resize_token_embeddings(len(tokenizer)) # 重置embed matrix 94 | 95 | num_parameters = 0 96 | parameters = model.parameters() 97 | for parameter in parameters: 98 | num_parameters += parameter.numel() 99 | print('total parameters nums:', num_parameters) 100 | # 加载训练集 如果有预处理数据 直接加载 否则 直接预处理 101 | try: 102 | train_data_list = json.load(open(args.train_data_path_processed, 'r', encoding='utf8')) 103 | except: 104 | train_data_list = load_data(args.train_data_path, tokenizer, args.train_data_path_processed) 105 | 106 | train_dataset = GPT2Dataset(train_data_list, args.max_len) 107 | 108 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) 109 | 110 | # 计算所有epoch进行参数优化的总步数total_steps 111 | total_steps = int(len(train_dataset) * args.epochs / args.batch_size / args.gradient_accumulation) 112 | print("total train step num: {}".format(total_steps)) 113 | 114 | # 设置优化器,并且在初始训练时,使用warmup策略 115 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 116 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10000, 117 | num_training_steps=total_steps) 118 | 119 | print("start training ...") 120 | model.train() 121 | for epoch in range(args.epochs): 122 | epoch_loss, epoch_acc = 0, 0 123 | epoch_start_time = datetime.now() 124 | for step, input_ids in enumerate(train_dataloader): 125 | if torch.cuda.is_available(): 126 | input_ids = input_ids.cuda() 127 | outputs = model(input_ids=input_ids) 128 | 129 | logits = outputs.logits 130 | loss = caculate_loss(logits, input_ids, tokenizer.pad_token_id, smoothing=True) 131 | accuracy = calculate_acc(logits, input_ids, ignore_index=tokenizer.pad_token_id) 132 | 133 | if args.gradient_accumulation > 1: 134 | loss = loss / args.gradient_accumulation 135 | accuracy = accuracy / args.gradient_accumulation 136 | loss.backward() 137 | print('epoch:{}, step:{}, loss:{:10f}, accuracy:{:10f}, lr:{:10f}'.format(epoch, step, loss, accuracy, scheduler.get_last_lr()[0])) 138 | 139 | # 梯度裁剪 一般梯度裁剪尽量别用 会影响效果 140 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 141 | epoch_loss += loss.item() 142 | epoch_acc += accuracy 143 | 144 | if (step + 1) % args.gradient_accumulation == 0: 145 | optimizer.step() 146 | optimizer.zero_grad() 147 | scheduler.step() # 进行warm_up 148 | 149 | avg_loss = epoch_loss / len(train_dataloader) 150 | avg_acc = epoch_acc / len(train_dataloader) 151 | 152 | ss = 'epoch:{}, loss:{:10f}, accuracy:{:10f}'.format(epoch, avg_loss, avg_acc) 153 | loss_path = os.path.join(args.output_dir, 'logs.txt') 154 | with open(loss_path, 'a+', encoding='utf8') as f: 155 | f.write(ss + '\n') 156 | 157 | # 一个epoch跑完保存一下模型 158 | model_save_path = os.path.join(args.output_dir, 'model_epoch_{}'.format(epoch + 1)) 159 | if not os.path.exists(model_save_path): 160 | os.mkdir(model_save_path) 161 | model_to_save = model.module if hasattr(model, 'module') else model 162 | model_to_save.save_pretrained(model_save_path) 163 | tokenizer.save_pretrained(model_save_path) 164 | epoch_finish_time = datetime.now() 165 | print('per epoch cost time: {}'.format(epoch_finish_time - epoch_start_time)) 166 | -------------------------------------------------------------------------------- /MT5/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-04 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser('--小作文生成') 12 | parser.add_argument('--batch_size', default=8, type=int, help='批次大小') 13 | parser.add_argument('--pretrain_model_path', default='./t5_pretrain', type=str, help='预训练模型') 14 | parser.add_argument('--data_path', default='../data/article.json', type=str, help='处理过的训练数据') 15 | parser.add_argument('--epochs', default=10, type=int, required=False, help='训练的轮次') 16 | parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') 17 | parser.add_argument('--learning_rate', default=2e-4, type=float, required=False, help='学习率') 18 | parser.add_argument('--output_dir', default='./output', type=str, required=False, help='多少步汇报一次loss') 19 | parser.add_argument('--warmup_steps_rate', default=0.005, type=float, required=False, help='warm up步数占总步数的比例') 20 | return parser.parse_args() 21 | -------------------------------------------------------------------------------- /MT5/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-04 6 | """ 7 | import torch 8 | import json 9 | import pandas as pd 10 | from torch.utils.data import Dataset 11 | 12 | 13 | def load_data(path): 14 | input_text, target_text = [], [] 15 | with open(path, 'r', encoding='utf8') as f: 16 | lines = f.readlines() 17 | for line in lines: 18 | line = line.strip() 19 | line = json.loads(line) 20 | input_text.append(line['title']) 21 | target_text.append(line['article'][:256]) # 人为截断前256个token 22 | df = pd.DataFrame({'input_text': input_text, 'target_text': target_text}) 23 | return df 24 | 25 | 26 | class MT5Dataset(Dataset): 27 | def __init__(self, data, tokenizer): 28 | self.input_text = data.input_text 29 | self.target_text = data.target_text 30 | self.tokenizer = tokenizer 31 | 32 | def __len__(self): 33 | return len(self.target_text) 34 | 35 | def __getitem__(self, index): 36 | input_ids = self.tokenizer.encode(self.input_text[index]) 37 | output_ids = self.tokenizer.encode(self.target_text[index]) 38 | return {'input_ids': input_ids, 'decoder_input_ids': output_ids, 39 | 'attention_mask': [1] * len(input_ids), 'decoder_attention_mask': [1] * len(output_ids)} 40 | 41 | 42 | def pad_to_maxlen(input_ids, max_len, pad_value=0): 43 | if len(input_ids) >= max_len: 44 | input_ids = input_ids[:max_len] 45 | else: 46 | input_ids = input_ids + [pad_value] * (max_len - len(input_ids)) 47 | return input_ids 48 | 49 | 50 | def collate_fn(batch): 51 | # 将输入和输出进行padding 52 | input_max_len = max([len(d['input_ids']) for d in batch]) 53 | output_max_len = max([len(d['decoder_input_ids']) for d in batch]) 54 | 55 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask = [], [], [], [] 56 | for item in batch: 57 | input_ids.append(pad_to_maxlen(item['input_ids'], max_len=input_max_len)) 58 | attention_mask.append(pad_to_maxlen(item['attention_mask'], max_len=input_max_len)) 59 | 60 | decoder_input_ids.append(pad_to_maxlen(item['decoder_input_ids'], max_len=output_max_len)) 61 | decoder_attention_mask.append(pad_to_maxlen(item['decoder_attention_mask'], max_len=output_max_len)) 62 | 63 | all_input_ids = torch.tensor(input_ids, dtype=torch.long) 64 | all_input_mask = torch.tensor(attention_mask, dtype=torch.long) 65 | all_decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) 66 | all_decoder_attention_mask = torch.tensor(decoder_attention_mask, dtype=torch.long) 67 | return all_input_ids, all_input_mask, all_decoder_input_ids, all_decoder_attention_mask 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /MT5/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-04 6 | """ 7 | from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration 8 | from tokenizer import T5PegasusTokenizer 9 | import pandas as pd 10 | import torch 11 | import os 12 | import re 13 | import json 14 | import requests 15 | from config import set_args 16 | from transformers.models.gpt2 import GPT2LMHeadModel 17 | from transformers import BertTokenizer 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | 22 | def greedy_decode(input_ids): 23 | max_length = 300 24 | input_ids = torch.tensor([input_ids], dtype=torch.long) 25 | if torch.cuda.is_available(): 26 | input_ids = input_ids.cuda() 27 | output_greedy = model.generate(input_ids, decoder_start_token_id=tokenizer.cls_token_id, max_length=max_length, do_sample=False, eos_token_id=tokenizer.sep_token_id) 28 | res = tokenizer.decode(output_greedy[0]) 29 | res = [v.replace(' ', '') for v in res] 30 | res = ''.join(res) 31 | res = res.replace('[CLS]', '').replace('[SEP]', '') 32 | 33 | return res 34 | 35 | 36 | def beamsearch_decode(input_ids): 37 | max_length = 300 38 | input_ids = torch.tensor([input_ids], dtype=torch.long) 39 | if torch.cuda.is_available(): 40 | input_ids = input_ids.cuda() 41 | 42 | output_beam = model.generate(input_ids, max_length=max_length, num_beams=3, 43 | do_sample=False, 44 | no_repeat_ngram_size=2, 45 | decoder_start_token_id=tokenizer.cls_token_id, 46 | eos_token_id=tokenizer.sep_token_id) 47 | res = tokenizer.decode(output_beam[0]) 48 | res = [v.replace(' ', '') for v in res] 49 | res = ''.join(res) 50 | res = res.replace('[CLS]', '').replace('[SEP]', '') 51 | return res 52 | 53 | 54 | def greedy_sample_decode(input_ids): 55 | max_length = 300 56 | repetition_penalty = 1.1 57 | temperature = 1 58 | topk = 5 59 | topp = 0.95 60 | input_ids = torch.tensor([input_ids], dtype=torch.long) 61 | if torch.cuda.is_available(): 62 | input_ids = input_ids.cuda() 63 | 64 | output_greedy_random = model.generate(input_ids, max_length=max_length, do_sample=True, 65 | temperature=temperature, top_k=topk, top_p=topp, 66 | repetition_penalty=repetition_penalty, 67 | decoder_start_token_id=tokenizer.cls_token_id, 68 | eos_token_id=tokenizer.sep_token_id) 69 | res = tokenizer.decode(output_greedy_random[0]) 70 | res = [v.replace(' ', '') for v in res] 71 | res = ''.join(res) 72 | res = res.replace('[CLS]', '').replace('[SEP]', '') 73 | return res 74 | 75 | 76 | def beamsearch_sample_decode(input_ids): 77 | max_length = 300 78 | repetition_penalty = 1.1 79 | num_beams = 3 80 | temperature = 1 81 | topk = 5 82 | topp = 0.95 83 | input_ids = torch.tensor([input_ids], dtype=torch.long) 84 | if torch.cuda.is_available(): 85 | input_ids = input_ids.cuda() 86 | 87 | output_beamsearch_random = model.generate(input_ids, max_length=max_length, do_sample=True, 88 | num_beams=num_beams, temperature=temperature, top_k=topk, 89 | top_p=topp, repetition_penalty=repetition_penalty, 90 | decoder_start_token_id=tokenizer.cls_token_id, 91 | eos_token_id=tokenizer.sep_token_id, 92 | pad_token_id=tokenizer.pad_token_id) 93 | res = tokenizer.decode(output_beamsearch_random[0]) 94 | res = [v.replace(' ', '') for v in res] 95 | res = ''.join(res) 96 | res = res.replace('[CLS]', '').replace('[SEP]', '') 97 | return res 98 | 99 | 100 | if __name__ == '__main__': 101 | args = set_args() 102 | checkpoint = '/usr/home/xiaolu10/xiaolu10/gpu_task/text_generation/MT5/checkpoint/model_epoch_10' 103 | tokenizer = T5PegasusTokenizer.from_pretrained(checkpoint) 104 | model = MT5ForConditionalGeneration.from_pretrained(checkpoint) 105 | 106 | if torch.cuda.is_available(): 107 | model.cuda() 108 | # model.half() 109 | model.eval() 110 | 111 | title = '我的祖国' 112 | title_ids = tokenizer.encode(title) 113 | 114 | gen_article_greedy = greedy_decode(title_ids) 115 | gen_article_beamsearch = beamsearch_decode(title_ids) 116 | gen_article_greedy_sample = greedy_sample_decode(title_ids) 117 | gen_article_beamsearch_sample = beamsearch_sample_decode(title_ids) 118 | print(gen_article_greedy) 119 | print(gen_article_beamsearch) 120 | print(gen_article_greedy_sample) 121 | print(gen_article_beamsearch_sample) 122 | -------------------------------------------------------------------------------- /MT5/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-04 6 | """ 7 | import os 8 | import torch.cuda 9 | from datetime import datetime 10 | from torch.optim import AdamW 11 | from config import set_args 12 | from torch.nn import CrossEntropyLoss 13 | from tokenizer import T5PegasusTokenizer 14 | from torch.utils.data import DataLoader 15 | from data_helper import MT5Dataset, load_data, collate_fn 16 | from transformers import get_linear_schedule_with_warmup 17 | from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration 18 | 19 | 20 | def calc_loss(logits, decoder_input_ids, decoder_attention_mask): 21 | # 计算损失 22 | decoder_mask = decoder_attention_mask[:, 1:].reshape(-1).bool() 23 | logits = logits[:, :-1] 24 | logits = logits.reshape((-1, logits.size(-1)))[decoder_mask] 25 | labels = decoder_input_ids[:, 1:].reshape(-1)[decoder_mask] 26 | loss = loss_fct(logits, labels) 27 | return loss 28 | 29 | 30 | def calculate_acc(logit, labels, ignore_index): 31 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 32 | labels = labels[..., 1:].contiguous().view(-1) 33 | _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index 34 | # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 35 | non_pad_mask = labels.ne(ignore_index) 36 | n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() 37 | n_word = non_pad_mask.sum().item() 38 | acc = n_correct / n_word 39 | return acc 40 | 41 | 42 | if __name__ == '__main__': 43 | args = set_args() 44 | tokenizer = T5PegasusTokenizer.from_pretrained(args.pretrain_model_path) 45 | train_df = load_data(args.data_path) 46 | train_dataset = MT5Dataset(train_df, tokenizer) 47 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn) 48 | 49 | model = MT5ForConditionalGeneration.from_pretrained(args.pretrain_model_path) 50 | 51 | num_parameters = 0 52 | parameters = model.parameters() 53 | for parameter in parameters: 54 | num_parameters += parameter.numel() 55 | print('total parameters nums:', num_parameters) 56 | 57 | if torch.cuda.is_available(): 58 | model.cuda() 59 | 60 | total_steps = int(len(train_dataset) * args.epochs / args.batch_size / args.gradient_accumulation) 61 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 62 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps_rate * total_steps, 63 | num_training_steps=total_steps) 64 | 65 | loss_fct = CrossEntropyLoss(ignore_index=0) 66 | 67 | for epoch in range(args.epochs): 68 | model.train() 69 | epoch_start_time = datetime.now() 70 | 71 | epoch_loss, epoch_acc = 0, 0 72 | for step, batch in enumerate(train_dataloader): 73 | if torch.cuda.is_available(): 74 | batch = (t.cuda() for t in batch) 75 | input_ids, input_mask, decoder_input_ids, decoder_attention_mask = batch 76 | output = model(input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask) 77 | logits = output.logits 78 | 79 | loss = calc_loss(logits, decoder_input_ids, decoder_attention_mask) 80 | accuracy = calculate_acc(logits, decoder_input_ids, ignore_index=tokenizer.pad_token_id) 81 | 82 | loss.backward() 83 | print('epoch:{}, step:{}, loss:{:10f}, accuracy:{:10f}, lr:{:10f}'.format(epoch, step, loss, accuracy, scheduler.get_last_lr()[0])) 84 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 85 | epoch_loss += loss.item() 86 | epoch_acc += accuracy 87 | 88 | if (step + 1) % args.gradient_accumulation == 0: 89 | optimizer.step() 90 | optimizer.zero_grad() 91 | scheduler.step() # 进行warm_up 92 | 93 | avg_loss = epoch_loss / len(train_dataloader) 94 | avg_acc = epoch_acc / len(train_dataloader) 95 | 96 | ss = 'epoch:{}, loss:{:10f}, accuracy:{:10f}'.format(epoch, avg_loss, avg_acc) 97 | loss_path = os.path.join(args.output_dir, 'loss.txt'.format(epoch + 1)) 98 | with open(loss_path, 'a+', encoding='utf8') as f: 99 | f.write(ss + '\n') 100 | 101 | # 一个epoch跑完保存一下模型 102 | model_save_path = os.path.join(args.output_dir, 'model_epoch_{}'.format(epoch + 1)) 103 | if not os.path.exists(model_save_path): 104 | os.mkdir(model_save_path) 105 | model_to_save = model.module if hasattr(model, 'module') else model 106 | model_to_save.save_pretrained(model_save_path) 107 | tokenizer.save_pretrained(model_save_path) 108 | epoch_finish_time = datetime.now() 109 | print('per epoch cost time: {}'.format(epoch_finish_time - epoch_start_time)) 110 | -------------------------------------------------------------------------------- /MT5/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : tokenizer.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-04 6 | """ 7 | import jieba 8 | from transformers import BertTokenizer 9 | 10 | 11 | class T5PegasusTokenizer(BertTokenizer): 12 | def __init__(self, pre_tokenizer=lambda x: jieba.cut(x, HMM=False), *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.pre_tokenizer = pre_tokenizer 15 | 16 | def _tokenize(self, text, *arg, **kwargs): 17 | split_tokens = [] 18 | for text in self.pre_tokenizer(text): 19 | if text in self.vocab: 20 | split_tokens.append(text) 21 | else: 22 | split_tokens.extend(super()._tokenize(text)) 23 | return split_tokens -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-Generation-Chinese-Pytorch 2 | 3 | ## 数据集描述 4 | 数据集下载路径: 链接: https://pan.baidu.com/s/1qBlgAvGHYEbjmyawcW1zcA?pwd=g7jf 提取码: g7jf 5 | 6 | 数据集由https://github.com/yangjianxin1/CPM 该开源项目提供。 总共包含269370条数据。每条数据包含两部分: 作文题目、作文内容。感谢提供。 7 | 格式: 8 | ``` 9 | {"id": 1, "title": "妈妈我想对你说", "article": "“世上只有妈妈好,有妈的孩子想块宝,投进妈妈的怀抱,幸福享不了……”我们经常哼唱“世上只有妈妈好”这首歌。是的,当我生病时,是谁在我的身边守候我?当我遇到困难时,是谁在我身边支持我?当我伤心难过时又是谁在我身边安慰我?那就是妈妈。每次都是妈妈在我身边给予我无尽的爱,就像海滩上的沙粒一样。我真想对我那亲爱的妈妈说一声:“感谢你,我亲爱的妈妈。”\n妈妈的爱,无论是问候、关心、还是训斥,都是那么亲切、那么温暖。那么您知道吗,女儿也有许多话想对您说……\n妈妈,您真的好累。我见到的总是您忙碌的身影,忙农活,忙家务。每次我想来帮助您,您总说:“凤宝,不用你来做,多去学习、去看书。”但是亲爱的妈妈,您歇歇吧。每当看到您头上又添了银发,看到,看到您额头上的皱纹,女儿的心总会很心疼,这是岁月的痕迹,更是操劳的印记。有时忙农活时,您会不小心地弄伤自己,浑身是伤,女儿看着很心酸。妈妈,您歇歇吧,不要这么累了,我长大了,可以帮助您了。我会在回家时帮您做家务,我会在您工作时送上杯热茶,我会把自己的学习搞好,不再让您操心,我会时刻体谅您的,妈妈……\n妈妈,谢谢您给了我生命,让我来到这个世界,感受着这世界的美好。\n妈妈,谢谢您含辛茹苦地培养我长大,并让我无忧无虑地成长;谢谢您为我做的每一口香甜的饭菜;谢谢您在我每一次生病时给我的呵护和关怀;谢谢您每次下雨时给我送伞;谢谢您在我成绩不理想或受到挫折时,总是给我一次次的鼓励。\n妈妈我想对您说:“你不要太辛苦,一定要注意身体。”\n妈妈我想对您说:“我懂事了,我不会再惹您生气了。”\n"} 10 | {"id": 2, "title": "徜徉在书籍的阳光世界", "article": "一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙;一本书是一个人的耳朵,它可以让你听到大自然的呼唤,听到社会的声音。\n《森林报》是苏联著名科普作家维。比安基的代表作品,他以春夏秋冬四季为序,有层次、有类别地将森林里动植物的新鲜事描写得栩栩如生,引人入胜。这本书教会我们如何去观察、认识大自然。这本书教会我们感悟生命,体验生动的愉快探速之旅,激发我们热爱科学的兴趣。\n《三字经》、《弟子规》、《论语》这样的国学经典,我也会拿来阅读,虽然似懂非懂,但读起来朗朗上口,觉得挺有趣。读着读着,好似开始了一场时空之旅,与古代圣贤结为知己,进行心与心之间的倾听与问候。这些书籍让我们在阅读中品味高雅。\n在成长的过程中,每个人都有着自己不一样的心路历程。阳光姐姐写的《成长的秘密》一书让我们获益不浅。作者用简单生动的文字,把温馨感人、新鲜快乐、爆笑的校园生活展现在我们眼前。书中的人物宁佳鑫看上去弱小,但她实际却很坚强,在她身上,我看到了她散发出的正能量和她在逆境中奋起的精神。她的经历告诉我:无论遇到什么样的挫折与坎坷,都不要气馁。阳光总在风雨后,只要我们坚持不懈地去想办法克服困难,并付诸行动,就一定会柳暗花明!\n法国作家德尔伦曾说过“智慧可以转化成力量,完成你认为不可能完成的事。”是啊,智慧的力量很强大,这些力量隐藏在书中。当我们在阅读之际,这些知识就偷偷地跑进我们的脑海里,渐渐地,渐渐地,它们就永远地保存下来,显示出无穷的魅力,让我们的未来畅通无阻。\n书籍,用爱和勇气唤醒每个孩子的心灵;书籍让我们感受到温暖与力量;书籍,教我们用心灵在文字间快乐舞蹈。\n让我们走进书籍的阳光世界,获取成长的力量。\n"} 11 | ``` 12 | 13 | ## GPT2 14 | 15 | 16 | ## CPM 17 | 18 | 19 | ## MT5 20 | 预训练模型下载: https://huggingface.co/imxly/t5-pegasus 21 | 22 | ## UniLM 23 | 24 | -------------------------------------------------------------------------------- /UniLM/bert_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : bert_model.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-02-21 6 | """ 7 | import math 8 | import torch 9 | from torch import nn 10 | # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 11 | 12 | 13 | def swish(x): 14 | return x * torch.sigmoid(x) 15 | 16 | 17 | def gelu(x): 18 | """ 19 | """ 20 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 21 | 22 | 23 | def mish(x): 24 | return x * torch.tanh(nn.functional.softplus(x)) 25 | 26 | 27 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish} 28 | 29 | 30 | class BertConfig(object): 31 | def __init__( 32 | self, 33 | vocab_size, 34 | hidden_size=768, 35 | num_hidden_layers=12, 36 | num_attention_heads=12, 37 | intermediate_size=3072, 38 | hidden_act="gelu", 39 | hidden_dropout_prob=0.1, 40 | attention_probs_dropout_prob=0.1, 41 | max_position_embeddings=512, 42 | type_vocab_size=2, 43 | initializer_range=0.02, 44 | layer_norm_eps=1e-12, 45 | ): 46 | self.vocab_size = vocab_size 47 | self.hidden_size = hidden_size 48 | self.num_hidden_layers = num_hidden_layers 49 | self.num_attention_heads = num_attention_heads 50 | self.hidden_act = hidden_act 51 | self.intermediate_size = intermediate_size 52 | self.hidden_dropout_prob = hidden_dropout_prob 53 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 54 | self.max_position_embeddings = max_position_embeddings 55 | self.type_vocab_size = type_vocab_size 56 | self.initializer_range = initializer_range 57 | self.layer_norm_eps = layer_norm_eps 58 | 59 | 60 | class BertEmbeddings(nn.Module): 61 | """ 62 | Construct the embeddings from word, position and token_type embeddings. 63 | """ 64 | 65 | def __init__(self, config): 66 | super().__init__() 67 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 68 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 69 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 70 | 71 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 72 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 73 | 74 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None): 75 | # if input_ids is not None: 76 | # input_shape = input_ids.size() 77 | # else: 78 | # input_shape = inputs_embeds.size()[:-1] 79 | input_shape = input_ids.size() 80 | 81 | seq_length = input_shape[1] 82 | device = input_ids.device 83 | if position_ids is None: 84 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device) 85 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 86 | if token_type_ids is None: 87 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 88 | 89 | inputs_embeds = self.word_embeddings(input_ids) 90 | position_embeddings = self.position_embeddings(position_ids) 91 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 92 | 93 | embeddings = inputs_embeds + position_embeddings + token_type_embeddings 94 | embeddings = self.LayerNorm(embeddings) 95 | embeddings = self.dropout(embeddings) 96 | return embeddings 97 | 98 | 99 | class BertSelfAttention(nn.Module): 100 | def __init__(self, config: BertConfig): 101 | super().__init__() 102 | if config.hidden_size % config.num_attention_heads != 0: 103 | raise ValueError( 104 | "The hidden size (%d) is not a multiple of the number of attention " 105 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 106 | ) 107 | 108 | self.num_attention_heads = config.num_attention_heads 109 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 110 | self.all_head_size = self.num_attention_heads * self.attention_head_size 111 | 112 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 113 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 114 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 115 | 116 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 117 | 118 | def transpose_for_scores(self, x): 119 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 120 | x = x.view(*new_x_shape) 121 | # 最后xshape (batch_size, num_attention_heads, seq_len, head_size) 122 | return x.permute(0, 2, 1, 3) 123 | 124 | def forward( 125 | self, 126 | hidden_states, 127 | attention_mask, 128 | output_attentions=False 129 | ): 130 | mixed_query_layer = self.query(hidden_states) 131 | mixed_key_layer = self.key(hidden_states) 132 | mixed_value_layer = self.value(hidden_states) 133 | 134 | query_layer = self.transpose_for_scores(mixed_query_layer) 135 | key_layer = self.transpose_for_scores(mixed_key_layer) 136 | value_layer = self.transpose_for_scores(mixed_value_layer) 137 | 138 | # Take the dot product between "query" and "key" to get the raw attention scores. 139 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 140 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 141 | 142 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 143 | attention_scores = attention_scores + attention_mask 144 | 145 | # Normalize the attention scores to probabilities. 146 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 147 | 148 | # This is actually dropping out entire tokens to attend to, which might 149 | # seem a bit unusual, but is taken from the original Transformer paper. 150 | attention_probs = self.dropout(attention_probs) 151 | 152 | # 注意力加权 153 | context_layer = torch.matmul(attention_probs, value_layer) 154 | # 把加权后的V reshape, 得到[batch_size, length, embedding_dimension] 155 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 156 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 157 | context_layer = context_layer.view(*new_context_layer_shape) 158 | 159 | # 得到输出 160 | if output_attentions: 161 | return context_layer, attention_probs 162 | return context_layer, None 163 | 164 | 165 | class BertSelfOutput(nn.Module): 166 | def __init__(self, config): 167 | super().__init__() 168 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 169 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 170 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 171 | 172 | def forward(self, hidden_states, input_tensor): 173 | hidden_states = self.dense(hidden_states) 174 | hidden_states = self.dropout(hidden_states) 175 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 176 | return hidden_states 177 | 178 | 179 | class BertAttention(nn.Module): 180 | def __init__(self, config): 181 | super().__init__() 182 | self.self = BertSelfAttention(config) 183 | self.output = BertSelfOutput(config) 184 | 185 | def forward( 186 | self, 187 | hidden_states, 188 | attention_mask, 189 | output_attentions=False 190 | ): 191 | self_outputs, attention_metrix = self.self(hidden_states, attention_mask, output_attentions=output_attentions) 192 | attention_output = self.output(self_outputs, hidden_states) 193 | 194 | return attention_output, attention_metrix 195 | 196 | 197 | class BertIntermediate(nn.Module): 198 | def __init__(self, config): 199 | super().__init__() 200 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 201 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 202 | 203 | def forward(self, hidden_states): 204 | hidden_states = self.dense(hidden_states) 205 | hidden_states = self.intermediate_act_fn(hidden_states) 206 | return hidden_states 207 | 208 | 209 | class BertOutput(nn.Module): 210 | def __init__(self, config): 211 | super().__init__() 212 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 213 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 214 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 215 | 216 | def forward(self, hidden_states, input_tensor): 217 | hidden_states = self.dense(hidden_states) 218 | hidden_states = self.dropout(hidden_states) 219 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 220 | return hidden_states 221 | 222 | 223 | class BertLayer(nn.Module): 224 | def __init__(self, config): 225 | super().__init__() 226 | self.attention = BertAttention(config) 227 | self.intermediate = BertIntermediate(config) 228 | self.output = BertOutput(config) 229 | 230 | def forward( 231 | self, 232 | hidden_states, 233 | attention_mask, 234 | output_attentions=False 235 | ): 236 | attention_output, attention_matrix = self.attention(hidden_states, attention_mask, 237 | output_attentions=output_attentions) 238 | intermediate_output = self.intermediate(attention_output) 239 | layer_output = self.output(intermediate_output, attention_output) 240 | return layer_output, attention_matrix 241 | 242 | 243 | class BertEncoder(nn.Module): 244 | def __init__(self, config): 245 | super().__init__() 246 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 247 | 248 | def forward( 249 | self, 250 | hidden_states, 251 | attention_mask, 252 | output_all_encoded_layers=True, 253 | output_attentions=False 254 | ): 255 | all_encoder_layers = [] 256 | all_attention_matrices = [] 257 | for i, layer_module in enumerate(self.layer): 258 | 259 | layer_output, attention_matrix = layer_module( 260 | hidden_states, attention_mask, output_attentions=output_attentions 261 | ) 262 | hidden_states = layer_output 263 | if output_all_encoded_layers: 264 | all_encoder_layers.append(hidden_states) 265 | all_attention_matrices.append(attention_matrix) 266 | if not output_all_encoded_layers: 267 | all_encoder_layers.append(hidden_states) 268 | all_attention_matrices.append(attention_matrix) 269 | 270 | return all_encoder_layers, all_attention_matrices 271 | 272 | 273 | class BertPooler(nn.Module): 274 | def __init__(self, config): 275 | super().__init__() 276 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 277 | self.activation = nn.Tanh() 278 | 279 | def forward(self, hidden_states): 280 | # We "pool" the model by simply taking the hidden state corresponding 281 | # to the first token. 282 | first_token_tensor = hidden_states[:, 0] 283 | pooled_output = self.dense(first_token_tensor) 284 | pooled_output = self.activation(pooled_output) 285 | return pooled_output 286 | 287 | 288 | class BertPredictionHeadTransform(nn.Module): 289 | def __init__(self, config): 290 | super().__init__() 291 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 292 | self.transform_act_fn = ACT2FN[config.hidden_act] 293 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 294 | 295 | def forward(self, hidden_states): 296 | hidden_states = self.dense(hidden_states) 297 | hidden_states = self.transform_act_fn(hidden_states) 298 | hidden_states = self.LayerNorm(hidden_states) 299 | return hidden_states 300 | 301 | 302 | class BertLMPredictionHead(nn.Module): 303 | def __init__(self, config, bert_model_embedding_weights): 304 | super().__init__() 305 | self.transform = BertPredictionHeadTransform(config) 306 | 307 | # The output weights are the same as the input embeddings, but there is 308 | # an output-only bias for each token. 309 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 310 | 311 | self.decoder.weight = bert_model_embedding_weights 312 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 313 | 314 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 315 | self.decoder.bias = self.bias 316 | 317 | def forward(self, hidden_states): 318 | hidden_states = self.transform(hidden_states) 319 | hidden_states = self.decoder(hidden_states) 320 | return hidden_states 321 | 322 | 323 | class BertOnlyMLMHead(nn.Module): 324 | def __init__(self, config): 325 | super().__init__() 326 | self.predictions = BertLMPredictionHead(config) 327 | 328 | def forward(self, sequence_output): 329 | prediction_scores = self.predictions(sequence_output) 330 | return prediction_scores 331 | 332 | 333 | class BertOnlyNSPHead(nn.Module): 334 | def __init__(self, config): 335 | super().__init__() 336 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 337 | 338 | def forward(self, pooled_output): 339 | seq_relationship_score = self.seq_relationship(pooled_output) 340 | return seq_relationship_score 341 | 342 | 343 | class BertPreTrainingHeads(nn.Module): 344 | def __init__(self, config): 345 | super().__init__() 346 | self.predictions = BertLMPredictionHead(config) 347 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 348 | 349 | def forward(self, sequence_output, pooled_output): 350 | prediction_scores = self.predictions(sequence_output) 351 | seq_relationship_score = self.seq_relationship(pooled_output) 352 | return prediction_scores, seq_relationship_score 353 | 354 | 355 | class BertPreTrainedModel(nn.Module): 356 | """ An abstract class to handle weights initialization and 357 | a simple interface for downloading and loading pretrained models. 358 | """ 359 | 360 | def __init__(self, config, *inputs, **kwargs): 361 | super(BertPreTrainedModel, self).__init__() 362 | if not isinstance(config, BertConfig): 363 | raise ValueError( 364 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 365 | "To create a model from a Google pretrained model use " 366 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 367 | self.__class__.__name__, self.__class__.__name__ 368 | )) 369 | self.config = config 370 | 371 | def init_bert_weights(self, module): 372 | """ Initialize the weights. 373 | """ 374 | if isinstance(module, (nn.Linear)): 375 | # 初始线性映射层的参数为正态分布 376 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 377 | if isinstance(module, nn.Linear) and module.bias is not None: 378 | # 初始化偏置为0 379 | module.bias.data.zero_() 380 | 381 | 382 | class BertModel(BertPreTrainedModel): 383 | """ 384 | The model can behave as an encoder (with only self-attention) as well 385 | as a decoder, in which case a layer of cross-attention is added between 386 | the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, 387 | Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 388 | To behave as an decoder the model needs to be initialized with the 389 | :obj:`is_decoder` argument of the configuration set to :obj:`True`; an 390 | :obj:`encoder_hidden_states` is expected as an input to the forward pass. 391 | .. _`Attention is all you need`: 392 | https://arxiv.org/abs/1706.03762 393 | """ 394 | 395 | def __init__(self, config): 396 | super().__init__(config) 397 | self.config = config 398 | 399 | self.embeddings = BertEmbeddings(config) 400 | self.encoder = BertEncoder(config) 401 | self.pooler = BertPooler(config) 402 | 403 | self.apply(self.init_bert_weights) 404 | 405 | def forward( 406 | self, 407 | input_ids, 408 | attention_mask=None, 409 | token_type_ids=None, 410 | position_ids=None, 411 | output_all_encoded_layers=True, 412 | output_attentions=False 413 | ): 414 | 415 | extended_attention_mask = (input_ids > 0).float() 416 | # 注意力矩阵mask: [batch_size, 1, 1, seq_length] 417 | extended_attention_mask = extended_attention_mask.unsqueeze(1).unsqueeze(2) 418 | # print(extended_attention_mask.size()) # torch.Size([2, 1, 1, 44]) 419 | 420 | if attention_mask is not None: 421 | # 如果传进来的注意力mask不是null,那就直接用传进来的注意力mask 乘 原始mask 422 | # 注意 原始mask是extended_attention_mask,这个是用来把pad部分置为0,去掉pad部分影响 423 | extended_attention_mask = attention_mask * extended_attention_mask 424 | # print(extended_attention_mask.size()) # torch.Size([2, 1, 25, 25]) 425 | 426 | if token_type_ids is None: 427 | token_type_ids = torch.zeros_like(input_ids) 428 | 429 | # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 430 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 431 | 432 | embedding_output = self.embeddings( 433 | input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids 434 | ) 435 | encoder_layers, all_attention_matrices = self.encoder( 436 | embedding_output, 437 | attention_mask=extended_attention_mask, 438 | output_all_encoded_layers=output_all_encoded_layers, 439 | output_attentions=output_attentions 440 | ) 441 | sequence_output = encoder_layers[-1] 442 | pooled_output = self.pooler(sequence_output) 443 | 444 | if output_attentions: 445 | return all_attention_matrices 446 | if not output_all_encoded_layers: 447 | # 如果不用输出所有encoder层 448 | encoder_layers = encoder_layers[-1] 449 | return encoder_layers, pooled_output 450 | -------------------------------------------------------------------------------- /UniLM/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-05 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser('--小作文生成') 12 | parser.add_argument('--corpus_path', default='../data/article.json', type=str, help='训练数据') 13 | parser.add_argument('--bert_vocab_path', default='./roberta_pretrain/vocab.txt', type=str, help='bert的词表') 14 | parser.add_argument('--bert_pretrain_weight_path', default='./roberta_pretrain/pytorch_model.bin', type=str, help='unilm的权重') 15 | 16 | parser.add_argument('--output_dir', default='./outputs', type=str, help='模型的输出') 17 | parser.add_argument('--batch_size', default=16, type=int, help='训练批次的大小') 18 | parser.add_argument('--learning_rate', default=1e-4, type=float, help='学习率的大小') 19 | parser.add_argument('--warmup_proportion', default=0.01, type=float, help='warm up概率,即训练总步长的百分之多少,进行warm up') 20 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='梯度积累') 21 | parser.add_argument('--adam_epsilon', default=1e-8, type=float, help='Adam优化器的epsilon值') 22 | parser.add_argument('--max_length', default=300, type=int, help='最大长度') 23 | parser.add_argument('--num_train_epochs', default=10, type=int, help='训练几轮') 24 | return parser.parse_args() 25 | -------------------------------------------------------------------------------- /UniLM/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-05 6 | """ 7 | import torch 8 | from torch.utils.data.dataset import Dataset 9 | from tokenizer import load_bert_vocab, Tokenizer 10 | import json 11 | 12 | 13 | def load_data(data_path): 14 | sents_src = [] 15 | sents_tgt = [] 16 | with open(data_path, 'r', encoding='utf8') as f: 17 | lines = f.readlines() 18 | for line in lines: 19 | line = line.strip() 20 | line = json.loads(line) 21 | title = line['title'] 22 | article = line['article'][:256] # 直接截断到256 23 | sents_src.append(title) 24 | sents_tgt.append(article) 25 | return sents_src, sents_tgt 26 | 27 | 28 | class UniLMDataset(Dataset): 29 | def __init__(self, data_path, bert_vocab_path): 30 | super(UniLMDataset, self).__init__() 31 | self.sents_src, self.sents_tgt = load_data(data_path) 32 | self.word2idx = load_bert_vocab(bert_vocab_path) 33 | self.idx2word = {k: v for v, k in self.word2idx.items()} 34 | self.tokenizer = Tokenizer(self.word2idx) 35 | 36 | def __getitem__(self, i): 37 | # 得到单个数据 38 | src = self.sents_src[i] 39 | tgt = self.sents_tgt[i] 40 | 41 | token_ids, token_type_ids = self.tokenizer.encode(src, tgt) 42 | output = { 43 | "token_ids": token_ids, 44 | "token_type_ids": token_type_ids, 45 | } 46 | return output 47 | 48 | def __len__(self): 49 | return len(self.sents_src) 50 | 51 | 52 | def padding(indice, max_length, pad_idx=0): 53 | temp = [] 54 | for item in indice: 55 | if len(item) >= max_length: 56 | item = item[:max_length] 57 | temp.append(item) 58 | else: 59 | item = item + [pad_idx] * (max_length - len(item)) 60 | temp.append(item) 61 | return torch.tensor(temp) 62 | 63 | 64 | def collate_fn(batch): 65 | token_ids = [data["token_ids"] for data in batch] 66 | max_length = max([len(t) for t in token_ids]) # 计算当前batch的最大长度 67 | if max_length > 300: 68 | max_length = 300 69 | token_type_ids = [data["token_type_ids"] for data in batch] 70 | token_ids_padded = padding(token_ids, max_length) 71 | token_type_ids_padded = padding(token_type_ids, max_length) 72 | target_ids_padded = token_ids_padded[:, 1:].contiguous() 73 | return token_ids_padded, token_type_ids_padded, target_ids_padded 74 | -------------------------------------------------------------------------------- /UniLM/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-02-22 6 | """ 7 | import torch 8 | from model import Model 9 | from config import set_args 10 | import torch.nn.functional as F 11 | from bert_model import BertConfig 12 | from tokenizer import load_bert_vocab, Tokenizer 13 | 14 | 15 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 16 | assert logits.dim() == 1 # 注意 这里只能支持单样本的处理 17 | top_k = min(top_k, logits.size(-1)) # 也算是一个检查吧,logits.size(-1)说明你有多少个vocab, 如果你top_k都大于vocab_num, 那topk还有意义吗,所以选最小的 18 | 19 | if top_k > 0: 20 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] # 首先得到topk个最大概率中的最小概率。 然后小于这个最小概率的,都是我们不在我们采样的序列中 21 | logits[indices_to_remove] = filter_value # 将这些不在采样序列中的token的概率置成负无穷 22 | 23 | if top_p > 0.0: 24 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits递减排序 返回两个东西: 1. 排序后的概率序列 2. 排序后每个概率原先在的位置 25 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # 对排序后的概率softmax,然后累加 26 | # 举个例子说明上面torch.cumsum()函数 假设有序列 [5, 4, 2, 1] 执行torch.cumsum()后变成: [5, 9, 11, 12] 27 | 28 | sorted_indices_to_remove = cumulative_probs > top_p # 输出: [False, Flase, True, True, True, True] 29 | 30 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # 将索引向右移动,使第一个位置保持在阈值之上 31 | sorted_indices_to_remove[..., 0] = 0 # 至少要确保概率最高的满足情况 不做这一步处理可能通过卡阈值 当前的概率都不满足 32 | 33 | indices_to_remove = sorted_indices[sorted_indices_to_remove] # 取出不在采样序列中的token索引 34 | logits[indices_to_remove] = filter_value # 然后将其概率置为负无穷 35 | return logits 36 | 37 | 38 | def sample_decode(encode_output): 39 | # {'token_ids': [101, 2769, 4638, 1959, 1959, 102, 2769, 4638, 1959, 1959, 3221, 671, 702, 1249, 1227, 3320, 2141, 4638, 1093, 3333, 1967, 1957, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} 40 | 41 | max_len = 50 # 最多生成50个token 42 | repetition_penalty = 1.1 # 惩罚项 主要是为了让生成的token在后面尽量少出现 43 | temperature = 1 # 控制样本的生成尽可能多样性 44 | topk = 5 # 最高k选1 45 | topp = 0.95 # 最高累计概率 46 | 47 | input_ids, token_type_ids = encode_output['token_ids'], encode_output['token_type_ids'] 48 | curr_input_tensor = torch.tensor([input_ids]).long() # 把输入整理成: [CLS] input [SEP] 49 | token_type_ids = torch.tensor([token_type_ids]).long() 50 | if torch.cuda.is_available(): 51 | curr_input_tensor = curr_input_tensor.cuda() 52 | token_type_ids = token_type_ids.cuda() 53 | 54 | generated = [] 55 | for _ in range(max_len): 56 | with torch.no_grad(): 57 | outputs = model(curr_input_tensor, token_type_ids) 58 | print(outputs.size()) # torch.Size([1, 4, 50000]) 59 | 60 | next_token_logits = outputs[0][-1, :] # 这里相当于取得是序列最后的位置的logits 我们认为其是预测的下一个token的概率 61 | # print(next_token_logits.size()) 62 | 63 | # 如果某个token之前已经生成过了,给其一个惩罚 就是在其对应的概率上出一个惩罚因子, 显然这个惩罚因子要>=1. 64 | for id in set(generated): 65 | next_token_logits[id] /= repetition_penalty 66 | 67 | # 这里对所有的概率都处于一个temperature,是为了将概率整体的放小,然后通过softmax后,它们之间的差距就不是很大, 68 | # 这样低概率的就更有可能采样到。满足其多样性 69 | next_token_logits = next_token_logits / temperature 70 | # print(next_token_logits.size()) 71 | 72 | # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token 73 | next_token_logits[word2idx['[UNK]']] = -float('Inf') 74 | 75 | # 可以指定topk 也可以指定topp, topk是选概率最大的前k个做采样, topp卡的时累计概率,如果越大,说明你采样的序列越多,概率越小,采样序列越小。 76 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=topk, top_p=topp) 77 | 78 | # 按概率采样 概率越高 越容易被采样到. 79 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # 按概率采样 采出当前解码出的token索引 80 | 81 | if next_token == sep_token_id: # 遇到[SEP]则表明response生成结束 82 | break 83 | 84 | generated.append(next_token.item()) 85 | 86 | next_token = next_token.unsqueeze(-1) 87 | next_token_type_ids = torch.tensor([[1]], dtype=torch.long) 88 | if torch.cuda.is_available(): 89 | next_token_type_ids = next_token_type_ids.cuda() 90 | curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=1) 91 | token_type_ids = torch.cat((token_type_ids, next_token_type_ids), dim=1) 92 | 93 | result = [] 94 | for t in generated: 95 | result.append(idx2word.get(t, word2idx['[UNK]'])) 96 | result = ''.join(result) 97 | result = result.replace('[UNK]', '') 98 | return result 99 | 100 | 101 | if __name__ == '__main__': 102 | args = set_args() 103 | word2idx = load_bert_vocab(args.bert_vocab_path) 104 | idx2word = {} 105 | for word, idx in word2idx.items(): 106 | idx2word[idx] = word 107 | 108 | cls_token_id = word2idx['[CLS]'] 109 | sep_token_id = word2idx['[SEP]'] 110 | tokenizer = Tokenizer(word2idx) 111 | 112 | config = BertConfig(len(word2idx)) 113 | model = Model(config) 114 | 115 | # 加载模型 116 | model.load_state_dict(torch.load('./checkpoint/Epoch-9.bin', map_location='cpu')) 117 | model.eval() 118 | if torch.cuda.is_available(): 119 | model.cuda() 120 | 121 | title = '我的奶奶' 122 | prefix = '我的奶奶是一个勤劳朴实的农村妇女' 123 | token_ids, token_type_ids = tokenizer.encode(title, prefix) 124 | output = { 125 | "token_ids": token_ids, 126 | "token_type_ids": token_type_ids, 127 | } 128 | # {'token_ids': [101, 2769, 4638, 1959, 1959, 102, 2769, 4638, 1959, 1959, 3221, 671, 702, 1249, 1227, 3320, 2141, 4638, 1093, 3333, 1967, 1957, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} 129 | 130 | res = sample_decode(output) 131 | print(res) 132 | 133 | -------------------------------------------------------------------------------- /UniLM/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : model.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-05 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from bert_model import BertModel, BertConfig, BertLMPredictionHead 10 | from tokenizer import Tokenizer, load_bert_vocab 11 | from config import set_args 12 | 13 | 14 | args = set_args() 15 | 16 | 17 | class Model(nn.Module): 18 | def __init__(self, config: BertConfig): 19 | super(Model, self).__init__() 20 | # 获取配置信息 21 | self.hidden_dim = config.hidden_size 22 | self.vocab_size = config.vocab_size 23 | 24 | self.bert = BertModel(config) 25 | 26 | # 解码 27 | self.decoder = BertLMPredictionHead(config, self.bert.embeddings.word_embeddings.weight) 28 | 29 | # 加载字典和分词器 30 | self.word2ix = load_bert_vocab(args.bert_vocab_path) 31 | self.tokenizer = Tokenizer(self.word2ix) 32 | 33 | def compute_loss(self, predictions, labels, target_mask): 34 | """ 35 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 36 | """ 37 | predictions = predictions.view(-1, self.vocab_size) 38 | labels = labels.view(-1) 39 | target_mask = target_mask.view(-1).float() 40 | loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none") 41 | return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() # 通过mask 取消 pad 和句子a部分预测的影响 42 | 43 | def forward(self, input_tensor, token_type_id=None, position_enc=None, labels=None): 44 | input_shape = input_tensor.size() # batch_size, max_len 45 | seq_len = input_shape[1] 46 | # 构建特殊的mask 47 | if torch.cuda.is_available(): 48 | ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32).cuda() 49 | else: 50 | ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32) 51 | 52 | a_mask = ones.tril() # 下三角矩阵 53 | 54 | s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float() # batch_size, 1, 1, max_len 55 | s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float() # batch_size, 1, max_len, 1 56 | 57 | a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask 58 | 59 | # print(a_mask.size()) # torch.Size([2, 1, 33, 33]) 60 | enc_layers, _ = self.bert(input_tensor, 61 | position_ids=position_enc, 62 | token_type_ids=token_type_id, 63 | attention_mask=a_mask, 64 | output_all_encoded_layers=True) 65 | 66 | sequence_out = enc_layers[-1] # 取出来最后一层输出 67 | # print(sequence_out.size()) # torch.Size([2, 267, 768]) 68 | 69 | predictions = self.decoder(sequence_out) 70 | 71 | if labels is not None: 72 | # 计算loss 需要构建特殊的输出mask 才能计算正确的loss 73 | predictions = predictions[:, :-1].contiguous() # 错位预测 那最后以为的预测 就没有意义 也就不用进行损失计算 74 | target_mask = token_type_id[:, 1:].contiguous() # 从第二位开始 计算loss 75 | loss = self.compute_loss(predictions, labels, target_mask) 76 | 77 | return predictions, loss 78 | else: 79 | return predictions 80 | 81 | def generate(self, text, out_max_length=10, beam_size=3): 82 | token_ids, token_type_ids = self.tokenizer.encode(text) 83 | token_ids = torch.tensor(token_ids).view(1, -1) 84 | token_type_ids = torch.tensor(token_type_ids).view(1, -1) 85 | if torch.cuda.is_available(): 86 | token_ids = token_ids.cuda() 87 | token_type_ids = token_type_ids.cuda() 88 | 89 | output_ids = self.beam_search(token_ids, token_type_ids, self.word2ix, beam_size=beam_size, 90 | device=token_ids.device, out_max_length=out_max_length) 91 | return self.tokenizer.decode(output_ids) 92 | 93 | def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, device='cpu', out_max_length=10): 94 | sep_id = word2ix['[SEP]'] 95 | # 用来保存输出序列 96 | output_ids = torch.empty(1, 0, device=device, dtype=torch.long) 97 | with torch.no_grad(): 98 | output_scores = torch.zeros(token_ids.shape[0], device=device) 99 | for step in range(out_max_length): 100 | if step == 0: 101 | scores = self.forward(token_ids, token_type_ids) 102 | # 重复beam_size次 103 | token_ids = token_ids.view(1, -1).repeat(beam_size, 1) 104 | token_type_ids = token_type_ids.view(1, -1).repeat(beam_size, 1) 105 | else: 106 | scores = self.forward(new_input_ids, new_token_type_ids) 107 | 108 | logit_score = torch.log_softmax(scores[:, -1], dim=-1) # 取出当前步对所有词的预测 109 | logit_score = output_scores.view(-1, 1) + logit_score # 累计概率 110 | 111 | # 展开 取topk 112 | logit_score = logit_score.view(-1) 113 | hype_score, hype_pos = torch.topk(logit_score, beam_size) 114 | # print(hype_score) # tensor([-5.6436, -5.7821, -5.8964]) 115 | # print(hype_pos) # tensor([4743, 4131, 2115]) 116 | indice1 = torch.div(hype_pos, scores.shape[-1], rounding_mode='floor') 117 | indice2 = (hype_pos % scores.shape[-1]).long().reshape(-1, 1) # 列索引 118 | 119 | # 更新分数 120 | output_scores = hype_score 121 | output_ids = torch.cat([output_ids[indice1], indice2], dim=1).long() 122 | new_input_ids = torch.cat([token_ids, output_ids], dim=1) # 下一步的输入 123 | new_token_type_ids = torch.cat([token_type_ids, torch.ones_like(output_ids)], dim=1) 124 | end_counts = (output_ids == sep_id).sum(1) # 统计出现end的标记 125 | best_one = output_scores.argmax() 126 | if end_counts[best_one] == 1: 127 | # 当前概率累加最大的 出现了结束标记 那就终止了 128 | return output_ids[best_one][:-1] 129 | else: 130 | # 保留未完成的部分 131 | flag = (end_counts < 1) # 标记未完成的序列 132 | if not flag.all(): 133 | # 如果有已经完成的 134 | token_ids = token_ids[flag] 135 | token_type_ids = token_type_ids[flag] 136 | 137 | new_input_ids = new_input_ids[flag] 138 | new_token_type_ids = new_token_type_ids[flag] 139 | 140 | output_ids = output_ids[flag] # 扔掉已完成序列 141 | output_scores = output_scores[flag] # 扔掉已完成序列 142 | end_counts = end_counts[flag] # 扔掉已完成end计数 143 | beam_size = flag.sum() # topk相应变化 144 | return output_ids[output_scores.argmax()] 145 | -------------------------------------------------------------------------------- /UniLM/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-08-05 6 | """ 7 | import torch 8 | import time 9 | import os 10 | import datetime 11 | from model import Model 12 | from torch.optim import AdamW 13 | from bert_model import BertConfig 14 | from data_helper import UniLMDataset, collate_fn 15 | from config import set_args 16 | from tokenizer import load_bert_vocab 17 | from torch.utils.data.dataloader import DataLoader 18 | from transformers import get_linear_schedule_with_warmup 19 | 20 | 21 | if __name__ == '__main__': 22 | args = set_args() 23 | # os.makedirs(args.output_dir, exist_ok=True) 24 | 25 | # 加载数据集 26 | train_dataset = UniLMDataset(data_path=args.corpus_path, bert_vocab_path=args.bert_vocab_path) 27 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) 28 | 29 | total_steps = int(len(train_data_loader) * args.num_train_epochs / args.gradient_accumulation_steps) 30 | 31 | # 实例化模型 32 | word2idx = load_bert_vocab(args.bert_vocab_path) 33 | bertconfig = BertConfig(vocab_size=len(word2idx)) 34 | model = Model(config=bertconfig) 35 | 36 | # 加载预训练模型 37 | model.load_state_dict(torch.load(args.bert_pretrain_weight_path), strict=False) 38 | 39 | if torch.cuda.is_available(): 40 | model.cuda() 41 | 42 | # 声明需要优化的参数 并定义相关优化器 43 | param_optimizer = list(model.named_parameters()) 44 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 45 | optimizer_grouped_parameters = [ 46 | {'params': [p for n, p in param_optimizer if not any( 47 | nd in n for nd in no_decay)], 'weight_decay': 0.01}, 48 | {'params': [p for n, p in param_optimizer if any( 49 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 50 | ] 51 | 52 | # 设置优化器 53 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 54 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion * total_steps), 55 | num_training_steps=total_steps) 56 | 57 | for epoch in range(args.num_train_epochs): 58 | model.train() 59 | total_loss, count = 0.0, 0 60 | for step, batch in enumerate(train_data_loader): 61 | if torch.cuda.is_available(): 62 | batch = (t.cuda() for t in batch) 63 | token_ids, token_type_ids, target_ids = batch 64 | 65 | # 因为传入了target标签,因此会计算loss并且返回 66 | predictions, loss = model(token_ids, token_type_ids, labels=target_ids) 67 | total_loss += loss.item() 68 | count += 1 69 | print('epoch:{}, step:{}, loss:{:8f}'.format(epoch, step, loss)) 70 | 71 | if args.gradient_accumulation_steps > 1: 72 | loss = loss / args.gradient_accumulation_steps 73 | 74 | loss.backward() 75 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 76 | 77 | if (step + 1) % args.gradient_accumulation_steps == 0: 78 | optimizer.step() 79 | scheduler.step() 80 | optimizer.zero_grad() 81 | 82 | s = 'Epoch: {} | Train_AvgLoss: {:10f} '.format(epoch, total_loss / count) 83 | logs_path = os.path.join(args.output_dir, 'logs.txt') 84 | with open(logs_path, 'a+') as f: 85 | s += '\n' 86 | f.write(s) 87 | 88 | output_dir = os.path.join(args.output_dir, "Epoch-{}.bin".format(epoch)) 89 | model_to_save = model.module if hasattr(model, "module") else model 90 | torch.save(model_to_save.state_dict(), output_dir) 91 | # 清空cuda缓存 92 | torch.cuda.empty_cache() 93 | -------------------------------------------------------------------------------- /UniLM/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : tokenizer.py 3 | @author : xiaolu 4 | @email : luxiaonlp@163.com 5 | @time : 2022-02-21 6 | """ 7 | import unicodedata 8 | from typing import List, Dict 9 | from config import set_args 10 | 11 | args = set_args() 12 | 13 | 14 | def load_bert_vocab(bert_vocab_path) -> Dict[str, int]: 15 | """ 16 | 加载官方中文bert模型字典 17 | """ 18 | with open(bert_vocab_path, "r", encoding="utf-8") as f: 19 | lines = f.readlines() 20 | word2idx = {} 21 | for index, line in enumerate(lines): 22 | word2idx[line.strip("\n")] = index 23 | return word2idx 24 | 25 | 26 | class BasicTokenizer(object): 27 | def __init__(self): 28 | """初始化 29 | """ 30 | self._token_pad = '[PAD]' 31 | self._token_cls = '[CLS]' 32 | self._token_sep = '[SEP]' 33 | self._token_unk = '[UNK]' 34 | self._token_mask = '[MASK]' 35 | 36 | def tokenize(self, text: str, add_cls=True, add_sep=True, max_length=None): 37 | """ 38 | 分词函数 39 | """ 40 | tokens = self._tokenize(text) 41 | if add_cls: 42 | tokens.insert(0, self._token_cls) 43 | if add_sep: 44 | tokens.append(self._token_sep) 45 | 46 | if max_length is not None: 47 | self.truncate_sequence(max_length, tokens, None, -2) 48 | 49 | return tokens 50 | 51 | def token_to_id(self, token): 52 | """ 53 | token转换为对应的id 54 | """ 55 | raise NotImplementedError 56 | 57 | def tokens_to_ids(self, tokens): 58 | """ 59 | token序列转换为对应的id序列 60 | """ 61 | return [self.token_to_id(token) for token in tokens] 62 | 63 | def truncate_sequence(self, 64 | max_length, 65 | first_sequence: List[str], 66 | second_sequence=None, 67 | pop_index=-1): 68 | """ 69 | 截断总长度 70 | """ 71 | if second_sequence is None: 72 | second_sequence = [] 73 | 74 | while True: 75 | total_length = len(first_sequence) + len(second_sequence) 76 | if total_length <= max_length: 77 | break 78 | elif len(first_sequence) > len(second_sequence): 79 | first_sequence.pop(pop_index) 80 | else: 81 | second_sequence.pop(pop_index) 82 | 83 | def encode(self, 84 | first_text, 85 | second_text=None, 86 | max_length=None, 87 | first_length=None, 88 | second_length=None): 89 | """ 90 | 输出文本对应token id和segment id 91 | 如果传入first_length,则强行padding第一个句子到指定长度; 92 | 同理,如果传入second_length,则强行padding第二个句子到指定长度。 93 | """ 94 | first_tokens = self.tokenize(first_text) 95 | 96 | if second_text is None: 97 | second_tokens = None 98 | else: 99 | second_tokens = self.tokenize(second_text, add_cls=False) 100 | 101 | if max_length is not None: 102 | self.truncate_sequence(max_length, first_tokens, second_tokens, -2) 103 | 104 | first_token_ids = self.tokens_to_ids(first_tokens) 105 | if first_length is not None: 106 | first_token_ids = first_token_ids[:first_length] 107 | first_token_ids.extend([self._token_pad_id] * 108 | (first_length - len(first_token_ids))) 109 | first_segment_ids = [0] * len(first_token_ids) 110 | 111 | if second_text is not None: 112 | second_token_ids = self.tokens_to_ids(second_tokens) 113 | if second_length is not None: 114 | second_token_ids = second_token_ids[:second_length] 115 | second_token_ids.extend( 116 | [self._token_pad_id] * 117 | (second_length - len(second_token_ids))) 118 | second_segment_ids = [1] * len(second_token_ids) 119 | 120 | first_token_ids.extend(second_token_ids) 121 | first_segment_ids.extend(second_segment_ids) 122 | 123 | return first_token_ids, first_segment_ids 124 | 125 | def id_to_token(self, i): 126 | """ 127 | id序列为对应的token 128 | """ 129 | raise NotImplementedError 130 | 131 | def ids_to_tokens(self, ids): 132 | """ 133 | id序列转换为对应的token序列 134 | """ 135 | return [self.id_to_token(i) for i in ids.cpu().numpy().tolist()] 136 | 137 | def decode(self, ids): 138 | """ 139 | 转为可读文本 140 | """ 141 | raise NotImplementedError 142 | 143 | def _tokenize(self, text): 144 | """ 145 | 基本分词函数 146 | """ 147 | raise NotImplementedError 148 | 149 | 150 | class Tokenizer(BasicTokenizer): 151 | def __init__(self, token_dict): 152 | """初始化 153 | """ 154 | super(Tokenizer, self).__init__() 155 | 156 | self._token_dict = token_dict 157 | 158 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 159 | for token in ['pad', 'cls', 'sep', 'unk', 'mask']: 160 | try: 161 | 162 | _token_id = token_dict[getattr(self, "_token_" + str(token))] 163 | setattr(self, "_token_" + str(token) + "_id", _token_id) 164 | except Exception as e: 165 | 166 | pass 167 | self._vocab_size = len(token_dict) 168 | 169 | def token_to_id(self, token): 170 | """token转换为对应的id 171 | """ 172 | return self._token_dict.get(token, self._token_unk_id) 173 | 174 | def id_to_token(self, i): 175 | """id转换为对应的token 176 | """ 177 | return self._token_dict_inv[i] 178 | 179 | def decode(self, ids): 180 | """转为可读文本 181 | """ 182 | tokens = self.ids_to_tokens(ids) 183 | 184 | return "".join(tokens).strip() 185 | 186 | def _tokenize(self, text): 187 | """基本分词函数 188 | """ 189 | spaced = '' 190 | for ch in text: 191 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 192 | spaced += ' ' + ch + ' ' 193 | elif self._is_space(ch): 194 | spaced += ' ' 195 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 196 | continue 197 | else: 198 | spaced += ch 199 | 200 | return spaced.strip().split() 201 | 202 | @staticmethod 203 | def _is_space(ch): 204 | """空格类字符判断 205 | """ 206 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 207 | unicodedata.category(ch) == 'Zs' 208 | 209 | @staticmethod 210 | def _is_punctuation(ch): 211 | """标点符号类字符判断(全/半角均在此内) 212 | """ 213 | code = ord(ch) 214 | return 33 <= code <= 47 or \ 215 | 58 <= code <= 64 or \ 216 | 91 <= code <= 96 or \ 217 | 123 <= code <= 126 or \ 218 | unicodedata.category(ch).startswith('P') 219 | 220 | @staticmethod 221 | def _cjk_punctuation(): 222 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\xb7\uff01\uff1f\uff61\u3002' 223 | 224 | @staticmethod 225 | def _is_cjk_character(ch): 226 | """CJK类字符判断(包括中文字符也在此列) 227 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 228 | """ 229 | code = ord(ch) 230 | return 0x4E00 <= code <= 0x9FFF or \ 231 | 0x3400 <= code <= 0x4DBF or \ 232 | 0x20000 <= code <= 0x2A6DF or \ 233 | 0x2A700 <= code <= 0x2B73F or \ 234 | 0x2B740 <= code <= 0x2B81F or \ 235 | 0x2B820 <= code <= 0x2CEAF or \ 236 | 0xF900 <= code <= 0xFAFF or \ 237 | 0x2F800 <= code <= 0x2FA1F 238 | 239 | @staticmethod 240 | def _is_control(ch): 241 | """控制类字符判断 242 | """ 243 | return unicodedata.category(ch) in ('Cc', 'Cf') 244 | 245 | @staticmethod 246 | def _is_special(ch): 247 | """判断是不是有特殊含义的符号 248 | """ 249 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 250 | 251 | 252 | if __name__ == "__main__": 253 | # 加载bert当前预训练模型对应的词表 254 | word2idx = load_bert_vocab('./roberta_pretrain/vocab.txt') 255 | # print(word2idx) 256 | 257 | tokenizer = Tokenizer(word2idx) 258 | input_ids = tokenizer.encode('王浩泽荣') 259 | print(input_ids) 260 | exit() 261 | input_ids, segment_ids = tokenizer.encode("你好啊,今天过的怎么样?", "我很好,谢谢你啦") 262 | text = tokenizer.decode(input_ids) 263 | print(input_ids) 264 | print(text) 265 | print(segment_ids) 266 | print(tokenizer.encode("今天天气真好啊")) 267 | -------------------------------------------------------------------------------- /seq2seq_rnn/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--seed', type=int, default=42, help='随机种子') 13 | parser.add_argument('--output_dir', default='output_dir/', type=str, help='模型输出路径') 14 | parser.add_argument('--train_data_src', default='../data/train.src.txt', type=str, help='训练文章') 15 | parser.add_argument('--train_data_tgt', default='../data/train.tgt.txt', type=str, help='训练摘要') 16 | parser.add_argument('--valid_data_src', default='../data/test.src.txt', type=str, help='验证文章') 17 | parser.add_argument('--valid_data_tgt', default='../data/test.tgt.txt', type=str, help='验证摘要') 18 | 19 | parser.add_argument('--test_data_src', default='../data/test.src.txt', type=str, help='测试文章') 20 | parser.add_argument('--test_data_tgt', default='../data/test.tgt.txt', type=str, help='测试摘要') 21 | 22 | parser.add_argument('--vocab2id_path', default='../data/vocab2id.json', type=str, help='词表') 23 | parser.add_argument('--emb_size', default=256, type=int, help='词嵌入大小') 24 | parser.add_argument('--hidden_size', default=512, type=int, help='隐层大小') 25 | parser.add_argument('--num_layers', default=2, type=int, help='层大小') 26 | parser.add_argument('--dropout', default=0.5, type=float, help='dropout') 27 | parser.add_argument('--num_train_epochs', default=50, type=int, help='模型训练的轮数') 28 | parser.add_argument('--batch_size', default=64, type=int, help='训练时每个batch的大小') 29 | parser.add_argument('--learning_rate', default=1e-4, type=float, help='模型训练时的学习率') 30 | parser.add_argument('--max_grad_norm', default=1.0, type=float, help='') 31 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='梯度积累') 32 | parser.add_argument('--pointer', default=False, type=bool, help='是否只用指针网络') 33 | parser.add_argument('--logging_steps', default=5, type=int, help='保存训练日志的步数') 34 | return parser.parse_args() 35 | -------------------------------------------------------------------------------- /seq2seq_rnn/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def load_data(path): 12 | data = [] 13 | with open(path, 'r', encoding='utf8') as f: 14 | lines = f.readlines() 15 | for line in lines: 16 | line = line.strip() 17 | data.append(line) 18 | return data 19 | 20 | 21 | class SummaryDataset(Dataset): 22 | def __init__(self, context, summary, tokenizer, vocab2id): 23 | self.context = context 24 | self.summary = summary 25 | self.tokenizer = tokenizer 26 | self.vocab2id = vocab2id 27 | 28 | def __len__(self): 29 | return len(self.context) 30 | 31 | def __getitem__(self, item): 32 | # 输入 33 | context = self.context[item] 34 | # context_token = self.tokenizer.tokenize(context) 35 | context_token = list(context) 36 | context_seq_len = len(context_token) 37 | context_input_ids = [] 38 | for word in context_token: 39 | context_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 40 | 41 | # 输出 42 | summary = self.summary[item] 43 | # summary_token = self.tokenizer.tokenize(summary) 44 | summary_token = list(summary) 45 | summary_seq_len = len(summary_token) 46 | summary_input_ids = [] 47 | summary_input_ids.append(self.vocab2id.get('')) 48 | for word in summary_token: 49 | summary_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 50 | summary_input_ids.append(self.vocab2id.get('')) 51 | return {'context_input_ids': context_input_ids, 'context_seq_len': context_seq_len, 52 | 'summary_input_ids': summary_input_ids, 'summary_seq_len': summary_seq_len} 53 | 54 | 55 | class Collator: 56 | def __init__(self, pad_id, is_train=True): 57 | self.pad_id = pad_id 58 | self.is_train = is_train 59 | 60 | def __call__(self, batch): 61 | context_max_len = max([d['context_seq_len'] for d in batch]) 62 | summary_max_len = max([d['summary_seq_len'] for d in batch]) 63 | 64 | if context_max_len > 256: 65 | context_max_len = 256 66 | if summary_max_len > 64: 67 | summary_max_len = 64 68 | 69 | context_input_ids_list, context_seq_len_list = [], [] 70 | summary_input_ids_list, summary_seq_len_list = [], [] 71 | for item in batch: 72 | context_input_ids_list.append(self.pad_to_maxlen(item['context_input_ids'], max_len=context_max_len)) 73 | summary_input_ids_list.append(self.pad_to_maxlen(item['summary_input_ids'], max_len=summary_max_len)) 74 | context_seq_len_list.append(item['context_seq_len']) 75 | summary_seq_len_list.append(item['summary_seq_len']) 76 | 77 | context_input_ids_tensor = torch.tensor(context_input_ids_list, dtype=torch.long) 78 | summary_input_ids_tensor = torch.tensor(summary_input_ids_list, dtype=torch.long) 79 | context_seq_len_tensor = torch.tensor(context_seq_len_list, dtype=torch.long) 80 | summary_seq_len_tensor = torch.tensor(summary_seq_len_list, dtype=torch.long) 81 | return context_input_ids_tensor, context_seq_len_tensor, summary_input_ids_tensor, summary_seq_len_tensor 82 | 83 | def pad_to_maxlen(self, input_ids, max_len): 84 | if len(input_ids) >= max_len: 85 | input_ids = input_ids[:max_len] 86 | else: 87 | input_ids = input_ids + [self.pad_id] * (max_len - len(input_ids)) 88 | return input_ids 89 | -------------------------------------------------------------------------------- /seq2seq_rnn/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import os 8 | import json 9 | import torch 10 | import random 11 | import pandas as pd 12 | from model import Model 13 | from rouge import Rouge 14 | from config import set_args 15 | from tqdm.contrib import tzip 16 | from data_helper import load_data 17 | from transformers.models.bert.tokenization_bert import BasicTokenizer 18 | 19 | 20 | # 模型预测过程 21 | def predict(model, vocab, text, max_len=60): 22 | model.eval() 23 | with torch.no_grad(): 24 | # context_token = tokenizer.tokenize(text) 25 | context_token = list(text) 26 | context_input_ids = [] 27 | for word in context_token: 28 | context_input_ids.append(vocab2id.get(word, vocab2id.get(''))) 29 | 30 | src_lengths = torch.tensor([len(context_input_ids)]) 31 | src_input_ids = torch.tensor([context_input_ids]) 32 | if torch.cuda.is_available(): 33 | src_lengths = src_lengths.cuda() 34 | src_input_ids = src_input_ids.cuda() 35 | 36 | 37 | src_emb = model.embedding(src_input_ids) 38 | enc_output, prev_hidden = model.encoder(src_emb, src_lengths) 39 | 40 | # enc_output, prev_hidden = model.encoder(src_input_ids, src_lengths) 41 | # print(enc_output.size()) # torch.Size([1, 94, 256]) 42 | decoder_input_ids = torch.tensor([vocab['']]).to(src_input_ids.device) 43 | result = [] 44 | for t in range(max_len): 45 | dec_input = decoder_input_ids.unsqueeze(1) 46 | decoder_emb = model.embedding(dec_input) 47 | 48 | logits, prev_hidden, attention_weights = model.decoder(decoder_emb, prev_hidden, enc_output, src_lengths) 49 | _, max_ids = logits.max(dim=-1) 50 | 51 | # decoder_input_ids = torch.cat(max_ids, dim=-1) 52 | decoder_input_ids = torch.tensor(max_ids).to(src_input_ids.device) 53 | 54 | ids = max_ids.cpu().numpy().tolist()[0] 55 | 56 | if ids == vocab2id.get(''): 57 | break 58 | result.append(ids) 59 | return result 60 | 61 | 62 | if __name__ == '__main__': 63 | args = set_args() 64 | # tokenizer = BasicTokenizer() 65 | test_context = load_data(args.test_data_src) 66 | test_summary = load_data(args.test_data_tgt) 67 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 68 | 69 | id2vocab = {id: vocab for vocab, id in vocab2id.items()} 70 | 71 | model = Model(vocab2id) 72 | # epoch_11.bin 73 | # epoch_2.bin 74 | # model.load_state_dict(torch.load('./output_dir/epoch_14.bin', map_location='cpu')) 75 | 76 | model.load_state_dict(torch.load('./output_dir/epoch_19.bin', map_location='cpu')) 77 | 78 | if torch.cuda.is_available(): 79 | model.cuda() 80 | 81 | final_context, final_summary, final_gen_summary = [], [], [] 82 | for context, summary in tzip(test_context, test_summary): 83 | gen_summary = predict(model, vocab2id, context) 84 | gen_summary = ''.join([id2vocab[idx] for idx in gen_summary]) 85 | final_context.append(context) 86 | final_summary.append(summary) 87 | final_gen_summary.append(gen_summary) 88 | df = pd.DataFrame({'context': final_context, 'summary': final_summary, 'gen_summary': final_gen_summary}) 89 | df.to_csv('./result.csv', index=False) 90 | 91 | # 计算指标 92 | rouge = Rouge() 93 | hyps, refs = [], [] 94 | for context, summary, gen_summary in zip(df['context'], df['summary'], df['gen_summary']): 95 | refs.append(' '.join(list(summary))) 96 | hyps.append(' '.join(list(gen_summary))) 97 | scores = rouge.get_scores(hyps, refs, avg=True) 98 | print(scores) 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /seq2seq_rnn/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : model.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import torch 8 | import random 9 | from torch import nn 10 | from config import set_args 11 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 12 | 13 | 14 | args = set_args() 15 | 16 | 17 | class Encoder(nn.Module): 18 | def __init__(self, vocab): 19 | super(Encoder, self).__init__() 20 | self.vocab_size = len(vocab) 21 | self.gru = nn.GRU(args.emb_size, args.hidden_size, num_layers=args.num_layers, batch_first=True, dropout=args.dropout) 22 | self.dropout = nn.Dropout(args.dropout) 23 | self.linear = nn.Linear(args.hidden_size, args.hidden_size) 24 | self.relu = nn.ReLU() 25 | 26 | def forward(self, enc_input, text_lengths): 27 | text_lengths = text_lengths.to("cpu") 28 | # print(text_lengths) 29 | # embedded = self.dropout(self.embedding(enc_input)) # [batch_size, seq_len, emb_size] 30 | embedded = self.dropout(enc_input) 31 | embedded = pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False) 32 | output, hidden = self.gru(embedded) 33 | output, _ = pad_packed_sequence(output, batch_first=True) 34 | output = self.relu(self.linear(output)) 35 | return output, hidden[-1].detach() 36 | 37 | 38 | class Attention(nn.Module): 39 | def __init__(self): 40 | super(Attention, self).__init__() 41 | self.linear = nn.Linear(args.hidden_size + args.emb_size, args.hidden_size) 42 | self.v = nn.Linear(args.hidden_size, 1) 43 | self.softmax = nn.Softmax(dim=-1) 44 | 45 | def forward(self, dec_input, enc_output, text_lengths): 46 | # print(dec_input.size(), enc_output.size(), text_lengths.size(), coverage_vector.size()) 47 | # torch.Size([8, 1, 128]) torch.Size([8, 107, 256]) torch.Size([8]) torch.Size([8, 107]) 48 | 49 | seq_len, hidden_size = enc_output.size(1), enc_output.size(-1) 50 | 51 | s = dec_input.repeat(1, seq_len, 1) # [batch_size, seq_len, hidden_size] torch.Size([8, 107, 128]) 52 | # print(s.size()) # torch.Size([8, 107, 128]) 53 | 54 | # coverage_vector_copy = coverage_vector.unsqueeze(2).repeat(1, 1, hidden_size) 55 | # print(coverage_vector_copy.size()) # torch.Size([8, 107, 256]) 56 | 57 | x = torch.tanh(self.linear(torch.cat([enc_output, s], dim=2))) 58 | # print(x.size()) # torch.Size([8, 107, 256]) 59 | 60 | attention = self.v(x).squeeze(-1) # [batch_size, seq_len] 61 | max_len = enc_output.size(1) 62 | 63 | text_lengths = text_lengths.to('cpu') 64 | mask = torch.arange(max_len).expand(text_lengths.shape[0], max_len) >= text_lengths.unsqueeze(1) 65 | 66 | attention.masked_fill_(mask.to(dec_input.device), float('-inf')) 67 | attention_weights = self.softmax(attention) 68 | # 更新 coverage_vector 69 | # coverage_vector += attention_weights 70 | return attention_weights # [batch, seq_len], [batch_size, seq_len] 71 | 72 | 73 | class Decoder(nn.Module): 74 | def __init__(self, vocab, attention): 75 | super(Decoder, self).__init__() 76 | self.vocab_size = len(vocab) 77 | # self.embedding = nn.Embedding(self.vocab_size, args.emb_size, padding_idx=vocab['']) 78 | self.attention = attention 79 | self.gru = nn.GRU(args.emb_size + args.hidden_size, args.hidden_size, batch_first=True) 80 | self.linear = nn.Linear(args.emb_size + 2 * args.hidden_size, self.vocab_size) 81 | self.dropout = nn.Dropout(args.dropout) 82 | # 设置 PGN 网络架构的参数,用于计算 p_gen 83 | # if args.pointer: 84 | # self.w_gen = nn.Linear(args.hidden_size * 2 + args.emb_size, 1) 85 | 86 | def forward(self, dec_input, prev_hidden, enc_output, text_lengths): 87 | # print(prev_hidden.size()) # torch.Size([8, 256]) 88 | # print(enc_output.size()) # torch.Size([8, 107, 256]) 89 | # embedded = self.embedding(dec_input) 90 | embedded = dec_input 91 | 92 | # 加入 coverage 机制后,attention 的计算公式参考 https://zhuanlan.zhihu.com/p/453600830 93 | attention_weights = self.attention(embedded, enc_output, text_lengths) 94 | # print(attention_weights.size()) # torch.Size([8, 107]) 95 | # print(coverage_vector.size()) # torch.Size([8, 107]) 96 | 97 | attention_weights = attention_weights.unsqueeze(1) # [batch_size, 1, enc_len] 98 | c = torch.bmm(attention_weights, enc_output) # [batch_size, 1, hidden_size] 99 | 100 | # 将经过 embedding 处理过的 decoder 输入,和上下文向量一起送入到 GRU 网络中 101 | gru_input = torch.cat([embedded, c], dim=2) 102 | # print(gru_input.size()) # torch.Size([8, 1, 384]) 103 | 104 | # prev_hidden 是上个时间步的隐状态,作为 decoder 的参数传入进来 105 | dec_output, dec_hidden = self.gru(gru_input, prev_hidden.unsqueeze(0)) 106 | dec_output = self.linear(torch.cat((dec_output.squeeze(1), c.squeeze(1), embedded.squeeze(1)), dim=1)) # [batch_size, vocab_size] 107 | dec_hidden = dec_hidden.squeeze(0) 108 | return dec_output, dec_hidden, attention_weights.squeeze(1) 109 | 110 | 111 | class Model(nn.Module): 112 | def __init__(self, vocab): 113 | super(Model, self).__init__() 114 | self.vocab = vocab 115 | self.vocab_size = len(vocab) 116 | self.embedding = nn.Embedding(self.vocab_size, args.emb_size, padding_idx=vocab[""]) 117 | 118 | self.encoder = Encoder(vocab) 119 | 120 | attention = Attention() 121 | self.decoder = Decoder(vocab, attention) 122 | 123 | def forward(self, src, tgt, src_lengths, teacher_forcing_ratio=0.5): 124 | src_emb = self.embedding(src) 125 | # batch_size, max_len 126 | enc_output, prev_hidden = self.encoder(src_emb, src_lengths) 127 | 128 | # print(enc_output.size()) # torch.Size([8, 115, 256]) 129 | # print(prev_hidden.size()) # torch.Size([8, 256]) 130 | 131 | batch_size = tgt.size(0) 132 | tgt_len = tgt.size(1) 133 | 134 | dec_input = tgt[:, 0] 135 | dec_outputs = torch.zeros(batch_size, tgt_len, len(self.vocab)).to(src.device) 136 | 137 | for t in range(tgt_len - 1): 138 | dec_input = dec_input.unsqueeze(1) # torch.Size([8, 1]) 139 | dec_emb = self.embedding(dec_input) 140 | dec_output, prev_hidden, _ = self.decoder(dec_emb, prev_hidden, enc_output, src_lengths) 141 | dec_outputs[:, t, :] = dec_output 142 | teacher_force = random.random() < teacher_forcing_ratio 143 | top1 = dec_output.argmax(1) 144 | dec_input = tgt[:, t] if teacher_force else top1 145 | return dec_outputs 146 | 147 | 148 | -------------------------------------------------------------------------------- /seq2seq_rnn/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import os 8 | import json 9 | import torch 10 | import random 11 | import numpy as np 12 | from model import Model 13 | from torch import nn 14 | from config import set_args 15 | import torch.nn.functional as F 16 | from transformers import get_linear_schedule_with_warmup 17 | from torch.utils.data import DataLoader 18 | from data_helper import load_data, SummaryDataset, Collator 19 | from transformers.models.bert.tokenization_bert import BasicTokenizer 20 | from loguru import logger 21 | try: 22 | from torch.utils.tensorboard import SummaryWriter 23 | except ImportError: 24 | from tensorboard import SummaryWriter 25 | 26 | # python -m tensorboard.main --logdir=./runs --host=127.0.0.1 27 | 28 | def set_seed(args): 29 | torch.manual_seed(args.seed) 30 | random.seed(args.seed) 31 | np.random.seed(args.seed) 32 | 33 | 34 | logger.add('./logger/log.log') 35 | 36 | def calc_acc(logits, decoder_input_ids): 37 | decoder_attention_mask = torch.ne(decoder_input_ids, 0).to(logits.device) 38 | mask = decoder_attention_mask.view(-1).eq(1) 39 | labels = decoder_input_ids.view(-1) 40 | # print(mask.size()) # torch.Size([168]) 41 | # print(labels.size()) # torch.Size([168]) 42 | # print(logits.size()) # torch.Size([8, 21, 10002]) 43 | logits = logits.contiguous().view(-1, logits.size(-1)) 44 | # print(logits.size()) # torch.Size([168, 10002]) 45 | 46 | _, logits = logits.max(dim=-1) 47 | # print(logits.size()) # torch.Size([168]) 48 | n_correct = logits.eq(labels).masked_select(mask).sum().item() 49 | n_word = mask.sum().item() 50 | return n_correct, n_word 51 | 52 | 53 | def evaluate(dev_data_loader): 54 | total_loss, total = 0.0, 0.0 55 | total_correct, total_word = 0.0, 0.0 56 | 57 | # 进行测试 58 | model.eval() 59 | for step, batch in enumerate(dev_data_loader): 60 | with torch.no_grad(): 61 | if torch.cuda.is_available(): 62 | batch = (t.cuda() for t in batch) 63 | 64 | context_input_ids, context_seq_len, summary_input_ids, summary_seq_len = batch 65 | logits = model(context_input_ids, summary_input_ids, context_seq_len) 66 | loss = loss_func(logits, summary_input_ids, pad_id, smoothing=True) 67 | 68 | # 对loss进行累加 69 | total_loss += loss * context_input_ids.size(0) 70 | total += context_input_ids.size(0) 71 | 72 | n_correct, n_word = calc_acc(logits, summary_input_ids) 73 | total_correct += n_correct 74 | total_word += n_word 75 | # 计算最终测试集的loss和acc结果 76 | test_loss = total_loss / total 77 | test_acc = total_correct / total_word 78 | return test_loss, test_acc 79 | 80 | 81 | def loss_func(logits, labels, pad_id, smoothing=False): 82 | if smoothing: 83 | logit = logits[..., :-1, :].contiguous().view(-1, logits.size(2)) 84 | labels = labels[..., 1:].contiguous().view(-1) 85 | eps = 0.1 86 | n_class = logit.size(-1) 87 | one_hot = torch.zeros_like(logit).scatter(1, labels.view(-1, 1), 1) 88 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 89 | log_prb = F.log_softmax(logit, dim=1) 90 | non_pad_mask = labels.ne(pad_id) 91 | loss = -(one_hot * log_prb).sum(dim=1) 92 | loss = loss.masked_select(non_pad_mask).mean() # average later 93 | else: 94 | logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) 95 | labels = labels[..., 1:].contiguous().view(-1) 96 | loss = F.cross_entropy(logits, labels, ignore_index=pad_id) 97 | return loss 98 | 99 | 100 | if __name__ == '__main__': 101 | args = set_args() 102 | set_seed(args) 103 | os.makedirs(args.output_dir, exist_ok=True) 104 | 105 | # 加载词表 106 | tokenizer = BasicTokenizer() 107 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 108 | pad_id = vocab2id.get('') 109 | collate_fn = Collator(pad_id=pad_id, is_train=True) 110 | # 加载训练数据 111 | train_context = load_data(args.train_data_src) 112 | train_summary = load_data(args.train_data_tgt) 113 | train_dataset = SummaryDataset(context=train_context, summary=train_summary, tokenizer=tokenizer, vocab2id=vocab2id) 114 | train_data_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) 115 | 116 | # 加载验证集 117 | valid_context = load_data(args.valid_data_src) 118 | valid_summary = load_data(args.valid_data_tgt) 119 | valid_dataset = SummaryDataset(context=valid_context, summary=valid_summary, tokenizer=tokenizer, vocab2id=vocab2id) 120 | valid_data_loader = DataLoader(valid_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn) 121 | 122 | model = Model(vocab=vocab2id) 123 | 124 | if torch.cuda.is_available(): 125 | model.cuda() 126 | 127 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 128 | 129 | total_steps = len(train_data_loader) * args.num_train_epochs 130 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.05 * total_steps, 131 | num_training_steps=total_steps) 132 | 133 | 134 | tb_write = SummaryWriter() 135 | global_step = 0 136 | tr_loss, logging_loss, min_loss = 0.0, 0.0, 0.0 137 | for epoch in range(args.num_train_epochs): 138 | for step, batch in enumerate(train_data_loader): 139 | if torch.cuda.is_available(): 140 | batch = (t.cuda() for t in batch) 141 | context_input_ids, context_seq_len, summary_input_ids, summary_seq_len = batch 142 | logits = model(context_input_ids, summary_input_ids, context_seq_len) 143 | loss = loss_func(logits, summary_input_ids, pad_id, smoothing=False) 144 | tr_loss += loss.item() 145 | 146 | loss.backward() 147 | 148 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 149 | 150 | optimizer.step() 151 | scheduler.step() 152 | optimizer.zero_grad() 153 | 154 | global_step += 1 155 | # 如果步数整除logging_steps,则记录学习率和训练集损失值 156 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 157 | tb_write.add_scalar("train_loss", (tr_loss - logging_loss) / args.logging_steps, global_step) 158 | logging_loss = tr_loss 159 | 160 | loss = loss.item() 161 | logger.info('epoch:{}, step:{}, loss:{}'.format(epoch, step, round(loss, 4))) 162 | 163 | eval_loss, eval_acc = evaluate(valid_data_loader) 164 | tb_write.add_scalar("test_loss", eval_loss, global_step) 165 | tb_write.add_scalar("test_acc", eval_acc, global_step) 166 | print("test_loss: {}, test_acc:{}".format(eval_loss, eval_acc)) 167 | model.train() 168 | 169 | # 每个epoch进行完,则保存模型 170 | output_dir = os.path.join(args.output_dir, "epoch_{}.bin".format(epoch)) 171 | torch.save(model.state_dict(), output_dir) 172 | 173 | 174 | 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /seq2seq_rnn/start.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | nohup python3 run_train.py > ./log.log 2>&1 & 4 | 5 | -------------------------------------------------------------------------------- /seq2seq_rnn_pointer_network/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--seed', type=int, default=42, help='随机种子') 13 | parser.add_argument('--output_dir', default='output_dir/', type=str, help='模型输出路径') 14 | parser.add_argument('--train_data_src', default='../data/train.src.txt', type=str, help='训练文章') 15 | parser.add_argument('--train_data_tgt', default='../data/train.tgt.txt', type=str, help='训练摘要') 16 | parser.add_argument('--valid_data_src', default='../data/test.src.txt', type=str, help='验证文章') 17 | parser.add_argument('--valid_data_tgt', default='../data/test.tgt.txt', type=str, help='验证摘要') 18 | 19 | parser.add_argument('--test_data_src', default='../data/test.src.txt', type=str, help='测试文章') 20 | parser.add_argument('--test_data_tgt', default='../data/test.tgt.txt', type=str, help='测试摘要') 21 | parser.add_argument('--cov_lambda', default=1, type=int) 22 | 23 | parser.add_argument('--vocab2id_path', default='../data/vocab2id.json', type=str, help='词表') 24 | parser.add_argument('--emb_size', default=256, type=int, help='词嵌入大小') 25 | parser.add_argument('--hidden_size', default=512, type=int, help='隐层大小') 26 | parser.add_argument('--num_layers', default=2, type=int, help='层大小') 27 | parser.add_argument('--dropout', default=0.5, type=float, help='dropout') 28 | parser.add_argument('--num_train_epochs', default=50, type=int, help='模型训练的轮数') 29 | parser.add_argument('--batch_size', default=64, type=int, help='训练时每个batch的大小') 30 | parser.add_argument('--learning_rate', default=1e-4, type=float, help='模型训练时的学习率') 31 | parser.add_argument('--max_grad_norm', default=1.0, type=float, help='') 32 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='梯度积累') 33 | parser.add_argument('--pointer', default=False, type=bool, help='是否只用指针网络') 34 | parser.add_argument('--logging_steps', default=5, type=int, help='保存训练日志的步数') 35 | return parser.parse_args() 36 | -------------------------------------------------------------------------------- /seq2seq_rnn_pointer_network/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def load_data(path): 12 | data = [] 13 | with open(path, 'r', encoding='utf8') as f: 14 | lines = f.readlines() 15 | for line in lines: 16 | line = line.strip() 17 | data.append(line) 18 | return data 19 | 20 | 21 | class SummaryDataset(Dataset): 22 | def __init__(self, context, summary, tokenizer, vocab2id): 23 | self.context = context 24 | self.summary = summary 25 | self.tokenizer = tokenizer 26 | self.vocab2id = vocab2id 27 | 28 | def __len__(self): 29 | return len(self.context) 30 | 31 | def __getitem__(self, item): 32 | # 输入 33 | context = self.context[item] 34 | # context_token = self.tokenizer.tokenize(context) 35 | context_token = list(context) 36 | context_seq_len = len(context_token) 37 | context_input_ids = [] 38 | for word in context_token: 39 | context_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 40 | 41 | # 输出 42 | summary = self.summary[item] 43 | # summary_token = self.tokenizer.tokenize(summary) 44 | summary_token = list(summary) 45 | summary_seq_len = len(summary_token) 46 | summary_input_ids = [] 47 | summary_input_ids.append(self.vocab2id.get('')) 48 | for word in summary_token: 49 | summary_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 50 | summary_input_ids.append(self.vocab2id.get('')) 51 | return {'context_input_ids': context_input_ids, 'context_seq_len': context_seq_len, 52 | 'summary_input_ids': summary_input_ids, 'summary_seq_len': summary_seq_len} 53 | 54 | 55 | class Collator: 56 | def __init__(self, pad_id, is_train=True): 57 | self.pad_id = pad_id 58 | self.is_train = is_train 59 | 60 | def __call__(self, batch): 61 | context_max_len = max([d['context_seq_len'] for d in batch]) 62 | summary_max_len = max([d['summary_seq_len'] for d in batch]) 63 | 64 | if context_max_len > 256: 65 | context_max_len = 256 66 | if summary_max_len > 64: 67 | summary_max_len = 64 68 | 69 | context_input_ids_list, context_seq_len_list = [], [] 70 | summary_input_ids_list, summary_seq_len_list = [], [] 71 | for item in batch: 72 | context_input_ids_list.append(self.pad_to_maxlen(item['context_input_ids'], max_len=context_max_len)) 73 | summary_input_ids_list.append(self.pad_to_maxlen(item['summary_input_ids'], max_len=summary_max_len)) 74 | context_seq_len_list.append(item['context_seq_len']) 75 | summary_seq_len_list.append(item['summary_seq_len']) 76 | 77 | context_input_ids_tensor = torch.tensor(context_input_ids_list, dtype=torch.long) 78 | summary_input_ids_tensor = torch.tensor(summary_input_ids_list, dtype=torch.long) 79 | context_seq_len_tensor = torch.tensor(context_seq_len_list, dtype=torch.long) 80 | summary_seq_len_tensor = torch.tensor(summary_seq_len_list, dtype=torch.long) 81 | return context_input_ids_tensor, context_seq_len_tensor, summary_input_ids_tensor, summary_seq_len_tensor 82 | 83 | def pad_to_maxlen(self, input_ids, max_len): 84 | if len(input_ids) >= max_len: 85 | input_ids = input_ids[:max_len] 86 | else: 87 | input_ids = input_ids + [self.pad_id] * (max_len - len(input_ids)) 88 | return input_ids 89 | -------------------------------------------------------------------------------- /seq2seq_rnn_pointer_network/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import os 8 | import json 9 | import torch 10 | import random 11 | import pandas as pd 12 | from model import Model 13 | from rouge import Rouge 14 | from config import set_args 15 | from tqdm.contrib import tzip 16 | from data_helper import load_data 17 | from transformers.models.bert.tokenization_bert import BasicTokenizer 18 | 19 | 20 | # 模型预测过程 21 | def predict(model, vocab, text, max_len=60): 22 | model.eval() 23 | with torch.no_grad(): 24 | # context_token = tokenizer.tokenize(text) 25 | context_token = list(text) 26 | context_input_ids = [] 27 | for word in context_token: 28 | context_input_ids.append(vocab2id.get(word, vocab2id.get(''))) 29 | 30 | src_lengths = torch.tensor([len(context_input_ids)]) 31 | src_input_ids = torch.tensor([context_input_ids]) 32 | if torch.cuda.is_available(): 33 | src_lengths = src_lengths.cuda() 34 | src_input_ids = src_input_ids.cuda() 35 | 36 | 37 | src_emb = model.embedding(src_input_ids) 38 | 39 | coverage_vector = torch.zeros_like(src_input_ids, dtype=torch.float32).to(src_emb.device) 40 | 41 | enc_output, prev_hidden = model.encoder(src_emb, src_lengths) 42 | 43 | # enc_output, prev_hidden = model.encoder(src_input_ids, src_lengths) 44 | # print(enc_output.size()) # torch.Size([1, 94, 256]) 45 | decoder_input_ids = torch.tensor([vocab['']]).to(src_input_ids.device) 46 | result = [] 47 | for t in range(max_len): 48 | dec_input = decoder_input_ids.unsqueeze(1) 49 | decoder_emb = model.embedding(dec_input) 50 | 51 | 52 | dec_output, prev_hidden, attention_weights, p_gen, coverage_vector = model.decoder(decoder_emb, prev_hidden, 53 | enc_output, src_lengths, 54 | coverage_vector) 55 | 56 | final_distribution = model.get_final_distribution(src_input_ids, p_gen, dec_output, attention_weights, 0) 57 | max_ids = final_distribution.argmax(1) 58 | 59 | # _, max_ids = logits.max(dim=-1) 60 | 61 | # decoder_input_ids = torch.cat(max_ids, dim=-1) 62 | decoder_input_ids = torch.tensor(max_ids).to(src_input_ids.device) 63 | 64 | ids = max_ids.cpu().numpy().tolist()[0] 65 | 66 | if ids == vocab2id.get(''): 67 | break 68 | result.append(ids) 69 | return result 70 | 71 | 72 | if __name__ == '__main__': 73 | args = set_args() 74 | # tokenizer = BasicTokenizer() 75 | test_context = load_data(args.test_data_src) 76 | test_summary = load_data(args.test_data_tgt) 77 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 78 | 79 | id2vocab = {id: vocab for vocab, id in vocab2id.items()} 80 | 81 | model = Model(vocab2id) 82 | model.load_state_dict(torch.load('./output_dir/epoch_19.bin', map_location='cpu')) 83 | 84 | if torch.cuda.is_available(): 85 | model.cuda() 86 | 87 | final_context, final_summary, final_gen_summary = [], [], [] 88 | for context, summary in tzip(test_context, test_summary): 89 | gen_summary = predict(model, vocab2id, context) 90 | gen_summary = ''.join([id2vocab[idx] for idx in gen_summary]) 91 | final_context.append(context) 92 | final_summary.append(summary) 93 | final_gen_summary.append(gen_summary) 94 | df = pd.DataFrame({'context': final_context, 'summary': final_summary, 'gen_summary': final_gen_summary}) 95 | df.to_csv('./result.csv', index=False) 96 | 97 | # 计算指标 98 | rouge = Rouge() 99 | hyps, refs = [], [] 100 | for context, summary, gen_summary in zip(df['context'], df['summary'], df['gen_summary']): 101 | refs.append(' '.join(list(summary))) 102 | hyps.append(' '.join(list(gen_summary))) 103 | scores = rouge.get_scores(hyps, refs, avg=True) 104 | print(scores) 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /seq2seq_rnn_pointer_network/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : model.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import torch 8 | import random 9 | from torch import nn 10 | from config import set_args 11 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 12 | 13 | 14 | args = set_args() 15 | 16 | 17 | class Encoder(nn.Module): 18 | def __init__(self, vocab): 19 | super(Encoder, self).__init__() 20 | self.vocab_size = len(vocab) 21 | self.gru = nn.GRU(args.emb_size, args.hidden_size, num_layers=args.num_layers, batch_first=True, dropout=args.dropout) 22 | self.dropout = nn.Dropout(args.dropout) 23 | self.linear = nn.Linear(args.hidden_size, args.hidden_size) 24 | self.relu = nn.ReLU() 25 | 26 | def forward(self, enc_input, text_lengths): 27 | text_lengths = text_lengths.to("cpu") 28 | # print(text_lengths) 29 | # embedded = self.dropout(self.embedding(enc_input)) # [batch_size, seq_len, emb_size] 30 | embedded = self.dropout(enc_input) 31 | embedded = pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False) 32 | output, hidden = self.gru(embedded) 33 | output, _ = pad_packed_sequence(output, batch_first=True) 34 | output = self.relu(self.linear(output)) 35 | return output, hidden[-1].detach() 36 | 37 | 38 | class Attention(nn.Module): 39 | def __init__(self): 40 | super(Attention, self).__init__() 41 | self.linear = nn.Linear(args.hidden_size * 2 + args.emb_size, args.hidden_size) 42 | self.v = nn.Linear(args.hidden_size, 1) 43 | self.softmax = nn.Softmax(dim=-1) 44 | 45 | def forward(self, dec_input, enc_output, text_lengths, coverage_vector): 46 | # print(dec_input.size(), enc_output.size(), text_lengths.size(), coverage_vector.size()) 47 | # torch.Size([8, 1, 128]) torch.Size([8, 107, 256]) torch.Size([8]) torch.Size([8, 107]) 48 | 49 | seq_len, hidden_size = enc_output.size(1), enc_output.size(-1) 50 | 51 | s = dec_input.repeat(1, seq_len, 1) # [batch_size, seq_len, hidden_size] torch.Size([8, 107, 128]) 52 | # print(s.size()) # torch.Size([8, 107, 128]) 53 | coverage_vector_copy = coverage_vector.unsqueeze(2).repeat(1, 1, hidden_size) 54 | 55 | x = torch.tanh(self.linear(torch.cat([enc_output, s, coverage_vector_copy], dim=2))) 56 | # print(x.size()) # torch.Size([8, 107, 256]) 57 | 58 | attention = self.v(x).squeeze(-1) # [batch_size, seq_len] 59 | max_len = enc_output.size(1) 60 | 61 | text_lengths = text_lengths.to('cpu') 62 | mask = torch.arange(max_len).expand(text_lengths.shape[0], max_len) >= text_lengths.unsqueeze(1) 63 | 64 | attention.masked_fill_(mask.to(dec_input.device), float('-inf')) 65 | attention_weights = self.softmax(attention) 66 | # 更新 coverage_vector 67 | coverage_vector += attention_weights 68 | return attention_weights, coverage_vector # [batch, seq_len], [batch_size, seq_len] 69 | 70 | 71 | class Decoder(nn.Module): 72 | def __init__(self, vocab, attention): 73 | super(Decoder, self).__init__() 74 | self.vocab_size = len(vocab) 75 | # self.embedding = nn.Embedding(self.vocab_size, args.emb_size, padding_idx=vocab['']) 76 | self.attention = attention 77 | self.gru = nn.GRU(args.emb_size + args.hidden_size, args.hidden_size, batch_first=True) 78 | self.linear = nn.Linear(args.emb_size + 2 * args.hidden_size, self.vocab_size) 79 | self.dropout = nn.Dropout(args.dropout) 80 | 81 | # PGN 网络 计算p_gen 82 | self.w_gen = nn.Linear(args.hidden_size * 2 + args.emb_size, 1) 83 | 84 | def forward(self, dec_input, prev_hidden, enc_output, text_lengths, coverage_vector): 85 | # print(prev_hidden.size()) # torch.Size([8, 256]) 86 | # print(enc_output.size()) # torch.Size([8, 107, 256]) 87 | # embedded = self.embedding(dec_input) 88 | embedded = dec_input 89 | 90 | attention_weights, coverage_vector = self.attention(embedded, enc_output, text_lengths, coverage_vector) 91 | # print(attention_weights.size()) # torch.Size([8, 107]) 92 | # print(coverage_vector.size()) # torch.Size([8, 107]) 93 | 94 | attention_weights = attention_weights.unsqueeze(1) # [batch_size, 1, enc_len] 95 | c = torch.bmm(attention_weights, enc_output) # [batch_size, 1, hidden_size] 96 | 97 | # 将经过 embedding 处理过的 decoder 输入,和上下文向量一起送入到 GRU 网络中 98 | gru_input = torch.cat([embedded, c], dim=2) 99 | # print(gru_input.size()) # torch.Size([8, 1, 384]) 100 | 101 | # prev_hidden 是上个时间步的隐状态,作为 decoder 的参数传入进来 102 | dec_output, dec_hidden = self.gru(gru_input, prev_hidden.unsqueeze(0)) 103 | dec_output = self.linear(torch.cat((dec_output.squeeze(1), c.squeeze(1), embedded.squeeze(1)), dim=1)) # [batch_size, vocab_size] 104 | dec_hidden = dec_hidden.squeeze(0) 105 | x_gen = torch.cat([dec_hidden, c.squeeze(1), embedded.squeeze(1)], dim=1) 106 | p_gen = torch.sigmoid(self.w_gen(x_gen)) 107 | return dec_output, dec_hidden, attention_weights.squeeze(1), p_gen, coverage_vector 108 | 109 | 110 | class Model(nn.Module): 111 | # PGN网络实现 112 | def __init__(self, vocab): 113 | super(Model, self).__init__() 114 | self.vocab = vocab 115 | self.vocab_size = len(vocab) 116 | self.embedding = nn.Embedding(self.vocab_size, args.emb_size, padding_idx=vocab[""]) 117 | 118 | self.encoder = Encoder(vocab) 119 | 120 | attention = Attention() 121 | self.decoder = Decoder(vocab, attention) 122 | 123 | def get_final_distribution(self, x, p_gen, p_vocab, attention_weights, max_oov): 124 | batch_size = x.shape[0] 125 | p_gen = torch.clamp(p_gen, 0.001, 0.999) 126 | p_vocab_weighted = p_gen * p_vocab 127 | attention_weighted = (1 - p_gen) * attention_weights 128 | # 加入 max_oov 维度,将原文中的 OOV 单词考虑进来 129 | extension = torch.zeros((batch_size, max_oov), dtype=torch.float).to(x.device) 130 | p_vocab_extended = torch.cat([p_vocab_weighted, extension], dim=-1) 131 | # p_gen * p_vocab + (1 - p_gen) * attention_weights, 将 attention weights 中的每个位置 idx 映射成该位置的 token_id 132 | final_distribution = p_vocab_extended.scatter_add_(dim=1, index=x, src=attention_weighted) 133 | # 输出最终的 vocab distribution [batch_size, vocab_size + len(oov)] 134 | return final_distribution 135 | 136 | def forward(self, src, tgt, src_lengths, teacher_forcing_ratio=0.5): 137 | src_emb = self.embedding(src) 138 | # batch_size, max_len 139 | enc_output, prev_hidden = self.encoder(src_emb, src_lengths) 140 | 141 | # print(enc_output.size()) # torch.Size([8, 115, 256]) 142 | # print(prev_hidden.size()) # torch.Size([8, 256]) 143 | batch_size = tgt.size(0) 144 | tgt_len = tgt.size(1) 145 | 146 | dec_input = tgt[:, 0] 147 | dec_outputs = torch.zeros(batch_size, tgt_len, len(self.vocab)).to(src.device) 148 | coverage_vector = torch.zeros_like(src, dtype=torch.float32).to(src.device) 149 | 150 | for t in range(tgt_len - 1): 151 | dec_input = dec_input.unsqueeze(1) # torch.Size([8, 1]) 152 | dec_emb = self.embedding(dec_input) 153 | 154 | # dec_output, prev_hidden, _ = self.decoder(dec_emb, prev_hidden, enc_output, src_lengths) 155 | dec_output, prev_hidden, attention_weights, p_gen, coverage_vector = self.decoder(dec_emb, prev_hidden, 156 | enc_output, src_lengths, 157 | coverage_vector) 158 | final_distribution = self.get_final_distribution(src, p_gen, dec_output, attention_weights, 0) 159 | teacher_force = random.random() < teacher_forcing_ratio 160 | dec_outputs[:, t, :] = final_distribution 161 | top1 = final_distribution.argmax(1) 162 | dec_input = tgt[:, t] if teacher_force else top1 163 | return dec_outputs, attention_weights, coverage_vector 164 | 165 | -------------------------------------------------------------------------------- /seq2seq_rnn_pointer_network/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-02 6 | """ 7 | import os 8 | import json 9 | import torch 10 | import random 11 | import numpy as np 12 | from model import Model 13 | from torch import nn 14 | from config import set_args 15 | import torch.nn.functional as F 16 | from transformers import get_linear_schedule_with_warmup 17 | from torch.utils.data import DataLoader 18 | from data_helper import load_data, SummaryDataset, Collator 19 | from transformers.models.bert.tokenization_bert import BasicTokenizer 20 | from loguru import logger 21 | try: 22 | from torch.utils.tensorboard import SummaryWriter 23 | except ImportError: 24 | from tensorboard import SummaryWriter 25 | 26 | # python -m tensorboard.main --logdir=./runs --host=127.0.0.1 27 | 28 | def set_seed(args): 29 | torch.manual_seed(args.seed) 30 | random.seed(args.seed) 31 | np.random.seed(args.seed) 32 | 33 | 34 | logger.add('./logger/log.log') 35 | 36 | def calc_acc(logits, decoder_input_ids): 37 | decoder_attention_mask = torch.ne(decoder_input_ids, 0).to(logits.device) 38 | mask = decoder_attention_mask.view(-1).eq(1) 39 | labels = decoder_input_ids.view(-1) 40 | # print(mask.size()) # torch.Size([168]) 41 | # print(labels.size()) # torch.Size([168]) 42 | # print(logits.size()) # torch.Size([8, 21, 10002]) 43 | logits = logits.contiguous().view(-1, logits.size(-1)) 44 | # print(logits.size()) # torch.Size([168, 10002]) 45 | 46 | _, logits = logits.max(dim=-1) 47 | # print(logits.size()) # torch.Size([168]) 48 | n_correct = logits.eq(labels).masked_select(mask).sum().item() 49 | n_word = mask.sum().item() 50 | return n_correct, n_word 51 | 52 | 53 | def evaluate(dev_data_loader): 54 | total_loss, total = 0.0, 0.0 55 | total_correct, total_word = 0.0, 0.0 56 | 57 | # 进行测试 58 | model.eval() 59 | for step, batch in enumerate(dev_data_loader): 60 | with torch.no_grad(): 61 | if torch.cuda.is_available(): 62 | batch = (t.cuda() for t in batch) 63 | 64 | context_input_ids, context_seq_len, summary_input_ids, summary_seq_len = batch 65 | # logits = model(context_input_ids, summary_input_ids, context_seq_len) 66 | logits, attention_weights, coverage_vector = model(context_input_ids, summary_input_ids, context_seq_len) 67 | loss = loss_func(logits, summary_input_ids, pad_id, smoothing=False) 68 | 69 | 70 | # 对loss进行累加 71 | total_loss += loss * context_input_ids.size(0) 72 | total += context_input_ids.size(0) 73 | 74 | n_correct, n_word = calc_acc(logits, summary_input_ids) 75 | total_correct += n_correct 76 | total_word += n_word 77 | # 计算最终测试集的loss和acc结果 78 | test_loss = total_loss / total 79 | test_acc = total_correct / total_word 80 | return test_loss, test_acc 81 | 82 | 83 | def loss_func(logits, labels, pad_id, smoothing=False): 84 | if smoothing: 85 | logit = logits[..., :-1, :].contiguous().view(-1, logits.size(2)) 86 | labels = labels[..., 1:].contiguous().view(-1) 87 | eps = 0.1 88 | n_class = logit.size(-1) 89 | one_hot = torch.zeros_like(logit).scatter(1, labels.view(-1, 1), 1) 90 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 91 | log_prb = F.log_softmax(logit, dim=1) 92 | non_pad_mask = labels.ne(pad_id) 93 | loss = -(one_hot * log_prb).sum(dim=1) 94 | loss = loss.masked_select(non_pad_mask).mean() # average later 95 | else: 96 | logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) 97 | labels = labels[..., 1:].contiguous().view(-1) 98 | loss = F.cross_entropy(logits, labels, ignore_index=pad_id) 99 | return loss 100 | 101 | 102 | if __name__ == '__main__': 103 | args = set_args() 104 | set_seed(args) 105 | os.makedirs(args.output_dir, exist_ok=True) 106 | 107 | # 加载词表 108 | tokenizer = BasicTokenizer() 109 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 110 | pad_id = vocab2id.get('') 111 | collate_fn = Collator(pad_id=pad_id, is_train=True) 112 | # 加载训练数据 113 | train_context = load_data(args.train_data_src) 114 | train_summary = load_data(args.train_data_tgt) 115 | train_dataset = SummaryDataset(context=train_context, summary=train_summary, tokenizer=tokenizer, vocab2id=vocab2id) 116 | train_data_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) 117 | 118 | # 加载验证集 119 | valid_context = load_data(args.valid_data_src) 120 | valid_summary = load_data(args.valid_data_tgt) 121 | valid_dataset = SummaryDataset(context=valid_context, summary=valid_summary, tokenizer=tokenizer, vocab2id=vocab2id) 122 | valid_data_loader = DataLoader(valid_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn) 123 | 124 | model = Model(vocab=vocab2id) 125 | 126 | if torch.cuda.is_available(): 127 | model.cuda() 128 | 129 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 130 | 131 | total_steps = len(train_data_loader) * args.num_train_epochs 132 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0.05 * total_steps, 133 | num_training_steps=total_steps) 134 | 135 | tb_write = SummaryWriter() 136 | global_step = 0 137 | tr_loss, logging_loss, min_loss = 0.0, 0.0, 0.0 138 | for epoch in range(args.num_train_epochs): 139 | for step, batch in enumerate(train_data_loader): 140 | if torch.cuda.is_available(): 141 | batch = (t.cuda() for t in batch) 142 | context_input_ids, context_seq_len, summary_input_ids, summary_seq_len = batch 143 | logits, attention_weights, coverage_vector = model(context_input_ids, summary_input_ids, context_seq_len) 144 | 145 | ce_loss = loss_func(logits, summary_input_ids, pad_id, smoothing=False) 146 | c_t = torch.min(attention_weights, coverage_vector) 147 | cov_loss = torch.mean(torch.sum(c_t, dim=1)) 148 | # 计算整体 loss 149 | loss = ce_loss + args.cov_lambda * cov_loss 150 | 151 | tr_loss += loss.item() 152 | loss.backward() 153 | 154 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 155 | 156 | optimizer.step() 157 | scheduler.step() 158 | optimizer.zero_grad() 159 | 160 | global_step += 1 161 | # 如果步数整除logging_steps,则记录学习率和训练集损失值 162 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 163 | tb_write.add_scalar("train_loss", (tr_loss - logging_loss) / args.logging_steps, global_step) 164 | logging_loss = tr_loss 165 | # print('epoch:{}, step:{}, loss:{}'.format(epoch, step, round(loss.item(), 4))) 166 | logger.info('epoch:{}, step:{}, loss:{}'.format(epoch, step, round(loss.item(), 4))) 167 | 168 | eval_loss, eval_acc = evaluate(valid_data_loader) 169 | tb_write.add_scalar("test_loss", eval_loss, global_step) 170 | tb_write.add_scalar("test_acc", eval_acc, global_step) 171 | print("test_loss: {}, test_acc:{}".format(eval_loss, eval_acc)) 172 | model.train() 173 | 174 | # 每个epoch进行完,则保存模型 175 | output_dir = os.path.join(args.output_dir, "epoch_{}.bin".format(epoch)) 176 | torch.save(model.state_dict(), output_dir) 177 | 178 | 179 | 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /seq2seq_rnn_pointer_network/start.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | nohup python run_train.py > ./log.log 2>&1 & 4 | 5 | -------------------------------------------------------------------------------- /seq2seq_transformer/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser('--rnn gen') 12 | parser.add_argument('--epochs', default=50, type=int, help='训练几轮') 13 | parser.add_argument('--learning_rate', default=1e-3, type=int, help='学习率') 14 | parser.add_argument('--batch_size', default=64, type=int, help='训练的批次大小') 15 | 16 | parser.add_argument('--train_data_src', default='../data/train.src.txt', type=str, help='训练文章') 17 | parser.add_argument('--train_data_tgt', default='../data/train.tgt.txt', type=str, help='训练摘要') 18 | parser.add_argument('--valid_data_src', default='../data/test.src.txt', type=str, help='验证文章') 19 | parser.add_argument('--valid_data_tgt', default='../data/test.tgt.txt', type=str, help='验证摘要') 20 | parser.add_argument('--test_data_src', default='../data/test.src.txt', type=str, help='测试文章') 21 | parser.add_argument('--test_data_tgt', default='../data/test.tgt.txt', type=str, help='测试摘要') 22 | parser.add_argument('--vocab2id_path', default='../data/vocab2id.json', type=str, help='词表') 23 | parser.add_argument('--logging_steps', default=5, type=int) 24 | 25 | # clip 26 | parser.add_argument('--clip', default=2.0, type=float, help='梯度裁剪') 27 | parser.add_argument('--seed', default=43, type=int, help='随机种子大小') 28 | parser.add_argument('--output_dir', default='./output', type=str, help='模型输出路径') 29 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='梯度积聚') 30 | 31 | parser.add_argument('--hidden_size', default=1024, type=int, help='隐层大小') 32 | parser.add_argument('--d_model', default=512, type=int, help='') 33 | parser.add_argument('--heads', default=8, type=int, help='') 34 | parser.add_argument('--num_layers', default=6, type=int, help='') 35 | return parser.parse_args() 36 | -------------------------------------------------------------------------------- /seq2seq_transformer/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def load_data(path): 12 | data = [] 13 | with open(path, 'r', encoding='utf8') as f: 14 | lines = f.readlines() 15 | for line in lines: 16 | line = line.strip() 17 | data.append(line) 18 | return data 19 | 20 | 21 | class SummaryDataset(Dataset): 22 | def __init__(self, context, summary, tokenizer, vocab2id): 23 | self.context = context 24 | self.summary = summary 25 | self.tokenizer = tokenizer 26 | self.vocab2id = vocab2id 27 | 28 | def __len__(self): 29 | return len(self.context) 30 | 31 | def __getitem__(self, item): 32 | # 输入 33 | context = self.context[item] 34 | # context_token = self.tokenizer.tokenize(context) 35 | context_token = list(context) 36 | context_seq_len = len(context_token) 37 | context_input_ids = [] 38 | for word in context_token: 39 | context_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 40 | 41 | # 输出 42 | summary = self.summary[item] 43 | 44 | summary_token = list(summary) 45 | summary_seq_len = len(summary_token) 46 | summary_input_ids = [] 47 | 48 | summary_input_ids.append(self.vocab2id.get('')) 49 | for word in summary_token: 50 | summary_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 51 | summary_input_ids.append(self.vocab2id.get('')) 52 | return {'context_input_ids': context_input_ids, 'context_seq_len': context_seq_len, 53 | 'summary_input_ids': summary_input_ids, 'summary_seq_len': summary_seq_len} 54 | 55 | 56 | class Collator: 57 | def __init__(self, pad_id, is_train=True): 58 | self.pad_id = pad_id 59 | self.is_train = is_train 60 | 61 | def __call__(self, batch): 62 | context_max_len = max([d['context_seq_len'] for d in batch]) 63 | summary_max_len = max([d['summary_seq_len'] for d in batch]) 64 | 65 | if context_max_len > 256: 66 | context_max_len = 256 67 | if summary_max_len > 64: 68 | summary_max_len = 64 69 | 70 | context_input_ids_list, context_seq_len_list = [], [] 71 | summary_input_ids_list, summary_seq_len_list = [], [] 72 | for item in batch: 73 | context_input_ids_list.append(self.pad_to_maxlen(item['context_input_ids'], max_len=context_max_len)) 74 | summary_input_ids_list.append(self.pad_to_maxlen(item['summary_input_ids'], max_len=summary_max_len)) 75 | context_seq_len_list.append(item['context_seq_len']) 76 | summary_seq_len_list.append(item['summary_seq_len']) 77 | 78 | context_input_ids_tensor = torch.tensor(context_input_ids_list, dtype=torch.long) 79 | summary_input_ids_tensor = torch.tensor(summary_input_ids_list, dtype=torch.long) 80 | context_seq_len_tensor = torch.tensor(context_seq_len_list, dtype=torch.long) 81 | summary_seq_len_tensor = torch.tensor(summary_seq_len_list, dtype=torch.long) 82 | return context_input_ids_tensor, context_seq_len_tensor, summary_input_ids_tensor, summary_seq_len_tensor 83 | 84 | def pad_to_maxlen(self, input_ids, max_len): 85 | if len(input_ids) >= max_len: 86 | input_ids = input_ids[:max_len] 87 | else: 88 | input_ids = input_ids + [self.pad_id] * (max_len - len(input_ids)) 89 | return input_ids 90 | -------------------------------------------------------------------------------- /seq2seq_transformer/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import json 8 | import torch 9 | import torch.utils.data 10 | from model import Transformer 11 | from rouge import Rouge 12 | from config import set_args 13 | from tqdm.contrib import tzip 14 | from data_helper import load_data 15 | import torch.nn.functional as F 16 | import pandas as pd 17 | 18 | 19 | 20 | 21 | @torch.no_grad() 22 | def predict(model, vocab2id, context): 23 | input_text = str(context) 24 | input_ids = [vocab2id.get(v, vocab2id['']) for v in list(input_text)] 25 | input_ids = torch.tensor([input_ids], dtype=torch.long) 26 | input_mask = (input_ids != 0).unsqueeze(1).unsqueeze(1) 27 | 28 | start_token = vocab2id[''] 29 | if torch.cuda.is_available(): 30 | input_ids, input_mask = input_ids.cuda(), input_mask.cuda() 31 | 32 | encoded = model.encode(input_ids, input_mask) 33 | 34 | words = torch.tensor([[vocab2id['']]]) 35 | for step in range(max_len): 36 | size = words.size(1) 37 | # 下三角矩阵 38 | target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) 39 | target_mask = target_mask.unsqueeze(0).unsqueeze(0) 40 | 41 | # 下三角 42 | if torch.cuda.is_available(): 43 | words, target_mask = words.cuda(), target_mask.cuda() 44 | 45 | decoded = model.decode(words, target_mask, encoded, input_mask) 46 | predictions = model.logit(decoded[:, -1]) 47 | _, next_word = torch.max(predictions, dim=1) 48 | 49 | next_word = next_word.item() 50 | if next_word == vocab2id['']: 51 | break 52 | words = torch.cat([words, torch.LongTensor([[next_word]]).cuda()], dim=1) # (1,step+2) 53 | 54 | # [1, 128] => [128] .squeeze(0) 55 | if words.dim() == 2: 56 | words = words.squeeze(0) 57 | words = words.tolist() 58 | 59 | sen_idx = [w for w in words if w not in {vocab2id['']}] 60 | sentence = ''.join([id2vocab[sen_idx[k]] for k in range(len(sen_idx))]) 61 | return sentence 62 | 63 | 64 | 65 | if __name__ == '__main__': 66 | args = set_args() 67 | test_context = load_data(args.test_data_src)[:100] 68 | test_summary = load_data(args.test_data_tgt)[:100] 69 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 70 | max_len = 60 71 | 72 | id2vocab = {id: vocab for vocab, id in vocab2id.items()} 73 | model = Transformer(d_model=args.d_model, heads=args.heads, num_layers=args.num_layers, word_map=vocab2id) 74 | model.load_state_dict(torch.load('./output/epoch_{}.bin'.format(0), map_location='cpu')) 75 | if torch.cuda.is_available(): 76 | model.cuda() 77 | 78 | model.eval() 79 | 80 | final_context, final_summary, final_gen_summary = [], [], [] 81 | for context, summary in tzip(test_context, test_summary): 82 | gen_summary = predict(model, vocab2id, context) 83 | final_context.append(context) 84 | final_summary.append(summary) 85 | final_gen_summary.append(gen_summary) 86 | df = pd.DataFrame({'context': final_context, 'summary': final_summary, 'gen_summary': final_gen_summary}) 87 | df.to_csv('./result.csv', index=False) 88 | 89 | # 计算指标 90 | rouge = Rouge() 91 | hyps, refs = [], [] 92 | for context, summary, gen_summary in zip(df['context'], df['summary'], df['gen_summary']): 93 | refs.append(' '.join(list(summary))) 94 | hyps.append(' '.join(list(gen_summary))) 95 | scores = rouge.get_scores(hyps, refs, avg=True) 96 | print(scores) 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /seq2seq_transformer/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : model.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class Embeddings(nn.Module): 17 | """ 18 | Implements embeddings of the words and adds their positional encodings. 19 | """ 20 | def __init__(self, vocab_size, d_model, max_len=256): 21 | super(Embeddings, self).__init__() 22 | self.d_model = d_model 23 | self.dropout = nn.Dropout(0.1) 24 | self.embed = nn.Embedding(vocab_size, d_model) 25 | self.pe = self.create_positinal_encoding(max_len, self.d_model) 26 | self.dropout = nn.Dropout(0.1) 27 | 28 | def create_positinal_encoding(self, max_len, d_model): 29 | pe = torch.zeros(max_len, d_model).to(device) 30 | for pos in range(max_len): # for each position of the word 31 | for i in range(0, d_model, 2): # for each dimension of the each position 32 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) 33 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 34 | pe = pe.unsqueeze(0) # include the batch size 35 | return pe 36 | 37 | def forward(self, encoded_words): 38 | embedding = self.embed(encoded_words) * math.sqrt(self.d_model) 39 | embedding += self.pe[:, :embedding.size(1)] # pe will automatically be expanded with the same batch size as encoded_words 40 | embedding = self.dropout(embedding) 41 | return embedding 42 | 43 | 44 | class MultiHeadAttention(nn.Module): 45 | def __init__(self, heads, d_model): 46 | super(MultiHeadAttention, self).__init__() 47 | assert d_model % heads == 0 48 | self.d_k = d_model // heads 49 | self.heads = heads 50 | self.dropout = nn.Dropout(0.1) 51 | self.query = nn.Linear(d_model, d_model) 52 | self.key = nn.Linear(d_model, d_model) 53 | self.value = nn.Linear(d_model, d_model) 54 | self.concat = nn.Linear(d_model, d_model) 55 | 56 | def forward(self, query, key, value, mask): 57 | """ 58 | query, key, value of shape: (batch_size, max_len, 512) 59 | mask of shape: (batch_size, 1, 1, max_words) 60 | """ 61 | query = self.query(query) 62 | key = self.key(key) 63 | value = self.value(value) 64 | # print(query.size(), key.size(), value.size()) 65 | # torch.Size([2, 8, 512]) torch.Size([2, 8, 512]) torch.Size([2, 8, 512]) 66 | # query: (batch_size, max_len, 512) -> (batch_size, max_len, 8, 64) -> (batch_size, 8, max_len, 64) 67 | query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) 68 | key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) 69 | value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) 70 | # (batch_size, 8, max_len, 64) 71 | # (batch_size, 8, max_len, max_len) 72 | 73 | scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1)) 74 | # (batch_size, 8, max_len, max_len) 75 | 76 | # print(mask.size()) # torch.Size([2, 1, 1, 14]) 77 | # [1, 1, 1, 0, 0, 0] 78 | scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len) 79 | weights = F.softmax(scores, dim=-1) # (batch_size, h, max_len, max_len) 80 | 81 | weights = self.dropout(weights) 82 | # print(weights.size()) # torch.Size([8, 8, 118, 118]) 83 | 84 | # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k) 85 | context = torch.matmul(weights, value) 86 | # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k) 87 | context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k) 88 | # batch_size, max_len, hidden_size 89 | # print(context.size()) # torch.Size([8, 9, 512]) 90 | interacted = self.concat(context) 91 | return interacted, weights 92 | 93 | 94 | class FeedForward(nn.Module): 95 | def __init__(self, d_model, middle_dim=2048): 96 | super(FeedForward, self).__init__() 97 | self.fc1 = nn.Linear(d_model, middle_dim) 98 | self.fc2 = nn.Linear(middle_dim, d_model) 99 | self.dropout = nn.Dropout(0.1) 100 | 101 | def forward(self, x): 102 | out = F.relu(self.fc1(x)) 103 | out = self.fc2(self.dropout(out)) 104 | return out 105 | 106 | 107 | class EncoderLayer(nn.Module): 108 | def __init__(self, d_model, heads): 109 | super(EncoderLayer, self).__init__() 110 | self.layernorm = nn.LayerNorm(d_model) # d_model 111 | self.self_multihead = MultiHeadAttention(heads, d_model) 112 | self.feed_forward = FeedForward(d_model) 113 | self.dropout = nn.Dropout(0.1) 114 | 115 | def forward(self, embeddings, mask): 116 | # embeddings: batch_size, max_len, 512 117 | x, attn = self.self_multihead(embeddings, embeddings, embeddings, mask) 118 | interacted = self.dropout(x) 119 | # batch_size, max_len, hidden_size 120 | 121 | interacted = self.layernorm(interacted + embeddings) # 残差+归一化 122 | feed_forward_out = self.dropout(self.feed_forward(interacted)) 123 | # feed_forward_out: batch_size, max_len, hidden_size 124 | encoded = self.layernorm(feed_forward_out + interacted) 125 | return encoded 126 | 127 | 128 | class DecoderLayer(nn.Module): 129 | def __init__(self, d_model, heads): 130 | super(DecoderLayer, self).__init__() 131 | self.layernorm = nn.LayerNorm(d_model) 132 | self.self_multihead = MultiHeadAttention(heads, d_model) 133 | self.src_multihead = MultiHeadAttention(heads, d_model) 134 | self.feed_forward = FeedForward(d_model) 135 | self.dropout = nn.Dropout(0.1) 136 | 137 | def forward(self, embeddings, encoded, src_mask, target_mask): 138 | x, _ = self.self_multihead(embeddings, embeddings, embeddings, target_mask) 139 | query = self.dropout(x) 140 | query = self.layernorm(query + embeddings) 141 | 142 | x, attn = self.src_multihead(query, encoded, encoded, src_mask) 143 | interacted = self.dropout(x) 144 | interacted = self.layernorm(interacted + query) 145 | feed_forward_out = self.dropout(self.feed_forward(interacted)) 146 | decoded = self.layernorm(feed_forward_out + interacted) 147 | return decoded, attn 148 | 149 | 150 | class Transformer(nn.Module): 151 | def __init__(self, d_model, heads, num_layers, word_map): 152 | ''' 153 | d_model: 512 154 | heads: 8 155 | num_layers: 6 156 | word_map: vocab2id 157 | ''' 158 | super(Transformer, self).__init__() 159 | self.d_model = d_model 160 | self.vocab_size = len(word_map) 161 | self.embed = Embeddings(self.vocab_size, d_model) 162 | self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)]) 163 | self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)]) 164 | self.logit = nn.Linear(d_model, self.vocab_size) 165 | 166 | # Pointer work 指针网络 167 | self.switch = nn.Linear(self.vocab_size, 1) 168 | 169 | def encode(self, src_words, src_mask): 170 | src_embeddings = self.embed(src_words) 171 | # print(src_embeddings.size()) # (batch_size,max_len,512) 172 | 173 | for layer in self.encoder: 174 | src_embeddings = layer(src_embeddings, src_mask) 175 | return src_embeddings 176 | 177 | def decode(self, target_words, target_mask, src_embeddings, src_mask): 178 | tgt_embeddings = self.embed(target_words) 179 | # print(tgt_embeddings.size()) 180 | 181 | for layer in self.decoder: 182 | tgt_embeddings, attention = layer(tgt_embeddings, src_embeddings, src_mask, target_mask) 183 | 184 | return tgt_embeddings 185 | 186 | 187 | def forward(self, src_words, src_mask, target_words, target_mask): 188 | encoded = self.encode(src_words, src_mask) 189 | # print(encoded.size()) # batch_size, max_len, hidden_size 190 | 191 | # print(encoded.size()) # torch.Size([2, 6, 512]) 192 | # [SOS, xx, xxx, ] -》[2, xx,xx] 193 | 194 | out = self.decode(target_words, target_mask, encoded, src_mask) 195 | out = F.log_softmax(self.logit(out), dim=2) 196 | return out 197 | 198 | 199 | -------------------------------------------------------------------------------- /seq2seq_transformer/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | from torch import nn 11 | import torch.utils.data 12 | from config import set_args 13 | from model import Transformer 14 | from loguru import logger 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from transformers import BasicTokenizer 18 | from data_helper import SummaryDataset, load_data, Collator 19 | try: 20 | from torch.utils.tensorboard import SummaryWriter 21 | except ImportError: 22 | from tensorboard import SummaryWriter 23 | 24 | logger.add('./logger/log.log') 25 | 26 | 27 | 28 | def create_masks(question, reply_input, reply_target): 29 | def subsequent_mask(size): 30 | mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) 31 | return mask.unsqueeze(0) 32 | question_mask = (question!=0) 33 | question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words) 34 | 35 | reply_input_mask = reply_input!=0 36 | reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words) 37 | reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 38 | reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words) 39 | reply_target_mask = reply_target!=0 # (batch_size, max_words) 40 | 41 | return question_mask, reply_input_mask, reply_target_mask 42 | 43 | 44 | class AdamWarmup: 45 | def __init__(self, model_size, warmup_steps, optimizer): 46 | self.model_size = model_size 47 | self.warmup_steps = warmup_steps 48 | self.optimizer = optimizer 49 | self.current_step = 0 50 | self.lr = 0 51 | 52 | def get_lr(self): 53 | return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), 54 | self.current_step * self.warmup_steps ** (-1.5)) 55 | 56 | def step(self): 57 | # Increment the number of steps each time we call the step function 58 | self.current_step += 1 59 | lr = self.get_lr() 60 | for param_group in self.optimizer.param_groups: 61 | param_group['lr'] = lr 62 | # update the learning rate 63 | self.lr = lr 64 | self.optimizer.step() 65 | 66 | 67 | class LossWithLS(nn.Module): 68 | def __init__(self, size, smooth): 69 | super(LossWithLS, self).__init__() 70 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False) 71 | self.confidence = 1.0 - smooth 72 | self.smooth = smooth 73 | self.size = size 74 | 75 | def forward(self, prediction, target, mask): 76 | """ 77 | prediction of shape: (batch_size, max_words, vocab_size) 78 | target and mask of shape: (batch_size, max_words) 79 | """ 80 | # prediction: batch_size, max_len, vocab_size 81 | prediction = prediction.view(-1, prediction.size(-1)) # (batch_size * max_words, vocab_size) 82 | # batch_size*max_len, vocab_size 83 | # batch_size*max_len 84 | target = target.contiguous().view(-1) # (batch_size * max_words) 85 | mask = mask.float() 86 | mask = mask.view(-1) # (batch_size * max_words) 87 | labels = prediction.data.clone() 88 | 89 | # 平滑 90 | labels.fill_(self.smooth / (self.size - 1)) 91 | labels.scatter_(1, target.data.unsqueeze(1), self.confidence) 92 | loss = self.criterion(prediction, labels) # (batch_size * max_words, vocab_size) 93 | loss = (loss.sum(1) * mask).sum() / mask.sum() 94 | return loss 95 | 96 | 97 | def calc_acc(logits, decoder_input_ids): 98 | decoder_attention_mask = torch.ne(decoder_input_ids, 0).to(logits.device) 99 | mask = decoder_attention_mask.view(-1).eq(1) 100 | labels = decoder_input_ids.reshape(-1) 101 | # labels = decoder_input_ids.view(-1) 102 | logits = logits.contiguous().view(-1, logits.size(-1)) 103 | 104 | _, logits = logits.max(dim=-1) 105 | n_correct = logits.eq(labels).masked_select(mask).sum().item() 106 | n_word = mask.sum().item() 107 | return n_correct, n_word 108 | 109 | 110 | def evaluate(valid_data_loader): 111 | model.eval() 112 | total_loss, total = 0.0, 0.0 113 | total_correct, total_word = 0.0, 0.0 114 | # 进行测试 115 | model.eval() 116 | for step, batch in enumerate(valid_data_loader): 117 | with torch.no_grad(): 118 | if torch.cuda.is_available(): 119 | batch = (t.cuda() for t in batch) 120 | 121 | s_input_ids, context_seq_len, t_input_ids, summary_seq_len = batch 122 | t_input_ids_input = t_input_ids[:, :-1] # [SOS ...] 123 | t_input_ids_output = t_input_ids[:, 1:] # [... EOS] 124 | 125 | # 自己看 126 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = create_masks(s_input_ids, 127 | t_input_ids_input, 128 | t_input_ids_output) 129 | if torch.cuda.is_available(): 130 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = s_input_ids_mask.cuda(), t_input_ids_input_mask.cuda(), t_input_ids_output_mask.cuda() 131 | output = model(s_input_ids, s_input_ids_mask, t_input_ids_input, t_input_ids_input_mask) 132 | loss = loss_func(output, t_input_ids_output, t_input_ids_output_mask) 133 | # 对loss进行累加 134 | total_loss += loss.item() * s_input_ids.size(0) 135 | total += s_input_ids.size(0) 136 | 137 | n_correct, n_word = calc_acc(output, t_input_ids_output) 138 | total_correct += n_correct 139 | total_word += n_word 140 | # 计算最终测试集的loss和acc结果 141 | test_loss = total_loss / total 142 | test_acc = total_correct / total_word 143 | return test_loss, test_acc 144 | 145 | 146 | if __name__ == '__main__': 147 | args = set_args() 148 | os.makedirs(args.output_dir, exist_ok=True) 149 | tokenizer = BasicTokenizer() 150 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 151 | 152 | pad_id = vocab2id.get('') 153 | collate_fn = Collator(pad_id=pad_id, is_train=True) 154 | 155 | # 加载训练数据 156 | train_context = load_data(args.train_data_src) 157 | train_summary = load_data(args.train_data_tgt) 158 | train_dataset = SummaryDataset(context=train_context, summary=train_summary, tokenizer=tokenizer, vocab2id=vocab2id) 159 | train_data_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) 160 | 161 | # 加载验证集 162 | valid_context = load_data(args.valid_data_src) 163 | valid_summary = load_data(args.valid_data_tgt) 164 | valid_dataset = SummaryDataset(context=valid_context, summary=valid_summary, tokenizer=tokenizer, vocab2id=vocab2id) 165 | valid_data_loader = DataLoader(valid_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn) 166 | 167 | model = Transformer(d_model=args.d_model, heads=args.heads, num_layers=args.num_layers, word_map=vocab2id) 168 | if torch.cuda.is_available(): 169 | model.cuda() 170 | 171 | # 优化 172 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) 173 | transformer_optimizer = AdamWarmup(model_size=args.d_model, warmup_steps=4000, optimizer=adam_optimizer) 174 | 175 | # 损失函数 176 | loss_func = LossWithLS(len(vocab2id), 0.1) 177 | tb_write = SummaryWriter() 178 | global_step = 0 179 | tr_loss, logging_loss, min_loss = 0.0, 0.0, 0.0 180 | for epoch in range(args.epochs): 181 | loss_lists = [] 182 | for step, batch in enumerate(train_data_loader): 183 | if torch.cuda.is_available(): 184 | batch = (t.cuda() for t in batch) 185 | s_input_ids, context_seq_len, t_input_ids, summary_seq_len = batch 186 | t_input_ids_input = t_input_ids[:, :-1] # [SOS ...] 187 | t_input_ids_output = t_input_ids[:, 1:] # [... EOS] 188 | 189 | # 自己看 190 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = create_masks(s_input_ids, t_input_ids_input, t_input_ids_output) 191 | 192 | if torch.cuda.is_available(): 193 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = s_input_ids_mask.cuda(), t_input_ids_input_mask.cuda(), t_input_ids_output_mask.cuda() 194 | 195 | output = model(s_input_ids, s_input_ids_mask, t_input_ids_input, t_input_ids_input_mask) 196 | loss = loss_func(output, t_input_ids_output, t_input_ids_output_mask) 197 | 198 | tr_loss += loss.item() 199 | 200 | transformer_optimizer.optimizer.zero_grad() 201 | loss.backward() 202 | transformer_optimizer.step() 203 | 204 | global_step += 1 205 | # 如果步数整除logging_steps,则记录学习率和训练集损失值 206 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 207 | tb_write.add_scalar("train_loss", (tr_loss - logging_loss) / args.logging_steps, global_step) 208 | logging_loss = tr_loss 209 | loss = loss.item() 210 | logger.info('epoch:{}, step:{}, loss:{}'.format(epoch, step, round(loss, 4))) 211 | 212 | eval_loss, eval_acc = evaluate(valid_data_loader) 213 | tb_write.add_scalar("test_loss", eval_loss, global_step) 214 | tb_write.add_scalar("test_acc", eval_acc, global_step) 215 | print("test_loss: {}, test_acc:{}".format(eval_loss, eval_acc)) 216 | model.train() 217 | 218 | # 每个epoch进行完,则保存模型 219 | output_dir = os.path.join(args.output_dir, "epoch_{}.bin".format(epoch)) 220 | torch.save(model.state_dict(), output_dir) 221 | 222 | -------------------------------------------------------------------------------- /seq2seq_transformer/start.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | nohup python run_train.py > ./log.log 2>&1 & 4 | -------------------------------------------------------------------------------- /seq2seq_transformer_pointer_network/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : config.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import argparse 8 | 9 | 10 | def set_args(): 11 | parser = argparse.ArgumentParser('--rnn gen') 12 | parser.add_argument('--epochs', default=50, type=int, help='训练几轮') 13 | parser.add_argument('--learning_rate', default=1e-3, type=int, help='学习率') 14 | parser.add_argument('--batch_size', default=64, type=int, help='训练的批次大小') 15 | 16 | parser.add_argument('--train_data_src', default='../data/train.src.txt', type=str, help='训练文章') 17 | parser.add_argument('--train_data_tgt', default='../data/train.tgt.txt', type=str, help='训练摘要') 18 | parser.add_argument('--valid_data_src', default='../data/test.src.txt', type=str, help='验证文章') 19 | parser.add_argument('--valid_data_tgt', default='../data/test.tgt.txt', type=str, help='验证摘要') 20 | parser.add_argument('--test_data_src', default='../data/test.src.txt', type=str, help='测试文章') 21 | parser.add_argument('--test_data_tgt', default='../data/test.tgt.txt', type=str, help='测试摘要') 22 | parser.add_argument('--vocab2id_path', default='../data/vocab2id.json', type=str, help='词表') 23 | parser.add_argument('--logging_steps', default=5, type=int) 24 | 25 | # clip 26 | parser.add_argument('--clip', default=2.0, type=float, help='梯度裁剪') 27 | parser.add_argument('--seed', default=43, type=int, help='随机种子大小') 28 | parser.add_argument('--output_dir', default='./output', type=str, help='模型输出路径') 29 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help='梯度积聚') 30 | 31 | parser.add_argument('--hidden_size', default=1024, type=int, help='隐层大小') 32 | parser.add_argument('--d_model', default=512, type=int, help='') 33 | parser.add_argument('--heads', default=8, type=int, help='') 34 | parser.add_argument('--num_layers', default=6, type=int, help='') 35 | return parser.parse_args() 36 | -------------------------------------------------------------------------------- /seq2seq_transformer_pointer_network/data_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : data_helper.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def load_data(path): 12 | data = [] 13 | with open(path, 'r', encoding='utf8') as f: 14 | lines = f.readlines() 15 | for line in lines: 16 | line = line.strip() 17 | data.append(line) 18 | return data 19 | 20 | 21 | class SummaryDataset(Dataset): 22 | def __init__(self, context, summary, tokenizer, vocab2id): 23 | self.context = context 24 | self.summary = summary 25 | self.tokenizer = tokenizer 26 | self.vocab2id = vocab2id 27 | 28 | def __len__(self): 29 | return len(self.context) 30 | 31 | def __getitem__(self, item): 32 | # 输入 33 | context = self.context[item] 34 | # context_token = self.tokenizer.tokenize(context) 35 | context_token = list(context) 36 | context_seq_len = len(context_token) 37 | context_input_ids = [] 38 | for word in context_token: 39 | context_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 40 | 41 | # 输出 42 | summary = self.summary[item] 43 | # summary_token = self.tokenizer.tokenize(summary) 44 | summary_token = list(summary) 45 | summary_seq_len = len(summary_token) 46 | summary_input_ids = [] 47 | summary_input_ids.append(self.vocab2id.get('')) 48 | for word in summary_token: 49 | summary_input_ids.append(self.vocab2id.get(word, self.vocab2id.get(''))) 50 | summary_input_ids.append(self.vocab2id.get('')) 51 | return {'context_input_ids': context_input_ids, 'context_seq_len': context_seq_len, 52 | 'summary_input_ids': summary_input_ids, 'summary_seq_len': summary_seq_len} 53 | 54 | 55 | class Collator: 56 | def __init__(self, pad_id, is_train=True): 57 | self.pad_id = pad_id 58 | self.is_train = is_train 59 | 60 | def __call__(self, batch): 61 | context_max_len = max([d['context_seq_len'] for d in batch]) 62 | summary_max_len = max([d['summary_seq_len'] for d in batch]) 63 | 64 | if context_max_len > 256: 65 | context_max_len = 256 66 | if summary_max_len > 64: 67 | summary_max_len = 64 68 | 69 | context_input_ids_list, context_seq_len_list = [], [] 70 | summary_input_ids_list, summary_seq_len_list = [], [] 71 | for item in batch: 72 | context_input_ids_list.append(self.pad_to_maxlen(item['context_input_ids'], max_len=context_max_len)) 73 | summary_input_ids_list.append(self.pad_to_maxlen(item['summary_input_ids'], max_len=summary_max_len)) 74 | context_seq_len_list.append(item['context_seq_len']) 75 | summary_seq_len_list.append(item['summary_seq_len']) 76 | 77 | context_input_ids_tensor = torch.tensor(context_input_ids_list, dtype=torch.long) 78 | summary_input_ids_tensor = torch.tensor(summary_input_ids_list, dtype=torch.long) 79 | context_seq_len_tensor = torch.tensor(context_seq_len_list, dtype=torch.long) 80 | summary_seq_len_tensor = torch.tensor(summary_seq_len_list, dtype=torch.long) 81 | return context_input_ids_tensor, context_seq_len_tensor, summary_input_ids_tensor, summary_seq_len_tensor 82 | 83 | def pad_to_maxlen(self, input_ids, max_len): 84 | if len(input_ids) >= max_len: 85 | input_ids = input_ids[:max_len] 86 | else: 87 | input_ids = input_ids + [self.pad_id] * (max_len - len(input_ids)) 88 | return input_ids 89 | -------------------------------------------------------------------------------- /seq2seq_transformer_pointer_network/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : inference.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import json 8 | import torch 9 | import torch.utils.data 10 | from model import Transformer 11 | from rouge import Rouge 12 | from config import set_args 13 | from tqdm.contrib import tzip 14 | from data_helper import load_data 15 | import pandas as pd 16 | 17 | 18 | 19 | def predict(model, vocab2id, context): 20 | input_text = str(context) 21 | input_ids = [vocab2id.get(v, vocab2id['']) for v in list(input_text)] 22 | input_ids = torch.tensor([input_ids], dtype=torch.long) 23 | input_mask = (input_ids != 0).unsqueeze(1).unsqueeze(1) 24 | 25 | start_token = vocab2id[''] 26 | if torch.cuda.is_available(): 27 | input_ids, input_mask = input_ids.cuda(), input_mask.cuda() 28 | 29 | encoded = model.encode(input_ids, input_mask) 30 | words = torch.tensor([[vocab2id['']]]) 31 | for step in range(max_len): 32 | size = words.size(1) 33 | # 下三角矩阵 34 | target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) 35 | target_mask = target_mask.unsqueeze(0).unsqueeze(0) 36 | 37 | if torch.cuda.is_available(): 38 | words, target_mask = words.cuda(), target_mask.cuda() 39 | predictions = model.decode(words, target_mask, encoded, input_mask, input_ids, is_pointer=False) 40 | # predictions = model.decode(words, target_mask, encoded, input_mask, input_ids) 41 | predictions = predictions[:, -1] 42 | _, next_word = torch.max(predictions, dim=1) 43 | 44 | next_word = next_word.item() 45 | if next_word == vocab2id['']: 46 | break 47 | # words = torch.cat([words, torch.LongTensor([[next_word]]).cuda()], dim=1) # (1,step+2) 48 | words = torch.cat([words, torch.LongTensor([[next_word]]).cuda()], dim=1) # (1,step+2) 49 | 50 | # [1, 128] => [128] .squeeze(0) 51 | if words.dim() == 2: 52 | words = words.squeeze(0) 53 | words = words.tolist() 54 | 55 | sen_idx = [w for w in words if w not in {vocab2id['']}] 56 | sentence = ''.join([id2vocab[sen_idx[k]] for k in range(len(sen_idx))]) 57 | return sentence 58 | 59 | 60 | 61 | if __name__ == '__main__': 62 | args = set_args() 63 | test_context = load_data(args.test_data_src)[:100] 64 | test_summary = load_data(args.test_data_tgt)[:100] 65 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 66 | max_len = 60 67 | 68 | id2vocab = {id: vocab for vocab, id in vocab2id.items()} 69 | model = Transformer(d_model=args.d_model, heads=args.heads, num_layers=args.num_layers, word_map=vocab2id) 70 | model.load_state_dict(torch.load('./output/epoch_{}.bin'.format(0), map_location='cpu')) 71 | # epoch_5.bin 72 | if torch.cuda.is_available(): 73 | model.cuda() 74 | 75 | final_context, final_summary, final_gen_summary = [], [], [] 76 | for context, summary in tzip(test_context, test_summary): 77 | print(context) 78 | print(summary) 79 | gen_summary = predict(model, vocab2id, context) 80 | print(gen_summary) 81 | exit() 82 | final_context.append(context) 83 | final_summary.append(summary) 84 | final_gen_summary.append(gen_summary) 85 | df = pd.DataFrame({'context': final_context, 'summary': final_summary, 'gen_summary': final_gen_summary}) 86 | df.to_csv('./result.csv', index=False) 87 | 88 | # 计算指标 89 | rouge = Rouge() 90 | hyps, refs = [], [] 91 | for context, summary, gen_summary in zip(df['context'], df['summary'], df['gen_summary']): 92 | refs.append(' '.join(list(summary))) 93 | hyps.append(' '.join(list(gen_summary))) 94 | scores = rouge.get_scores(hyps, refs, avg=True) 95 | print(scores) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /seq2seq_transformer_pointer_network/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : model.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class Embeddings(nn.Module): 17 | """ 18 | Implements embeddings of the words and adds their positional encodings. 19 | """ 20 | def __init__(self, vocab_size, d_model, max_len=256): 21 | super(Embeddings, self).__init__() 22 | self.d_model = d_model 23 | self.dropout = nn.Dropout(0.1) 24 | self.embed = nn.Embedding(vocab_size, d_model) 25 | self.pe = self.create_positinal_encoding(max_len, self.d_model) 26 | self.dropout = nn.Dropout(0.1) 27 | 28 | def create_positinal_encoding(self, max_len, d_model): 29 | pe = torch.zeros(max_len, d_model).to(device) 30 | for pos in range(max_len): # for each position of the word 31 | for i in range(0, d_model, 2): # for each dimension of the each position 32 | pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) 33 | pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 34 | pe = pe.unsqueeze(0) # include the batch size 35 | return pe 36 | 37 | def forward(self, encoded_words): 38 | embedding = self.embed(encoded_words) * math.sqrt(self.d_model) 39 | embedding += self.pe[:, :embedding.size(1)] # pe will automatically be expanded with the same batch size as encoded_words 40 | embedding = self.dropout(embedding) 41 | return embedding 42 | 43 | 44 | class MultiHeadAttention(nn.Module): 45 | def __init__(self, heads, d_model): 46 | super(MultiHeadAttention, self).__init__() 47 | assert d_model % heads == 0 48 | self.d_k = d_model // heads 49 | self.heads = heads 50 | self.dropout = nn.Dropout(0.1) 51 | self.query = nn.Linear(d_model, d_model) 52 | self.key = nn.Linear(d_model, d_model) 53 | self.value = nn.Linear(d_model, d_model) 54 | self.concat = nn.Linear(d_model, d_model) 55 | 56 | def forward(self, query, key, value, mask): 57 | """ 58 | query, key, value of shape: (batch_size, max_len, 512) 59 | mask of shape: (batch_size, 1, 1, max_words) 60 | """ 61 | query = self.query(query) 62 | key = self.key(key) 63 | value = self.value(value) 64 | # print(query.size(), key.size(), value.size()) 65 | # torch.Size([2, 8, 512]) torch.Size([2, 8, 512]) torch.Size([2, 8, 512]) 66 | # query: (batch_size, max_len, 512) -> (batch_size, max_len, 8, 64) -> (batch_size, 8, max_len, 64) 67 | query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) 68 | key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) 69 | value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) 70 | # (batch_size, 8, max_len, 64) 71 | # (batch_size, 8, max_len, max_len) 72 | 73 | scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1)) 74 | # (batch_size, 8, max_len, max_len) 75 | 76 | # print(mask.size()) # torch.Size([2, 1, 1, 14]) 77 | # [1, 1, 1, 0, 0, 0] 78 | scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len) 79 | weights = F.softmax(scores, dim=-1) # (batch_size, h, max_len, max_len) 80 | 81 | weights = self.dropout(weights) 82 | # print(weights.size()) # torch.Size([8, 8, 118, 118]) 83 | 84 | # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k) 85 | context = torch.matmul(weights, value) 86 | # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k) 87 | context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k) 88 | # batch_size, max_len, hidden_size 89 | # print(context.size()) # torch.Size([8, 9, 512]) 90 | interacted = self.concat(context) 91 | return interacted, weights 92 | 93 | 94 | class FeedForward(nn.Module): 95 | def __init__(self, d_model, middle_dim=2048): 96 | super(FeedForward, self).__init__() 97 | self.fc1 = nn.Linear(d_model, middle_dim) 98 | self.fc2 = nn.Linear(middle_dim, d_model) 99 | self.dropout = nn.Dropout(0.1) 100 | 101 | def forward(self, x): 102 | out = F.relu(self.fc1(x)) 103 | out = self.fc2(self.dropout(out)) 104 | return out 105 | 106 | 107 | class EncoderLayer(nn.Module): 108 | def __init__(self, d_model, heads): 109 | super(EncoderLayer, self).__init__() 110 | self.layernorm = nn.LayerNorm(d_model) # d_model 111 | self.self_multihead = MultiHeadAttention(heads, d_model) 112 | self.feed_forward = FeedForward(d_model) 113 | self.dropout = nn.Dropout(0.1) 114 | 115 | def forward(self, embeddings, mask): 116 | # embeddings: batch_size, max_len, 512 117 | x, attn = self.self_multihead(embeddings, embeddings, embeddings, mask) 118 | interacted = self.dropout(x) 119 | # batch_size, max_len, hidden_size 120 | 121 | interacted = self.layernorm(interacted + embeddings) # 残差+归一化 122 | feed_forward_out = self.dropout(self.feed_forward(interacted)) 123 | # feed_forward_out: batch_size, max_len, hidden_size 124 | encoded = self.layernorm(feed_forward_out + interacted) 125 | return encoded 126 | 127 | 128 | class DecoderLayer(nn.Module): 129 | def __init__(self, d_model, heads): 130 | super(DecoderLayer, self).__init__() 131 | self.layernorm = nn.LayerNorm(d_model) 132 | self.self_multihead = MultiHeadAttention(heads, d_model) 133 | self.src_multihead = MultiHeadAttention(heads, d_model) 134 | self.feed_forward = FeedForward(d_model) 135 | self.dropout = nn.Dropout(0.1) 136 | 137 | def forward(self, embeddings, encoded, src_mask, target_mask): 138 | x, _ = self.self_multihead(embeddings, embeddings, embeddings, target_mask) 139 | query = self.dropout(x) 140 | query = self.layernorm(query + embeddings) 141 | 142 | x, attn = self.src_multihead(query, encoded, encoded, src_mask) 143 | interacted = self.dropout(x) 144 | interacted = self.layernorm(interacted + query) 145 | feed_forward_out = self.dropout(self.feed_forward(interacted)) 146 | decoded = self.layernorm(feed_forward_out + interacted) 147 | return decoded, attn 148 | 149 | 150 | class Transformer(nn.Module): 151 | def __init__(self, d_model, heads, num_layers, word_map): 152 | ''' 153 | d_model: 512 154 | heads: 8 155 | num_layers: 6 156 | word_map: vocab2id 157 | ''' 158 | super(Transformer, self).__init__() 159 | self.d_model = d_model 160 | self.vocab_size = len(word_map) 161 | self.embed = Embeddings(self.vocab_size, d_model) 162 | self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)]) 163 | self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)]) 164 | self.logit = nn.Linear(d_model, self.vocab_size) 165 | 166 | # Pointer work 指针网络 167 | self.switch = nn.Linear(self.vocab_size, 1) 168 | 169 | def encode(self, src_words, src_mask): 170 | src_embeddings = self.embed(src_words) 171 | # print(src_embeddings.size()) # (batch_size,max_len,512) 172 | 173 | for layer in self.encoder: 174 | src_embeddings = layer(src_embeddings, src_mask) 175 | return src_embeddings 176 | 177 | def decode(self, target_words, target_mask, src_embeddings, src_mask, src_words, is_pointer=True): 178 | tgt_embeddings = self.embed(target_words) 179 | # print(tgt_embeddings.size()) 180 | 181 | for layer in self.decoder: 182 | tgt_embeddings, attention = layer(tgt_embeddings, src_embeddings, src_mask, target_mask) 183 | 184 | decoded = self.logit(tgt_embeddings) 185 | if is_pointer: 186 | x = self.switch(decoded) 187 | p_pointer = torch.sigmoid(x) / 10 188 | 189 | if torch.max(src_words) + 1 > decoded.shape[-1]: 190 | extended = Variable( 191 | torch.zeros((decoded.shape[0], decoded.shape[1], torch.max(src_words) + 1 - decoded.shape[-1]))).to( 192 | decoded.device) 193 | decoded = torch.cat((decoded, extended), dim=2) 194 | 195 | output = ((1 - p_pointer) * F.softmax(decoded, dim=2)).scatter_add( 196 | 2, src_words.unsqueeze(1).repeat(1, decoded.shape[1], 1), 197 | p_pointer * attention[:, 3]) + 1e-10 198 | out = torch.log(output) 199 | return out 200 | else: 201 | out = F.log_softmax(decoded, dim=2) 202 | return out 203 | 204 | 205 | def forward(self, src_words, src_mask, target_words, target_mask): 206 | encoded = self.encode(src_words, src_mask) 207 | # print(encoded.size()) # batch_size, max_len, hidden_size 208 | 209 | # print(encoded.size()) # torch.Size([2, 6, 512]) 210 | # [SOS, xx, xxx, ] -》[2, xx,xx] 211 | out = self.decode(target_words, target_mask, encoded, src_mask, src_words) 212 | return out 213 | 214 | 215 | -------------------------------------------------------------------------------- /seq2seq_transformer_pointer_network/run_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file : run_train.py 3 | @author : Henry 4 | @email : luxiaonlp@163.com 5 | @time : 2024-01-06 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | from torch import nn 11 | import torch.utils.data 12 | from config import set_args 13 | from model import Transformer 14 | from loguru import logger 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from transformers import BasicTokenizer 18 | from data_helper import SummaryDataset, load_data, Collator 19 | try: 20 | from torch.utils.tensorboard import SummaryWriter 21 | except ImportError: 22 | from tensorboard import SummaryWriter 23 | 24 | logger.add('./logger/log.log') 25 | 26 | 27 | def create_masks(question, reply_input, reply_target): 28 | def subsequent_mask(size): 29 | mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) 30 | return mask.unsqueeze(0) 31 | 32 | question_mask = (question != 0) 33 | # batch_size, max_len = question.size() 34 | # question_mask = question_mask.unsqueeze(1).expand(batch_size, max_len, max_len).unsqueeze(1) 35 | question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words) 36 | 37 | reply_input_mask = reply_input != 0 38 | reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words) 39 | reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 40 | reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words) 41 | reply_target_mask = reply_target != 0 # (batch_size, max_words) 42 | return question_mask, reply_input_mask, reply_target_mask 43 | 44 | 45 | class AdamWarmup: 46 | def __init__(self, model_size, warmup_steps, optimizer): 47 | self.model_size = model_size 48 | self.warmup_steps = warmup_steps 49 | self.optimizer = optimizer 50 | self.current_step = 0 51 | self.lr = 0 52 | 53 | def get_lr(self): 54 | return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), 55 | self.current_step * self.warmup_steps ** (-1.5)) 56 | 57 | def step(self): 58 | # Increment the number of steps each time we call the step function 59 | self.current_step += 1 60 | lr = self.get_lr() 61 | for param_group in self.optimizer.param_groups: 62 | param_group['lr'] = lr 63 | # update the learning rate 64 | self.lr = lr 65 | self.optimizer.step() 66 | 67 | 68 | class LossWithLS(nn.Module): 69 | def __init__(self, size, smooth): 70 | super(LossWithLS, self).__init__() 71 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False) 72 | self.confidence = 1.0 - smooth 73 | self.smooth = smooth 74 | self.size = size 75 | 76 | def forward(self, prediction, target, mask): 77 | """ 78 | prediction of shape: (batch_size, max_words, vocab_size) 79 | target and mask of shape: (batch_size, max_words) 80 | """ 81 | # prediction: batch_size, max_len, vocab_size 82 | prediction = prediction.view(-1, prediction.size(-1)) # (batch_size * max_words, vocab_size) 83 | # batch_size*max_len, vocab_size 84 | # batch_size*max_len 85 | target = target.contiguous().view(-1) # (batch_size * max_words) 86 | mask = mask.float() 87 | mask = mask.view(-1) # (batch_size * max_words) 88 | labels = prediction.data.clone() 89 | 90 | # 平滑 91 | labels.fill_(self.smooth / (self.size - 1)) 92 | labels.scatter_(1, target.data.unsqueeze(1), self.confidence) 93 | loss = self.criterion(prediction, labels) # (batch_size * max_words, vocab_size) 94 | loss = (loss.sum(1) * mask).sum() / mask.sum() 95 | return loss 96 | 97 | 98 | def calc_acc(logits, decoder_input_ids): 99 | decoder_attention_mask = torch.ne(decoder_input_ids, 0).to(logits.device) 100 | mask = decoder_attention_mask.view(-1).eq(1) 101 | labels = decoder_input_ids.reshape(-1) 102 | # labels = decoder_input_ids.view(-1) 103 | 104 | # print(mask.size()) # torch.Size([168]) 105 | # print(labels.size()) # torch.Size([168]) 106 | # print(logits.size()) # torch.Size([8, 21, 10002]) 107 | logits = logits.contiguous().view(-1, logits.size(-1)) 108 | # print(logits.size()) # torch.Size([168, 10002]) 109 | 110 | _, logits = logits.max(dim=-1) 111 | # print(logits.size()) # torch.Size([168]) 112 | n_correct = logits.eq(labels).masked_select(mask).sum().item() 113 | n_word = mask.sum().item() 114 | return n_correct, n_word 115 | 116 | 117 | def evaluate(valid_data_loader): 118 | model.eval() 119 | total_loss, total = 0.0, 0.0 120 | total_correct, total_word = 0.0, 0.0 121 | # 进行测试 122 | model.eval() 123 | for step, batch in enumerate(valid_data_loader): 124 | with torch.no_grad(): 125 | if torch.cuda.is_available(): 126 | batch = (t.cuda() for t in batch) 127 | 128 | s_input_ids, context_seq_len, t_input_ids, summary_seq_len = batch 129 | t_input_ids_input = t_input_ids[:, :-1] # [SOS ...] 130 | t_input_ids_output = t_input_ids[:, 1:] # [... EOS] 131 | 132 | # 自己看 133 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = create_masks(s_input_ids, 134 | t_input_ids_input, 135 | t_input_ids_output) 136 | if torch.cuda.is_available(): 137 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = s_input_ids_mask.cuda(), t_input_ids_input_mask.cuda(), t_input_ids_output_mask.cuda() 138 | output = model(s_input_ids, s_input_ids_mask, t_input_ids_input, t_input_ids_input_mask) 139 | # loss = loss_func(output, t_input_ids_output, pad_id=0) 140 | loss = loss_func(output, t_input_ids_output, t_input_ids_output_mask) 141 | # 对loss进行累加 142 | total_loss += loss.item() * s_input_ids.size(0) 143 | total += s_input_ids.size(0) 144 | 145 | n_correct, n_word = calc_acc(output, t_input_ids_output) 146 | total_correct += n_correct 147 | total_word += n_word 148 | # 计算最终测试集的loss和acc结果 149 | test_loss = total_loss / total 150 | test_acc = total_correct / total_word 151 | return test_loss, test_acc 152 | 153 | 154 | def loss_func_ori(logits, labels, pad_id, smoothing=False): 155 | if smoothing: 156 | logit = logits[..., :-1, :].contiguous().view(-1, logits.size(2)) 157 | labels = labels[..., 1:].contiguous().view(-1) 158 | eps = 0.1 159 | n_class = logit.size(-1) 160 | one_hot = torch.zeros_like(logit).scatter(1, labels.view(-1, 1), 1) 161 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 162 | log_prb = F.log_softmax(logit, dim=1) 163 | non_pad_mask = labels.ne(pad_id) 164 | loss = -(one_hot * log_prb).sum(dim=1) 165 | loss = loss.masked_select(non_pad_mask).mean() # average later 166 | else: 167 | logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) 168 | labels = labels[..., 1:].contiguous().view(-1) 169 | loss = F.cross_entropy(logits, labels, ignore_index=0) 170 | return loss 171 | 172 | 173 | if __name__ == '__main__': 174 | args = set_args() 175 | os.makedirs(args.output_dir, exist_ok=True) 176 | tokenizer = BasicTokenizer() 177 | vocab2id = json.load(open(args.vocab2id_path, 'r', encoding='utf8')) 178 | 179 | pad_id = vocab2id.get('') 180 | collate_fn = Collator(pad_id=pad_id, is_train=True) 181 | 182 | # 加载训练数据 183 | train_context = load_data(args.train_data_src) 184 | train_summary = load_data(args.train_data_tgt) 185 | train_dataset = SummaryDataset(context=train_context, summary=train_summary, tokenizer=tokenizer, vocab2id=vocab2id) 186 | train_data_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) 187 | 188 | # 加载验证集 189 | valid_context = load_data(args.valid_data_src) 190 | valid_summary = load_data(args.valid_data_tgt) 191 | valid_dataset = SummaryDataset(context=valid_context, summary=valid_summary, tokenizer=tokenizer, vocab2id=vocab2id) 192 | valid_data_loader = DataLoader(valid_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn) 193 | 194 | model = Transformer(d_model=args.d_model, heads=args.heads, num_layers=args.num_layers, word_map=vocab2id) 195 | if torch.cuda.is_available(): 196 | model.cuda() 197 | 198 | # 优化 199 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) 200 | transformer_optimizer = AdamWarmup(model_size=args.d_model, warmup_steps=4000, optimizer=adam_optimizer) 201 | 202 | # 损失函数 203 | loss_func = LossWithLS(len(vocab2id), 0.1) 204 | 205 | tb_write = SummaryWriter() 206 | global_step = 0 207 | tr_loss, logging_loss, min_loss = 0.0, 0.0, 0.0 208 | for epoch in range(args.epochs): 209 | loss_lists = [] 210 | for step, batch in enumerate(train_data_loader): 211 | if torch.cuda.is_available(): 212 | batch = (t.cuda() for t in batch) 213 | s_input_ids, context_seq_len, t_input_ids, summary_seq_len = batch 214 | t_input_ids_input = t_input_ids[:, :-1] # [SOS ...] 215 | t_input_ids_output = t_input_ids[:, 1:] # [... EOS] 216 | 217 | # 自己看 218 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = create_masks(s_input_ids, 219 | t_input_ids_input, 220 | t_input_ids_output) 221 | if torch.cuda.is_available(): 222 | s_input_ids_mask, t_input_ids_input_mask, t_input_ids_output_mask = s_input_ids_mask.cuda(), t_input_ids_input_mask.cuda(), t_input_ids_output_mask.cuda() 223 | 224 | output = model(s_input_ids, s_input_ids_mask, t_input_ids_input, t_input_ids_input_mask) 225 | # print(s_input_ids.size()) # torch.Size([8, 123]) 226 | # print(t_input_ids.size()) # torch.Size([8, 24]) 227 | # print(t_input_ids_input.size()) # torch.Size([8, 23]) 228 | # print(output.size()) # torch.Size([8, 23, 5121]) 229 | loss = loss_func(output, t_input_ids_output, t_input_ids_output_mask) 230 | 231 | tr_loss += loss.item() 232 | 233 | transformer_optimizer.optimizer.zero_grad() 234 | loss.backward() 235 | transformer_optimizer.step() 236 | 237 | global_step += 1 238 | # 如果步数整除logging_steps,则记录学习率和训练集损失值 239 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 240 | tb_write.add_scalar("train_loss", (tr_loss - logging_loss) / args.logging_steps, global_step) 241 | logging_loss = tr_loss 242 | loss = loss.item() 243 | logger.info('epoch:{}, step:{}, loss:{}'.format(epoch, step, round(loss, 4))) 244 | 245 | eval_loss, eval_acc = evaluate(valid_data_loader) 246 | tb_write.add_scalar("test_loss", eval_loss, global_step) 247 | tb_write.add_scalar("test_acc", eval_acc, global_step) 248 | print("test_loss: {}, test_acc:{}".format(eval_loss, eval_acc)) 249 | model.train() 250 | 251 | # 每个epoch进行完,则保存模型 252 | output_dir = os.path.join(args.output_dir, "epoch_{}.bin".format(epoch)) 253 | torch.save(model.state_dict(), output_dir) 254 | 255 | -------------------------------------------------------------------------------- /seq2seq_transformer_pointer_network/start.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | nohup python3 run_train.py > ./log.log 2>&1 & 4 | --------------------------------------------------------------------------------