├── .idea
├── .gitignore
├── ailabner.iml
├── encodings.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── README.md
├── __init__.py
├── __pycache__
├── config.cpython-37.pyc
├── data.cpython-37.pyc
├── evaluate.cpython-37.pyc
├── evaluating.cpython-37.pyc
├── operate_bilstm.cpython-37.pyc
└── utils.cpython-37.pyc
├── ckpts
└── bilstm_crf.pkl
├── config.py
├── data.py
├── data
├── crf_tag2id.pkl
├── crf_word2id.pkl
├── dev.char
├── lables.char
├── test.char
└── train.char
├── evaluate.py
├── evaluating.py
├── main.py
├── modelgraph
├── BILSTM.py
├── BILSTM_CRF.py
├── __init__.py
└── __pycache__
│ ├── BILSTM.cpython-37.pyc
│ ├── BILSTM_CRF.cpython-37.pyc
│ └── __init__.cpython-37.pyc
├── operate_bilstm.py
├── predict.py
├── requirements.txt
├── result.txt
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/ailabner.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PytorchBilstmCRF-Information-Extraction
2 | 基于Bilstm + CRF的信息抽取模型
3 |
4 |
5 |
6 | 运行:python main.py
7 |
8 | 预测:python predict.py
9 |
10 | [博客链接:基于BiLSTM+CRF的信息抽取模型](https://blog.csdn.net/qq_44193969/article/details/116008734?spm=1001.2014.3001.5502)
11 |
12 | 有任何问题,随时私信
13 |
14 | 有任何建议,随时私信
15 |
16 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__init__.py
--------------------------------------------------------------------------------
/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/evaluate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/evaluate.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/evaluating.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/evaluating.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/operate_bilstm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/operate_bilstm.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/ckpts/bilstm_crf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/ckpts/bilstm_crf.pkl
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # 设置lstm训练参数
2 | class TrainingConfig(object):
3 | batch_size = 16
4 | # 学习速率
5 | lr = 0.0005
6 | epoches = 10
7 | print_step = 100
8 |
9 | class LSTMConfig(object):
10 | emb_size = 256 # 词向量的维数
11 | hidden_size = 256 # lstm隐向量的维数
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from codecs import open
2 | import os
3 |
4 |
5 | def build_corpus(split, make_vocab=True, data_dir='./data'):
6 | assert split.lower() in ["train","dev","test"]
7 | word_lists = []
8 | tag_lists = []
9 | with open(os.path.join(data_dir,split+".char"),'r',encoding='utf-8') as f:
10 | word_list = []
11 | tag_list = []
12 | for line in f:
13 | if line != '\n':
14 | word,tag = line.strip('\n').split()
15 | word_list.append(word)
16 | tag_list.append(tag)
17 | else:
18 | word_lists.append(word_list)
19 | tag_lists.append(tag_list)
20 | word_list = []
21 | tag_list = []
22 | if make_vocab:
23 | word2id = build_map(word_lists)
24 | tag2id = build_map(tag_lists)
25 | return word_lists,tag_lists,word2id,tag2id
26 | else:
27 | return word_lists,tag_lists
28 |
29 |
30 | def build_map(lists):
31 | maps = {}
32 | for list_ in lists:
33 | for e in list_:
34 | if e not in maps:
35 | maps[e] = len(maps)
36 | return maps
--------------------------------------------------------------------------------
/data/crf_tag2id.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/data/crf_tag2id.pkl
--------------------------------------------------------------------------------
/data/crf_word2id.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/data/crf_word2id.pkl
--------------------------------------------------------------------------------
/data/lables.char:
--------------------------------------------------------------------------------
1 | B-NAME
2 | E-NAME
3 | O
4 | B-CONT
5 | M-CONT
6 | E-CONT
7 | B-RACE
8 | E-RACE
9 | B-TITLE
10 | M-TITLE
11 | E-TITLE
12 | B-EDU
13 | M-EDU
14 | E-EDU
15 | B-ORG
16 | M-ORG
17 | E-ORG
18 | M-NAME
19 | B-PRO
20 | M-PRO
21 | E-PRO
22 | S-RACE
23 | S-NAME
24 | B-LOC
25 | M-LOC
26 | E-LOC
27 | M-RACE
28 | S-ORG
29 | B-ID
30 | M-ID
31 | E-ID
32 |
33 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import Counter
3 | import pickle
4 |
5 | from operate_bilstm import BiLSTM_operator
6 | from evaluating import Metrics
7 | from utils import save_model
8 |
9 |
10 | def bilstm_train_and_eval(train_data,dev_data,test_data,word2id,tag2id,crf=True,remove_0=False):
11 | train_word_lists, train_tag_lists = train_data
12 | dev_word_lists, dev_tag_lists = dev_data
13 | test_word_lists, test_tag_lists = test_data
14 |
15 | start = time.time()
16 | vocab_size = len(word2id)
17 | out_size = len(tag2id)
18 |
19 | bilstm_operator = BiLSTM_operator(vocab_size,out_size,crf=crf)
20 | model_name = "bilstm_crf" if crf else "bilstm"
21 |
22 | print("start to train the {} ...".format(model_name))
23 | bilstm_operator.train(train_word_lists,train_tag_lists,dev_word_lists,dev_tag_lists,word2id,tag2id)
24 | save_model(bilstm_operator, "./ckpts/" + model_name + ".pkl")
25 |
26 | print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
27 | print("评估{}模型中...".format(model_name))
28 | pred_tag_lists, test_tag_lists = bilstm_operator.test(
29 | test_word_lists, test_tag_lists, word2id, tag2id)
30 |
31 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_0=remove_0)
32 | dtype = 'Bi_LSTM+CRF' if crf else 'Bi_LSTM'
33 | metrics.report_scores(dtype=dtype)
34 |
35 | return pred_tag_lists
--------------------------------------------------------------------------------
/evaluating.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from utils import flatten_lists
3 |
4 | class Metrics(object):
5 | """评价模型,计算每个标签的精确率、召回率、F1分数"""
6 | def __init__(self,gloden_tags,predict_tags,remove_0=False):
7 | self.golden_tags = flatten_lists(gloden_tags)
8 | self.predict_tags = flatten_lists(predict_tags)
9 |
10 | if remove_0: # 不统计非实体标记
11 | self._remove_Otags()
12 |
13 | # 所有的tag总数
14 | self.tagset = set(self.golden_tags)
15 | self.correct_tags_number = self.count_correct_tags()
16 | # print(self.correct_tags_number)
17 | self.predict_tags_count = Counter(self.predict_tags)
18 | self.golden_tags_count = Counter(self.golden_tags)
19 |
20 | # 精确率
21 | self.precision_scores = self.cal_precision()
22 | # 召回率
23 | self.recall_scores = self.cal_recall()
24 | # F1
25 | self.f1_scores = self.cal_f1()
26 |
27 | def cal_precision(self):
28 | """计算每个标签的精确率"""
29 | precision_scores = {}
30 | for tag in self.tagset:
31 | precision_scores[tag] = 0 if self.correct_tags_number.get(tag,0)==0 else \
32 | self.correct_tags_number.get(tag,0) / self.predict_tags_count[tag]
33 |
34 | return precision_scores
35 |
36 | def cal_recall(self):
37 | """计算每个标签的召回率"""
38 | recall_scores = {}
39 | for tag in self.tagset:
40 | recall_scores[tag] = self.correct_tags_number.get(tag,0) / self.golden_tags_count[tag]
41 |
42 | return recall_scores
43 |
44 | def cal_f1(self):
45 | """计算f1分数"""
46 | f1_scores = {}
47 | for tag in self.tagset:
48 | f1_scores[tag] = 2*self.precision_scores[tag]*self.recall_scores[tag] / \
49 | (self.precision_scores[tag] + self.recall_scores[tag] + 1e-10)
50 | return f1_scores
51 |
52 | def count_correct_tags(self):
53 | """计算每种标签预测正确的个数(对应精确率、召回率计算公式上的tp),用于后面精确率以及召回率的计算"""
54 | correct_dict = {}
55 | for gold_tag, predict_tag in zip(self.golden_tags, self.predict_tags):
56 | if gold_tag == predict_tag:
57 | if gold_tag not in correct_dict:
58 | correct_dict[gold_tag] = 1
59 | else:
60 | correct_dict[gold_tag] += 1
61 |
62 | return correct_dict
63 |
64 | def _remove_Otags(self):
65 |
66 | length = len(self.golden_tags)
67 | O_tag_indices = [i for i in range(length)
68 | if self.golden_tags[i] == 'O']
69 |
70 | self.golden_tags = [tag for i, tag in enumerate(self.golden_tags)
71 | if i not in O_tag_indices]
72 |
73 | self.predict_tags = [tag for i, tag in enumerate(self.predict_tags)
74 | if i not in O_tag_indices]
75 | print("原总标记数为{},移除了{}个O标记,占比{:.2f}%".format(
76 | length,
77 | len(O_tag_indices),
78 | len(O_tag_indices) / length * 100
79 | ))
80 |
81 | def report_scores(self,dtype='HMM'):
82 | """将结果用表格的形式打印出来,像这个样子:
83 |
84 | precision recall f1-score support
85 | B-LOC 0.775 0.757 0.766 1084
86 | I-LOC 0.601 0.631 0.616 325
87 | B-MISC 0.698 0.499 0.582 339
88 | I-MISC 0.644 0.567 0.603 557
89 | B-ORG 0.795 0.801 0.798 1400
90 | I-ORG 0.831 0.773 0.801 1104
91 | B-PER 0.812 0.876 0.843 735
92 | I-PER 0.873 0.931 0.901 634
93 |
94 | avg/total 0.779 0.764 0.770 6178
95 | """
96 | # 打印表头
97 | header_format = '{:>9s} {:>9} {:>9} {:>9} {:>9}'
98 | header = ['precision', 'recall', 'f1-score', 'support']
99 | with open('result.txt','a') as fout:
100 | fout.write('\n')
101 | fout.write('=========='*10)
102 | fout.write('\n')
103 | fout.write('模型:{},test结果如下:'.format(dtype))
104 | fout.write('\n')
105 | fout.write(header_format.format('', *header))
106 | print(header_format.format('', *header))
107 |
108 | row_format = '{:>9s} {:>9.4f} {:>9.4f} {:>9.4f} {:>9}'
109 | # 打印每个标签的 精确率、召回率、f1分数
110 | for tag in self.tagset:
111 | print(row_format.format(
112 | tag,
113 | self.precision_scores[tag],
114 | self.recall_scores[tag],
115 | self.f1_scores[tag],
116 | self.golden_tags_count[tag]
117 | ))
118 | fout.write('\n')
119 | fout.write(row_format.format(
120 | tag,
121 | self.precision_scores[tag],
122 | self.recall_scores[tag],
123 | self.f1_scores[tag],
124 | self.golden_tags_count[tag]
125 | ))
126 |
127 | # 计算并打印平均值
128 | avg_metrics = self._cal_weighted_average()
129 | print(row_format.format(
130 | 'avg/total',
131 | avg_metrics['precision'],
132 | avg_metrics['recall'],
133 | avg_metrics['f1_score'],
134 | len(self.golden_tags)
135 | ))
136 | fout.write('\n')
137 | fout.write(row_format.format(
138 | 'avg/total',
139 | avg_metrics['precision'],
140 | avg_metrics['recall'],
141 | avg_metrics['f1_score'],
142 | len(self.golden_tags)
143 | ))
144 | fout.write('\n')
145 |
146 |
147 | def _cal_weighted_average(self):
148 |
149 | weighted_average = {}
150 | total = len(self.golden_tags)
151 |
152 | # 计算weighted precisions:
153 | weighted_average['precision'] = 0.
154 | weighted_average['recall'] = 0.
155 | weighted_average['f1_score'] = 0.
156 | for tag in self.tagset:
157 | size = self.golden_tags_count[tag]
158 | weighted_average['precision'] += self.precision_scores[tag] * size
159 | weighted_average['recall'] += self.recall_scores[tag] * size
160 | weighted_average['f1_score'] += self.f1_scores[tag] * size
161 |
162 | for metric in weighted_average.keys():
163 | weighted_average[metric] /= total
164 |
165 | return weighted_average
166 |
167 | def report_confusion_matrix(self):
168 | """计算混淆矩阵"""
169 |
170 | print("\nConfusion Matrix:")
171 | tag_list = list(self.tagset)
172 | # 初始化混淆矩阵 matrix[i][j]表示第i个tag被模型预测成第j个tag的次数
173 | tags_size = len(tag_list)
174 | matrix = []
175 | for i in range(tags_size):
176 | matrix.append([0] * tags_size)
177 |
178 | # 遍历tags列表
179 | for golden_tag, predict_tag in zip(self.golden_tags, self.predict_tags):
180 | try:
181 | row = tag_list.index(golden_tag)
182 | col = tag_list.index(predict_tag)
183 | matrix[row][col] += 1
184 | except ValueError: # 有极少数标记没有出现在golden_tags,但出现在predict_tags,跳过这些标记
185 | continue
186 |
187 | # 输出矩阵
188 | row_format_ = '{:>7} ' * (tags_size+1)
189 | print(row_format_.format("", *tag_list))
190 | for i, row in enumerate(matrix):
191 | print(row_format_.format(tag_list[i], *row))
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from data import build_corpus
2 | from evaluate import bilstm_train_and_eval
3 | from utils import extend_maps,prepocess_data_for_lstmcrf, save_obj, load_obj
4 |
5 |
6 | print("读取数据中...")
7 | train_word_lists,train_tag_lists,word2id,tag2id = build_corpus("train")
8 | dev_word_lists,dev_tag_lists = build_corpus("dev",make_vocab=False)
9 | test_word_lists,test_tag_lists = build_corpus("test",make_vocab=False)
10 |
11 |
12 | print("正在训练评估Bi-LSTM+CRF模型...")
13 | crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
14 | save_obj(crf_word2id, 'crf_word2id')
15 | save_obj(crf_tag2id, 'crf_tag2id')
16 | # import os
17 | # #保存word2id
18 | # if os.path.exists('data/crf_word2id.pkl'):
19 | # crf_word2id = load_obj('crf_word2id')
20 | # else:
21 | # save_obj(crf_word2id, 'crf_word2id')
22 | #
23 | # #保存tag2id
24 | # if os.path.exists('data/crf_tag2id.pkl'):
25 | # crf_tag2id = load_obj('crf_tag2id')
26 | # else:
27 | # save_obj(crf_tag2id, 'crf_tag2id')
28 |
29 |
30 | print(' '.join([i[0] for i in crf_tag2id.items()]))
31 |
32 | train_word_lists, train_tag_lists = prepocess_data_for_lstmcrf(
33 | train_word_lists, train_tag_lists
34 | )
35 |
36 |
37 | dev_word_lists, dev_tag_lists = prepocess_data_for_lstmcrf(
38 | dev_word_lists, dev_tag_lists
39 | )
40 | test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
41 | test_word_lists, test_tag_lists, test=True
42 | )
43 |
44 |
45 | lstmcrf_pred = bilstm_train_and_eval(
46 | (train_word_lists, train_tag_lists),
47 | (dev_word_lists, dev_tag_lists),
48 | (test_word_lists, test_tag_lists),
49 | crf_word2id, crf_tag2id
50 | )
--------------------------------------------------------------------------------
/modelgraph/BILSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | class BiLSTM(nn.Module):
6 | def __init__(self, vocab_size, emb_size, hidden_size, out_size, dropout=0.1):
7 | super(BiLSTM, self).__init__()
8 | self.embedding = nn.Embedding(vocab_size, emb_size)
9 | self.bilstm = nn.LSTM(emb_size, hidden_size, batch_first=True, bidirectional=True)
10 | self.fc =nn.Linear(2*hidden_size, out_size)
11 | self.dropout =nn.Dropout(dropout)
12 |
13 | def forward(self, x, lengths):
14 | emb = self.dropout(self.embedding(x))
15 | emb = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True)
16 | emb, _ = self.bilstm(emb)
17 | # print("shape of x: ")
18 | # print(x.shape)
19 | emb, _ = nn.utils.rnn.pad_packed_sequence(emb, batch_first=True, padding_value=0., total_length=x.shape[1])
20 | scores = self.fc(emb)
21 |
22 | return scores
23 |
24 | def test(self, x, lengths, _):
25 | logits = self.forward(x, lengths)
26 | _, batch_tagids = torch.max(logits, dim=2)
27 | return batch_tagids
28 |
29 | def cal_loss(logits, targets, tag2id):
30 | PAD = tag2id.get('')
31 | assert PAD is not None
32 | mask = (targets != PAD)
33 | targets = targets[mask]
34 | out_size = logits.size(2)
35 | logits = logits.masked_select(
36 | mask.unsqueeze(2).expand(-1, -1, out_size)
37 | ).contiguous().view(-1, out_size)
38 | assert logits.size(0) == targets.size(0)
39 | loss = F.cross_entropy(logits, targets)
40 | return loss
--------------------------------------------------------------------------------
/modelgraph/BILSTM_CRF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from modelgraph.BILSTM import BiLSTM
4 | from itertools import zip_longest
5 |
6 | class BiLSTM_CRF(nn.Module):
7 | def __init__(self, vocab_size, emb_size, hidden_size, out_size):
8 | super(BiLSTM_CRF, self).__init__()
9 | self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size)
10 | self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size)
11 |
12 | def forward(self, sents_tensor, lengths):
13 | emission = self.bilstm(sents_tensor, lengths)
14 | batch_size, max_len, out_size = emission.size()
15 | crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0)
16 | return crf_scores
17 |
18 | def test(self, test_sents_tensor, lengths, tag2id):
19 | start_id = tag2id['']
20 | end_id = tag2id['']
21 | pad = tag2id['']
22 | tagset_size = len(tag2id)
23 |
24 | crf_scores =self.forward(test_sents_tensor, lengths)
25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 | B , L , T, _ =crf_scores.size()
27 |
28 | viterbi = torch.zeros(B, L, T).to(device)
29 | backpointer = (torch.zeros(B, L, T).long() * end_id).to(device)
30 |
31 | lengths = torch.LongTensor(lengths).to(device)
32 |
33 | for step in range(L):
34 | batch_size_t =(lengths > step).sum().item()
35 | if step == 0:
36 | viterbi[:batch_size_t, step, :] = crf_scores[: batch_size_t, step, start_id, :]
37 | backpointer[:batch_size_t, step, :] = start_id
38 | else:
39 | max_scores, prev_tags = torch.max(viterbi[:batch_size_t, step-1, :].unsqueeze(2) + crf_scores[:batch_size_t, step, :, :], dim=1)
40 | viterbi[:batch_size_t, step, :] = max_scores
41 | backpointer[:batch_size_t, step, :] = prev_tags
42 |
43 | backpointer = backpointer.view(B, -1)
44 | tagids = []
45 | tags_t = None
46 | for step in range(L-1, 0, -1):
47 | batch_size_t = (lengths > step).sum().item()
48 | if step == L-1:
49 | index = torch.ones(batch_size_t).long() * (step * tagset_size)
50 | index = index.to(device)
51 | index += end_id
52 | else:
53 | prev_batch_size_t = len(tags_t)
54 | new_in_batch = torch.LongTensor([end_id] * (batch_size_t - prev_batch_size_t)).to(device)
55 | offset = torch.cat([tags_t, new_in_batch], dim=0)
56 | index = torch.ones(batch_size_t).long() * (step *tagset_size)
57 | index = index.to(device)
58 | index += offset.long()
59 |
60 | try:
61 | tags_t = backpointer[:batch_size_t].gather(dim=1, index=index.unsqueeze(1).long())
62 | except RuntimeError:
63 | import pdb
64 | pdb.set_trace()
65 | tags_t = tags_t.squeeze(1)
66 | tagids.append(tags_t.tolist())
67 | tagids = list(zip_longest(*reversed(tagids), fillvalue=pad))
68 | tagids = torch.Tensor(tagids).long()
69 |
70 | return tagids
71 |
72 |
73 | def cal_lstm_crf_loss(crf_scores, targets, tag2id):
74 | pad_id = tag2id.get('')
75 | start_id = tag2id.get('')
76 | end_id = tag2id.get('')
77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78 | batch_size, max_len = targets.size()
79 | target_size = len(tag2id)
80 | mask = (targets != pad_id)
81 | lengths = mask.sum(dim=1)
82 | targets = indexed(targets, target_size, start_id)
83 | targets = targets.masked_select(mask)
84 | flatten_scores = crf_scores.masked_select(
85 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores)
86 | ).view(-1, target_size*target_size).contiguous()
87 | golden_scores = flatten_scores.gather(
88 | dim=1, index=targets.unsqueeze(1)).sum()
89 | scores_upto_t = torch.zeros(batch_size, target_size).to(device)
90 | for t in range(max_len):
91 | batch_size_t = (lengths > t).sum().item()
92 | if t == 0:
93 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t,
94 | t, start_id, :]
95 | else:
96 | scores_upto_t[:batch_size_t] = torch.logsumexp(
97 | crf_scores[:batch_size_t, t, :, :] +
98 | scores_upto_t[:batch_size_t].unsqueeze(2),
99 | dim=1
100 | )
101 | all_path_scores = scores_upto_t[:, end_id].sum()
102 | loss = (all_path_scores - golden_scores) / batch_size
103 | return loss
104 |
105 | def indexed(targets, tagset_size, start_id):
106 | batch_size, max_len = targets.size()
107 | for col in range(max_len-1, 0, -1):
108 | targets[:, col] += (targets[:, col-1] * tagset_size)
109 | targets[:, 0] += (start_id * tagset_size)
110 | return targets
--------------------------------------------------------------------------------
/modelgraph/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__init__.py
--------------------------------------------------------------------------------
/modelgraph/__pycache__/BILSTM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__pycache__/BILSTM.cpython-37.pyc
--------------------------------------------------------------------------------
/modelgraph/__pycache__/BILSTM_CRF.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__pycache__/BILSTM_CRF.cpython-37.pyc
--------------------------------------------------------------------------------
/modelgraph/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/operate_bilstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from modelgraph.BILSTM import BiLSTM, cal_loss
6 | from modelgraph.BILSTM_CRF import BiLSTM_CRF, cal_lstm_crf_loss
7 | from config import TrainingConfig, LSTMConfig
8 | from utils import sort_by_lengths, tensorized
9 |
10 | from copy import deepcopy
11 | from tqdm import tqdm, trange
12 |
13 |
14 | class BiLSTM_operator(object):
15 | def __init__(self, vocab_size, out_size, crf=True):
16 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 | self.emb_size = LSTMConfig.emb_size
18 | self.hidden_size = LSTMConfig.hidden_size
19 | self.crf = crf
20 | if self.crf:
21 | self.model = BiLSTM_CRF(vocab_size,self.emb_size,self.hidden_size,out_size).to(self.device)
22 | self.cal_loss_func = cal_lstm_crf_loss
23 | else:
24 | self.model = BiLSTM(vocab_size,self.emb_size,self.hidden_size,out_size).to(self.device)
25 | self.cal_loss_func = cal_loss
26 |
27 | # 加载训练参数:
28 | self.epoches = TrainingConfig.epoches
29 | self.print_step = TrainingConfig.print_step
30 | self.lr = TrainingConfig.lr
31 | self.batch_size = TrainingConfig.batch_size
32 |
33 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
34 |
35 | self.step = 0
36 | self._best_val_loss = 1e18
37 | self.best_model = None
38 |
39 | def train(self, word_lists, tag_lists, dev_word_lists, dev_tag_lists, word2id, tag2id):
40 | word_lists, tag_lists, _ = sort_by_lengths(word_lists, tag_lists)
41 | dev_word_lists, dev_tag_lists, _ = sort_by_lengths(dev_word_lists, dev_tag_lists)
42 | print("训练数据总量:{}".format(len(word_lists)))
43 |
44 | batch_size = self.batch_size
45 | epoch_iterator = trange(1, self.epoches + 1, desc="Epoch")
46 | for epoch in epoch_iterator:
47 | self.step = 0
48 | losses = 0.
49 | for idx in trange(0,len(word_lists),batch_size,desc="Iteration"):
50 | batch_sents = word_lists[idx:idx+batch_size]
51 | batch_tags = tag_lists[idx:idx+batch_size]
52 | losses += self.train_step(batch_sents,batch_tags,word2id,tag2id)
53 |
54 | if self.step%TrainingConfig.print_step == 0:
55 | total_step = (len(word_lists)//batch_size + 1)
56 | print("Epoch {}, step/total_step: {}/{} {:.2f}% Loss:{:.4f}".format(
57 | epoch, self.step, total_step,
58 | 100. * self.step / total_step,
59 | losses / self.print_step
60 | ))
61 | losses = 0.
62 |
63 | val_loss = self.validate(
64 | dev_word_lists, dev_tag_lists, word2id, tag2id)
65 | print("Epoch {}, Val Loss:{:.4f}".format(epoch, val_loss))
66 |
67 | def train_step(self,batch_sents,batch_tags,word2id,tag2id):
68 | self.model.train()
69 | self.step+=1
70 |
71 | # 数据转tensor
72 | tensorized_sents,lengths = tensorized(batch_sents,word2id)
73 | targets,_ = tensorized(batch_tags,tag2id)
74 | tensorized_sents,targets = tensorized_sents.to(self.device),targets.to(self.device)
75 |
76 | scores = self.model(tensorized_sents,lengths)
77 |
78 | # 计算损失,反向传递
79 | self.model.zero_grad()
80 | loss = self.cal_loss_func(scores,targets,tag2id)
81 | loss.backward()
82 | self.optimizer.step()
83 |
84 | return loss.item()
85 |
86 | def validate(self, dev_word_lists, dev_tag_lists, word2id, tag2id):
87 | self.model.eval()
88 | with torch.no_grad():
89 | val_losses = 0.
90 | val_step = 0
91 | for ind in range(0, len(dev_word_lists), self.batch_size):
92 | val_step += 1
93 | # 准备batch数据
94 | batch_sents = dev_word_lists[ind:ind+self.batch_size]
95 | batch_tags = dev_tag_lists[ind:ind+self.batch_size]
96 | tensorized_sents, lengths = tensorized(batch_sents, word2id)
97 | tensorized_sents = tensorized_sents.to(self.device)
98 | targets, lengths = tensorized(batch_tags, tag2id)
99 | targets = targets.to(self.device)
100 |
101 | # forward
102 | scores = self.model(tensorized_sents, lengths)
103 |
104 | # 计算损失
105 | loss = self.cal_loss_func(scores, targets, tag2id).to(self.device)
106 | val_losses += loss.item()
107 | val_loss = val_losses / val_step
108 |
109 | if val_loss < self._best_val_loss:
110 | print("保存模型...")
111 | self.best_model = deepcopy(self.model)
112 | self._best_val_loss = val_loss
113 |
114 | return val_loss
115 |
116 | def test(self,word_lists,tag_lists,word2id,tag2id):
117 | word_lists,tag_lists,indices = sort_by_lengths(word_lists,tag_lists)
118 | tensorized_sents, lengths = tensorized(word_lists, word2id)
119 | tensorized_sents = tensorized_sents.to(self.device)
120 |
121 | self.best_model.eval()
122 | with torch.no_grad():
123 | batch_tagids = self.best_model.test(tensorized_sents,lengths,tag2id)
124 | pred_tag_lists = []
125 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items())
126 | for i, ids in enumerate(batch_tagids):
127 | tag_list = []
128 | if self.crf:
129 | for j in range(lengths[i] - 1):
130 | tag_list.append(id2tag[ids[j].item()])
131 | else:
132 | for j in range(lengths[i]):
133 | tag_list.append(id2tag[ids[j].item()])
134 | pred_tag_lists.append(tag_list)
135 | ind_maps = sorted(list(enumerate(indices)), key=lambda e: e[1])
136 | indices, _ = list(zip(*ind_maps))
137 | pred_tag_lists = [pred_tag_lists[i] for i in indices]
138 | tag_lists = [tag_lists[i] for i in indices]
139 |
140 | return pred_tag_lists, tag_lists
141 |
142 | def predict(self, word_lists, word2id, tag2id):
143 | """返回最佳模型在测试集上的预测结果"""
144 | # 数据准备
145 | # word_lists,tag_lists,indices = sort_by_lengths(word_lists,tag_lists)
146 |
147 | tensorized_sents, lengths = tensorized(word_lists, word2id)
148 | tensorized_sents = tensorized_sents.to(self.device)
149 |
150 | self.best_model.eval()
151 | with torch.no_grad():
152 | batch_tagids = self.best_model.test(tensorized_sents, lengths, tag2id)
153 |
154 | # 将id转化为标注
155 | pred_tag_lists = []
156 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items())
157 | for i, ids in enumerate(batch_tagids):
158 | tag_list = []
159 | if self.crf:
160 | for j in range(lengths[i] - 1):
161 | tag_list.append(id2tag[ids[j].item()])
162 | else:
163 | for j in range(lengths[i]):
164 | tag_list.append(id2tag[ids[j].item()])
165 | pred_tag_lists.append(tag_list)
166 |
167 | return pred_tag_lists
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | #%%
2 | import torch
3 | import pickle
4 | from utils import load_obj, tensorized
5 |
6 |
7 | def predict(model, text):
8 | text_list = list(text)
9 | text_list.append("")
10 | text_list = [text_list]
11 | crf_word2id = load_obj('crf_word2id')
12 | crf_tag2id = load_obj('crf_tag2id')
13 | # vocab_size = len(crf_word2id)
14 | # out_size = len(crf_tag2id)
15 | pred_tag_lists = model.predict(text_list, crf_word2id, crf_tag2id)
16 | return pred_tag_lists[0]
17 |
18 |
19 | def result_process(text_list, tag_list):
20 | tuple_result = zip(text_list, tag_list)
21 | sent_out = []
22 | tags_out = []
23 | outputs = []
24 | words = ""
25 | for s, t in tuple_result:
26 | if t.startswith('B-') or t == 'O':
27 | if len(words):
28 | sent_out.append(words)
29 | # print(sent_out)
30 | if t != 'O':
31 | tags_out.append(t.split('-')[1])
32 | else:
33 | tags_out.append(t)
34 | words = s
35 | # print(words)
36 | else:
37 | words += s
38 | # %%
39 | if len(sent_out) < len(tags_out):
40 | sent_out.append(words)
41 | outputs.append(''.join([str((s, t)) for s, t in zip(sent_out, tags_out)]))
42 | return outputs, [*zip(sent_out, tags_out)]
43 |
44 |
45 |
46 | #%%
47 | if __name__ == '__main__':
48 |
49 | modelpath = './ckpts/bilstm_crf.pkl'
50 | f = open(modelpath, 'rb')
51 | s = f.read()
52 | model = pickle.loads(s)
53 |
54 | text = '法外狂徒张三丰,身份证号362502190211032345'
55 | tag_res = predict(model, text)
56 | result, tuple_re = result_process(list(text), tag_res)
57 |
58 | print(text)
59 | # #%%
60 | #print(tuple_re)
61 | # print(result)
62 | result = []
63 | tag = []
64 | for s,t in tuple_re:
65 | if t !='O':
66 | result.append(s)
67 | tag.append(t)
68 | print([*zip(result, tag)])
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | tqdm==4.55.1
--------------------------------------------------------------------------------
/result.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/result.txt
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 |
4 |
5 | def sort_by_lengths(word_lists,tag_lists):
6 | pairs = list(zip(word_lists, tag_lists))
7 | indices = sorted(range(len(pairs)), key=lambda x: len(pairs[x][0]), reverse=True)
8 |
9 | pairs = [pairs[i] for i in indices]
10 | word_lists, tag_lists = list(zip(*pairs))
11 | return word_lists, tag_lists, indices
12 |
13 |
14 | def tensorized(batch, maps):
15 | PAD = maps.get('')
16 | UNK = maps.get('')
17 |
18 | max_len = len(batch[0])
19 | batch_size = len(batch)
20 |
21 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD
22 | for i, l in enumerate(batch):
23 | for j, e in enumerate(l):
24 | batch_tensor[i][j] = maps.get(e, UNK)
25 |
26 | lengths = [len(l) for l in batch]
27 | return batch_tensor, lengths
28 |
29 |
30 | def save_obj(obj, name):
31 | with open('data/'+ name + '.pkl', 'wb') as f:
32 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
33 |
34 |
35 | def load_obj(name):
36 | with open('data/' + name + '.pkl', 'rb') as f:
37 | return pickle.load(f)
38 |
39 |
40 | def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False):
41 | assert len(word_lists) == len(tag_lists)
42 | for i in range(len(word_lists)):
43 | word_lists[i].append("")
44 | if not test: # 如果是测试数据,就不需要加end token了
45 | tag_lists[i].append("")
46 |
47 | return word_lists, tag_lists
48 |
49 |
50 | def flatten_lists(lists):
51 | """将list of list 压平成list"""
52 | flatten_list = []
53 | for list_ in lists:
54 | if type(list_) == list:
55 | flatten_list.extend(list_)
56 | else:
57 | flatten_list.append(list_)
58 | return flatten_list
59 |
60 |
61 | def extend_maps(word2id, tag2id, for_crf=True):
62 | word2id[''] = len(word2id)
63 | word2id[''] = len(word2id)
64 | tag2id[''] = len(tag2id)
65 | tag2id[''] = len(tag2id)
66 | # 如果是加了CRF的bilstm 那么还要加入 和 token
67 | if for_crf:
68 | word2id[''] = len(word2id)
69 | word2id[''] = len(word2id)
70 | tag2id[''] = len(tag2id)
71 | tag2id[''] = len(tag2id)
72 |
73 | return word2id, tag2id
74 |
75 |
76 | def save_model(model,file_name):
77 | with open(file_name,'wb') as f:
78 | pickle.dump(model,f)
--------------------------------------------------------------------------------