├── src ├── data │ ├── __init__.py │ ├── path.py │ └── data_process.py ├── model │ ├── __init__.py │ ├── path.py │ ├── bilstm.py │ ├── bilstm_crf.py │ └── bert_bilstm_crf.py ├── tools │ ├── __init__.py │ ├── help.py │ ├── get_ner_level_acc.py │ └── get_pretrained_vec.py ├── train │ ├── __Init__.py │ ├── train.py │ └── train_helper.py └── main.py ├── .gitignore └── README.md /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/train/__Init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data/model 3 | data/model/bert 4 | data/model/bilstm-crf 5 | data/model/data_obj 6 | .DS_Store 7 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BiLSTM-CRF-NER 2 | 3 | 博客链接: 4 | 5 | [https://blog.csdn.net/qq_44193969/article/details/123998548](https://blog.csdn.net/qq_44193969/article/details/123998548) 6 | 7 | 代码中有任何错误请在`Github`中提`issue`,文章中有任何错误希望各位可以指正,非常感谢您的宝贵意见🌹 8 | 9 | 有帮助麻烦给个`star`,谢谢 10 | -------------------------------------------------------------------------------- /src/data/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_data_dir(): 5 | data_dir = "../data/dataset" 6 | if not os.path.exists(data_dir): 7 | os.makedirs(data_dir) 8 | return data_dir 9 | 10 | 11 | def get_train_data_path(): 12 | train_path = os.path.join(get_data_dir(), "train.txt") 13 | return train_path 14 | 15 | 16 | def get_eval_data_path(): 17 | eval_path = os.path.join(get_data_dir(), "eval.txt") 18 | return eval_path 19 | 20 | 21 | def get_test_data_path(): 22 | test_path = os.path.join(get_data_dir(), "test.txt") 23 | return test_path -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from train.train import Train 2 | from tools.get_ner_level_acc import precision 3 | 4 | if __name__ == "__main__": 5 | 6 | use_pretrained_w2v = False 7 | model_type = "bert-bilstm-crf" 8 | 9 | model_train = Train() 10 | model_train.train(use_pretrained_w2v=use_pretrained_w2v, model_type=model_type) 11 | 12 | text = "张铁柱,大学本科,毕业于东华理工大学,汉族。" 13 | 14 | result = model_train.predict(text, use_pretrained_w2v, model_type) 15 | print(result[0]) 16 | 17 | result_dic = model_train.get_ner_list_dic(text, use_pretrained_w2v, model_type) 18 | print(result_dic) 19 | 20 | -------------------------------------------------------------------------------- /src/model/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_bert_dir(): 4 | return '../data/model/bert' 5 | 6 | 7 | def get_model_dir(): 8 | model_dir = '../data/model' 9 | if not os.path.exists(model_dir): 10 | os.makedirs(model_dir) 11 | return model_dir 12 | 13 | 14 | def get_chinese_wwm_ext_pytorch_path(): 15 | bert_path = os.path.join(get_bert_dir(), 'chinese_wwm_ext_pytorch') 16 | return bert_path 17 | 18 | 19 | 20 | def get_data_obj_dir(): 21 | data_obj_dir = '../data/model/data_obj' 22 | if not os.path.exists(data_obj_dir): 23 | os.makedirs(data_obj_dir) 24 | return data_obj_dir 25 | 26 | 27 | def get_word2id_path(): 28 | word2id_path = os.path.join(get_data_obj_dir(), 'word2id.pkl') 29 | return word2id_path 30 | 31 | 32 | def get_id2word_path(): 33 | id2word_path = os.path.join(get_data_obj_dir(), 'id2word.pkl') 34 | return id2word_path 35 | 36 | 37 | def get_tag2id_path(): 38 | tag2id_path = os.path.join(get_data_obj_dir(), "tag2id.pkl") 39 | return tag2id_path 40 | 41 | 42 | def get_pretrained_char_vec_path(): 43 | pretained_char_vec_path = os.path.join(get_data_obj_dir(), "pretrained_char_vec.txt") 44 | return pretained_char_vec_path -------------------------------------------------------------------------------- /src/tools/help.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | 5 | 6 | def save_as_pickle(obj, file_path): 7 | with open(file_path, 'wb') as f: 8 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 9 | 10 | 11 | def load_pickle_obj(file_path): 12 | with open(file_path, 'rb') as f: 13 | return pickle.load(f) 14 | 15 | 16 | def sort_by_lengths(word_lists,tag_lists): 17 | pairs = list(zip(word_lists, tag_lists)) 18 | indices = sorted(range(len(pairs)), key=lambda x: len(pairs[x][0]), reverse=True) 19 | 20 | pairs = [pairs[i] for i in indices] 21 | word_lists, tag_lists = list(zip(*pairs)) 22 | return word_lists, tag_lists, indices 23 | 24 | 25 | def batch_sents_to_tensorized(batch, maps): 26 | PAD = maps.get('') 27 | UNK = maps.get('') 28 | 29 | max_len = len(batch[0]) 30 | batch_size = len(batch) 31 | 32 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD 33 | for i, l in enumerate(batch): 34 | for j, e in enumerate(l): 35 | batch_tensor[i][j] = maps.get(e, UNK) 36 | 37 | lengths = [len(l) for l in batch] 38 | return batch_tensor, lengths 39 | 40 | 41 | def flatten_lists(lists): 42 | """将list of list 压平成list""" 43 | flatten_list = [] 44 | for list_ in lists: 45 | if type(list_) == list: 46 | flatten_list.extend(list_) 47 | else: 48 | flatten_list.append(list_) 49 | return flatten_list -------------------------------------------------------------------------------- /src/data/data_process.py: -------------------------------------------------------------------------------- 1 | 2 | def load_data(file_path): 3 | """ 4 | # 加载数据 5 | word_lists: [[]] 6 | tag_lists: [[]] 7 | """ 8 | word_lists = [] 9 | tag_lists = [] 10 | with open(file_path, 'r') as f: 11 | word_list = [] 12 | tag_list = [] 13 | for line in f: 14 | if line != '\n': 15 | word, tag = line.strip('\n').split() 16 | word_list.append(word) 17 | tag_list.append(tag) 18 | else: 19 | word_lists.append(word_list) 20 | tag_lists.append(tag_list) 21 | word_list = [] 22 | tag_list = [] 23 | return word_lists, tag_lists 24 | 25 | 26 | def get_word2id(lists): 27 | """ 28 | # 得到 word2id dict 29 | """ 30 | maps = {} 31 | for list_ in lists: 32 | for e in list_: 33 | if e not in maps: 34 | maps[e] = len(maps) 35 | return maps 36 | 37 | 38 | def extend_vocab(word2id, tag2id, for_crf=True): 39 | """ 40 | # 补充word2id 41 | 未登陆词: 42 | 补码: 43 | 句子开始标志: 44 | 句子结束标志: 45 | """ 46 | word2id[''] = len(word2id) 47 | word2id[''] = len(word2id) 48 | tag2id[''] = len(tag2id) 49 | tag2id[''] = len(tag2id) 50 | # 如果是加了CRF的bilstm 那么还要加入token 51 | if for_crf: 52 | word2id[''] = len(word2id) 53 | word2id[''] = len(word2id) 54 | tag2id[''] = len(tag2id) 55 | tag2id[''] = len(tag2id) 56 | 57 | return word2id, tag2id 58 | 59 | 60 | def add_end_token(word_lists, tag_lists, test=False): 61 | ''' 62 | # 加上结束符: 63 | ''' 64 | assert len(word_lists) == len(tag_lists) 65 | for i in range(len(word_lists)): 66 | # 给每个句子末尾加上 67 | word_lists[i].append("") 68 | if not test: # 如果是测试数据,就不需要加end token了 69 | tag_lists[i].append("") 70 | 71 | return word_lists, tag_lists 72 | 73 | -------------------------------------------------------------------------------- /src/model/bilstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from tools.get_pretrained_vec import GetPretrainedVec 5 | from model.path import get_pretrained_char_vec_path 6 | from model.path import get_word2id_path 7 | from model.path import get_id2word_path 8 | 9 | 10 | class BiLSTM(nn.Module): 11 | def __init__(self, vocab_size, emb_size, hidden_size, out_size, drop_out=0.5, use_pretrained_w2v=False): 12 | super(BiLSTM, self).__init__() 13 | self.embedding = nn.Embedding(vocab_size, emb_size) 14 | if use_pretrained_w2v: 15 | print("加载预训练字向量...") 16 | vec_path = get_pretrained_char_vec_path() 17 | word2id_path = get_word2id_path() 18 | id2word_path = get_id2word_path() 19 | embedding_pretrained = GetPretrainedVec.get_w2v_weight(emb_size, vec_path, word2id_path, id2word_path) 20 | self.embedding.weight.data.copy_(embedding_pretrained) 21 | self.embedding.weight.requires_grad = True 22 | self.bilstm = nn.LSTM(emb_size, hidden_size, batch_first=True, bidirectional=True) 23 | self.fc = nn.Linear(hidden_size*2, out_size) 24 | self.dropout = nn.Dropout(drop_out) 25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | def forward(self, x, lengths): 28 | x = x.to(self.device) 29 | emb = self.embedding(x) 30 | emb = self.dropout(emb) 31 | emb = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True) 32 | emb, _ = self.bilstm(emb) 33 | output, _ = nn.utils.rnn.pad_packed_sequence(emb, batch_first=True, padding_value=0., total_length=x.shape[1]) 34 | scores = self.fc(output) 35 | return scores 36 | 37 | def predict(self, x, lengths, _): 38 | scores = self.forward(x, lengths) 39 | _, batch_tagids = torch.max(scores, dim=2) 40 | return batch_tagids 41 | 42 | 43 | def cal_bilstm_loss(logits, targets, tag2id): 44 | PAD = tag2id.get('') 45 | assert PAD is not None 46 | mask = (targets != PAD) 47 | targets = targets[mask] 48 | out_size = logits.size(2) 49 | logits = logits.masked_select( 50 | mask.unsqueeze(2).expand(-1, -1, out_size) 51 | ).contiguous().view(-1, out_size) 52 | assert logits.size(0) == targets.size(0) 53 | loss = F.cross_entropy(logits, targets) 54 | return loss -------------------------------------------------------------------------------- /src/tools/get_ner_level_acc.py: -------------------------------------------------------------------------------- 1 | 2 | from .help import flatten_lists 3 | 4 | def _find_tag(labels, B_label="B-COM",I_label="M-COM", E_label="E-COM", S_label="S-COM"): 5 | result = [] 6 | lenth = 0 7 | for num in range(len(labels)): 8 | if labels[num] == B_label: 9 | song_pos0 = num 10 | if labels[num] == B_label and labels[num+1] == E_label: 11 | lenth = 2 12 | result.append((song_pos0,lenth)) 13 | 14 | if labels[num] == I_label and labels[num-1] == B_label: 15 | lenth = 2 16 | for num2 in range(num,len(labels)): 17 | if labels[num2] == I_label and labels[num2-1] == I_label: 18 | lenth += 1 19 | if labels[num2] == E_label: 20 | lenth += 1 21 | result.append((song_pos0,lenth)) 22 | break 23 | if labels[num] == S_label: 24 | lenth = 1 25 | song_pos0 = num 26 | result.append((song_pos0,lenth)) 27 | 28 | return result 29 | 30 | 31 | tags = [("B-NAME","M-NAME", "E-NAME", "S-NAME"), 32 | ("B-TITLE","M-TITLE", "E-TITLE", "S-TITLE"), 33 | ("B-ORG","M-ORG", "E-ORG", "S-ORG"), 34 | ("B-RACE","M-RACE", "E-RACE", "S-RACE"), 35 | ("B-EDU","M-EDU", "E-EDU", "S-EDU"), 36 | ("B-CONT","M-CONT", "E-CONT", "S-CONT"), 37 | ("B-LOC","M-LOC", "E-LOC", "S-LOC"), 38 | ("B-PRO","M-PRO", "E-PRO", "S-PRO")] 39 | 40 | 41 | def find_all_tag(labels): 42 | result = {} 43 | for tag in tags: 44 | res = _find_tag(labels, B_label=tag[0], I_label=tag[1], E_label=tag[2], S_label=tag[3]) 45 | result[tag[0].split("-")[1]] = res 46 | return result 47 | 48 | def precision(pre_labels,true_labels): 49 | ''' 50 | :param pre_tags: list 51 | :param true_tags: list 52 | :return: 53 | ''' 54 | pre = [] 55 | pre_labels = flatten_lists(pre_labels) 56 | true_labels = flatten_lists(true_labels) 57 | 58 | pre_result = find_all_tag(pre_labels) 59 | true_result = find_all_tag(true_labels) 60 | 61 | result_dic = {} 62 | for name in pre_result: 63 | for x in pre_result[name]: 64 | if result_dic.get(name) is None: 65 | result_dic[name] = [] 66 | if x: 67 | if pre_labels[x[0]:x[0]+x[1]] == true_labels[x[0]:x[0]+x[1]]: 68 | result_dic[name].append(1) 69 | else: 70 | result_dic[name].append(0) 71 | # print(f'tag: {name} , length: {len(result_dic[name])}') 72 | 73 | sum_result = 0 74 | for name in result_dic: 75 | sum_result += sum(result_dic[name]) 76 | # print(f'tag2: {name} , length2: {len(result_dic[name])}') 77 | result_dic[name] = sum(result_dic[name]) / len(result_dic[name]) 78 | 79 | for name in pre_result: 80 | for x in pre_result[name]: 81 | if x: 82 | if pre_labels[x[0]:x[0]+x[1]] == true_labels[x[0]:x[0]+x[1]]: 83 | pre.append(1) 84 | else: 85 | pre.append(0) 86 | total_precision = sum(pre)/len(pre) 87 | 88 | return total_precision, result_dic -------------------------------------------------------------------------------- /src/tools/get_pretrained_vec.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | import jieba 3 | import torch 4 | import logging 5 | import numpy as np 6 | from tqdm import tqdm 7 | from .help import load_pickle_obj 8 | from gensim.models import KeyedVectors 9 | from transformers import BertModel, BertTokenizerFast 10 | from model.path import get_chinese_wwm_ext_pytorch_path 11 | 12 | 13 | jieba.setLogLevel(logging.INFO) 14 | 15 | 16 | class GetPretrainedVec: 17 | def __init__(self): 18 | self.bert_path = get_chinese_wwm_ext_pytorch_path() 19 | 20 | def load(self): 21 | self.bert = BertModel.from_pretrained(self.bert_path) 22 | self.token = BertTokenizerFast.from_pretrained(self.bert_path) 23 | 24 | # Bert 字向量生成 25 | def get_data(self, path, char=False): 26 | words = [] 27 | with open(path, "r", encoding="utf-8") as f: 28 | sentences = f.readlines() 29 | if char: 30 | for sent in sentences: 31 | words.extend([word.strip() for word in sent.strip().replace(" ", "") if word not in words]) 32 | else: 33 | for sentence in sentences: 34 | cut_word = jieba.lcut(sentence.strip().replace(" ", "")) 35 | words.extend([w for w in cut_word if w not in words]) 36 | return words 37 | 38 | 39 | def get_bert_embed(self, src_path, vec_save_path, char=False): 40 | words = self.get_data(src_path, char) 41 | words.append("") 42 | words.append("") 43 | words.append("") 44 | words.append("") 45 | # 字向量 46 | if char: 47 | file_char = open(vec_save_path, "a+", encoding="utf-8") 48 | file_char.write(str(len(words)) + " " + "768" + "\n") 49 | for word in tqdm(words, desc="编码字向量:"): 50 | inputs = self.token.encode_plus(word, padding="max_length", truncation=True, max_length=10, 51 | add_special_tokens=True, 52 | return_tensors="pt") 53 | out = self.bert(**inputs) 54 | out = out[0].detach().numpy().tolist() 55 | out_str = " ".join("%s" % embed for embed in out[0][1]) 56 | embed_out = word + " " + out_str + "\n" 57 | file_char.write(embed_out) 58 | file_char.close() 59 | else: 60 | file_word = open(vec_save_path, "a+", encoding="utf-8") 61 | file_word.write(str(len(words)) + " " + "768" + "\n") 62 | # 词向量 (采用字向量累加求均值) 63 | for word in tqdm(words, desc="编码词向量:"): 64 | words_embed = np.zeros(768) # bert tensor shape is 768 65 | inputs = self.token.encode_plus(word, padding="max_length", truncation=True, max_length=50, add_special_tokens=True, 66 | return_tensors="pt") 67 | out = self.bert(**inputs) 68 | word_len = len(word) 69 | out_ = out[0].detach().numpy() 70 | for i in range(1, word_len + 1): 71 | out_str = out_[0][i] 72 | words_embed += out_str 73 | words_embed = words_embed / word_len 74 | words_embedding = words_embed.tolist() 75 | result = word + " " + " ".join("%s" % embed for embed in words_embedding) + "\n" 76 | file_word.write(result) 77 | 78 | file_word.close() 79 | 80 | @staticmethod 81 | def get_w2v_weight(embedding_size, vec_path, word2id_path, id2word_path): 82 | w2v_model = KeyedVectors.load_word2vec_format(vec_path, binary=False) 83 | 84 | word2id = load_pickle_obj(word2id_path) 85 | id2word = load_pickle_obj(id2word_path) 86 | vocab_size = len(word2id) 87 | embedding_size = embedding_size 88 | weight = torch.zeros(vocab_size, embedding_size) 89 | for i in range(len(w2v_model.index2word)): 90 | try: 91 | index = word2id[w2v_model.index2word[i]] 92 | except: 93 | continue 94 | weight[index, :] = torch.from_numpy(w2v_model.get_vector(id2word[word2id[w2v_model.index2word[i]]])) 95 | 96 | return weight 97 | 98 | -------------------------------------------------------------------------------- /src/model/bilstm_crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .bilstm import BiLSTM 4 | from itertools import zip_longest 5 | 6 | class BiLSTM_CRF(nn.Module): 7 | def __init__(self, vocab_size, emb_size, hidden_size, out_size, dropout, use_pretrained_w2v): 8 | super(BiLSTM_CRF, self).__init__() 9 | self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size, dropout, use_pretrained_w2v) 10 | self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size) 11 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | def forward(self, sents_tensor, lengths): 14 | emission = self.bilstm(sents_tensor, lengths).to(self.device) 15 | batch_size, max_len, out_size = emission.size() 16 | crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0) 17 | return crf_scores 18 | 19 | def predict(self, test_sents_tensor, lengths, tag2id): 20 | start_id = tag2id[''] 21 | end_id = tag2id[''] 22 | pad = tag2id[''] 23 | tagset_size = len(tag2id) 24 | 25 | crf_scores = self.forward(test_sents_tensor, lengths) 26 | 27 | B , L , T, _ = crf_scores.size() 28 | 29 | viterbi = torch.zeros(B, L, T).to(self.device) 30 | backpointer = (torch.zeros(B, L, T).long() * end_id).to(self.device) 31 | 32 | lengths = torch.LongTensor(lengths).to(self.device) 33 | 34 | for step in range(L): 35 | batch_size_t =(lengths > step).sum().item() 36 | if step == 0: 37 | viterbi[:batch_size_t, step, :] = crf_scores[:batch_size_t, step, start_id, :] 38 | backpointer[:batch_size_t, step, :] = start_id 39 | else: 40 | max_scores, prev_tags = torch.max(viterbi[:batch_size_t, step-1, :].unsqueeze(2) + crf_scores[:batch_size_t, step, :, :], dim=1) 41 | viterbi[:batch_size_t, step, :] = max_scores 42 | backpointer[:batch_size_t, step, :] = prev_tags 43 | 44 | backpointer = backpointer.view(B, -1) 45 | tagids = [] 46 | tags_t = None 47 | for step in range(L-1, 0, -1): 48 | batch_size_t = (lengths > step).sum().item() 49 | if step == L-1: 50 | index = torch.ones(batch_size_t).long() * (step * tagset_size) 51 | index = index.to(self.device) 52 | index += end_id 53 | else: 54 | prev_batch_size_t = len(tags_t) 55 | new_in_batch = torch.LongTensor([end_id] * (batch_size_t - prev_batch_size_t)).to(self.device) 56 | offset = torch.cat([tags_t, new_in_batch], dim=0) 57 | index = torch.ones(batch_size_t).long() * (step *tagset_size) 58 | index = index.to(self.device) 59 | index += offset.long() 60 | 61 | tags_t = backpointer[:batch_size_t].gather(dim=1, index=index.unsqueeze(1).long()) 62 | tags_t = tags_t.squeeze(1) 63 | tagids.append(tags_t.tolist()) 64 | tagids = list(zip_longest(*reversed(tagids), fillvalue=pad)) 65 | tagids = torch.Tensor(tagids).long() 66 | 67 | return tagids 68 | 69 | 70 | def cal_bilstm_crf_loss(crf_scores, targets, tag2id): 71 | pad_id = tag2id.get('') 72 | start_id = tag2id.get('') 73 | end_id = tag2id.get('') 74 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 75 | batch_size, max_len = targets.size() 76 | target_size = len(tag2id) 77 | mask = (targets != pad_id) 78 | lengths = mask.sum(dim=1) 79 | targets = indexed(targets, target_size, start_id) 80 | targets = targets.masked_select(mask) 81 | flatten_scores = crf_scores.masked_select( 82 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores) 83 | ).view(-1, target_size*target_size).contiguous() 84 | golden_scores = flatten_scores.gather( 85 | dim=1, index=targets.unsqueeze(1)).sum() 86 | scores_upto_t = torch.zeros(batch_size, target_size).to(device) 87 | for t in range(max_len): 88 | batch_size_t = (lengths > t).sum().item() 89 | if t == 0: 90 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, 91 | t, start_id, :] 92 | else: 93 | scores_upto_t[:batch_size_t] = torch.logsumexp( 94 | crf_scores[:batch_size_t, t, :, :] + 95 | scores_upto_t[:batch_size_t].unsqueeze(2), 96 | dim=1 97 | ) 98 | all_path_scores = scores_upto_t[:, end_id].sum() 99 | loss = (all_path_scores - golden_scores) / batch_size 100 | return loss 101 | 102 | def indexed(targets, tagset_size, start_id): 103 | batch_size, max_len = targets.size() 104 | for col in range(max_len-1, 0, -1): 105 | targets[:, col] += (targets[:, col-1] * tagset_size) 106 | targets[:, 0] += (start_id * tagset_size) 107 | return targets -------------------------------------------------------------------------------- /src/train/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | from data.data_process import load_data 5 | from data.data_process import get_word2id 6 | from data.data_process import extend_vocab 7 | from data.data_process import add_end_token 8 | from data.path import get_train_data_path 9 | from data.path import get_eval_data_path 10 | from data.path import get_test_data_path 11 | from model.path import get_word2id_path 12 | from model.path import get_tag2id_path 13 | from model.path import get_id2word_path 14 | from model.path import get_pretrained_char_vec_path 15 | from tools.get_ner_level_acc import find_all_tag 16 | from tools.help import save_as_pickle 17 | from tools.help import load_pickle_obj 18 | from tools.get_pretrained_vec import GetPretrainedVec 19 | from .train_helper import NerModel 20 | 21 | class Train: 22 | def __init__(self) -> None: 23 | self.train_data_path = get_train_data_path() 24 | self.eval_data_path = get_eval_data_path() 25 | self.test_data_path = get_test_data_path() 26 | self.word2id_path = get_word2id_path() 27 | self.tag2id_path = get_tag2id_path() 28 | self.id2word_path = get_id2word_path() 29 | self.vec_path = get_pretrained_char_vec_path() 30 | self.word2id = None 31 | self.tag2id = None 32 | self.id2word = None 33 | self.get_pretrained_vec = GetPretrainedVec() 34 | 35 | def load(self): 36 | self.get_pretrained_vec.load() 37 | 38 | def prepare_data(self): 39 | self.train_word_lists, self.train_tag_lists = load_data(self.train_data_path) 40 | self.eval_word_lists, self.eval_tag_list = load_data(self.eval_data_path) 41 | self.test_word_lists, self.test_tag_list = load_data(self.test_data_path) 42 | 43 | self.word2id = get_word2id(self.train_word_lists) 44 | self.tag2id = get_word2id(self.train_tag_lists) 45 | self.word2id, self.tag2id = extend_vocab(self.word2id, self.tag2id) 46 | 47 | self.id2word = {self.word2id[w]: w for w in self.word2id} 48 | 49 | save_as_pickle(self.word2id, self.word2id_path) 50 | save_as_pickle(self.tag2id, self.tag2id_path) 51 | save_as_pickle(self.id2word, self.id2word_path) 52 | 53 | if not os.path.exists(self.vec_path): 54 | print('用 BERT 生成预训练向量') 55 | self.get_pretrained_vec.get_bert_embed(self.train_data_path, self.vec_path, char=True) 56 | 57 | self.train_word_lists, self.train_tag_lists = add_end_token(self.train_word_lists, self.train_tag_lists) 58 | 59 | self.eval_word_lists, self.eval_tag_list = add_end_token(self.eval_word_lists, self.eval_tag_list) 60 | 61 | self.test_word_lists, self.test_tag_list = add_end_token(self.test_word_lists, self.test_tag_list, test=True) 62 | 63 | return (self.train_word_lists, self.train_tag_lists, 64 | self.eval_word_lists, self.eval_tag_list, 65 | self.test_word_lists, self.test_tag_list) 66 | 67 | 68 | def train(self, use_pretrained_w2v=False, model_type="bilstm-crf"): 69 | self.get_pretrained_vec.load() 70 | train_word_lists, train_tag_lists, dev_word_lists, dev_tag_lists, test_word_lists, test_tag_lists = self.prepare_data() 71 | 72 | word2id = load_pickle_obj(self.word2id_path) 73 | tag2id = load_pickle_obj(self.tag2id_path) 74 | 75 | print(f"tag2id: {tag2id}") 76 | vocab_size = len(word2id) 77 | out_size = len(tag2id) 78 | ner_model = NerModel(vocab_size, out_size, use_pretrained_w2v=use_pretrained_w2v, model_type=model_type) 79 | print(f"vocab_size: {vocab_size}, out_size: {out_size}") 80 | print("start to train the {} model ...".format(model_type)) 81 | 82 | ner_model.train(train_word_lists, train_tag_lists, dev_word_lists, dev_tag_lists, test_word_lists, test_tag_lists, word2id, tag2id) 83 | 84 | 85 | def predict(self, text, use_pretrained_w2v, model_type): 86 | word2id = load_pickle_obj(self.word2id_path) 87 | tag2id = load_pickle_obj(self.tag2id_path) 88 | vocab_size = len(word2id) 89 | out_size = len(tag2id) 90 | ner_model = NerModel(vocab_size, out_size, use_pretrained_w2v=use_pretrained_w2v, model_type=model_type) 91 | result = ner_model.predict(text) 92 | return result 93 | 94 | def get_ner_list_dic(self, text, use_pretrained_w2v, model_type): 95 | 96 | text_list = list(text) 97 | tag_list = self.predict(text, use_pretrained_w2v, model_type)[0] 98 | tag_dic = find_all_tag(tag_list) 99 | 100 | print("tag_dic: ", tag_dic) 101 | 102 | result_dic = {} 103 | for name in tag_dic: 104 | for x in tag_dic[name]: 105 | if result_dic.get(name) is None: 106 | result_dic[name] = [] 107 | if x: 108 | ner_name = ''.join(text_list[x[0]:x[0]+x[1]]) 109 | result_dic[name].append(ner_name) 110 | for name in result_dic: 111 | result_dic[name] = list(set(result_dic[name])) 112 | 113 | return result_dic -------------------------------------------------------------------------------- /src/model/bert_bilstm_crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from itertools import zip_longest 5 | from transformers import BeitConfig, BertModel 6 | from .path import get_chinese_wwm_ext_pytorch_path 7 | 8 | class BertBiLstmCrf(nn.Module): 9 | def __init__(self, vocab_size, emb_size, hidden_size, out_size, drop_out=0.1, use_pretrained_w2v=False): 10 | super(BertBiLstmCrf, self).__init__() 11 | self.bert_path = get_chinese_wwm_ext_pytorch_path() 12 | self.bert_config = BeitConfig.from_pretrained(self.bert_path) 13 | self.bert = BertModel.from_pretrained(self.bert_path) 14 | emb_size = 768 15 | for param in self.bert.parameters(): 16 | param.requires_grad = True 17 | self.bilstm = nn.LSTM(emb_size, hidden_size, batch_first=True, bidirectional=True) 18 | self.fc = nn.Linear(hidden_size*2, out_size) 19 | self.dropout = nn.Dropout(drop_out) 20 | self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size) 21 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | def forward(self, x, lengths): 24 | emb = self.bert(x)[0] 25 | emb = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True) 26 | emb, _ = self.bilstm(emb) 27 | output, _ = nn.utils.rnn.pad_packed_sequence(emb, batch_first=True, padding_value=0., total_length=x.shape[1]) 28 | output = self.dropout(output) 29 | emission = self.fc(output) 30 | batch_size, max_len, out_size = emission.size() 31 | crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0) 32 | return crf_scores 33 | 34 | def predict(self, test_sents_tensor, lengths, tag2id): 35 | start_id = tag2id[''] 36 | end_id = tag2id[''] 37 | pad = tag2id[''] 38 | tagset_size = len(tag2id) 39 | 40 | crf_scores = self.forward(test_sents_tensor, lengths) 41 | 42 | B , L , T, _ = crf_scores.size() 43 | 44 | viterbi = torch.zeros(B, L, T).to(self.device) 45 | backpointer = (torch.zeros(B, L, T).long() * end_id).to(self.device) 46 | 47 | lengths = torch.LongTensor(lengths).to(self.device) 48 | 49 | for step in range(L): 50 | batch_size_t =(lengths > step).sum().item() 51 | if step == 0: 52 | viterbi[:batch_size_t, step, :] = crf_scores[:batch_size_t, step, start_id, :] 53 | backpointer[:batch_size_t, step, :] = start_id 54 | else: 55 | max_scores, prev_tags = torch.max(viterbi[:batch_size_t, step-1, :].unsqueeze(2) + crf_scores[:batch_size_t, step, :, :], dim=1) 56 | viterbi[:batch_size_t, step, :] = max_scores 57 | backpointer[:batch_size_t, step, :] = prev_tags 58 | 59 | backpointer = backpointer.view(B, -1) 60 | tagids = [] 61 | tags_t = None 62 | for step in range(L-1, 0, -1): 63 | batch_size_t = (lengths > step).sum().item() 64 | if step == L-1: 65 | index = torch.ones(batch_size_t).long() * (step * tagset_size) 66 | index = index.to(self.device) 67 | index += end_id 68 | else: 69 | prev_batch_size_t = len(tags_t) 70 | new_in_batch = torch.LongTensor([end_id] * (batch_size_t - prev_batch_size_t)).to(self.device) 71 | offset = torch.cat([tags_t, new_in_batch], dim=0) 72 | index = torch.ones(batch_size_t).long() * (step *tagset_size) 73 | index = index.to(self.device) 74 | index += offset.long() 75 | 76 | tags_t = backpointer[:batch_size_t].gather(dim=1, index=index.unsqueeze(1).long()) 77 | tags_t = tags_t.squeeze(1) 78 | tagids.append(tags_t.tolist()) 79 | tagids = list(zip_longest(*reversed(tagids), fillvalue=pad)) 80 | tagids = torch.Tensor(tagids).long() 81 | 82 | return tagids 83 | 84 | 85 | def cal_bert_bilstm_crf_loss(crf_scores, targets, tag2id): 86 | pad_id = tag2id.get('') 87 | start_id = tag2id.get('') 88 | end_id = tag2id.get('') 89 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 90 | batch_size, max_len = targets.size() 91 | target_size = len(tag2id) 92 | mask = (targets != pad_id) 93 | lengths = mask.sum(dim=1) 94 | targets = indexed(targets, target_size, start_id) 95 | targets = targets.masked_select(mask) 96 | flatten_scores = crf_scores.masked_select( 97 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores) 98 | ).view(-1, target_size*target_size).contiguous() 99 | golden_scores = flatten_scores.gather( 100 | dim=1, index=targets.unsqueeze(1)).sum() 101 | scores_upto_t = torch.zeros(batch_size, target_size).to(device) 102 | for t in range(max_len): 103 | batch_size_t = (lengths > t).sum().item() 104 | if t == 0: 105 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, 106 | t, start_id, :] 107 | else: 108 | scores_upto_t[:batch_size_t] = torch.logsumexp( 109 | crf_scores[:batch_size_t, t, :, :] + 110 | scores_upto_t[:batch_size_t].unsqueeze(2), 111 | dim=1 112 | ) 113 | all_path_scores = scores_upto_t[:, end_id].sum() 114 | loss = (all_path_scores - golden_scores) / batch_size 115 | return loss 116 | 117 | def indexed(targets, tagset_size, start_id): 118 | batch_size, max_len = targets.size() 119 | for col in range(max_len-1, 0, -1): 120 | targets[:, col] += (targets[:, col-1] * tagset_size) 121 | targets[:, 0] += (start_id * tagset_size) 122 | return targets -------------------------------------------------------------------------------- /src/train/train_helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from copy import deepcopy 8 | from tqdm import tqdm, trange 9 | from torch.optim.lr_scheduler import ExponentialLR 10 | from model.bert_bilstm_crf import BertBiLstmCrf, cal_bert_bilstm_crf_loss 11 | from model.bilstm import BiLSTM, cal_bilstm_loss 12 | from model.bilstm_crf import BiLSTM_CRF, cal_bilstm_crf_loss 13 | from tools.get_ner_level_acc import precision 14 | from tools.help import load_pickle_obj, sort_by_lengths, batch_sents_to_tensorized 15 | from model.path import get_model_dir, get_tag2id_path, get_word2id_path 16 | 17 | torch.manual_seed(1234) 18 | 19 | class NerModel(object): 20 | def __init__(self, vocab_size, out_size, use_pretrained_w2v=False, model_type="bilstm-crf"): 21 | self.model_dir = get_model_dir() 22 | self.model_type_dir = os.path.join(self.model_dir, model_type) 23 | if not os.path.exists(self.model_type_dir): 24 | os.makedirs(self.model_type_dir) 25 | self.model_type = model_type 26 | self.vocab_size = vocab_size 27 | self.out_size = out_size 28 | self.batch_size = 64 29 | self.lr = 0.01 30 | self.epoches = 20 31 | self.print_step = 20 32 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | print(f"使用 : {self.device} ...") 34 | self.emb_size = 768 35 | self.hidden_size = 256 36 | self.dropout = 0.5 37 | self.use_pretrained_w2v = use_pretrained_w2v 38 | if self.model_type == "bilstm-crf": 39 | self.model = BiLSTM_CRF(self.vocab_size, self.emb_size, self.hidden_size, self.out_size, self.dropout, self.use_pretrained_w2v) 40 | self.loss_cal_fun = cal_bilstm_crf_loss 41 | elif self.model_type == "bilstm": 42 | self.model = BiLSTM(self.vocab_size, self.emb_size, self.hidden_size, self.out_size, self.dropout, self.use_pretrained_w2v) 43 | self.loss_cal_fun = cal_bilstm_loss 44 | elif self.model_type == "bert-bilstm-crf": 45 | self.model = BertBiLstmCrf(self.vocab_size, self.emb_size, self.hidden_size, self.out_size, self.dropout, self.use_pretrained_w2v) 46 | self.loss_cal_fun = cal_bert_bilstm_crf_loss 47 | 48 | self.model.to(self.device) 49 | 50 | self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr, weight_decay=0.005) 51 | self.scheduler = ExponentialLR(self.optimizer, gamma = 0.8) 52 | 53 | self.step = 0 54 | self.best_val_loss = 1e18 55 | if self.use_pretrained_w2v: 56 | model_name = f"{self.model_type}-pretrained.pt" 57 | else: 58 | model_name = f"{self.model_type}.pt" 59 | self.model_save_path = os.path.join(self.model_type_dir, model_name) 60 | 61 | def train(self, train_word_lists, train_tag_lists, dev_word_lists, dev_tag_lists, test_word_lists, test_tag_lists, word2id, tag2id): 62 | train_word_lists, train_tag_lists, _ = sort_by_lengths(train_word_lists, train_tag_lists) 63 | dev_word_lists, dev_tag_lists, _ = sort_by_lengths(dev_word_lists, dev_tag_lists) 64 | 65 | total_step = (len(train_word_lists)//self.batch_size + 1) 66 | 67 | epoch_iterator = trange(1, self.epoches + 1, desc="Epoch") 68 | for epoch in epoch_iterator: 69 | self.step = 0 70 | loss_sum = 0. 71 | for idx in trange(0, len(train_word_lists), self.batch_size, desc="Iteration:"): 72 | batch_sents = train_word_lists[idx : idx + self.batch_size] 73 | batch_tags = train_tag_lists[idx : idx + self.batch_size] 74 | loss_sum += self.train_step(batch_sents, batch_tags, word2id, tag2id) 75 | if self.step == total_step: 76 | print("\nEpoch {}, step/total_step: {}/{} {:.2f}% Loss:{:.4f}".format( 77 | epoch, self.step, total_step, 78 | 100. * self.step / total_step, 79 | loss_sum / self.print_step 80 | )) 81 | loss_sum = 0. 82 | self.validate(epoch, dev_word_lists, dev_tag_lists, word2id, tag2id) 83 | self.scheduler.step() 84 | if epoch > 10 and self.model_type != "bert-bilstm-crf": 85 | self.test(test_word_lists, test_tag_lists, word2id, tag2id) 86 | elif epoch > 15: 87 | self.test(test_word_lists, test_tag_lists, word2id, tag2id) 88 | 89 | 90 | def train_step(self, batch_sents, batch_tags, word2id, tag2id): 91 | self.model.train() 92 | self.step+=1 93 | batch_sents_tensor, sents_lengths = batch_sents_to_tensorized(batch_sents, word2id) 94 | labels_tensor, _ = batch_sents_to_tensorized(batch_tags, tag2id) 95 | 96 | batch_sents_tensor, labels_tensor = batch_sents_tensor.to(self.device), labels_tensor.to(self.device) 97 | scores = self.model(batch_sents_tensor, sents_lengths) 98 | 99 | self.model.zero_grad() 100 | loss = self.loss_cal_fun(scores, labels_tensor, tag2id) 101 | loss.backward() 102 | self.optimizer.step() 103 | 104 | return loss.item() 105 | 106 | 107 | def validate(self, epoch, dev_word_lists, dev_tag_lists, word2id, tag2id): 108 | self.model.eval() 109 | with torch.no_grad(): 110 | val_loss = 0. 111 | val_step = 0 112 | for idx in range(0, len(dev_word_lists), self.batch_size): 113 | 114 | val_step+=1 115 | batch_sents = dev_word_lists[idx : idx + self.batch_size] 116 | batch_tags = dev_tag_lists[idx : idx + self.batch_size] 117 | batch_sents_tensor, sents_lengths = batch_sents_to_tensorized(batch_sents, word2id) 118 | labels_tensor, _ = batch_sents_to_tensorized(batch_tags, tag2id) 119 | batch_sents_tensor, labels_tensor = batch_sents_tensor.to(self.device), labels_tensor.to(self.device) 120 | scores = self.model(batch_sents_tensor, sents_lengths) 121 | 122 | loss = self.loss_cal_fun(scores, labels_tensor, tag2id).item() 123 | 124 | val_loss += loss 125 | 126 | print(f"------epoch: {epoch}, val loss: {val_loss}") 127 | 128 | if val_loss < self.best_val_loss: 129 | self.best_val_loss = val_loss 130 | self.best_model = deepcopy(self.model) 131 | print(f"保存模型,path: {self.model_save_path}") 132 | torch.save(self.best_model.state_dict(), self.model_save_path) 133 | print(f"curren best val loss: {self.best_val_loss}") 134 | 135 | def test(self, test_word_lists, test_tag_lists, word2id, tag2id): 136 | test_word_lists, test_tag_lists, indices = sort_by_lengths(test_word_lists, test_tag_lists) 137 | batch_sents_tensor, sents_lengths = batch_sents_to_tensorized(test_word_lists, word2id) 138 | batch_sents_tensor = batch_sents_tensor.to(self.device) 139 | self.best_model.eval() 140 | with torch.no_grad(): 141 | batch_tagids = self.best_model.predict(batch_sents_tensor, sents_lengths, tag2id) 142 | pre_tag_lists = [] 143 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items()) 144 | for i, ids in enumerate(batch_tagids): 145 | tag_list = [] 146 | if self.model_type.find("bilstm-crf") != -1: 147 | for j in range(sents_lengths[i] - 1): 148 | tag_list.append(id2tag[ids[j].item()]) 149 | else: 150 | for j in range(sents_lengths[i]): 151 | tag_list.append(id2tag[ids[j].item()]) 152 | pre_tag_lists.append(tag_list) 153 | ind_maps = sorted(list(enumerate(indices)), key=lambda e: e[1]) 154 | indices, _ = list(zip(*ind_maps)) 155 | pre_tag_lists = [pre_tag_lists[i] for i in indices] 156 | tag_lists = [test_tag_lists[i] for i in indices] 157 | 158 | total_precision, result_dic = precision(pre_tag_lists, tag_lists) 159 | print(f"实体级准确率为: {total_precision}") 160 | print(f"各实体对应的准确率为: {json.dumps(result_dic, ensure_ascii=False, indent=4)}") 161 | 162 | 163 | def predict(self, text): 164 | text_list = list(text) 165 | if self.model_type.find("bilstm-crf") != -1: 166 | text_list.append("") 167 | text_list = [text_list] 168 | word2id_path = get_word2id_path() 169 | tag2id_path = get_tag2id_path() 170 | word2id = load_pickle_obj(word2id_path) 171 | tag2id = load_pickle_obj(tag2id_path) 172 | 173 | tensorized_sents, lengths = batch_sents_to_tensorized(text_list, word2id) 174 | tensorized_sents = tensorized_sents.to(self.device) 175 | model_path = self.model_save_path 176 | self.model.load_state_dict(torch.load(model_path)) 177 | self.model.eval() 178 | with torch.no_grad(): 179 | batch_tagids = self.model.predict(tensorized_sents, lengths, tag2id) 180 | pre_tag_lists = [] 181 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items()) 182 | for i, ids in enumerate(batch_tagids): 183 | tag_list = [] 184 | if self.model_type.find("bilstm-crf") != -1: 185 | for j in range(lengths[i] - 1): 186 | tag_list.append(id2tag[ids[j].item()]) 187 | else: 188 | for j in range(lengths[i]): 189 | tag_list.append(id2tag[ids[j].item()]) 190 | pre_tag_lists.append(tag_list) 191 | return pre_tag_lists --------------------------------------------------------------------------------