├── .idea ├── .gitignore ├── ailabner.iml ├── encodings.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── README.md ├── __init__.py ├── __pycache__ ├── config.cpython-37.pyc ├── data.cpython-37.pyc ├── evaluate.cpython-37.pyc ├── evaluating.cpython-37.pyc ├── operate_bilstm.cpython-37.pyc └── utils.cpython-37.pyc ├── ckpts └── bilstm_crf.pkl ├── config.py ├── data.py ├── data ├── crf_tag2id.pkl ├── crf_word2id.pkl ├── dev.char ├── lables.char ├── test.char └── train.char ├── evaluate.py ├── evaluating.py ├── main.py ├── modelgraph ├── BILSTM.py ├── BILSTM_CRF.py ├── __init__.py └── __pycache__ │ ├── BILSTM.cpython-37.pyc │ ├── BILSTM_CRF.cpython-37.pyc │ └── __init__.cpython-37.pyc ├── operate_bilstm.py ├── predict.py ├── requirements.txt ├── result.txt └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/ailabner.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PytorchBilstmCRF-Information-Extraction 2 | 基于Bilstm + CRF的信息抽取模型 3 | 4 | 5 | 6 | 运行:python main.py 7 | 8 | 预测:python predict.py 9 | 10 | [博客链接:基于BiLSTM+CRF的信息抽取模型](https://blog.csdn.net/qq_44193969/article/details/116008734?spm=1001.2014.3001.5502) 11 | 12 | 有任何问题,随时私信 13 | 14 | 有任何建议,随时私信 15 | 16 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__init__.py -------------------------------------------------------------------------------- /__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/evaluate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/evaluate.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/evaluating.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/evaluating.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/operate_bilstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/operate_bilstm.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ckpts/bilstm_crf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/ckpts/bilstm_crf.pkl -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # 设置lstm训练参数 2 | class TrainingConfig(object): 3 | batch_size = 16 4 | # 学习速率 5 | lr = 0.0005 6 | epoches = 10 7 | print_step = 100 8 | 9 | class LSTMConfig(object): 10 | emb_size = 256 # 词向量的维数 11 | hidden_size = 256 # lstm隐向量的维数 -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from codecs import open 2 | import os 3 | 4 | 5 | def build_corpus(split, make_vocab=True, data_dir='./data'): 6 | assert split.lower() in ["train","dev","test"] 7 | word_lists = [] 8 | tag_lists = [] 9 | with open(os.path.join(data_dir,split+".char"),'r',encoding='utf-8') as f: 10 | word_list = [] 11 | tag_list = [] 12 | for line in f: 13 | if line != '\n': 14 | word,tag = line.strip('\n').split() 15 | word_list.append(word) 16 | tag_list.append(tag) 17 | else: 18 | word_lists.append(word_list) 19 | tag_lists.append(tag_list) 20 | word_list = [] 21 | tag_list = [] 22 | if make_vocab: 23 | word2id = build_map(word_lists) 24 | tag2id = build_map(tag_lists) 25 | return word_lists,tag_lists,word2id,tag2id 26 | else: 27 | return word_lists,tag_lists 28 | 29 | 30 | def build_map(lists): 31 | maps = {} 32 | for list_ in lists: 33 | for e in list_: 34 | if e not in maps: 35 | maps[e] = len(maps) 36 | return maps -------------------------------------------------------------------------------- /data/crf_tag2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/data/crf_tag2id.pkl -------------------------------------------------------------------------------- /data/crf_word2id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/data/crf_word2id.pkl -------------------------------------------------------------------------------- /data/lables.char: -------------------------------------------------------------------------------- 1 | B-NAME 2 | E-NAME 3 | O 4 | B-CONT 5 | M-CONT 6 | E-CONT 7 | B-RACE 8 | E-RACE 9 | B-TITLE 10 | M-TITLE 11 | E-TITLE 12 | B-EDU 13 | M-EDU 14 | E-EDU 15 | B-ORG 16 | M-ORG 17 | E-ORG 18 | M-NAME 19 | B-PRO 20 | M-PRO 21 | E-PRO 22 | S-RACE 23 | S-NAME 24 | B-LOC 25 | M-LOC 26 | E-LOC 27 | M-RACE 28 | S-ORG 29 | B-ID 30 | M-ID 31 | E-ID 32 | 33 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import Counter 3 | import pickle 4 | 5 | from operate_bilstm import BiLSTM_operator 6 | from evaluating import Metrics 7 | from utils import save_model 8 | 9 | 10 | def bilstm_train_and_eval(train_data,dev_data,test_data,word2id,tag2id,crf=True,remove_0=False): 11 | train_word_lists, train_tag_lists = train_data 12 | dev_word_lists, dev_tag_lists = dev_data 13 | test_word_lists, test_tag_lists = test_data 14 | 15 | start = time.time() 16 | vocab_size = len(word2id) 17 | out_size = len(tag2id) 18 | 19 | bilstm_operator = BiLSTM_operator(vocab_size,out_size,crf=crf) 20 | model_name = "bilstm_crf" if crf else "bilstm" 21 | 22 | print("start to train the {} ...".format(model_name)) 23 | bilstm_operator.train(train_word_lists,train_tag_lists,dev_word_lists,dev_tag_lists,word2id,tag2id) 24 | save_model(bilstm_operator, "./ckpts/" + model_name + ".pkl") 25 | 26 | print("训练完毕,共用时{}秒.".format(int(time.time() - start))) 27 | print("评估{}模型中...".format(model_name)) 28 | pred_tag_lists, test_tag_lists = bilstm_operator.test( 29 | test_word_lists, test_tag_lists, word2id, tag2id) 30 | 31 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_0=remove_0) 32 | dtype = 'Bi_LSTM+CRF' if crf else 'Bi_LSTM' 33 | metrics.report_scores(dtype=dtype) 34 | 35 | return pred_tag_lists -------------------------------------------------------------------------------- /evaluating.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from utils import flatten_lists 3 | 4 | class Metrics(object): 5 | """评价模型,计算每个标签的精确率、召回率、F1分数""" 6 | def __init__(self,gloden_tags,predict_tags,remove_0=False): 7 | self.golden_tags = flatten_lists(gloden_tags) 8 | self.predict_tags = flatten_lists(predict_tags) 9 | 10 | if remove_0: # 不统计非实体标记 11 | self._remove_Otags() 12 | 13 | # 所有的tag总数 14 | self.tagset = set(self.golden_tags) 15 | self.correct_tags_number = self.count_correct_tags() 16 | # print(self.correct_tags_number) 17 | self.predict_tags_count = Counter(self.predict_tags) 18 | self.golden_tags_count = Counter(self.golden_tags) 19 | 20 | # 精确率 21 | self.precision_scores = self.cal_precision() 22 | # 召回率 23 | self.recall_scores = self.cal_recall() 24 | # F1 25 | self.f1_scores = self.cal_f1() 26 | 27 | def cal_precision(self): 28 | """计算每个标签的精确率""" 29 | precision_scores = {} 30 | for tag in self.tagset: 31 | precision_scores[tag] = 0 if self.correct_tags_number.get(tag,0)==0 else \ 32 | self.correct_tags_number.get(tag,0) / self.predict_tags_count[tag] 33 | 34 | return precision_scores 35 | 36 | def cal_recall(self): 37 | """计算每个标签的召回率""" 38 | recall_scores = {} 39 | for tag in self.tagset: 40 | recall_scores[tag] = self.correct_tags_number.get(tag,0) / self.golden_tags_count[tag] 41 | 42 | return recall_scores 43 | 44 | def cal_f1(self): 45 | """计算f1分数""" 46 | f1_scores = {} 47 | for tag in self.tagset: 48 | f1_scores[tag] = 2*self.precision_scores[tag]*self.recall_scores[tag] / \ 49 | (self.precision_scores[tag] + self.recall_scores[tag] + 1e-10) 50 | return f1_scores 51 | 52 | def count_correct_tags(self): 53 | """计算每种标签预测正确的个数(对应精确率、召回率计算公式上的tp),用于后面精确率以及召回率的计算""" 54 | correct_dict = {} 55 | for gold_tag, predict_tag in zip(self.golden_tags, self.predict_tags): 56 | if gold_tag == predict_tag: 57 | if gold_tag not in correct_dict: 58 | correct_dict[gold_tag] = 1 59 | else: 60 | correct_dict[gold_tag] += 1 61 | 62 | return correct_dict 63 | 64 | def _remove_Otags(self): 65 | 66 | length = len(self.golden_tags) 67 | O_tag_indices = [i for i in range(length) 68 | if self.golden_tags[i] == 'O'] 69 | 70 | self.golden_tags = [tag for i, tag in enumerate(self.golden_tags) 71 | if i not in O_tag_indices] 72 | 73 | self.predict_tags = [tag for i, tag in enumerate(self.predict_tags) 74 | if i not in O_tag_indices] 75 | print("原总标记数为{},移除了{}个O标记,占比{:.2f}%".format( 76 | length, 77 | len(O_tag_indices), 78 | len(O_tag_indices) / length * 100 79 | )) 80 | 81 | def report_scores(self,dtype='HMM'): 82 | """将结果用表格的形式打印出来,像这个样子: 83 | 84 | precision recall f1-score support 85 | B-LOC 0.775 0.757 0.766 1084 86 | I-LOC 0.601 0.631 0.616 325 87 | B-MISC 0.698 0.499 0.582 339 88 | I-MISC 0.644 0.567 0.603 557 89 | B-ORG 0.795 0.801 0.798 1400 90 | I-ORG 0.831 0.773 0.801 1104 91 | B-PER 0.812 0.876 0.843 735 92 | I-PER 0.873 0.931 0.901 634 93 | 94 | avg/total 0.779 0.764 0.770 6178 95 | """ 96 | # 打印表头 97 | header_format = '{:>9s} {:>9} {:>9} {:>9} {:>9}' 98 | header = ['precision', 'recall', 'f1-score', 'support'] 99 | with open('result.txt','a') as fout: 100 | fout.write('\n') 101 | fout.write('=========='*10) 102 | fout.write('\n') 103 | fout.write('模型:{},test结果如下:'.format(dtype)) 104 | fout.write('\n') 105 | fout.write(header_format.format('', *header)) 106 | print(header_format.format('', *header)) 107 | 108 | row_format = '{:>9s} {:>9.4f} {:>9.4f} {:>9.4f} {:>9}' 109 | # 打印每个标签的 精确率、召回率、f1分数 110 | for tag in self.tagset: 111 | print(row_format.format( 112 | tag, 113 | self.precision_scores[tag], 114 | self.recall_scores[tag], 115 | self.f1_scores[tag], 116 | self.golden_tags_count[tag] 117 | )) 118 | fout.write('\n') 119 | fout.write(row_format.format( 120 | tag, 121 | self.precision_scores[tag], 122 | self.recall_scores[tag], 123 | self.f1_scores[tag], 124 | self.golden_tags_count[tag] 125 | )) 126 | 127 | # 计算并打印平均值 128 | avg_metrics = self._cal_weighted_average() 129 | print(row_format.format( 130 | 'avg/total', 131 | avg_metrics['precision'], 132 | avg_metrics['recall'], 133 | avg_metrics['f1_score'], 134 | len(self.golden_tags) 135 | )) 136 | fout.write('\n') 137 | fout.write(row_format.format( 138 | 'avg/total', 139 | avg_metrics['precision'], 140 | avg_metrics['recall'], 141 | avg_metrics['f1_score'], 142 | len(self.golden_tags) 143 | )) 144 | fout.write('\n') 145 | 146 | 147 | def _cal_weighted_average(self): 148 | 149 | weighted_average = {} 150 | total = len(self.golden_tags) 151 | 152 | # 计算weighted precisions: 153 | weighted_average['precision'] = 0. 154 | weighted_average['recall'] = 0. 155 | weighted_average['f1_score'] = 0. 156 | for tag in self.tagset: 157 | size = self.golden_tags_count[tag] 158 | weighted_average['precision'] += self.precision_scores[tag] * size 159 | weighted_average['recall'] += self.recall_scores[tag] * size 160 | weighted_average['f1_score'] += self.f1_scores[tag] * size 161 | 162 | for metric in weighted_average.keys(): 163 | weighted_average[metric] /= total 164 | 165 | return weighted_average 166 | 167 | def report_confusion_matrix(self): 168 | """计算混淆矩阵""" 169 | 170 | print("\nConfusion Matrix:") 171 | tag_list = list(self.tagset) 172 | # 初始化混淆矩阵 matrix[i][j]表示第i个tag被模型预测成第j个tag的次数 173 | tags_size = len(tag_list) 174 | matrix = [] 175 | for i in range(tags_size): 176 | matrix.append([0] * tags_size) 177 | 178 | # 遍历tags列表 179 | for golden_tag, predict_tag in zip(self.golden_tags, self.predict_tags): 180 | try: 181 | row = tag_list.index(golden_tag) 182 | col = tag_list.index(predict_tag) 183 | matrix[row][col] += 1 184 | except ValueError: # 有极少数标记没有出现在golden_tags,但出现在predict_tags,跳过这些标记 185 | continue 186 | 187 | # 输出矩阵 188 | row_format_ = '{:>7} ' * (tags_size+1) 189 | print(row_format_.format("", *tag_list)) 190 | for i, row in enumerate(matrix): 191 | print(row_format_.format(tag_list[i], *row)) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from data import build_corpus 2 | from evaluate import bilstm_train_and_eval 3 | from utils import extend_maps,prepocess_data_for_lstmcrf, save_obj, load_obj 4 | 5 | 6 | print("读取数据中...") 7 | train_word_lists,train_tag_lists,word2id,tag2id = build_corpus("train") 8 | dev_word_lists,dev_tag_lists = build_corpus("dev",make_vocab=False) 9 | test_word_lists,test_tag_lists = build_corpus("test",make_vocab=False) 10 | 11 | 12 | print("正在训练评估Bi-LSTM+CRF模型...") 13 | crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True) 14 | save_obj(crf_word2id, 'crf_word2id') 15 | save_obj(crf_tag2id, 'crf_tag2id') 16 | # import os 17 | # #保存word2id 18 | # if os.path.exists('data/crf_word2id.pkl'): 19 | # crf_word2id = load_obj('crf_word2id') 20 | # else: 21 | # save_obj(crf_word2id, 'crf_word2id') 22 | # 23 | # #保存tag2id 24 | # if os.path.exists('data/crf_tag2id.pkl'): 25 | # crf_tag2id = load_obj('crf_tag2id') 26 | # else: 27 | # save_obj(crf_tag2id, 'crf_tag2id') 28 | 29 | 30 | print(' '.join([i[0] for i in crf_tag2id.items()])) 31 | 32 | train_word_lists, train_tag_lists = prepocess_data_for_lstmcrf( 33 | train_word_lists, train_tag_lists 34 | ) 35 | 36 | 37 | dev_word_lists, dev_tag_lists = prepocess_data_for_lstmcrf( 38 | dev_word_lists, dev_tag_lists 39 | ) 40 | test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf( 41 | test_word_lists, test_tag_lists, test=True 42 | ) 43 | 44 | 45 | lstmcrf_pred = bilstm_train_and_eval( 46 | (train_word_lists, train_tag_lists), 47 | (dev_word_lists, dev_tag_lists), 48 | (test_word_lists, test_tag_lists), 49 | crf_word2id, crf_tag2id 50 | ) -------------------------------------------------------------------------------- /modelgraph/BILSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class BiLSTM(nn.Module): 6 | def __init__(self, vocab_size, emb_size, hidden_size, out_size, dropout=0.1): 7 | super(BiLSTM, self).__init__() 8 | self.embedding = nn.Embedding(vocab_size, emb_size) 9 | self.bilstm = nn.LSTM(emb_size, hidden_size, batch_first=True, bidirectional=True) 10 | self.fc =nn.Linear(2*hidden_size, out_size) 11 | self.dropout =nn.Dropout(dropout) 12 | 13 | def forward(self, x, lengths): 14 | emb = self.dropout(self.embedding(x)) 15 | emb = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True) 16 | emb, _ = self.bilstm(emb) 17 | # print("shape of x: ") 18 | # print(x.shape) 19 | emb, _ = nn.utils.rnn.pad_packed_sequence(emb, batch_first=True, padding_value=0., total_length=x.shape[1]) 20 | scores = self.fc(emb) 21 | 22 | return scores 23 | 24 | def test(self, x, lengths, _): 25 | logits = self.forward(x, lengths) 26 | _, batch_tagids = torch.max(logits, dim=2) 27 | return batch_tagids 28 | 29 | def cal_loss(logits, targets, tag2id): 30 | PAD = tag2id.get('') 31 | assert PAD is not None 32 | mask = (targets != PAD) 33 | targets = targets[mask] 34 | out_size = logits.size(2) 35 | logits = logits.masked_select( 36 | mask.unsqueeze(2).expand(-1, -1, out_size) 37 | ).contiguous().view(-1, out_size) 38 | assert logits.size(0) == targets.size(0) 39 | loss = F.cross_entropy(logits, targets) 40 | return loss -------------------------------------------------------------------------------- /modelgraph/BILSTM_CRF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modelgraph.BILSTM import BiLSTM 4 | from itertools import zip_longest 5 | 6 | class BiLSTM_CRF(nn.Module): 7 | def __init__(self, vocab_size, emb_size, hidden_size, out_size): 8 | super(BiLSTM_CRF, self).__init__() 9 | self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size) 10 | self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size) 11 | 12 | def forward(self, sents_tensor, lengths): 13 | emission = self.bilstm(sents_tensor, lengths) 14 | batch_size, max_len, out_size = emission.size() 15 | crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0) 16 | return crf_scores 17 | 18 | def test(self, test_sents_tensor, lengths, tag2id): 19 | start_id = tag2id[''] 20 | end_id = tag2id[''] 21 | pad = tag2id[''] 22 | tagset_size = len(tag2id) 23 | 24 | crf_scores =self.forward(test_sents_tensor, lengths) 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | B , L , T, _ =crf_scores.size() 27 | 28 | viterbi = torch.zeros(B, L, T).to(device) 29 | backpointer = (torch.zeros(B, L, T).long() * end_id).to(device) 30 | 31 | lengths = torch.LongTensor(lengths).to(device) 32 | 33 | for step in range(L): 34 | batch_size_t =(lengths > step).sum().item() 35 | if step == 0: 36 | viterbi[:batch_size_t, step, :] = crf_scores[: batch_size_t, step, start_id, :] 37 | backpointer[:batch_size_t, step, :] = start_id 38 | else: 39 | max_scores, prev_tags = torch.max(viterbi[:batch_size_t, step-1, :].unsqueeze(2) + crf_scores[:batch_size_t, step, :, :], dim=1) 40 | viterbi[:batch_size_t, step, :] = max_scores 41 | backpointer[:batch_size_t, step, :] = prev_tags 42 | 43 | backpointer = backpointer.view(B, -1) 44 | tagids = [] 45 | tags_t = None 46 | for step in range(L-1, 0, -1): 47 | batch_size_t = (lengths > step).sum().item() 48 | if step == L-1: 49 | index = torch.ones(batch_size_t).long() * (step * tagset_size) 50 | index = index.to(device) 51 | index += end_id 52 | else: 53 | prev_batch_size_t = len(tags_t) 54 | new_in_batch = torch.LongTensor([end_id] * (batch_size_t - prev_batch_size_t)).to(device) 55 | offset = torch.cat([tags_t, new_in_batch], dim=0) 56 | index = torch.ones(batch_size_t).long() * (step *tagset_size) 57 | index = index.to(device) 58 | index += offset.long() 59 | 60 | try: 61 | tags_t = backpointer[:batch_size_t].gather(dim=1, index=index.unsqueeze(1).long()) 62 | except RuntimeError: 63 | import pdb 64 | pdb.set_trace() 65 | tags_t = tags_t.squeeze(1) 66 | tagids.append(tags_t.tolist()) 67 | tagids = list(zip_longest(*reversed(tagids), fillvalue=pad)) 68 | tagids = torch.Tensor(tagids).long() 69 | 70 | return tagids 71 | 72 | 73 | def cal_lstm_crf_loss(crf_scores, targets, tag2id): 74 | pad_id = tag2id.get('') 75 | start_id = tag2id.get('') 76 | end_id = tag2id.get('') 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | batch_size, max_len = targets.size() 79 | target_size = len(tag2id) 80 | mask = (targets != pad_id) 81 | lengths = mask.sum(dim=1) 82 | targets = indexed(targets, target_size, start_id) 83 | targets = targets.masked_select(mask) 84 | flatten_scores = crf_scores.masked_select( 85 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores) 86 | ).view(-1, target_size*target_size).contiguous() 87 | golden_scores = flatten_scores.gather( 88 | dim=1, index=targets.unsqueeze(1)).sum() 89 | scores_upto_t = torch.zeros(batch_size, target_size).to(device) 90 | for t in range(max_len): 91 | batch_size_t = (lengths > t).sum().item() 92 | if t == 0: 93 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, 94 | t, start_id, :] 95 | else: 96 | scores_upto_t[:batch_size_t] = torch.logsumexp( 97 | crf_scores[:batch_size_t, t, :, :] + 98 | scores_upto_t[:batch_size_t].unsqueeze(2), 99 | dim=1 100 | ) 101 | all_path_scores = scores_upto_t[:, end_id].sum() 102 | loss = (all_path_scores - golden_scores) / batch_size 103 | return loss 104 | 105 | def indexed(targets, tagset_size, start_id): 106 | batch_size, max_len = targets.size() 107 | for col in range(max_len-1, 0, -1): 108 | targets[:, col] += (targets[:, col-1] * tagset_size) 109 | targets[:, 0] += (start_id * tagset_size) 110 | return targets -------------------------------------------------------------------------------- /modelgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__init__.py -------------------------------------------------------------------------------- /modelgraph/__pycache__/BILSTM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__pycache__/BILSTM.cpython-37.pyc -------------------------------------------------------------------------------- /modelgraph/__pycache__/BILSTM_CRF.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__pycache__/BILSTM_CRF.cpython-37.pyc -------------------------------------------------------------------------------- /modelgraph/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/modelgraph/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /operate_bilstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modelgraph.BILSTM import BiLSTM, cal_loss 6 | from modelgraph.BILSTM_CRF import BiLSTM_CRF, cal_lstm_crf_loss 7 | from config import TrainingConfig, LSTMConfig 8 | from utils import sort_by_lengths, tensorized 9 | 10 | from copy import deepcopy 11 | from tqdm import tqdm, trange 12 | 13 | 14 | class BiLSTM_operator(object): 15 | def __init__(self, vocab_size, out_size, crf=True): 16 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | self.emb_size = LSTMConfig.emb_size 18 | self.hidden_size = LSTMConfig.hidden_size 19 | self.crf = crf 20 | if self.crf: 21 | self.model = BiLSTM_CRF(vocab_size,self.emb_size,self.hidden_size,out_size).to(self.device) 22 | self.cal_loss_func = cal_lstm_crf_loss 23 | else: 24 | self.model = BiLSTM(vocab_size,self.emb_size,self.hidden_size,out_size).to(self.device) 25 | self.cal_loss_func = cal_loss 26 | 27 | # 加载训练参数: 28 | self.epoches = TrainingConfig.epoches 29 | self.print_step = TrainingConfig.print_step 30 | self.lr = TrainingConfig.lr 31 | self.batch_size = TrainingConfig.batch_size 32 | 33 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 34 | 35 | self.step = 0 36 | self._best_val_loss = 1e18 37 | self.best_model = None 38 | 39 | def train(self, word_lists, tag_lists, dev_word_lists, dev_tag_lists, word2id, tag2id): 40 | word_lists, tag_lists, _ = sort_by_lengths(word_lists, tag_lists) 41 | dev_word_lists, dev_tag_lists, _ = sort_by_lengths(dev_word_lists, dev_tag_lists) 42 | print("训练数据总量:{}".format(len(word_lists))) 43 | 44 | batch_size = self.batch_size 45 | epoch_iterator = trange(1, self.epoches + 1, desc="Epoch") 46 | for epoch in epoch_iterator: 47 | self.step = 0 48 | losses = 0. 49 | for idx in trange(0,len(word_lists),batch_size,desc="Iteration"): 50 | batch_sents = word_lists[idx:idx+batch_size] 51 | batch_tags = tag_lists[idx:idx+batch_size] 52 | losses += self.train_step(batch_sents,batch_tags,word2id,tag2id) 53 | 54 | if self.step%TrainingConfig.print_step == 0: 55 | total_step = (len(word_lists)//batch_size + 1) 56 | print("Epoch {}, step/total_step: {}/{} {:.2f}% Loss:{:.4f}".format( 57 | epoch, self.step, total_step, 58 | 100. * self.step / total_step, 59 | losses / self.print_step 60 | )) 61 | losses = 0. 62 | 63 | val_loss = self.validate( 64 | dev_word_lists, dev_tag_lists, word2id, tag2id) 65 | print("Epoch {}, Val Loss:{:.4f}".format(epoch, val_loss)) 66 | 67 | def train_step(self,batch_sents,batch_tags,word2id,tag2id): 68 | self.model.train() 69 | self.step+=1 70 | 71 | # 数据转tensor 72 | tensorized_sents,lengths = tensorized(batch_sents,word2id) 73 | targets,_ = tensorized(batch_tags,tag2id) 74 | tensorized_sents,targets = tensorized_sents.to(self.device),targets.to(self.device) 75 | 76 | scores = self.model(tensorized_sents,lengths) 77 | 78 | # 计算损失,反向传递 79 | self.model.zero_grad() 80 | loss = self.cal_loss_func(scores,targets,tag2id) 81 | loss.backward() 82 | self.optimizer.step() 83 | 84 | return loss.item() 85 | 86 | def validate(self, dev_word_lists, dev_tag_lists, word2id, tag2id): 87 | self.model.eval() 88 | with torch.no_grad(): 89 | val_losses = 0. 90 | val_step = 0 91 | for ind in range(0, len(dev_word_lists), self.batch_size): 92 | val_step += 1 93 | # 准备batch数据 94 | batch_sents = dev_word_lists[ind:ind+self.batch_size] 95 | batch_tags = dev_tag_lists[ind:ind+self.batch_size] 96 | tensorized_sents, lengths = tensorized(batch_sents, word2id) 97 | tensorized_sents = tensorized_sents.to(self.device) 98 | targets, lengths = tensorized(batch_tags, tag2id) 99 | targets = targets.to(self.device) 100 | 101 | # forward 102 | scores = self.model(tensorized_sents, lengths) 103 | 104 | # 计算损失 105 | loss = self.cal_loss_func(scores, targets, tag2id).to(self.device) 106 | val_losses += loss.item() 107 | val_loss = val_losses / val_step 108 | 109 | if val_loss < self._best_val_loss: 110 | print("保存模型...") 111 | self.best_model = deepcopy(self.model) 112 | self._best_val_loss = val_loss 113 | 114 | return val_loss 115 | 116 | def test(self,word_lists,tag_lists,word2id,tag2id): 117 | word_lists,tag_lists,indices = sort_by_lengths(word_lists,tag_lists) 118 | tensorized_sents, lengths = tensorized(word_lists, word2id) 119 | tensorized_sents = tensorized_sents.to(self.device) 120 | 121 | self.best_model.eval() 122 | with torch.no_grad(): 123 | batch_tagids = self.best_model.test(tensorized_sents,lengths,tag2id) 124 | pred_tag_lists = [] 125 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items()) 126 | for i, ids in enumerate(batch_tagids): 127 | tag_list = [] 128 | if self.crf: 129 | for j in range(lengths[i] - 1): 130 | tag_list.append(id2tag[ids[j].item()]) 131 | else: 132 | for j in range(lengths[i]): 133 | tag_list.append(id2tag[ids[j].item()]) 134 | pred_tag_lists.append(tag_list) 135 | ind_maps = sorted(list(enumerate(indices)), key=lambda e: e[1]) 136 | indices, _ = list(zip(*ind_maps)) 137 | pred_tag_lists = [pred_tag_lists[i] for i in indices] 138 | tag_lists = [tag_lists[i] for i in indices] 139 | 140 | return pred_tag_lists, tag_lists 141 | 142 | def predict(self, word_lists, word2id, tag2id): 143 | """返回最佳模型在测试集上的预测结果""" 144 | # 数据准备 145 | # word_lists,tag_lists,indices = sort_by_lengths(word_lists,tag_lists) 146 | 147 | tensorized_sents, lengths = tensorized(word_lists, word2id) 148 | tensorized_sents = tensorized_sents.to(self.device) 149 | 150 | self.best_model.eval() 151 | with torch.no_grad(): 152 | batch_tagids = self.best_model.test(tensorized_sents, lengths, tag2id) 153 | 154 | # 将id转化为标注 155 | pred_tag_lists = [] 156 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items()) 157 | for i, ids in enumerate(batch_tagids): 158 | tag_list = [] 159 | if self.crf: 160 | for j in range(lengths[i] - 1): 161 | tag_list.append(id2tag[ids[j].item()]) 162 | else: 163 | for j in range(lengths[i]): 164 | tag_list.append(id2tag[ids[j].item()]) 165 | pred_tag_lists.append(tag_list) 166 | 167 | return pred_tag_lists -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | import pickle 4 | from utils import load_obj, tensorized 5 | 6 | 7 | def predict(model, text): 8 | text_list = list(text) 9 | text_list.append("") 10 | text_list = [text_list] 11 | crf_word2id = load_obj('crf_word2id') 12 | crf_tag2id = load_obj('crf_tag2id') 13 | # vocab_size = len(crf_word2id) 14 | # out_size = len(crf_tag2id) 15 | pred_tag_lists = model.predict(text_list, crf_word2id, crf_tag2id) 16 | return pred_tag_lists[0] 17 | 18 | 19 | def result_process(text_list, tag_list): 20 | tuple_result = zip(text_list, tag_list) 21 | sent_out = [] 22 | tags_out = [] 23 | outputs = [] 24 | words = "" 25 | for s, t in tuple_result: 26 | if t.startswith('B-') or t == 'O': 27 | if len(words): 28 | sent_out.append(words) 29 | # print(sent_out) 30 | if t != 'O': 31 | tags_out.append(t.split('-')[1]) 32 | else: 33 | tags_out.append(t) 34 | words = s 35 | # print(words) 36 | else: 37 | words += s 38 | # %% 39 | if len(sent_out) < len(tags_out): 40 | sent_out.append(words) 41 | outputs.append(''.join([str((s, t)) for s, t in zip(sent_out, tags_out)])) 42 | return outputs, [*zip(sent_out, tags_out)] 43 | 44 | 45 | 46 | #%% 47 | if __name__ == '__main__': 48 | 49 | modelpath = './ckpts/bilstm_crf.pkl' 50 | f = open(modelpath, 'rb') 51 | s = f.read() 52 | model = pickle.loads(s) 53 | 54 | text = '法外狂徒张三丰,身份证号362502190211032345' 55 | tag_res = predict(model, text) 56 | result, tuple_re = result_process(list(text), tag_res) 57 | 58 | print(text) 59 | # #%% 60 | #print(tuple_re) 61 | # print(result) 62 | result = [] 63 | tag = [] 64 | for s,t in tuple_re: 65 | if t !='O': 66 | result.append(s) 67 | tag.append(t) 68 | print([*zip(result, tag)]) 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | tqdm==4.55.1 -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanzhang-zhichen/PytorchBilstmCRF-Information-Extraction/0c6e9bc0d8aaec28e6ecc5e2b6efbc194356833d/result.txt -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | 4 | 5 | def sort_by_lengths(word_lists,tag_lists): 6 | pairs = list(zip(word_lists, tag_lists)) 7 | indices = sorted(range(len(pairs)), key=lambda x: len(pairs[x][0]), reverse=True) 8 | 9 | pairs = [pairs[i] for i in indices] 10 | word_lists, tag_lists = list(zip(*pairs)) 11 | return word_lists, tag_lists, indices 12 | 13 | 14 | def tensorized(batch, maps): 15 | PAD = maps.get('') 16 | UNK = maps.get('') 17 | 18 | max_len = len(batch[0]) 19 | batch_size = len(batch) 20 | 21 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD 22 | for i, l in enumerate(batch): 23 | for j, e in enumerate(l): 24 | batch_tensor[i][j] = maps.get(e, UNK) 25 | 26 | lengths = [len(l) for l in batch] 27 | return batch_tensor, lengths 28 | 29 | 30 | def save_obj(obj, name): 31 | with open('data/'+ name + '.pkl', 'wb') as f: 32 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 33 | 34 | 35 | def load_obj(name): 36 | with open('data/' + name + '.pkl', 'rb') as f: 37 | return pickle.load(f) 38 | 39 | 40 | def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False): 41 | assert len(word_lists) == len(tag_lists) 42 | for i in range(len(word_lists)): 43 | word_lists[i].append("") 44 | if not test: # 如果是测试数据,就不需要加end token了 45 | tag_lists[i].append("") 46 | 47 | return word_lists, tag_lists 48 | 49 | 50 | def flatten_lists(lists): 51 | """将list of list 压平成list""" 52 | flatten_list = [] 53 | for list_ in lists: 54 | if type(list_) == list: 55 | flatten_list.extend(list_) 56 | else: 57 | flatten_list.append(list_) 58 | return flatten_list 59 | 60 | 61 | def extend_maps(word2id, tag2id, for_crf=True): 62 | word2id[''] = len(word2id) 63 | word2id[''] = len(word2id) 64 | tag2id[''] = len(tag2id) 65 | tag2id[''] = len(tag2id) 66 | # 如果是加了CRF的bilstm 那么还要加入token 67 | if for_crf: 68 | word2id[''] = len(word2id) 69 | word2id[''] = len(word2id) 70 | tag2id[''] = len(tag2id) 71 | tag2id[''] = len(tag2id) 72 | 73 | return word2id, tag2id 74 | 75 | 76 | def save_model(model,file_name): 77 | with open(file_name,'wb') as f: 78 | pickle.dump(model,f) --------------------------------------------------------------------------------