├── .gitattributes
├── .idea
├── .gitignore
├── SPACE_pytorch.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
└── webServers.xml
├── LICENSE
├── README.md
├── extract_convert.py
├── extract_model.py
├── extract_vectorize.py
├── seq2seq_convert.py
├── seq2seq_model.py
├── snippets.py
└── test_model
├── lawformer.py
└── test_function.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../../../../../../../:\Users\25505\Desktop\文件\科研\graduation_project\law_data\CAIL2020\SPACE_pytorch\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/SPACE_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
22 |
23 |
24 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 eryihaha
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPACES-Pytorch
2 | 苏神SPACES pytorch版本复现 原keras 版本
3 | https://github.com/bojone/SPACES
4 |
--------------------------------------------------------------------------------
/extract_convert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from tqdm import tqdm
5 | from snippets import *
6 |
7 | # 初始化
8 | maxlen = 256
9 |
10 |
11 | def text_split(text, limited=True):
12 | """将长句按照标点分割为多个子句。
13 | """
14 | texts = text_segmentate(text, 1, u'\n。;:,')
15 | if limited:
16 | texts = texts[-maxlen:]
17 | return texts
18 |
19 |
20 | def extract_matching(texts, summaries, start_i=0, start_j=0):
21 | """在texts中找若干句子,使得它们连起来与summaries尽可能相似
22 | 算法:texts和summaries都分句,然后找出summaries最长的句子,在texts
23 | 中找与之最相似的句子作为匹配,剩下部分递归执行。
24 | """
25 | if len(texts) == 0 or len(summaries) == 0:
26 | return []
27 | i = np.argmax([len(s) for s in summaries])
28 | j = np.argmax([compute_main_metric(t, summaries[i], 'char') for t in texts])
29 | lm = extract_matching(texts[:j + 1], summaries[:i], start_i, start_j)
30 | rm = extract_matching(
31 | texts[j:], summaries[i + 1:], start_i + i + 1, start_j + j
32 | )
33 | return lm + [(start_i + i, start_j + j)] + rm
34 |
35 |
36 | def extract_flow(inputs):
37 | """单个样本的构建流(给parallel_apply用)
38 | """
39 | text, summary = inputs
40 | texts = text_split(text, True) # 取后maxlen句
41 | summaries = text_split(summary, False)
42 | mapping = extract_matching(texts, summaries)
43 | labels = sorted(set([i[1].item() for i in mapping]))
44 | pred_summary = ''.join([texts[i] for i in labels])
45 | metric = compute_main_metric(pred_summary, summary)
46 | return texts, labels, summary, metric
47 |
48 |
49 | def load_data(filename):
50 | """加载数据
51 | 返回:[(text, summary)]
52 | """
53 | D = []
54 | with open(filename, encoding='utf-8') as f:
55 | for l in f:
56 | l = json.loads(l)
57 | text = '\n'.join([d['sentence'] for d in l['text']])
58 | D.append((text, l['summary']))
59 | return D
60 |
61 |
62 | def convert(data):
63 | """分句,并转换为抽取式摘要
64 | """
65 | D = parallel_apply(
66 | func=extract_flow,
67 | iterable=tqdm(data, desc=u'转换数据'),
68 | workers=100,
69 | max_queue_size=200
70 | )
71 | total_metric = sum([d[3] for d in D])
72 | D = [d[:3] for d in D]
73 | print(u'抽取结果的平均指标: %s' % (total_metric / len(D)))
74 | return D
75 |
76 |
77 | if __name__ == '__main__':
78 |
79 | data_random_order_json = data_json[:-5] + '_random_order.json'
80 | data_extract_json = data_json[:-5] + '_extract.json'
81 |
82 | data = load_data(data_json)
83 | data = convert(data)
84 |
85 | if os.path.exists(data_random_order_json):
86 | idxs = json.load(open(data_random_order_json))
87 | else:
88 | idxs = list(range(len(data)))
89 | np.random.shuffle(idxs)
90 | json.dump(idxs, open(data_random_order_json, 'w'))
91 |
92 | data = [data[i] for i in idxs]
93 |
94 | with open(data_extract_json, 'w', encoding='utf-8') as f:
95 | for d in data:
96 | f.write(json.dumps(d, ensure_ascii=False) + '\n')
97 |
98 | print(u'输入数据:%s' % data_json)
99 | print(u'数据顺序:%s' % data_random_order_json)
100 | print(u'输出路径:%s' % data_extract_json)
101 |
--------------------------------------------------------------------------------
/extract_model.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 |
4 | import argparse
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from tqdm import tqdm
9 | from torch.utils.data import Dataset, DataLoader
10 | from snippets import *
11 | import logging
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--batch_size', type=int, default=6, help='batch size')
15 | parser.add_argument('--epoch_num', type=int, default=20, help='number of epochs')
16 | parser.add_argument('--each_test_epoch', type=int, default=1)
17 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
18 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer')
19 | parser.add_argument('--model_name', type=str, default='bert', help='matching model')
20 | parser.add_argument('--checkpoint', type=str, default="./checkpoint/", help='checkpoint path')
21 | parser.add_argument('--max_length', type=int, default=512, help='max length of each case')
22 | parser.add_argument('--input_size', type=int, default=768)
23 | parser.add_argument('--hidden_size', type=int, default=384)
24 | parser.add_argument('--kernel_size', type=int, default=3)
25 | parser.add_argument('--threshold', type=float, default=0.3)
26 | parser.add_argument('--cuda_pos', type=str, default='1', help='which GPU to use')
27 | parser.add_argument('--seed', type=int, default=42, help='max length of each case')
28 | args = parser.parse_args()
29 |
30 | np.random.seed(args.seed)
31 | torch.manual_seed(args.seed)
32 | torch.cuda.manual_seed_all(args.seed)
33 |
34 | log_name = "log_train"
35 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别
36 | filename='./logs/{}.log'.format(log_name),
37 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
38 | #a是追加模式,默认如果不写的话,就是追加模式
39 | format=
40 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
41 | #日志格式
42 | )
43 |
44 |
45 | # 配置信息
46 |
47 | data_extract_json = data_json[:-5] + '_extract.json'
48 | data_extract_npy = data_json[:-5] + '_extract.npy'
49 |
50 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu')
51 |
52 |
53 | if len(sys.argv) == 1:
54 | fold = 0
55 | else:
56 | fold = int(sys.argv[1])
57 |
58 |
59 | def load_checkpoint(model, optimizer, trained_epoch):
60 | filename = args.checkpoint + '/' + f"extract-{trained_epoch}.pkl"
61 | save_params = torch.load(filename)
62 | model.load_state_dict(save_params["model"])
63 | #optimizer.load_state_dict(save_params["optimizer"])
64 |
65 | def save_checkpoint(model, optimizer, trained_epoch):
66 | save_params = {
67 | "model": model.state_dict(),
68 | "optimizer": optimizer.state_dict(),
69 | "trained_epoch": trained_epoch,
70 | }
71 | if not os.path.exists(args.checkpoint):
72 | # 判断文件夹是否存在,不存在则创建文件夹
73 | os.mkdir(args.checkpoint)
74 | filename = args.checkpoint + '/' + f"extract-{trained_epoch}.pkl"
75 | torch.save(save_params, filename)
76 |
77 | def load_data(filename):
78 | """加载数据
79 | 返回:[(texts, labels, summary)]
80 | """
81 | D = []
82 | with open(filename, encoding='utf-8') as f:
83 | for l in f:
84 | D.append(json.loads(l))
85 | return D
86 |
87 |
88 | class ResidualGatedConv1D(nn.Module):
89 | """门控卷积
90 | """
91 | def __init__(self, filters, kernel_size, dilation_rate=1):
92 | super(ResidualGatedConv1D, self).__init__()
93 | self.filters = filters # 输出维度
94 | self.kernel_size = kernel_size
95 | self.dilation_rate = dilation_rate
96 | self.supports_masking = True
97 | self.padding = self.dilation_rate*(self.kernel_size - 1)//2
98 | self.conv1d = nn.Conv1d(filters, 2*filters, self.kernel_size, padding=self.padding, dilation=self.dilation_rate)
99 | self.layernorm = nn.LayerNorm(self.filters)
100 | self.alpha = nn.Parameter(torch.zeros(1))
101 |
102 |
103 | def forward(self, inputs):
104 | input_cov1d = inputs.permute([0, 2, 1])
105 | outputs = self.conv1d(input_cov1d)
106 | outputs = outputs.permute([0, 2, 1])
107 | gate = torch.sigmoid(outputs[..., self.filters:])
108 | outputs = outputs[..., :self.filters] * gate
109 | outputs = self.layernorm(outputs)
110 |
111 | if hasattr(self, 'dense'):
112 | inputs = self.dense(inputs)
113 |
114 | return inputs + self.alpha * outputs
115 |
116 |
117 | class Selector2(nn.Module):
118 | def __init__(self, input_size, filters, kernel_size, dilation_rate):
119 | """
120 | :param feature_size:每个词向量的长度
121 | """
122 | super(Selector2, self).__init__()
123 | self.dense1 = nn.Linear(input_size, filters, bias=False)
124 | self.ResidualGatedConv1D_1 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[0])
125 | self.ResidualGatedConv1D_2 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[1])
126 | self.ResidualGatedConv1D_3 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[2])
127 | self.ResidualGatedConv1D_4 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[3])
128 | self.ResidualGatedConv1D_5 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[4])
129 | self.ResidualGatedConv1D_6 = ResidualGatedConv1D(filters, kernel_size, dilation_rate=dilation_rate[5])
130 | self.dense2 = nn.Linear(filters, 1)
131 |
132 |
133 | def forward(self, inputs):
134 | mask = inputs.ge(0.00001)
135 | mask = torch.sum(mask, axis=-1).bool()
136 | x1 = self.dense1(nn.Dropout(0.1)(inputs))
137 | x2 = self.ResidualGatedConv1D_1(nn.Dropout(0.1)(x1))
138 | x3 = self.ResidualGatedConv1D_2(nn.Dropout(0.1)(x2))
139 | x4 = self.ResidualGatedConv1D_3(nn.Dropout(0.1)(x3))
140 | x5 = self.ResidualGatedConv1D_4(nn.Dropout(0.1)(x4))
141 | x6 = self.ResidualGatedConv1D_5(nn.Dropout(0.1)(x5))
142 | x7 = self.ResidualGatedConv1D_6(nn.Dropout(0.1)(x6))
143 | output = nn.Sigmoid()(self.dense2(nn.Dropout(0.1)(x7)))
144 | return output, mask
145 |
146 |
147 |
148 | class Selector_Dataset(Dataset):
149 | def __init__(self, data_x, data_y):
150 | super(Selector_Dataset, self).__init__()
151 | self.data_x_tensor = torch.from_numpy(data_x)
152 | self.data_y_tensor = torch.from_numpy(data_y)
153 | def __len__(self):
154 | return len(self.data_x_tensor)
155 | def __getitem__(self, idx):
156 | return self.data_x_tensor[idx], self.data_y_tensor[idx]
157 |
158 |
159 |
160 |
161 | def train(model, train_dataloader, valid_dataloader):
162 | model = model.to(device)
163 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
164 | criterion = nn.BCELoss(reduction='none')
165 | for epoch in range(args.epoch_num):
166 | epoch_loss = 0.0
167 | current_step = 0
168 | model.train()
169 | pbar = tqdm(train_dataloader, desc="Iteration", postfix='train')
170 | for batch_data in pbar:
171 | x_batch, label_batch = batch_data
172 | x_batch = x_batch.to(device)
173 | label_batch = label_batch.to(device)
174 | output_batch, batch_mask = model(x_batch)
175 | output_batch = output_batch.permute([0, 2, 1])
176 | loss = criterion(output_batch.squeeze(), label_batch.squeeze())
177 | loss = torch.div(torch.sum(loss*batch_mask), torch.sum(batch_mask))
178 | optimizer.zero_grad()
179 | loss.backward()
180 | optimizer.step()
181 |
182 | loss_item = loss.cpu().detach().item()
183 | epoch_loss += loss_item
184 | current_step += 1
185 | pbar.set_description("train loss {}".format(epoch_loss / current_step))
186 | if current_step % 100 == 0:
187 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step))
188 |
189 | epoch_loss = epoch_loss / current_step
190 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
191 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss))
192 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss))
193 | save_checkpoint(model, optimizer, epoch)
194 | model.eval()
195 | with torch.no_grad():
196 | correct = 0
197 | total = 0
198 | current_step = 0
199 | pbar = tqdm(valid_dataloader, desc="Iteration", postfix='valid')
200 | for batch_data in pbar:
201 | x_batch, label_batch = batch_data
202 | x_batch = x_batch.to(device)
203 | label_batch = label_batch.to(device).long()
204 | output_batch, batch_mask = model(x_batch)
205 | label_batch = label_batch.to(device)
206 | total += torch.sum(batch_mask)
207 | vec_correct = ((output_batch.squeeze()>args.threshold).long() == label_batch.squeeze().long())*batch_mask
208 | correct += torch.sum(vec_correct).cpu().item()
209 | pbar.set_description("valid acc {}".format(correct / total))
210 | current_step += 1
211 | if current_step % 100 == 0:
212 | logging.info('valid epoch {} acc {}/{}={:.4f}'.format(epoch, correct, total, correct / total))
213 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
214 | print('{} valid epoch {} acc {}/{}={:.4f}'.format(time_str, epoch, correct, total, correct / total))
215 | logging.info('valid epoch {} acc {}/{}={:.4f}'.format(epoch, correct, total, correct / total))
216 |
217 |
218 | if __name__ == '__main__':
219 |
220 | # 加载数据
221 | data = load_data(data_extract_json)
222 | data_x = np.load(data_extract_npy)
223 | data_y = np.zeros_like(data_x[..., :1])
224 |
225 | for i, d in enumerate(data):
226 | for j in d[1]:
227 | data_y[i, j] = 1
228 |
229 | train_data = data_split(data, fold, num_folds, 'train')
230 | valid_data = data_split(data, fold, num_folds, 'valid')
231 | train_x = data_split(data_x, fold, num_folds, 'train')
232 | valid_x = data_split(data_x, fold, num_folds, 'valid')
233 | train_y = data_split(data_y, fold, num_folds, 'train')
234 | valid_y = data_split(data_y, fold, num_folds, 'valid')
235 |
236 | train_dataloader = DataLoader(Selector_Dataset(train_x, train_y), batch_size=args.batch_size, shuffle=True, drop_last=True)
237 | valid_dataloader = DataLoader(Selector_Dataset(valid_x, valid_y), batch_size=len(valid_x), shuffle=False)
238 |
239 | model = Selector2(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1])
240 |
241 | train(model, train_dataloader, valid_dataloader)
242 |
243 |
244 |
245 |
--------------------------------------------------------------------------------
/extract_vectorize.py:
--------------------------------------------------------------------------------
1 | #! -*- coding: utf-8 -*-
2 | # 法研杯2020 司法摘要
3 | # 抽取式:句向量化
4 | # 科学空间:https://kexue.fm
5 |
6 | import json
7 | import numpy as np
8 | from tqdm import tqdm
9 | from transformers import BertTokenizer, BertModel, BertConfig
10 | from transformers import AutoModel, AutoTokenizer
11 | from snippets import *
12 | import torch.nn as nn
13 | import torch
14 |
15 |
16 |
17 | class GlobalAveragePooling1D(nn.Module):
18 | """自定义全局池化
19 | 对一个句子的pooler取平均,一个长句子用短句的pooler平均代替
20 | """
21 | def __init__(self):
22 | super(GlobalAveragePooling1D, self).__init__()
23 |
24 |
25 | def forward(self, inputs, mask=None):
26 | if mask is not None:
27 | mask = mask.to(torch.float)[:, :, None]
28 | return torch.sum(inputs * mask, dim=1) / torch.sum(mask, dim=1)
29 | else:
30 | return torch.mean(inputs, dim=1)
31 |
32 |
33 | class Selector_1(nn.Module):
34 | def __init__(self):
35 | super(Selector_1, self).__init__()
36 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_fold, mirror='tuna', do_lower_case=True)
37 | self.Pooling = GlobalAveragePooling1D()
38 | self.encoder = BertModel.from_pretrained(pretrained_bert_fold)
39 | self.max_seq_len = 512
40 |
41 |
42 | def predict(self, texts):
43 | """句子列表转换为句向量
44 | """
45 | with torch.no_grad():
46 | bert_output = self.tokenizer.batch_encode_plus(texts, padding=True, truncation=True, max_length=self.max_seq_len, return_tensors="pt")
47 | output_1 = self.encoder(**bert_output)["last_hidden_state"]
48 | outputs = self.Pooling(output_1)
49 | return outputs
50 |
51 |
52 |
53 | def load_data(filename):
54 | """加载数据
55 | 返回:[texts]
56 | """
57 | D = []
58 | with open(filename) as f:
59 | for l in f:
60 | texts = json.loads(l)[0]
61 | D.append(texts)
62 | return D
63 |
64 |
65 |
66 |
67 | def convert(data):
68 | """转换所有样本
69 | """
70 | embeddings = []
71 | model = Selector_1()
72 | for texts in tqdm(data, desc=u'向量化'):
73 | outputs = model.predict(texts)
74 | embeddings.append(outputs)
75 | embeddings = sequence_padding(embeddings)
76 | return embeddings
77 |
78 |
79 | if __name__ == '__main__':
80 |
81 | data_extract_json = data_json[:-5] + '_extract.json'
82 | data_extract_npy = data_json[:-5] + '_extract'
83 |
84 | data = load_data(data_extract_json)
85 | embeddings = convert(data)
86 | np.save(data_extract_npy, embeddings)
87 | print(u'输出路径:%s.npy' % data_extract_npy)
88 |
--------------------------------------------------------------------------------
/seq2seq_convert.py:
--------------------------------------------------------------------------------
1 | from extract_model import *
2 | from snippets import open
3 | import torch
4 |
5 | def fold_convert(data, data_x, fold):
6 | """每一fold用对应的模型做数据转换
7 | """
8 | valid_data = data_split(data, fold, num_folds, 'valid')
9 | valid_x = data_split(data_x, fold, num_folds, 'valid')
10 | with torch.no_grad():
11 | model = Selector2(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1])
12 | load_checkpoint(model, None, 19)
13 | model_output = model(torch.tensor(valid_x))[0]
14 | y_pred = model_output.cpu().numpy()
15 |
16 | results = []
17 | for d, yp in tqdm(zip(valid_data, y_pred), desc=u'转换中'):
18 | yp = yp[:len(d[0])]
19 | yp = np.where(yp > args.threshold)[0]
20 | source_1 = ''.join([d[0][i] for i in yp])
21 | source_2 = ''.join([d[0][i] for i in d[1]])
22 | result = {
23 | 'source_1': source_1,
24 | 'source_2': source_2,
25 | 'target': d[2],
26 | }
27 | results.append(result)
28 |
29 | return results
30 |
31 |
32 | def convert(filename, data, data_x):
33 | """转换为生成式数据
34 | """
35 | F = open(filename, 'w', encoding='utf-8')
36 | total_results = []
37 | for fold in range(num_folds):
38 | total_results.append(fold_convert(data, data_x, fold))
39 |
40 | # 按照原始顺序写入到文件中
41 | n = 0
42 | while True:
43 | i, j = n % num_folds, n // num_folds
44 | try:
45 | d = total_results[i][j]
46 | except:
47 | break
48 | F.write(json.dumps(d, ensure_ascii=False) + '\n')
49 | n += 1
50 |
51 | F.close()
52 |
53 |
54 | if __name__ == '__main__':
55 |
56 | data = load_data(data_extract_json)
57 | data_x = np.load(data_extract_npy)
58 | data_seq2seq_json = data_json[:-5] + '_seq2seq.json'
59 | convert(data_seq2seq_json, data, data_x)
60 | print(u'输出路径:%s' % data_seq2seq_json)
61 |
--------------------------------------------------------------------------------
/seq2seq_model.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os, json
3 | import numpy as np
4 | from tqdm import tqdm
5 | from transformers import BertTokenizer, AutoTokenizer
6 | import argparse
7 | import torch
8 | from transformers import AdamW
9 | import torch.nn as nn
10 | from tqdm import tqdm
11 | import copy
12 | from torch.utils.data import Dataset, DataLoader
13 | import logging
14 | from snippets import *
15 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab
16 | from bert_seq2seq import load_bert
17 | # 基本参数
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--batch_size', type=int, default=2, help='batch size')
20 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs')
21 | parser.add_argument('--each_test_epoch', type=int, default=1)
22 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate')
23 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer')
24 | parser.add_argument('--model_name', type=str, default='nezha', help='matching model')
25 | parser.add_argument('--checkpoint', type=str, default="./checkpoint/", help='checkpoint path')
26 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case')
27 | parser.add_argument('--maxlen', type=int, default=1024, help='max length of each case')
28 | parser.add_argument('--input_size', type=int, default=768)
29 | parser.add_argument('--hidden_size', type=int, default=384)
30 | parser.add_argument('--kernel_size', type=int, default=3)
31 | parser.add_argument('--threshold', type=float, default=0.3)
32 | parser.add_argument('--k_sparse', type=int, default=10)
33 | parser.add_argument('--cuda_pos', type=str, default='0', help='which GPU to use')
34 | parser.add_argument('--seed', type=int, default=42, help='max length of each case')
35 | args = parser.parse_args()
36 |
37 | np.random.seed(args.seed)
38 | torch.manual_seed(args.seed)
39 | torch.cuda.manual_seed_all(args.seed)
40 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu')
41 | log_name = "log_train"
42 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别
43 | filename='./logs/{}.log'.format(log_name),
44 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
45 | #a是追加模式,默认如果不写的话,就是追加模式
46 | format=
47 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
48 | #日志格式
49 | )
50 |
51 |
52 |
53 | data_seq2seq_json = data_json[:-5] + '_seq2seq.json'
54 | seq2seq_config_json = data_json[:-10] + 'seq2seq_config.json'
55 |
56 | if len(sys.argv) == 1:
57 | fold = 0
58 | else:
59 | fold = int(sys.argv[1])
60 |
61 |
62 | def load_data(filename):
63 | """加载数据
64 | 返回:[{...}]
65 | """
66 | D = []
67 | with open(filename) as f:
68 | for l in f:
69 | D.append(json.loads(l))
70 | return D
71 |
72 |
73 |
74 |
75 | def generate_copy_labels(source, target):
76 | """构建copy机制对应的label
77 | """
78 | mapping = longest_common_subsequence(source, target)[1]
79 | source_labels = [0] * len(source)
80 | target_labels = [0] * len(target)
81 | i0, j0 = -2, -2
82 | for i, j in mapping:
83 | if i == i0 + 1 and j == j0 + 1:
84 | source_labels[i] = 2
85 | target_labels[j] = 2
86 | else:
87 | source_labels[i] = 1
88 | target_labels[j] = 1
89 | i0, j0 = i, j
90 | return source_labels, target_labels
91 |
92 |
93 | def random_masking(token_ids_all):
94 | """对输入进行随机mask,增加泛化能力
95 | """
96 | result = []
97 | for token_ids in token_ids_all:
98 | rands = np.random.random(len(token_ids))
99 | result.append([
100 | t if r > 0.15 else np.random.choice(token_ids)
101 | for r, t in zip(rands, token_ids)
102 | ])
103 | return result
104 |
105 |
106 | class DataGenerator(Dataset):
107 | def __init__(self, input_data, random=True):
108 | super(DataGenerator, self).__init__()
109 | self.input_data = input_data
110 | self.random = random
111 |
112 | def __len__(self):
113 | return len(self.input_data)
114 |
115 | def __getitem__(self, idx):
116 |
117 | i = np.random.choice(2) + 1 if self.random else 1
118 | source, target = self.input_data[idx]['source_%s' % i], self.input_data[idx]['target']
119 | return [source, target]
120 |
121 |
122 | class Collate:
123 | def __init__(self):
124 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold)
125 |
126 | self.max_seq_len = args.maxlen
127 |
128 | def __call__(self, batch):
129 | # assert len(A_batch) == 1
130 | # print("A_batch: ", A_batch)
131 | dic_data = self.tokenizer.batch_encode_plus(batch, padding=True, truncation=True,
132 | max_length=self.max_seq_len)
133 | mask_dic_data = copy.deepcopy(dic_data)
134 |
135 | token_ids = dic_data["input_ids"]
136 |
137 | masked_token_ids = random_masking(token_ids)
138 | mask_dic_data['input_ids'] = masked_token_ids
139 | labels = []
140 | for item_masked_token_ids, item_token_ids in zip(masked_token_ids, token_ids):
141 | idx = item_token_ids.index(self.tokenizer.sep_token_id) + 1
142 | source_labels, target_labels = generate_copy_labels(
143 | item_masked_token_ids[:idx], item_token_ids[idx:]
144 | )
145 | """
146 | [CLS]...[SEP] ... [SEP]
147 | """
148 | labels.append(source_labels[1:] + target_labels) # 因为是预测所以第一位后移
149 |
150 |
151 | return torch.tensor(dic_data["input_ids"]), torch.tensor(dic_data["token_type_ids"]), torch.tensor(labels)
152 |
153 |
154 |
155 | def build_pretrain_dataloader(data, batch_size, shuffle=True, num_workers=0,):
156 | data_generator =DataGenerator(data, random=True)
157 | collate = Collate()
158 | return DataLoader(
159 | data_generator,
160 | batch_size=batch_size,
161 | shuffle=shuffle,
162 | num_workers=num_workers,
163 | collate_fn=collate
164 | )
165 |
166 |
167 | def compute_seq2seq_loss(predictions, token_type_id, input_ids, vocab_size):
168 |
169 | predictions = predictions[:, :-1].contiguous()
170 | target_mask = token_type_id[:, 1:].contiguous()
171 | """
172 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
173 | """
174 | predictions = predictions.view(-1, vocab_size)
175 | labels = input_ids[:, 1:].contiguous()
176 | labels = labels.view(-1)
177 | target_mask = target_mask.view(-1).float()
178 | # 正loss
179 | pos_loss = predictions[list(range(predictions.shape[0])), labels]
180 | # 负loss
181 | y_pred = torch.topk(predictions, k=args.k_sparse)[0]
182 | neg_loss = torch.logsumexp(y_pred, dim=-1)
183 |
184 | loss = neg_loss - pos_loss
185 | return (loss * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响
186 |
187 |
188 | def compute_copy_loss(predictions, token_type_id, labels):
189 | predictions = predictions[:, :-1].contiguous()
190 | target_mask = token_type_id[:, 1:].contiguous()
191 | """
192 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
193 | """
194 | predictions = predictions.view(-1, 3)
195 | labels = labels.view(-1)
196 | target_mask = target_mask.view(-1).float()
197 | loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none")
198 | return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响
199 |
200 | class GenerateModel(nn.Module):
201 | def __init__(self):
202 | super(GenerateModel, self).__init__()
203 | self.word2idx = load_chinese_base_vocab(pretrained_nezha_fold+"vocab.txt", simplfied=False)
204 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold)
205 | self.bert_model = load_bert(self.word2idx, model_name=args.model_name, model_class="seq2seq")
206 | ## 加载预训练的模型参数~
207 | self.bert_model.load_pretrain_params(pretrained_nezha_fold+"pytorch_model.bin")
208 | self.bert_model.set_device(device)
209 | self.configuration = self.bert_model.config
210 | self.linear = nn.Linear(self.configuration.hidden_size, 3)
211 |
212 | def forward(self, token_ids, token_type_ids):
213 | seq2seq_predictions, hidden_state = self.bert_model(token_ids, token_type_ids)
214 | copy_predictions = self.linear(nn.GELU()(hidden_state))
215 |
216 | return seq2seq_predictions, copy_predictions
217 |
218 |
219 | def load_checkpoint(model, optimizer, trained_epoch):
220 | filename = args.checkpoint + '/' + f"seq2seq-{trained_epoch}.pkl"
221 | save_params = torch.load(filename)
222 | model.load_state_dict(save_params["model"])
223 | optimizer.load_state_dict(save_params["optimizer"])
224 |
225 |
226 | def save_checkpoint(model, optimizer, trained_epoch):
227 | save_params = {
228 | "model": model.state_dict(),
229 | "optimizer": optimizer.state_dict(),
230 | "trained_epoch": trained_epoch,
231 | }
232 | if not os.path.exists(args.checkpoint):
233 | # 判断文件夹是否存在,不存在则创建文件夹
234 | os.mkdir(args.checkpoint)
235 | filename = args.checkpoint + '/' + f"seq2seq-{trained_epoch}.pkl"
236 | torch.save(save_params, filename)
237 |
238 |
239 | def train_valid(train_data, valid_data, model):
240 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
241 | # ema = EMA(model, 0.9999)
242 | # ema.register()
243 | for epoch in range(args.epochs):
244 | epoch_loss = 0.
245 | current_step = 0
246 | model.train()
247 | # for batch_data in tqdm(train_data_loader, ncols=0):
248 | pbar = tqdm(train_data, desc="Iteration", postfix='train')
249 | for batch_data in pbar:
250 | input_ids, token_type_ids, labels = batch_data
251 | input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device)
252 | seq2seq_predictions, copy_predictions = model(input_ids, token_type_ids)
253 |
254 | seq2seq_loss = compute_seq2seq_loss(seq2seq_predictions, token_type_ids, input_ids,
255 | model.configuration.vocab_size)
256 | copy_loss = compute_copy_loss(copy_predictions, token_type_ids, labels)
257 | loss = seq2seq_loss + 2 * copy_loss
258 | optimizer.zero_grad()
259 | loss.backward()
260 | optimizer.step()
261 | # ema.update()
262 | loss_item = loss.cpu().detach().item()
263 | epoch_loss += loss_item
264 | current_step += 1
265 | pbar.set_description("train loss {}".format(epoch_loss / current_step))
266 | if current_step % 100 == 0:
267 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step))
268 |
269 | epoch_loss = epoch_loss / current_step
270 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
271 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss))
272 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss))
273 | # todo 看一下 EMA是否会让模型准确率提升,如果可以的话在保存模型前加入 ema
274 | save_checkpoint(model, optimizer, epoch)
275 | with torch.no_grad():
276 | model.eval()
277 | # ema.apply_shadow()
278 | evaluate(valid_data, model, filename=r'./result/abstract.txt')
279 | # ema.restore()
280 | model.train()
281 |
282 | class AutoSummary(AutoRegressiveDecoder):
283 | """seq2seq解码器
284 | """
285 | def get_ngram_set(self, x, n):
286 | """生成ngram合集,返回结果格式是:
287 | {(n-1)-gram: set([n-gram的第n个字集合])}
288 | """
289 | result = {}
290 | for i in range(len(x) - n + 1):
291 | k = tuple(x[i:i + n])
292 | if k[:-1] not in result:
293 | result[k[:-1]] = set()
294 | result[k[:-1]].add(k[-1])
295 | return result
296 |
297 | @AutoRegressiveDecoder.wraps(default_rtype='logits', use_states=True)
298 | def predict(self, inputs, output_ids, states):
299 | token_ids, segment_ids = inputs
300 | token_ids = np.concatenate([token_ids, output_ids], 1)
301 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
302 | seq2seq_predictions, copy_predictions = self.model(torch.tensor(token_ids, device=device), torch.tensor(segment_ids, device=device))
303 | prediction = [seq2seq_predictions[:, -1].cpu().numpy(), torch.softmax(copy_predictions[:, -1], dim=-1).cpu().numpy()] # 返回最后一个字符的预测结果,(1, vocab_size),(1, 3) todo 我这里需要加一个softmax 前面的生成模型给也需要
304 | # states用来缓存ngram的n值
305 | if states is None:
306 | states = [0]
307 | elif len(states) == 1 and len(token_ids) > 1:
308 | states = states * len(token_ids)
309 | # 根据copy标签来调整概率分布
310 | probas = np.zeros_like(prediction[0]) - 1000 # 最终要返回的概率分布
311 | for i, token_ids in enumerate(inputs[0]):
312 | if states[i] == 0:
313 | prediction[1][i, 2] *= -1 # 0不能接2
314 | label = prediction[1][i].argmax() # 当前label
315 | if label < 2:
316 | states[i] = label
317 | else:
318 | states[i] += 1 # 2后面接什么都行
319 | if states[i] > 0:
320 | ngrams = self.get_ngram_set(token_ids, states[i])
321 | prefix = tuple(output_ids[i, 1 - states[i]:])
322 | if prefix in ngrams: # 如果确实是适合的ngram
323 | candidates = ngrams[prefix]
324 | else: # 没有的话就退回1gram
325 | ngrams = self.get_ngram_set(token_ids, 1)
326 | candidates = ngrams[tuple()]
327 | states[i] = 1
328 | candidates = list(candidates)
329 | probas[i, candidates] = prediction[0][i, candidates]
330 | else:
331 | probas[i] = prediction[0][i]
332 | idxs = probas[i].argpartition(-args.k_sparse)
333 | probas[i, idxs[:-args.k_sparse]] = -1000
334 | return probas, states
335 |
336 | def generate(self, text, topk=1):
337 | max_c_len = args.maxlen - self.maxlen
338 | encode_text = self.model.tokenizer(text, padding=True, truncation=True,
339 | max_length=max_c_len)
340 | token_ids, segment_ids = encode_text['input_ids'], encode_text['token_type_ids']
341 | output_ids = self.beam_search([token_ids, segment_ids],
342 | topk) # 基于beam search
343 | return ''.join(self.model.tokenizer.convert_ids_to_tokens(output_ids))
344 |
345 |
346 |
347 |
348 |
349 | def evaluate(data, model, topk=1, filename=None):
350 | """验证集评估
351 | """
352 | autosummary = AutoSummary(
353 | start_id=model.tokenizer.cls_token_id,
354 | end_id=model.tokenizer.sep_token_id,
355 | maxlen=args.maxlen // 2,
356 | model=model
357 | )
358 | if filename is not None:
359 | F = open(filename, 'w', encoding='utf-8')
360 | total_metrics = {k: 0.0 for k in metric_keys}
361 | for d in tqdm(data, desc=u'评估中'):
362 | pred_summary = autosummary.generate(d['source_1'], topk)
363 | metrics = compute_metrics(pred_summary, d['target'])
364 | for k, v in metrics.items():
365 | total_metrics[k] += v
366 | if filename is not None:
367 | F.write(d['target'] + '\t' + pred_summary + '\n')
368 | F.flush()
369 | if filename is not None:
370 | F.close()
371 | print(total_metrics)
372 | return {k: v / len(data) for k, v in total_metrics.items()}
373 |
374 | if __name__ == '__main__':
375 | # 加载数据
376 | data = load_data(data_seq2seq_json)
377 | train_data = data_split(data, fold, num_folds, 'train')
378 | valid_data = data_split(data, fold, num_folds, 'valid')
379 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size)
380 | G_model = GenerateModel()
381 | print(G_model)
382 | G_model = G_model.to(device)
383 | train_valid(train_data_loader, valid_data, G_model)
384 |
385 |
386 |
387 |
388 |
--------------------------------------------------------------------------------
/snippets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | from rouge import Rouge
4 | import os, sys
5 | import jieba
6 | import six
7 | from collections import defaultdict
8 |
9 |
10 |
11 | # 自定义词典
12 | user_dict_path = '/new_disk2/zhongxiang_sun/code/explanation_project/SPACES_torch/datasets/user_dict.txt'
13 | user_dict_path_2 = '/new_disk2/zhongxiang_sun/code/explanation_project/SPACES_torch/datasets/user_dict_2.txt'
14 | jieba.load_userdict(user_dict_path)
15 | jieba.initialize()
16 |
17 | # 设置递归深度
18 | sys.setrecursionlimit(1000000)
19 |
20 | # 标注数据
21 | data_json = '/new_disk2/zhongxiang_sun/code/explanation_project/SPACES_torch/datasets/train.json'
22 |
23 | # 保存权重的文件夹
24 | if not os.path.exists('weights'):
25 | os.mkdir('weights')
26 |
27 | # bert配置
28 | pretrained_bert_fold = "/new_disk2/zhongxiang_sun/code/pretrain_model/bert_legal/"
29 | pretrained_nezha_fold = "/new_disk2/zhongxiang_sun/code/pretrain_model/NEZHA/"
30 | # nezha配置
31 | nezha_config_path = '/root/kg/bert/nezha_base/bert_config.json'
32 | nezha_checkpoint_path = '/root/kg/bert/nezha_base/model.ckpt-900000'
33 | nezha_dict_path = '/root/kg/bert/nezha_base/vocab.txt'
34 |
35 | # 将数据划分N份,一份作为验证集
36 | num_folds = 15
37 |
38 | # 指标名
39 | metric_keys = ['main', 'rouge-1', 'rouge-2', 'rouge-l']
40 |
41 | # 计算rouge用
42 | rouge = Rouge()
43 |
44 | def softmax(x, axis=-1):
45 | """numpy版softmax
46 | """
47 | x = x - x.max(axis=axis, keepdims=True)
48 | x = np.exp(x)
49 | return x / x.sum(axis=axis, keepdims=True)
50 |
51 | class AutoRegressiveDecoder(object):
52 | """通用自回归生成模型解码基类
53 | 包含beam search和random sample两种策略
54 | """
55 | def __init__(self, start_id, end_id, maxlen,minlen=1, model=None, tokenizer=None):
56 | self.start_id = start_id
57 | self.end_id = end_id
58 | self.maxlen = maxlen
59 | self.minlen = minlen
60 | self.model = model
61 | self.tokenizer = tokenizer
62 | if start_id is None:
63 | self.first_output_ids = np.empty((1, 0), dtype=int)
64 | else:
65 | self.first_output_ids = np.array([[self.start_id]])
66 |
67 | @staticmethod
68 | def wraps(default_rtype='probas', use_states=False):
69 | """用来进一步完善predict函数
70 | 目前包含:1. 设置rtype参数,并做相应处理;
71 | 2. 确定states的使用,并做相应处理;
72 | 3. 设置温度参数,并做相应处理。
73 | """
74 | def actual_decorator(predict):
75 | def new_predict(
76 | self,
77 | inputs,
78 | output_ids,
79 | states,
80 | temperature=1,
81 | rtype=default_rtype
82 | ):
83 | assert rtype in ['probas', 'logits']
84 | prediction = predict(self, inputs, output_ids, states)
85 |
86 | if not use_states:
87 | prediction = (prediction, None)
88 |
89 | if default_rtype == 'logits':
90 | prediction = (
91 | softmax(prediction[0] / temperature), prediction[1]
92 | )
93 | elif temperature != 1:
94 | probas = np.power(prediction[0], 1.0 / temperature)
95 | probas = probas / probas.sum(axis=-1, keepdims=True)
96 | prediction = (probas, prediction[1])
97 |
98 | if rtype == 'probas':
99 | return prediction
100 | else:
101 | return np.log(prediction[0] + 1e-12), prediction[1]
102 |
103 | return new_predict
104 |
105 | return actual_decorator
106 |
107 |
108 |
109 | def predict(self, inputs, output_ids, states=None):
110 | """用户需自定义递归预测函数
111 | 说明:定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states,
112 | 其中default_rtype为字符串logits或probas,probas时返回归一化的概率,
113 | rtype=logits时则返回softmax前的结果或者概率对数。
114 | 返回:二元组 (得分或概率, states)
115 | """
116 | raise NotImplementedError
117 |
118 | def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1):
119 | """beam search解码
120 | 说明:这里的topk即beam size;
121 | 返回:最优解码序列。
122 | """
123 | inputs = [np.array([i]) for i in inputs]
124 | output_ids, output_scores = self.first_output_ids, np.zeros(1)
125 | for step in range(self.maxlen):
126 | scores, states = self.predict(
127 | inputs, output_ids, states, temperature, 'logits'
128 | ) # 计算当前得分
129 | if step == 0: # 第1步预测后将输入重复topk次
130 | inputs = [np.repeat(i, topk, axis=0) for i in inputs]
131 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分
132 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk
133 | indices_1 = indices // scores.shape[1] # 行索引
134 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引
135 | output_ids = np.concatenate([output_ids[indices_1], indices_2],
136 | 1) # 更新输出
137 | output_scores = np.take_along_axis(
138 | scores, indices, axis=None
139 | ) # 更新得分
140 | is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束
141 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记
142 | if output_ids.shape[1] >= self.minlen: # 最短长度判断
143 | best = output_scores.argmax() # 得分最大的那个
144 | if is_end[best] and end_counts[best] >= min_ends: # 如果已经终止
145 | return output_ids[best] # 直接输出
146 | else: # 否则,只保留未完成部分
147 | flag = ~is_end | (end_counts < min_ends) # 标记未完成序列
148 | if not flag.all(): # 如果有已完成的
149 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列
150 | output_ids = output_ids[flag] # 扔掉已完成序列
151 | output_scores = output_scores[flag] # 扔掉已完成序列
152 | end_counts = end_counts[flag] # 扔掉已完成end计数
153 | topk = flag.sum() # topk相应变化
154 | # 达到长度直接输出
155 | return output_ids[output_scores.argmax()]
156 |
157 | def random_sample(
158 | self,
159 | inputs,
160 | n,
161 | topk=None,
162 | topp=None,
163 | states=None,
164 | temperature=1,
165 | min_ends=1
166 | ):
167 | """随机采样n个结果
168 | 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
169 | 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
170 | 返回:n个解码序列组成的list。
171 | """
172 | inputs = [np.array([i]) for i in inputs]
173 | output_ids = self.first_output_ids
174 | results = []
175 | for step in range(self.maxlen):
176 | probas, states = self.predict(
177 | inputs, output_ids, states, temperature, 'probas'
178 | ) # 计算当前概率
179 | probas /= probas.sum(axis=1, keepdims=True) # 确保归一化
180 | if step == 0: # 第1步预测后将结果重复n次
181 | probas = np.repeat(probas, n, axis=0)
182 | inputs = [np.repeat(i, n, axis=0) for i in inputs]
183 | output_ids = np.repeat(output_ids, n, axis=0)
184 | if topk is not None:
185 | k_indices = probas.argpartition(-topk,
186 | axis=1)[:, -topk:] # 仅保留topk
187 | probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率
188 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化
189 | if topp is not None:
190 | p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序
191 | probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率
192 | cumsum_probas = np.cumsum(probas, axis=1) # 累积概率
193 | flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分
194 | flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果
195 | probas[flag] = 0 # 后面的全部置零
196 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化
197 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数
198 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样
199 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状
200 | if topp is not None:
201 | sample_ids = np.take_along_axis(
202 | p_indices, sample_ids, axis=1
203 | ) # 对齐原id
204 | if topk is not None:
205 | sample_ids = np.take_along_axis(
206 | k_indices, sample_ids, axis=1
207 | ) # 对齐原id
208 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出
209 | is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束
210 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记
211 | if output_ids.shape[1] >= self.minlen: # 最短长度判断
212 | flag = is_end & (end_counts >= min_ends) # 标记已完成序列
213 | if flag.any(): # 如果有已完成的
214 | for ids in output_ids[flag]: # 存好已完成序列
215 | results.append(ids)
216 | flag = (flag == False) # 标记未完成序列
217 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入
218 | output_ids = output_ids[flag] # 只保留未完成部分候选集
219 | end_counts = end_counts[flag] # 只保留未完成部分end计数
220 | if len(output_ids) == 0:
221 | break
222 | # 如果还有未完成序列,直接放入结果
223 | for ids in output_ids:
224 | results.append(ids)
225 | # 返回结果
226 | return results
227 |
228 | class EMA():
229 | def __init__(self, model, decay):
230 | self.model = model
231 | self.decay = decay
232 | self.shadow = {}
233 | self.backup = {}
234 |
235 | def register(self):
236 | for name, param in self.model.named_parameters():
237 | if param.requires_grad:
238 | self.shadow[name] = param.data.clone()
239 |
240 | def update(self):
241 | for name, param in self.model.named_parameters():
242 | if param.requires_grad:
243 | assert name in self.shadow
244 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
245 | self.shadow[name] = new_average.clone()
246 |
247 | def apply_shadow(self):
248 | for name, param in self.model.named_parameters():
249 | if param.requires_grad:
250 | assert name in self.shadow
251 | self.backup[name] = param.data
252 | param.data = self.shadow[name]
253 |
254 | def restore(self):
255 | for name, param in self.model.named_parameters():
256 | if param.requires_grad:
257 | assert name in self.backup
258 | param.data = self.backup[name]
259 | self.backup = {}
260 |
261 |
262 | def convert_to_unicode(text, encoding='utf-8', errors='ignore'):
263 | """字符串转换为unicode格式(假设输入为utf-8格式)
264 | """
265 | if is_py2:
266 | if isinstance(text, str):
267 | text = text.decode(encoding, errors=errors)
268 | else:
269 | if isinstance(text, bytes):
270 | text = text.decode(encoding, errors=errors)
271 | return text
272 |
273 |
274 | def convert_to_str(text, encoding='utf-8', errors='ignore'):
275 | """字符串转换为str格式(假设输入为utf-8格式)
276 | """
277 |
278 | if isinstance(text, bytes):
279 | text = text.decode(encoding, errors=errors)
280 | return text
281 |
282 |
283 | def is_string(s):
284 | """判断是否是字符串
285 | """
286 | return isinstance(s, str)
287 |
288 | _open_ = open
289 | is_py2 = six.PY2
290 | class open:
291 | """模仿python自带的open函数
292 | 作用:1.主要是为了同时兼容py2和py3;2.增加了索引功能,方便读取大文件。
293 | """
294 | def __init__(
295 | self, name, mode='r', encoding=None, errors='strict', indexable=False
296 | ):
297 | self.name = name
298 | if is_py2:
299 | self.file = _open_(name, mode)
300 | else:
301 | self.file = _open_(name, mode, encoding=encoding, errors=errors)
302 | self.encoding = encoding
303 | self.errors = errors
304 | self.iterator = None
305 | if indexable:
306 | if is_string(indexable) and os.path.exists(indexable):
307 | self.offsets = json.load(_open_(indexable))
308 | else:
309 | self.create_indexes()
310 | if is_string(indexable):
311 | json.dump(self.offsets, _open_(indexable, 'w'))
312 |
313 | def create_indexes(self):
314 | print('creating indexes ...')
315 | self.offsets, offset = [], 0
316 |
317 | while self.readline():
318 | self.offsets.append(offset)
319 | offset = self.tell()
320 | self.seek(0)
321 | print('indexes created.')
322 |
323 | def __getitem__(self, key):
324 | self.seek(self.offsets[key])
325 | l = self.readline()
326 | if self.encoding:
327 | l = convert_to_unicode(l, self.encoding, self.errors)
328 | return l
329 |
330 | def __len__(self):
331 | return len(self.offsets)
332 |
333 | def __iter__(self):
334 | if hasattr(self, 'offsets'):
335 | for i in range(len(self)):
336 | yield self[i]
337 | else:
338 | for l in self.file:
339 | if self.encoding:
340 | l = convert_to_unicode(l, self.encoding, self.errors)
341 | yield l
342 |
343 | def next(self):
344 | if self.iterator is None:
345 | self.iterator = self.__iter__()
346 | return next(self.iterator)
347 |
348 | def __next__(self):
349 | return self.next()
350 |
351 | def read(self):
352 | text = self.file.read()
353 | if self.encoding:
354 | text = convert_to_unicode(text, self.encoding, self.errors)
355 | return text
356 |
357 | def readline(self):
358 | text = self.file.readline()
359 | if self.encoding:
360 | text = convert_to_unicode(text, self.encoding, self.errors)
361 | return text
362 |
363 | def readlines(self):
364 | if self.encoding:
365 | return [
366 | convert_to_unicode(text, self.encoding, self.errors)
367 | for text in self.file.readlines()
368 | ]
369 | else:
370 | return self.file.readlines()
371 |
372 | def write(self, text):
373 | if self.encoding:
374 | text = convert_to_str(text, self.encoding, self.errors)
375 | self.file.write(text)
376 |
377 | def flush(self):
378 | self.file.flush()
379 |
380 | def close(self):
381 | self.file.close()
382 |
383 | def tell(self):
384 | return self.file.tell()
385 |
386 | def seek(self, offset=0):
387 | return self.file.seek(offset)
388 |
389 | def __enter__(self):
390 | return self
391 |
392 | def __exit__(self, type, value, tb):
393 | self.close()
394 |
395 |
396 | def parallel_apply(
397 | func,
398 | iterable,
399 | workers,
400 | max_queue_size,
401 | callback=None,
402 | dummy=False,
403 | random_seeds=True,
404 | unordered=True
405 | ):
406 | """多进程或多线程地将func应用到iterable的每个元素中。
407 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是
408 | 输出可能是func(c), func(a), func(b)。
409 | 参数:
410 | callback: 处理单个输出的回调函数;
411 | dummy: False是多进程/线性,True则是多线程/线性;
412 | random_seeds: 每个进程的随机种子;
413 | unordered: 若为False,则按照输入顺序返回,仅当callback为None时生效。
414 | """
415 | generator = parallel_apply_generator(
416 | func, iterable, workers, max_queue_size, dummy, random_seeds
417 | )
418 |
419 | if callback is None:
420 | if unordered:
421 | return [d for i, d in generator]
422 | else:
423 | results = sorted(generator, key=lambda d: d[0])
424 | return [d for i, d in results]
425 | else:
426 | for d in generator:
427 | callback(d)
428 |
429 | def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
430 | """Numpy函数,将序列padding到同一长度
431 | """
432 | if length is None:
433 | length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
434 | elif not hasattr(length, '__getitem__'):
435 | length = [length]
436 |
437 | slices = [np.s_[:length[i]] for i in range(seq_dims)]
438 | slices = tuple(slices) if len(slices) > 1 else slices[0]
439 | pad_width = [(0, 0) for _ in np.shape(inputs[0])]
440 |
441 | outputs = []
442 | for x in inputs:
443 | x = x[slices]
444 | for i in range(seq_dims):
445 | if mode == 'post':
446 | pad_width[i] = (0, length[i] - np.shape(x)[i])
447 | elif mode == 'pre':
448 | pad_width[i] = (length[i] - np.shape(x)[i], 0)
449 | else:
450 | raise ValueError('"mode" argument must be "post" or "pre".')
451 | x = np.pad(x, pad_width, 'constant', constant_values=value)
452 | outputs.append(x)
453 |
454 | return np.array(outputs)
455 |
456 | def parallel_apply_generator(
457 | func, iterable, workers, max_queue_size, dummy=False, random_seeds=True
458 | ):
459 | """多进程或多线程地将func应用到iterable的每个元素中。
460 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是
461 | 输出可能是func(c), func(a), func(b)。结果将作为一个
462 | generator返回,其中每个item是输入的序号以及该输入对应的
463 | 处理结果。
464 | 参数:
465 | dummy: False是多进程/线性,True则是多线程/线性;
466 | random_seeds: 每个进程的随机种子。
467 | """
468 | if dummy:
469 | from multiprocessing.dummy import Pool, Queue
470 | else:
471 | from multiprocessing import Pool, Queue
472 |
473 | in_queue, out_queue, seed_queue = Queue(max_queue_size), Queue(), Queue()
474 | if random_seeds is True:
475 | random_seeds = [None] * workers
476 | elif random_seeds is None or random_seeds is False:
477 | random_seeds = []
478 | for seed in random_seeds:
479 | seed_queue.put(seed)
480 |
481 | def worker_step(in_queue, out_queue):
482 | """单步函数包装成循环执行
483 | """
484 | if not seed_queue.empty():
485 | np.random.seed(seed_queue.get())
486 | while True:
487 | i, d = in_queue.get()
488 | r = func(d)
489 | out_queue.put((i, r))
490 |
491 | # 启动多进程/线程
492 | pool = Pool(workers, worker_step, (in_queue, out_queue))
493 |
494 | # 存入数据,取出结果
495 | in_count, out_count = 0, 0
496 | for i, d in enumerate(iterable):
497 | in_count += 1
498 | while True:
499 | try:
500 | in_queue.put((i, d), block=False)
501 | break
502 | except six.moves.queue.Full:
503 | for _ in range(out_queue.qsize()):
504 | yield out_queue.get()
505 | out_count += 1
506 | if in_count % max_queue_size == 0:
507 | for _ in range(out_queue.qsize()):
508 | yield out_queue.get()
509 | out_count += 1
510 |
511 | while out_count != in_count:
512 | for _ in range(out_queue.qsize()):
513 | yield out_queue.get()
514 | out_count += 1
515 |
516 | pool.terminate()
517 |
518 | def text_segmentate(text, maxlen, seps='\n', strips=None):
519 | """将文本按照标点符号划分为若干个短句
520 | """
521 | text = text.strip().strip(strips)
522 | if seps and len(text) > maxlen:
523 | pieces = text.split(seps[0])
524 | text, texts = '', []
525 | for i, p in enumerate(pieces):
526 | if text and p and len(text) + len(p) > maxlen - 1:
527 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips))
528 | text = ''
529 | if i + 1 == len(pieces):
530 | text = text + p
531 | else:
532 | text = text + p + seps[0]
533 | if text:
534 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips))
535 | return texts
536 | else:
537 | return [text]
538 |
539 |
540 | def load_user_dict(filename):
541 | """加载用户词典
542 | """
543 | user_dict = []
544 | with open(filename, encoding='utf-8') as f:
545 | for l in f:
546 | w = l.split()[0]
547 | user_dict.append(w)
548 | return user_dict
549 |
550 |
551 | def data_split(data, fold, num_folds, mode):
552 | """划分训练集和验证集
553 | """
554 | if mode == 'train':
555 | D = [d for i, d in enumerate(data) if i % num_folds != fold]
556 | else:
557 | D = [d for i, d in enumerate(data) if i % num_folds == fold]
558 |
559 | if isinstance(data, np.ndarray):
560 | return np.array(D)
561 | else:
562 | return D
563 |
564 |
565 | def compute_rouge(source, target, unit='word'):
566 | """计算rouge-1、rouge-2、rouge-l
567 | """
568 | # if unit == 'word':
569 | # source = jieba.cut(source, HMM=False)
570 | # target = jieba.cut(target, HMM=False)
571 | source, target = ' '.join(source), ' '.join(target)
572 | try:
573 | scores = rouge.get_scores(hyps=source, refs=target)
574 | return {
575 | 'rouge-1': scores[0]['rouge-1']['f'],
576 | 'rouge-2': scores[0]['rouge-2']['f'],
577 | 'rouge-l': scores[0]['rouge-l']['f'],
578 | }
579 | except ValueError:
580 | return {
581 | 'rouge-1': 0.0,
582 | 'rouge-2': 0.0,
583 | 'rouge-l': 0.0,
584 | }
585 |
586 |
587 | def compute_metrics(source, target, unit='word'):
588 | """计算所有metrics
589 | """
590 | metrics = compute_rouge(source, target, unit)
591 | metrics['main'] = (
592 | metrics['rouge-1'] * 0.2 + metrics['rouge-2'] * 0.4 +
593 | metrics['rouge-l'] * 0.4
594 | )
595 | return metrics
596 |
597 |
598 | def compute_main_metric(source, target, unit='word'):
599 | """计算主要metric
600 | """
601 | return compute_metrics(source, target, unit)['main']
602 |
603 |
604 | def longest_common_subsequence(source, target):
605 | """最长公共子序列(source和target的最长非连续子序列)
606 | 返回:子序列长度, 映射关系(映射对组成的list)
607 | 注意:最长公共子序列可能不止一个,所返回的映射只代表其中一个。
608 | """
609 | c = defaultdict(int)
610 | for i, si in enumerate(source, 1):
611 | for j, tj in enumerate(target, 1):
612 | if si == tj:
613 | c[i, j] = c[i - 1, j - 1] + 1
614 | elif c[i, j - 1] > c[i - 1, j]:
615 | c[i, j] = c[i, j - 1]
616 | else:
617 | c[i, j] = c[i - 1, j]
618 | l, mapping = c[len(source), len(target)], []
619 | i, j = len(source) - 1, len(target) - 1
620 | while len(mapping) < l:
621 | if source[i] == target[j]:
622 | mapping.append((i, j))
623 | i, j = i - 1, j - 1
624 | elif c[i + 1, j] > c[i, j + 1]:
625 | j = j - 1
626 | else:
627 | i = i - 1
628 | return l, mapping[::-1]
--------------------------------------------------------------------------------
/test_model/lawformer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | from transformers import LongformerForMaskedLM,RobertaForMaskedLM,AutoModelForMaskedLM,AutoTokenizer
3 | tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext", cache_dir="/new_disk2/zhongxiang_sun/code/pretrain_model/lawformer/")
4 | model = AutoModel.from_pretrained("thunlp/Lawformer", cache_dir="/new_disk2/zhongxiang_sun/code/pretrain_model/lawformer/")
5 | inputs = tokenizer("任某提起诉讼,请求判令解除婚姻关系并对夫妻共同财产进行分割。", return_tensors="pt")
6 | outputs = model(**inputs)
7 | print(outputs)
8 | print()
9 |
--------------------------------------------------------------------------------
/test_model/test_function.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.mkdir()
--------------------------------------------------------------------------------