├── .gitattributes ├── .idea ├── .gitignore ├── SPACE_pytorch.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── remote-mappings.xml └── webServers.xml ├── LICENSE ├── README.md ├── extract_convert.py ├── extract_model.py ├── extract_vectorize.py ├── seq2seq_convert.py ├── seq2seq_model.py ├── snippets.py └── test_model ├── lawformer.py └── test_function.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../../../../../../../:\Users\25505\Desktop\文件\科研\graduation_project\law_data\CAIL2020\SPACE_pytorch\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/SPACE_pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 36 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 38 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 eryihaha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPACES-Pytorch 2 | 苏神SPACES pytorch版本复现 原keras 版本 3 | https://github.com/bojone/SPACES 4 | -------------------------------------------------------------------------------- /extract_convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | from snippets import * 6 | 7 | # 初始化 8 | maxlen = 256 9 | 10 | 11 | def text_split(text, limited=True): 12 | """将长句按照标点分割为多个子句。 13 | """ 14 | texts = text_segmentate(text, 1, u'\n。;:,') 15 | if limited: 16 | texts = texts[-maxlen:] 17 | return texts 18 | 19 | 20 | def extract_matching(texts, summaries, start_i=0, start_j=0): 21 | """在texts中找若干句子,使得它们连起来与summaries尽可能相似 22 | 算法:texts和summaries都分句,然后找出summaries最长的句子,在texts 23 | 中找与之最相似的句子作为匹配,剩下部分递归执行。 24 | """ 25 | if len(texts) == 0 or len(summaries) == 0: 26 | return [] 27 | i = np.argmax([len(s) for s in summaries]) 28 | j = np.argmax([compute_main_metric(t, summaries[i], 'char') for t in texts]) 29 | lm = extract_matching(texts[:j + 1], summaries[:i], start_i, start_j) 30 | rm = extract_matching( 31 | texts[j:], summaries[i + 1:], start_i + i + 1, start_j + j 32 | ) 33 | return lm + [(start_i + i, start_j + j)] + rm 34 | 35 | 36 | def extract_flow(inputs): 37 | """单个样本的构建流(给parallel_apply用) 38 | """ 39 | text, summary = inputs 40 | texts = text_split(text, True) # 取后maxlen句 41 | summaries = text_split(summary, False) 42 | mapping = extract_matching(texts, summaries) 43 | labels = sorted(set([i[1].item() for i in mapping])) 44 | pred_summary = ''.join([texts[i] for i in labels]) 45 | metric = compute_main_metric(pred_summary, summary) 46 | return texts, labels, summary, metric 47 | 48 | 49 | def load_data(filename): 50 | """加载数据 51 | 返回:[(text, summary)] 52 | """ 53 | D = [] 54 | with open(filename, encoding='utf-8') as f: 55 | for l in f: 56 | l = json.loads(l) 57 | text = '\n'.join([d['sentence'] for d in l['text']]) 58 | D.append((text, l['summary'])) 59 | return D 60 | 61 | 62 | def convert(data): 63 | """分句,并转换为抽取式摘要 64 | """ 65 | D = parallel_apply( 66 | func=extract_flow, 67 | iterable=tqdm(data, desc=u'转换数据'), 68 | workers=100, 69 | max_queue_size=200 70 | ) 71 | total_metric = sum([d[3] for d in D]) 72 | D = [d[:3] for d in D] 73 | print(u'抽取结果的平均指标: %s' % (total_metric / len(D))) 74 | return D 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | data_random_order_json = data_json[:-5] + '_random_order.json' 80 | data_extract_json = data_json[:-5] + '_extract.json' 81 | 82 | data = load_data(data_json) 83 | data = convert(data) 84 | 85 | if os.path.exists(data_random_order_json): 86 | idxs = json.load(open(data_random_order_json)) 87 | else: 88 | idxs = list(range(len(data))) 89 | np.random.shuffle(idxs) 90 | json.dump(idxs, open(data_random_order_json, 'w')) 91 | 92 | data = [data[i] for i in idxs] 93 | 94 | with open(data_extract_json, 'w', encoding='utf-8') as f: 95 | for d in data: 96 | f.write(json.dumps(d, ensure_ascii=False) + '\n') 97 | 98 | print(u'输入数据:%s' % data_json) 99 | print(u'数据顺序:%s' % data_random_order_json) 100 | print(u'输出路径:%s' % data_extract_json) 101 | -------------------------------------------------------------------------------- /extract_model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | from torch.utils.data import Dataset, DataLoader 10 | from snippets import * 11 | import logging 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--batch_size', type=int, default=6, help='batch size') 15 | parser.add_argument('--epoch_num', type=int, default=20, help='number of epochs') 16 | parser.add_argument('--each_test_epoch', type=int, default=1) 17 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 18 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer') 19 | parser.add_argument('--model_name', type=str, default='bert', help='matching model') 20 | parser.add_argument('--checkpoint', type=str, default="./checkpoint/", help='checkpoint path') 21 | parser.add_argument('--max_length', type=int, default=512, help='max length of each case') 22 | parser.add_argument('--input_size', type=int, default=768) 23 | parser.add_argument('--hidden_size', type=int, default=384) 24 | parser.add_argument('--kernel_size', type=int, default=3) 25 | parser.add_argument('--threshold', type=float, default=0.3) 26 | parser.add_argument('--cuda_pos', type=str, default='1', help='which GPU to use') 27 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 28 | args = parser.parse_args() 29 | 30 | np.random.seed(args.seed) 31 | torch.manual_seed(args.seed) 32 | torch.cuda.manual_seed_all(args.seed) 33 | 34 | log_name = "log_train" 35 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 36 | filename='./logs/{}.log'.format(log_name), 37 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 38 | #a是追加模式,默认如果不写的话,就是追加模式 39 | format= 40 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 41 | #日志格式 42 | ) 43 | 44 | 45 | # 配置信息 46 | 47 | data_extract_json = data_json[:-5] + '_extract.json' 48 | data_extract_npy = data_json[:-5] + '_extract.npy' 49 | 50 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 51 | 52 | 53 | if len(sys.argv) == 1: 54 | fold = 0 55 | else: 56 | fold = int(sys.argv[1]) 57 | 58 | 59 | def load_checkpoint(model, optimizer, trained_epoch): 60 | filename = args.checkpoint + '/' + f"extract-{trained_epoch}.pkl" 61 | save_params = torch.load(filename) 62 | model.load_state_dict(save_params["model"]) 63 | #optimizer.load_state_dict(save_params["optimizer"]) 64 | 65 | def save_checkpoint(model, optimizer, trained_epoch): 66 | save_params = { 67 | "model": model.state_dict(), 68 | "optimizer": optimizer.state_dict(), 69 | "trained_epoch": trained_epoch, 70 | } 71 | if not os.path.exists(args.checkpoint): 72 | # 判断文件夹是否存在,不存在则创建文件夹 73 | os.mkdir(args.checkpoint) 74 | filename = args.checkpoint + '/' + f"extract-{trained_epoch}.pkl" 75 | torch.save(save_params, filename) 76 | 77 | def load_data(filename): 78 | """加载数据 79 | 返回:[(texts, labels, summary)] 80 | """ 81 | D = [] 82 | with open(filename, encoding='utf-8') as f: 83 | for l in f: 84 | D.append(json.loads(l)) 85 | return D 86 | 87 | 88 | class ResidualGatedConv1D(nn.Module): 89 | """门控卷积 90 | """ 91 | def __init__(self, filters, kernel_size, dilation_rate=1): 92 | super(ResidualGatedConv1D, self).__init__() 93 | self.filters = filters # 输出维度 94 | self.kernel_size = kernel_size 95 | self.dilation_rate = dilation_rate 96 | self.supports_masking = True 97 | self.padding = self.dilation_rate*(self.kernel_size - 1)//2 98 | self.conv1d = nn.Conv1d(filters, 2*filters, self.kernel_size, padding=self.padding, dilation=self.dilation_rate) 99 | self.layernorm = nn.LayerNorm(self.filters) 100 | self.alpha = nn.Parameter(torch.zeros(1)) 101 | 102 | 103 | def forward(self, inputs): 104 | input_cov1d = inputs.permute([0, 2, 1]) 105 | outputs = self.conv1d(input_cov1d) 106 | outputs = outputs.permute([0, 2, 1]) 107 | gate = torch.sigmoid(outputs[..., self.filters:]) 108 | outputs = outputs[..., :self.filters] * gate 109 | outputs = self.layernorm(outputs) 110 | 111 | if hasattr(self, 'dense'): 112 | inputs = self.dense(inputs) 113 | 114 | return inputs + self.alpha * outputs 115 | 116 | 117 | class Selector2(nn.Module): 118 | def __init__(self, input_size, filters, kernel_size, dilation_rate): 119 | """ 120 | :param feature_size:每个词向量的长度 121 | """ 122 | super(Selector2, self).__init__() 123 | self.dense1 = nn.Linear(input_size, filters, bias=False) 124 | self.ResidualGatedConv1D_1 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[0]) 125 | self.ResidualGatedConv1D_2 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[1]) 126 | self.ResidualGatedConv1D_3 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[2]) 127 | self.ResidualGatedConv1D_4 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[3]) 128 | self.ResidualGatedConv1D_5 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[4]) 129 | self.ResidualGatedConv1D_6 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[5]) 130 | self.dense2 = nn.Linear(filters, 1) 131 | 132 | 133 | def forward(self, inputs): 134 | mask = inputs.ge(0.00001) 135 | mask = torch.sum(mask, axis=-1).bool() 136 | x1 = self.dense1(nn.Dropout(0.1)(inputs)) 137 | x2 = self.ResidualGatedConv1D_1(nn.Dropout(0.1)(x1)) 138 | x3 = self.ResidualGatedConv1D_2(nn.Dropout(0.1)(x2)) 139 | x4 = self.ResidualGatedConv1D_3(nn.Dropout(0.1)(x3)) 140 | x5 = self.ResidualGatedConv1D_4(nn.Dropout(0.1)(x4)) 141 | x6 = self.ResidualGatedConv1D_5(nn.Dropout(0.1)(x5)) 142 | x7 = self.ResidualGatedConv1D_6(nn.Dropout(0.1)(x6)) 143 | output = nn.Sigmoid()(self.dense2(nn.Dropout(0.1)(x7))) 144 | return output, mask 145 | 146 | 147 | 148 | class Selector_Dataset(Dataset): 149 | def __init__(self, data_x, data_y): 150 | super(Selector_Dataset, self).__init__() 151 | self.data_x_tensor = torch.from_numpy(data_x) 152 | self.data_y_tensor = torch.from_numpy(data_y) 153 | def __len__(self): 154 | return len(self.data_x_tensor) 155 | def __getitem__(self, idx): 156 | return self.data_x_tensor[idx], self.data_y_tensor[idx] 157 | 158 | 159 | 160 | 161 | def train(model, train_dataloader, valid_dataloader): 162 | model = model.to(device) 163 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 164 | criterion = nn.BCELoss(reduction='none') 165 | for epoch in range(args.epoch_num): 166 | epoch_loss = 0.0 167 | current_step = 0 168 | model.train() 169 | pbar = tqdm(train_dataloader, desc="Iteration", postfix='train') 170 | for batch_data in pbar: 171 | x_batch, label_batch = batch_data 172 | x_batch = x_batch.to(device) 173 | label_batch = label_batch.to(device) 174 | output_batch, batch_mask = model(x_batch) 175 | output_batch = output_batch.permute([0, 2, 1]) 176 | loss = criterion(output_batch.squeeze(), label_batch.squeeze()) 177 | loss = torch.div(torch.sum(loss*batch_mask), torch.sum(batch_mask)) 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | 182 | loss_item = loss.cpu().detach().item() 183 | epoch_loss += loss_item 184 | current_step += 1 185 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 186 | if current_step % 100 == 0: 187 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 188 | 189 | epoch_loss = epoch_loss / current_step 190 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 191 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 192 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 193 | save_checkpoint(model, optimizer, epoch) 194 | model.eval() 195 | with torch.no_grad(): 196 | correct = 0 197 | total = 0 198 | current_step = 0 199 | pbar = tqdm(valid_dataloader, desc="Iteration", postfix='valid') 200 | for batch_data in pbar: 201 | x_batch, label_batch = batch_data 202 | x_batch = x_batch.to(device) 203 | label_batch = label_batch.to(device).long() 204 | output_batch, batch_mask = model(x_batch) 205 | label_batch = label_batch.to(device) 206 | total += torch.sum(batch_mask) 207 | vec_correct = ((output_batch.squeeze()>args.threshold).long() == label_batch.squeeze().long())*batch_mask 208 | correct += torch.sum(vec_correct).cpu().item() 209 | pbar.set_description("valid acc {}".format(correct / total)) 210 | current_step += 1 211 | if current_step % 100 == 0: 212 | logging.info('valid epoch {} acc {}/{}={:.4f}'.format(epoch, correct, total, correct / total)) 213 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 214 | print('{} valid epoch {} acc {}/{}={:.4f}'.format(time_str, epoch, correct, total, correct / total)) 215 | logging.info('valid epoch {} acc {}/{}={:.4f}'.format(epoch, correct, total, correct / total)) 216 | 217 | 218 | if __name__ == '__main__': 219 | 220 | # 加载数据 221 | data = load_data(data_extract_json) 222 | data_x = np.load(data_extract_npy) 223 | data_y = np.zeros_like(data_x[..., :1]) 224 | 225 | for i, d in enumerate(data): 226 | for j in d[1]: 227 | data_y[i, j] = 1 228 | 229 | train_data = data_split(data, fold, num_folds, 'train') 230 | valid_data = data_split(data, fold, num_folds, 'valid') 231 | train_x = data_split(data_x, fold, num_folds, 'train') 232 | valid_x = data_split(data_x, fold, num_folds, 'valid') 233 | train_y = data_split(data_y, fold, num_folds, 'train') 234 | valid_y = data_split(data_y, fold, num_folds, 'valid') 235 | 236 | train_dataloader = DataLoader(Selector_Dataset(train_x, train_y), batch_size=args.batch_size, shuffle=True, drop_last=True) 237 | valid_dataloader = DataLoader(Selector_Dataset(valid_x, valid_y), batch_size=len(valid_x), shuffle=False) 238 | 239 | model = Selector2(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1]) 240 | 241 | train(model, train_dataloader, valid_dataloader) 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /extract_vectorize.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 法研杯2020 司法摘要 3 | # 抽取式:句向量化 4 | # 科学空间:https://kexue.fm 5 | 6 | import json 7 | import numpy as np 8 | from tqdm import tqdm 9 | from transformers import BertTokenizer, BertModel, BertConfig 10 | from transformers import AutoModel, AutoTokenizer 11 | from snippets import * 12 | import torch.nn as nn 13 | import torch 14 | 15 | 16 | 17 | class GlobalAveragePooling1D(nn.Module): 18 | """自定义全局池化 19 | 对一个句子的pooler取平均,一个长句子用短句的pooler平均代替 20 | """ 21 | def __init__(self): 22 | super(GlobalAveragePooling1D, self).__init__() 23 | 24 | 25 | def forward(self, inputs, mask=None): 26 | if mask is not None: 27 | mask = mask.to(torch.float)[:, :, None] 28 | return torch.sum(inputs * mask, dim=1) / torch.sum(mask, dim=1) 29 | else: 30 | return torch.mean(inputs, dim=1) 31 | 32 | 33 | class Selector_1(nn.Module): 34 | def __init__(self): 35 | super(Selector_1, self).__init__() 36 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_fold, mirror='tuna', do_lower_case=True) 37 | self.Pooling = GlobalAveragePooling1D() 38 | self.encoder = BertModel.from_pretrained(pretrained_bert_fold) 39 | self.max_seq_len = 512 40 | 41 | 42 | def predict(self, texts): 43 | """句子列表转换为句向量 44 | """ 45 | with torch.no_grad(): 46 | bert_output = self.tokenizer.batch_encode_plus(texts, padding=True, truncation=True, max_length=self.max_seq_len, return_tensors="pt") 47 | output_1 = self.encoder(**bert_output)["last_hidden_state"] 48 | outputs = self.Pooling(output_1) 49 | return outputs 50 | 51 | 52 | 53 | def load_data(filename): 54 | """加载数据 55 | 返回:[texts] 56 | """ 57 | D = [] 58 | with open(filename) as f: 59 | for l in f: 60 | texts = json.loads(l)[0] 61 | D.append(texts) 62 | return D 63 | 64 | 65 | 66 | 67 | def convert(data): 68 | """转换所有样本 69 | """ 70 | embeddings = [] 71 | model = Selector_1() 72 | for texts in tqdm(data, desc=u'向量化'): 73 | outputs = model.predict(texts) 74 | embeddings.append(outputs) 75 | embeddings = sequence_padding(embeddings) 76 | return embeddings 77 | 78 | 79 | if __name__ == '__main__': 80 | 81 | data_extract_json = data_json[:-5] + '_extract.json' 82 | data_extract_npy = data_json[:-5] + '_extract' 83 | 84 | data = load_data(data_extract_json) 85 | embeddings = convert(data) 86 | np.save(data_extract_npy, embeddings) 87 | print(u'输出路径:%s.npy' % data_extract_npy) 88 | -------------------------------------------------------------------------------- /seq2seq_convert.py: -------------------------------------------------------------------------------- 1 | from extract_model import * 2 | from snippets import open 3 | import torch 4 | 5 | def fold_convert(data, data_x, fold): 6 | """每一fold用对应的模型做数据转换 7 | """ 8 | valid_data = data_split(data, fold, num_folds, 'valid') 9 | valid_x = data_split(data_x, fold, num_folds, 'valid') 10 | with torch.no_grad(): 11 | model = Selector2(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1]) 12 | load_checkpoint(model, None, 19) 13 | model_output = model(torch.tensor(valid_x))[0] 14 | y_pred = model_output.cpu().numpy() 15 | 16 | results = [] 17 | for d, yp in tqdm(zip(valid_data, y_pred), desc=u'转换中'): 18 | yp = yp[:len(d[0])] 19 | yp = np.where(yp > args.threshold)[0] 20 | source_1 = ''.join([d[0][i] for i in yp]) 21 | source_2 = ''.join([d[0][i] for i in d[1]]) 22 | result = { 23 | 'source_1': source_1, 24 | 'source_2': source_2, 25 | 'target': d[2], 26 | } 27 | results.append(result) 28 | 29 | return results 30 | 31 | 32 | def convert(filename, data, data_x): 33 | """转换为生成式数据 34 | """ 35 | F = open(filename, 'w', encoding='utf-8') 36 | total_results = [] 37 | for fold in range(num_folds): 38 | total_results.append(fold_convert(data, data_x, fold)) 39 | 40 | # 按照原始顺序写入到文件中 41 | n = 0 42 | while True: 43 | i, j = n % num_folds, n // num_folds 44 | try: 45 | d = total_results[i][j] 46 | except: 47 | break 48 | F.write(json.dumps(d, ensure_ascii=False) + '\n') 49 | n += 1 50 | 51 | F.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | data = load_data(data_extract_json) 57 | data_x = np.load(data_extract_npy) 58 | data_seq2seq_json = data_json[:-5] + '_seq2seq.json' 59 | convert(data_seq2seq_json, data, data_x) 60 | print(u'输出路径:%s' % data_seq2seq_json) 61 | -------------------------------------------------------------------------------- /seq2seq_model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os, json 3 | import numpy as np 4 | from tqdm import tqdm 5 | from transformers import BertTokenizer, AutoTokenizer 6 | import argparse 7 | import torch 8 | from transformers import AdamW 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | import copy 12 | from torch.utils.data import Dataset, DataLoader 13 | import logging 14 | from snippets import * 15 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 16 | from bert_seq2seq import load_bert 17 | # 基本参数 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--batch_size', type=int, default=2, help='batch size') 20 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs') 21 | parser.add_argument('--each_test_epoch', type=int, default=1) 22 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 23 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer') 24 | parser.add_argument('--model_name', type=str, default='nezha', help='matching model') 25 | parser.add_argument('--checkpoint', type=str, default="./checkpoint/", help='checkpoint path') 26 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case') 27 | parser.add_argument('--maxlen', type=int, default=1024, help='max length of each case') 28 | parser.add_argument('--input_size', type=int, default=768) 29 | parser.add_argument('--hidden_size', type=int, default=384) 30 | parser.add_argument('--kernel_size', type=int, default=3) 31 | parser.add_argument('--threshold', type=float, default=0.3) 32 | parser.add_argument('--k_sparse', type=int, default=10) 33 | parser.add_argument('--cuda_pos', type=str, default='0', help='which GPU to use') 34 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 35 | args = parser.parse_args() 36 | 37 | np.random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | torch.cuda.manual_seed_all(args.seed) 40 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 41 | log_name = "log_train" 42 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 43 | filename='./logs/{}.log'.format(log_name), 44 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 45 | #a是追加模式,默认如果不写的话,就是追加模式 46 | format= 47 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 48 | #日志格式 49 | ) 50 | 51 | 52 | 53 | data_seq2seq_json = data_json[:-5] + '_seq2seq.json' 54 | seq2seq_config_json = data_json[:-10] + 'seq2seq_config.json' 55 | 56 | if len(sys.argv) == 1: 57 | fold = 0 58 | else: 59 | fold = int(sys.argv[1]) 60 | 61 | 62 | def load_data(filename): 63 | """加载数据 64 | 返回:[{...}] 65 | """ 66 | D = [] 67 | with open(filename) as f: 68 | for l in f: 69 | D.append(json.loads(l)) 70 | return D 71 | 72 | 73 | 74 | 75 | def generate_copy_labels(source, target): 76 | """构建copy机制对应的label 77 | """ 78 | mapping = longest_common_subsequence(source, target)[1] 79 | source_labels = [0] * len(source) 80 | target_labels = [0] * len(target) 81 | i0, j0 = -2, -2 82 | for i, j in mapping: 83 | if i == i0 + 1 and j == j0 + 1: 84 | source_labels[i] = 2 85 | target_labels[j] = 2 86 | else: 87 | source_labels[i] = 1 88 | target_labels[j] = 1 89 | i0, j0 = i, j 90 | return source_labels, target_labels 91 | 92 | 93 | def random_masking(token_ids_all): 94 | """对输入进行随机mask,增加泛化能力 95 | """ 96 | result = [] 97 | for token_ids in token_ids_all: 98 | rands = np.random.random(len(token_ids)) 99 | result.append([ 100 | t if r > 0.15 else np.random.choice(token_ids) 101 | for r, t in zip(rands, token_ids) 102 | ]) 103 | return result 104 | 105 | 106 | class DataGenerator(Dataset): 107 | def __init__(self, input_data, random=True): 108 | super(DataGenerator, self).__init__() 109 | self.input_data = input_data 110 | self.random = random 111 | 112 | def __len__(self): 113 | return len(self.input_data) 114 | 115 | def __getitem__(self, idx): 116 | 117 | i = np.random.choice(2) + 1 if self.random else 1 118 | source, target = self.input_data[idx]['source_%s' % i], self.input_data[idx]['target'] 119 | return [source, target] 120 | 121 | 122 | class Collate: 123 | def __init__(self): 124 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 125 | 126 | self.max_seq_len = args.maxlen 127 | 128 | def __call__(self, batch): 129 | # assert len(A_batch) == 1 130 | # print("A_batch: ", A_batch) 131 | dic_data = self.tokenizer.batch_encode_plus(batch, padding=True, truncation=True, 132 | max_length=self.max_seq_len) 133 | mask_dic_data = copy.deepcopy(dic_data) 134 | 135 | token_ids = dic_data["input_ids"] 136 | 137 | masked_token_ids = random_masking(token_ids) 138 | mask_dic_data['input_ids'] = masked_token_ids 139 | labels = [] 140 | for item_masked_token_ids, item_token_ids in zip(masked_token_ids, token_ids): 141 | idx = item_token_ids.index(self.tokenizer.sep_token_id) + 1 142 | source_labels, target_labels = generate_copy_labels( 143 | item_masked_token_ids[:idx], item_token_ids[idx:] 144 | ) 145 | """ 146 | [CLS]...[SEP] ... [SEP] 147 | """ 148 | labels.append(source_labels[1:] + target_labels) # 因为是预测所以第一位后移 149 | 150 | 151 | return torch.tensor(dic_data["input_ids"]), torch.tensor(dic_data["token_type_ids"]), torch.tensor(labels) 152 | 153 | 154 | 155 | def build_pretrain_dataloader(data, batch_size, shuffle=True, num_workers=0,): 156 | data_generator =DataGenerator(data, random=True) 157 | collate = Collate() 158 | return DataLoader( 159 | data_generator, 160 | batch_size=batch_size, 161 | shuffle=shuffle, 162 | num_workers=num_workers, 163 | collate_fn=collate 164 | ) 165 | 166 | 167 | def compute_seq2seq_loss(predictions, token_type_id, input_ids, vocab_size): 168 | 169 | predictions = predictions[:, :-1].contiguous() 170 | target_mask = token_type_id[:, 1:].contiguous() 171 | """ 172 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 173 | """ 174 | predictions = predictions.view(-1, vocab_size) 175 | labels = input_ids[:, 1:].contiguous() 176 | labels = labels.view(-1) 177 | target_mask = target_mask.view(-1).float() 178 | # 正loss 179 | pos_loss = predictions[list(range(predictions.shape[0])), labels] 180 | # 负loss 181 | y_pred = torch.topk(predictions, k=args.k_sparse)[0] 182 | neg_loss = torch.logsumexp(y_pred, dim=-1) 183 | 184 | loss = neg_loss - pos_loss 185 | return (loss * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 186 | 187 | 188 | def compute_copy_loss(predictions, token_type_id, labels): 189 | predictions = predictions[:, :-1].contiguous() 190 | target_mask = token_type_id[:, 1:].contiguous() 191 | """ 192 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 193 | """ 194 | predictions = predictions.view(-1, 3) 195 | labels = labels.view(-1) 196 | target_mask = target_mask.view(-1).float() 197 | loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none") 198 | return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 199 | 200 | class GenerateModel(nn.Module): 201 | def __init__(self): 202 | super(GenerateModel, self).__init__() 203 | self.word2idx = load_chinese_base_vocab(pretrained_nezha_fold+"vocab.txt", simplfied=False) 204 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 205 | self.bert_model = load_bert(self.word2idx, model_name=args.model_name, model_class="seq2seq") 206 | ## 加载预训练的模型参数~ 207 | self.bert_model.load_pretrain_params(pretrained_nezha_fold+"pytorch_model.bin") 208 | self.bert_model.set_device(device) 209 | self.configuration = self.bert_model.config 210 | self.linear = nn.Linear(self.configuration.hidden_size, 3) 211 | 212 | def forward(self, token_ids, token_type_ids): 213 | seq2seq_predictions, hidden_state = self.bert_model(token_ids, token_type_ids) 214 | copy_predictions = self.linear(nn.GELU()(hidden_state)) 215 | 216 | return seq2seq_predictions, copy_predictions 217 | 218 | 219 | def load_checkpoint(model, optimizer, trained_epoch): 220 | filename = args.checkpoint + '/' + f"seq2seq-{trained_epoch}.pkl" 221 | save_params = torch.load(filename) 222 | model.load_state_dict(save_params["model"]) 223 | optimizer.load_state_dict(save_params["optimizer"]) 224 | 225 | 226 | def save_checkpoint(model, optimizer, trained_epoch): 227 | save_params = { 228 | "model": model.state_dict(), 229 | "optimizer": optimizer.state_dict(), 230 | "trained_epoch": trained_epoch, 231 | } 232 | if not os.path.exists(args.checkpoint): 233 | # 判断文件夹是否存在,不存在则创建文件夹 234 | os.mkdir(args.checkpoint) 235 | filename = args.checkpoint + '/' + f"seq2seq-{trained_epoch}.pkl" 236 | torch.save(save_params, filename) 237 | 238 | 239 | def train_valid(train_data, valid_data, model): 240 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 241 | # ema = EMA(model, 0.9999) 242 | # ema.register() 243 | for epoch in range(args.epochs): 244 | epoch_loss = 0. 245 | current_step = 0 246 | model.train() 247 | # for batch_data in tqdm(train_data_loader, ncols=0): 248 | pbar = tqdm(train_data, desc="Iteration", postfix='train') 249 | for batch_data in pbar: 250 | input_ids, token_type_ids, labels = batch_data 251 | input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device) 252 | seq2seq_predictions, copy_predictions = model(input_ids, token_type_ids) 253 | 254 | seq2seq_loss = compute_seq2seq_loss(seq2seq_predictions, token_type_ids, input_ids, 255 | model.configuration.vocab_size) 256 | copy_loss = compute_copy_loss(copy_predictions, token_type_ids, labels) 257 | loss = seq2seq_loss + 2 * copy_loss 258 | optimizer.zero_grad() 259 | loss.backward() 260 | optimizer.step() 261 | # ema.update() 262 | loss_item = loss.cpu().detach().item() 263 | epoch_loss += loss_item 264 | current_step += 1 265 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 266 | if current_step % 100 == 0: 267 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 268 | 269 | epoch_loss = epoch_loss / current_step 270 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 271 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 272 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 273 | # todo 看一下 EMA是否会让模型准确率提升,如果可以的话在保存模型前加入 ema 274 | save_checkpoint(model, optimizer, epoch) 275 | with torch.no_grad(): 276 | model.eval() 277 | # ema.apply_shadow() 278 | evaluate(valid_data, model, filename=r'./result/abstract.txt') 279 | # ema.restore() 280 | model.train() 281 | 282 | class AutoSummary(AutoRegressiveDecoder): 283 | """seq2seq解码器 284 | """ 285 | def get_ngram_set(self, x, n): 286 | """生成ngram合集,返回结果格式是: 287 | {(n-1)-gram: set([n-gram的第n个字集合])} 288 | """ 289 | result = {} 290 | for i in range(len(x) - n + 1): 291 | k = tuple(x[i:i + n]) 292 | if k[:-1] not in result: 293 | result[k[:-1]] = set() 294 | result[k[:-1]].add(k[-1]) 295 | return result 296 | 297 | @AutoRegressiveDecoder.wraps(default_rtype='logits', use_states=True) 298 | def predict(self, inputs, output_ids, states): 299 | token_ids, segment_ids = inputs 300 | token_ids = np.concatenate([token_ids, output_ids], 1) 301 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 302 | seq2seq_predictions, copy_predictions = self.model(torch.tensor(token_ids, device=device), torch.tensor(segment_ids, device=device)) 303 | prediction = [seq2seq_predictions[:, -1].cpu().numpy(), torch.softmax(copy_predictions[:, -1], dim=-1).cpu().numpy()] # 返回最后一个字符的预测结果,(1, vocab_size),(1, 3) todo 我这里需要加一个softmax 前面的生成模型给也需要 304 | # states用来缓存ngram的n值 305 | if states is None: 306 | states = [0] 307 | elif len(states) == 1 and len(token_ids) > 1: 308 | states = states * len(token_ids) 309 | # 根据copy标签来调整概率分布 310 | probas = np.zeros_like(prediction[0]) - 1000 # 最终要返回的概率分布 311 | for i, token_ids in enumerate(inputs[0]): 312 | if states[i] == 0: 313 | prediction[1][i, 2] *= -1 # 0不能接2 314 | label = prediction[1][i].argmax() # 当前label 315 | if label < 2: 316 | states[i] = label 317 | else: 318 | states[i] += 1 # 2后面接什么都行 319 | if states[i] > 0: 320 | ngrams = self.get_ngram_set(token_ids, states[i]) 321 | prefix = tuple(output_ids[i, 1 - states[i]:]) 322 | if prefix in ngrams: # 如果确实是适合的ngram 323 | candidates = ngrams[prefix] 324 | else: # 没有的话就退回1gram 325 | ngrams = self.get_ngram_set(token_ids, 1) 326 | candidates = ngrams[tuple()] 327 | states[i] = 1 328 | candidates = list(candidates) 329 | probas[i, candidates] = prediction[0][i, candidates] 330 | else: 331 | probas[i] = prediction[0][i] 332 | idxs = probas[i].argpartition(-args.k_sparse) 333 | probas[i, idxs[:-args.k_sparse]] = -1000 334 | return probas, states 335 | 336 | def generate(self, text, topk=1): 337 | max_c_len = args.maxlen - self.maxlen 338 | encode_text = self.model.tokenizer(text, padding=True, truncation=True, 339 | max_length=max_c_len) 340 | token_ids, segment_ids = encode_text['input_ids'], encode_text['token_type_ids'] 341 | output_ids = self.beam_search([token_ids, segment_ids], 342 | topk) # 基于beam search 343 | return ''.join(self.model.tokenizer.convert_ids_to_tokens(output_ids)) 344 | 345 | 346 | 347 | 348 | 349 | def evaluate(data, model, topk=1, filename=None): 350 | """验证集评估 351 | """ 352 | autosummary = AutoSummary( 353 | start_id=model.tokenizer.cls_token_id, 354 | end_id=model.tokenizer.sep_token_id, 355 | maxlen=args.maxlen // 2, 356 | model=model 357 | ) 358 | if filename is not None: 359 | F = open(filename, 'w', encoding='utf-8') 360 | total_metrics = {k: 0.0 for k in metric_keys} 361 | for d in tqdm(data, desc=u'评估中'): 362 | pred_summary = autosummary.generate(d['source_1'], topk) 363 | metrics = compute_metrics(pred_summary, d['target']) 364 | for k, v in metrics.items(): 365 | total_metrics[k] += v 366 | if filename is not None: 367 | F.write(d['target'] + '\t' + pred_summary + '\n') 368 | F.flush() 369 | if filename is not None: 370 | F.close() 371 | print(total_metrics) 372 | return {k: v / len(data) for k, v in total_metrics.items()} 373 | 374 | if __name__ == '__main__': 375 | # 加载数据 376 | data = load_data(data_seq2seq_json) 377 | train_data = data_split(data, fold, num_folds, 'train') 378 | valid_data = data_split(data, fold, num_folds, 'valid') 379 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size) 380 | G_model = GenerateModel() 381 | print(G_model) 382 | G_model = G_model.to(device) 383 | train_valid(train_data_loader, valid_data, G_model) 384 | 385 | 386 | 387 | 388 | -------------------------------------------------------------------------------- /snippets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from rouge import Rouge 4 | import os, sys 5 | import jieba 6 | import six 7 | from collections import defaultdict 8 | 9 | 10 | 11 | # 自定义词典 12 | user_dict_path = '/new_disk2/zhongxiang_sun/code/explanation_project/SPACES_torch/datasets/user_dict.txt' 13 | user_dict_path_2 = '/new_disk2/zhongxiang_sun/code/explanation_project/SPACES_torch/datasets/user_dict_2.txt' 14 | jieba.load_userdict(user_dict_path) 15 | jieba.initialize() 16 | 17 | # 设置递归深度 18 | sys.setrecursionlimit(1000000) 19 | 20 | # 标注数据 21 | data_json = '/new_disk2/zhongxiang_sun/code/explanation_project/SPACES_torch/datasets/train.json' 22 | 23 | # 保存权重的文件夹 24 | if not os.path.exists('weights'): 25 | os.mkdir('weights') 26 | 27 | # bert配置 28 | pretrained_bert_fold = "/new_disk2/zhongxiang_sun/code/pretrain_model/bert_legal/" 29 | pretrained_nezha_fold = "/new_disk2/zhongxiang_sun/code/pretrain_model/NEZHA/" 30 | # nezha配置 31 | nezha_config_path = '/root/kg/bert/nezha_base/bert_config.json' 32 | nezha_checkpoint_path = '/root/kg/bert/nezha_base/model.ckpt-900000' 33 | nezha_dict_path = '/root/kg/bert/nezha_base/vocab.txt' 34 | 35 | # 将数据划分N份,一份作为验证集 36 | num_folds = 15 37 | 38 | # 指标名 39 | metric_keys = ['main', 'rouge-1', 'rouge-2', 'rouge-l'] 40 | 41 | # 计算rouge用 42 | rouge = Rouge() 43 | 44 | def softmax(x, axis=-1): 45 | """numpy版softmax 46 | """ 47 | x = x - x.max(axis=axis, keepdims=True) 48 | x = np.exp(x) 49 | return x / x.sum(axis=axis, keepdims=True) 50 | 51 | class AutoRegressiveDecoder(object): 52 | """通用自回归生成模型解码基类 53 | 包含beam search和random sample两种策略 54 | """ 55 | def __init__(self, start_id, end_id, maxlen,minlen=1, model=None, tokenizer=None): 56 | self.start_id = start_id 57 | self.end_id = end_id 58 | self.maxlen = maxlen 59 | self.minlen = minlen 60 | self.model = model 61 | self.tokenizer = tokenizer 62 | if start_id is None: 63 | self.first_output_ids = np.empty((1, 0), dtype=int) 64 | else: 65 | self.first_output_ids = np.array([[self.start_id]]) 66 | 67 | @staticmethod 68 | def wraps(default_rtype='probas', use_states=False): 69 | """用来进一步完善predict函数 70 | 目前包含:1. 设置rtype参数,并做相应处理; 71 | 2. 确定states的使用,并做相应处理; 72 | 3. 设置温度参数,并做相应处理。 73 | """ 74 | def actual_decorator(predict): 75 | def new_predict( 76 | self, 77 | inputs, 78 | output_ids, 79 | states, 80 | temperature=1, 81 | rtype=default_rtype 82 | ): 83 | assert rtype in ['probas', 'logits'] 84 | prediction = predict(self, inputs, output_ids, states) 85 | 86 | if not use_states: 87 | prediction = (prediction, None) 88 | 89 | if default_rtype == 'logits': 90 | prediction = ( 91 | softmax(prediction[0] / temperature), prediction[1] 92 | ) 93 | elif temperature != 1: 94 | probas = np.power(prediction[0], 1.0 / temperature) 95 | probas = probas / probas.sum(axis=-1, keepdims=True) 96 | prediction = (probas, prediction[1]) 97 | 98 | if rtype == 'probas': 99 | return prediction 100 | else: 101 | return np.log(prediction[0] + 1e-12), prediction[1] 102 | 103 | return new_predict 104 | 105 | return actual_decorator 106 | 107 | 108 | 109 | def predict(self, inputs, output_ids, states=None): 110 | """用户需自定义递归预测函数 111 | 说明:定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states, 112 | 其中default_rtype为字符串logits或probas,probas时返回归一化的概率, 113 | rtype=logits时则返回softmax前的结果或者概率对数。 114 | 返回:二元组 (得分或概率, states) 115 | """ 116 | raise NotImplementedError 117 | 118 | def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1): 119 | """beam search解码 120 | 说明:这里的topk即beam size; 121 | 返回:最优解码序列。 122 | """ 123 | inputs = [np.array([i]) for i in inputs] 124 | output_ids, output_scores = self.first_output_ids, np.zeros(1) 125 | for step in range(self.maxlen): 126 | scores, states = self.predict( 127 | inputs, output_ids, states, temperature, 'logits' 128 | ) # 计算当前得分 129 | if step == 0: # 第1步预测后将输入重复topk次 130 | inputs = [np.repeat(i, topk, axis=0) for i in inputs] 131 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 132 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk 133 | indices_1 = indices // scores.shape[1] # 行索引 134 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 135 | output_ids = np.concatenate([output_ids[indices_1], indices_2], 136 | 1) # 更新输出 137 | output_scores = np.take_along_axis( 138 | scores, indices, axis=None 139 | ) # 更新得分 140 | is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束 141 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 142 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 143 | best = output_scores.argmax() # 得分最大的那个 144 | if is_end[best] and end_counts[best] >= min_ends: # 如果已经终止 145 | return output_ids[best] # 直接输出 146 | else: # 否则,只保留未完成部分 147 | flag = ~is_end | (end_counts < min_ends) # 标记未完成序列 148 | if not flag.all(): # 如果有已完成的 149 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列 150 | output_ids = output_ids[flag] # 扔掉已完成序列 151 | output_scores = output_scores[flag] # 扔掉已完成序列 152 | end_counts = end_counts[flag] # 扔掉已完成end计数 153 | topk = flag.sum() # topk相应变化 154 | # 达到长度直接输出 155 | return output_ids[output_scores.argmax()] 156 | 157 | def random_sample( 158 | self, 159 | inputs, 160 | n, 161 | topk=None, 162 | topp=None, 163 | states=None, 164 | temperature=1, 165 | min_ends=1 166 | ): 167 | """随机采样n个结果 168 | 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp 169 | 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 170 | 返回:n个解码序列组成的list。 171 | """ 172 | inputs = [np.array([i]) for i in inputs] 173 | output_ids = self.first_output_ids 174 | results = [] 175 | for step in range(self.maxlen): 176 | probas, states = self.predict( 177 | inputs, output_ids, states, temperature, 'probas' 178 | ) # 计算当前概率 179 | probas /= probas.sum(axis=1, keepdims=True) # 确保归一化 180 | if step == 0: # 第1步预测后将结果重复n次 181 | probas = np.repeat(probas, n, axis=0) 182 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 183 | output_ids = np.repeat(output_ids, n, axis=0) 184 | if topk is not None: 185 | k_indices = probas.argpartition(-topk, 186 | axis=1)[:, -topk:] # 仅保留topk 187 | probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率 188 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 189 | if topp is not None: 190 | p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序 191 | probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率 192 | cumsum_probas = np.cumsum(probas, axis=1) # 累积概率 193 | flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分 194 | flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果 195 | probas[flag] = 0 # 后面的全部置零 196 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 197 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数 198 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 199 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 200 | if topp is not None: 201 | sample_ids = np.take_along_axis( 202 | p_indices, sample_ids, axis=1 203 | ) # 对齐原id 204 | if topk is not None: 205 | sample_ids = np.take_along_axis( 206 | k_indices, sample_ids, axis=1 207 | ) # 对齐原id 208 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 209 | is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束 210 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 211 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 212 | flag = is_end & (end_counts >= min_ends) # 标记已完成序列 213 | if flag.any(): # 如果有已完成的 214 | for ids in output_ids[flag]: # 存好已完成序列 215 | results.append(ids) 216 | flag = (flag == False) # 标记未完成序列 217 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 218 | output_ids = output_ids[flag] # 只保留未完成部分候选集 219 | end_counts = end_counts[flag] # 只保留未完成部分end计数 220 | if len(output_ids) == 0: 221 | break 222 | # 如果还有未完成序列,直接放入结果 223 | for ids in output_ids: 224 | results.append(ids) 225 | # 返回结果 226 | return results 227 | 228 | class EMA(): 229 | def __init__(self, model, decay): 230 | self.model = model 231 | self.decay = decay 232 | self.shadow = {} 233 | self.backup = {} 234 | 235 | def register(self): 236 | for name, param in self.model.named_parameters(): 237 | if param.requires_grad: 238 | self.shadow[name] = param.data.clone() 239 | 240 | def update(self): 241 | for name, param in self.model.named_parameters(): 242 | if param.requires_grad: 243 | assert name in self.shadow 244 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 245 | self.shadow[name] = new_average.clone() 246 | 247 | def apply_shadow(self): 248 | for name, param in self.model.named_parameters(): 249 | if param.requires_grad: 250 | assert name in self.shadow 251 | self.backup[name] = param.data 252 | param.data = self.shadow[name] 253 | 254 | def restore(self): 255 | for name, param in self.model.named_parameters(): 256 | if param.requires_grad: 257 | assert name in self.backup 258 | param.data = self.backup[name] 259 | self.backup = {} 260 | 261 | 262 | def convert_to_unicode(text, encoding='utf-8', errors='ignore'): 263 | """字符串转换为unicode格式(假设输入为utf-8格式) 264 | """ 265 | if is_py2: 266 | if isinstance(text, str): 267 | text = text.decode(encoding, errors=errors) 268 | else: 269 | if isinstance(text, bytes): 270 | text = text.decode(encoding, errors=errors) 271 | return text 272 | 273 | 274 | def convert_to_str(text, encoding='utf-8', errors='ignore'): 275 | """字符串转换为str格式(假设输入为utf-8格式) 276 | """ 277 | 278 | if isinstance(text, bytes): 279 | text = text.decode(encoding, errors=errors) 280 | return text 281 | 282 | 283 | def is_string(s): 284 | """判断是否是字符串 285 | """ 286 | return isinstance(s, str) 287 | 288 | _open_ = open 289 | is_py2 = six.PY2 290 | class open: 291 | """模仿python自带的open函数 292 | 作用:1.主要是为了同时兼容py2和py3;2.增加了索引功能,方便读取大文件。 293 | """ 294 | def __init__( 295 | self, name, mode='r', encoding=None, errors='strict', indexable=False 296 | ): 297 | self.name = name 298 | if is_py2: 299 | self.file = _open_(name, mode) 300 | else: 301 | self.file = _open_(name, mode, encoding=encoding, errors=errors) 302 | self.encoding = encoding 303 | self.errors = errors 304 | self.iterator = None 305 | if indexable: 306 | if is_string(indexable) and os.path.exists(indexable): 307 | self.offsets = json.load(_open_(indexable)) 308 | else: 309 | self.create_indexes() 310 | if is_string(indexable): 311 | json.dump(self.offsets, _open_(indexable, 'w')) 312 | 313 | def create_indexes(self): 314 | print('creating indexes ...') 315 | self.offsets, offset = [], 0 316 | 317 | while self.readline(): 318 | self.offsets.append(offset) 319 | offset = self.tell() 320 | self.seek(0) 321 | print('indexes created.') 322 | 323 | def __getitem__(self, key): 324 | self.seek(self.offsets[key]) 325 | l = self.readline() 326 | if self.encoding: 327 | l = convert_to_unicode(l, self.encoding, self.errors) 328 | return l 329 | 330 | def __len__(self): 331 | return len(self.offsets) 332 | 333 | def __iter__(self): 334 | if hasattr(self, 'offsets'): 335 | for i in range(len(self)): 336 | yield self[i] 337 | else: 338 | for l in self.file: 339 | if self.encoding: 340 | l = convert_to_unicode(l, self.encoding, self.errors) 341 | yield l 342 | 343 | def next(self): 344 | if self.iterator is None: 345 | self.iterator = self.__iter__() 346 | return next(self.iterator) 347 | 348 | def __next__(self): 349 | return self.next() 350 | 351 | def read(self): 352 | text = self.file.read() 353 | if self.encoding: 354 | text = convert_to_unicode(text, self.encoding, self.errors) 355 | return text 356 | 357 | def readline(self): 358 | text = self.file.readline() 359 | if self.encoding: 360 | text = convert_to_unicode(text, self.encoding, self.errors) 361 | return text 362 | 363 | def readlines(self): 364 | if self.encoding: 365 | return [ 366 | convert_to_unicode(text, self.encoding, self.errors) 367 | for text in self.file.readlines() 368 | ] 369 | else: 370 | return self.file.readlines() 371 | 372 | def write(self, text): 373 | if self.encoding: 374 | text = convert_to_str(text, self.encoding, self.errors) 375 | self.file.write(text) 376 | 377 | def flush(self): 378 | self.file.flush() 379 | 380 | def close(self): 381 | self.file.close() 382 | 383 | def tell(self): 384 | return self.file.tell() 385 | 386 | def seek(self, offset=0): 387 | return self.file.seek(offset) 388 | 389 | def __enter__(self): 390 | return self 391 | 392 | def __exit__(self, type, value, tb): 393 | self.close() 394 | 395 | 396 | def parallel_apply( 397 | func, 398 | iterable, 399 | workers, 400 | max_queue_size, 401 | callback=None, 402 | dummy=False, 403 | random_seeds=True, 404 | unordered=True 405 | ): 406 | """多进程或多线程地将func应用到iterable的每个元素中。 407 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 408 | 输出可能是func(c), func(a), func(b)。 409 | 参数: 410 | callback: 处理单个输出的回调函数; 411 | dummy: False是多进程/线性,True则是多线程/线性; 412 | random_seeds: 每个进程的随机种子; 413 | unordered: 若为False,则按照输入顺序返回,仅当callback为None时生效。 414 | """ 415 | generator = parallel_apply_generator( 416 | func, iterable, workers, max_queue_size, dummy, random_seeds 417 | ) 418 | 419 | if callback is None: 420 | if unordered: 421 | return [d for i, d in generator] 422 | else: 423 | results = sorted(generator, key=lambda d: d[0]) 424 | return [d for i, d in results] 425 | else: 426 | for d in generator: 427 | callback(d) 428 | 429 | def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'): 430 | """Numpy函数,将序列padding到同一长度 431 | """ 432 | if length is None: 433 | length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) 434 | elif not hasattr(length, '__getitem__'): 435 | length = [length] 436 | 437 | slices = [np.s_[:length[i]] for i in range(seq_dims)] 438 | slices = tuple(slices) if len(slices) > 1 else slices[0] 439 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 440 | 441 | outputs = [] 442 | for x in inputs: 443 | x = x[slices] 444 | for i in range(seq_dims): 445 | if mode == 'post': 446 | pad_width[i] = (0, length[i] - np.shape(x)[i]) 447 | elif mode == 'pre': 448 | pad_width[i] = (length[i] - np.shape(x)[i], 0) 449 | else: 450 | raise ValueError('"mode" argument must be "post" or "pre".') 451 | x = np.pad(x, pad_width, 'constant', constant_values=value) 452 | outputs.append(x) 453 | 454 | return np.array(outputs) 455 | 456 | def parallel_apply_generator( 457 | func, iterable, workers, max_queue_size, dummy=False, random_seeds=True 458 | ): 459 | """多进程或多线程地将func应用到iterable的每个元素中。 460 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 461 | 输出可能是func(c), func(a), func(b)。结果将作为一个 462 | generator返回,其中每个item是输入的序号以及该输入对应的 463 | 处理结果。 464 | 参数: 465 | dummy: False是多进程/线性,True则是多线程/线性; 466 | random_seeds: 每个进程的随机种子。 467 | """ 468 | if dummy: 469 | from multiprocessing.dummy import Pool, Queue 470 | else: 471 | from multiprocessing import Pool, Queue 472 | 473 | in_queue, out_queue, seed_queue = Queue(max_queue_size), Queue(), Queue() 474 | if random_seeds is True: 475 | random_seeds = [None] * workers 476 | elif random_seeds is None or random_seeds is False: 477 | random_seeds = [] 478 | for seed in random_seeds: 479 | seed_queue.put(seed) 480 | 481 | def worker_step(in_queue, out_queue): 482 | """单步函数包装成循环执行 483 | """ 484 | if not seed_queue.empty(): 485 | np.random.seed(seed_queue.get()) 486 | while True: 487 | i, d = in_queue.get() 488 | r = func(d) 489 | out_queue.put((i, r)) 490 | 491 | # 启动多进程/线程 492 | pool = Pool(workers, worker_step, (in_queue, out_queue)) 493 | 494 | # 存入数据,取出结果 495 | in_count, out_count = 0, 0 496 | for i, d in enumerate(iterable): 497 | in_count += 1 498 | while True: 499 | try: 500 | in_queue.put((i, d), block=False) 501 | break 502 | except six.moves.queue.Full: 503 | for _ in range(out_queue.qsize()): 504 | yield out_queue.get() 505 | out_count += 1 506 | if in_count % max_queue_size == 0: 507 | for _ in range(out_queue.qsize()): 508 | yield out_queue.get() 509 | out_count += 1 510 | 511 | while out_count != in_count: 512 | for _ in range(out_queue.qsize()): 513 | yield out_queue.get() 514 | out_count += 1 515 | 516 | pool.terminate() 517 | 518 | def text_segmentate(text, maxlen, seps='\n', strips=None): 519 | """将文本按照标点符号划分为若干个短句 520 | """ 521 | text = text.strip().strip(strips) 522 | if seps and len(text) > maxlen: 523 | pieces = text.split(seps[0]) 524 | text, texts = '', [] 525 | for i, p in enumerate(pieces): 526 | if text and p and len(text) + len(p) > maxlen - 1: 527 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 528 | text = '' 529 | if i + 1 == len(pieces): 530 | text = text + p 531 | else: 532 | text = text + p + seps[0] 533 | if text: 534 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 535 | return texts 536 | else: 537 | return [text] 538 | 539 | 540 | def load_user_dict(filename): 541 | """加载用户词典 542 | """ 543 | user_dict = [] 544 | with open(filename, encoding='utf-8') as f: 545 | for l in f: 546 | w = l.split()[0] 547 | user_dict.append(w) 548 | return user_dict 549 | 550 | 551 | def data_split(data, fold, num_folds, mode): 552 | """划分训练集和验证集 553 | """ 554 | if mode == 'train': 555 | D = [d for i, d in enumerate(data) if i % num_folds != fold] 556 | else: 557 | D = [d for i, d in enumerate(data) if i % num_folds == fold] 558 | 559 | if isinstance(data, np.ndarray): 560 | return np.array(D) 561 | else: 562 | return D 563 | 564 | 565 | def compute_rouge(source, target, unit='word'): 566 | """计算rouge-1、rouge-2、rouge-l 567 | """ 568 | # if unit == 'word': 569 | # source = jieba.cut(source, HMM=False) 570 | # target = jieba.cut(target, HMM=False) 571 | source, target = ' '.join(source), ' '.join(target) 572 | try: 573 | scores = rouge.get_scores(hyps=source, refs=target) 574 | return { 575 | 'rouge-1': scores[0]['rouge-1']['f'], 576 | 'rouge-2': scores[0]['rouge-2']['f'], 577 | 'rouge-l': scores[0]['rouge-l']['f'], 578 | } 579 | except ValueError: 580 | return { 581 | 'rouge-1': 0.0, 582 | 'rouge-2': 0.0, 583 | 'rouge-l': 0.0, 584 | } 585 | 586 | 587 | def compute_metrics(source, target, unit='word'): 588 | """计算所有metrics 589 | """ 590 | metrics = compute_rouge(source, target, unit) 591 | metrics['main'] = ( 592 | metrics['rouge-1'] * 0.2 + metrics['rouge-2'] * 0.4 + 593 | metrics['rouge-l'] * 0.4 594 | ) 595 | return metrics 596 | 597 | 598 | def compute_main_metric(source, target, unit='word'): 599 | """计算主要metric 600 | """ 601 | return compute_metrics(source, target, unit)['main'] 602 | 603 | 604 | def longest_common_subsequence(source, target): 605 | """最长公共子序列(source和target的最长非连续子序列) 606 | 返回:子序列长度, 映射关系(映射对组成的list) 607 | 注意:最长公共子序列可能不止一个,所返回的映射只代表其中一个。 608 | """ 609 | c = defaultdict(int) 610 | for i, si in enumerate(source, 1): 611 | for j, tj in enumerate(target, 1): 612 | if si == tj: 613 | c[i, j] = c[i - 1, j - 1] + 1 614 | elif c[i, j - 1] > c[i - 1, j]: 615 | c[i, j] = c[i, j - 1] 616 | else: 617 | c[i, j] = c[i - 1, j] 618 | l, mapping = c[len(source), len(target)], [] 619 | i, j = len(source) - 1, len(target) - 1 620 | while len(mapping) < l: 621 | if source[i] == target[j]: 622 | mapping.append((i, j)) 623 | i, j = i - 1, j - 1 624 | elif c[i + 1, j] > c[i, j + 1]: 625 | j = j - 1 626 | else: 627 | i = i - 1 628 | return l, mapping[::-1] -------------------------------------------------------------------------------- /test_model/lawformer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | from transformers import LongformerForMaskedLM,RobertaForMaskedLM,AutoModelForMaskedLM,AutoTokenizer 3 | tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext", cache_dir="/new_disk2/zhongxiang_sun/code/pretrain_model/lawformer/") 4 | model = AutoModel.from_pretrained("thunlp/Lawformer", cache_dir="/new_disk2/zhongxiang_sun/code/pretrain_model/lawformer/") 5 | inputs = tokenizer("任某提起诉讼,请求判令解除婚姻关系并对夫妻共同财产进行分割。", return_tensors="pt") 6 | outputs = model(**inputs) 7 | print(outputs) 8 | print() 9 | -------------------------------------------------------------------------------- /test_model/test_function.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.mkdir() --------------------------------------------------------------------------------