├── .gitignore ├── README.md ├── ResumeNER ├── dev.char.bmes ├── test.char.bmes └── train.char.bmes ├── ckpts ├── bilstm.pkl ├── bilstm_crf.pkl ├── crf.pkl └── hmm.pkl ├── data.py ├── evaluate.py ├── evaluating.py ├── imgs ├── biLSTM_NER.png ├── decode_crf.png ├── func_set.png ├── log_likehood_crf.png ├── log_linear_crf.png └── w_crf.png ├── main.py ├── models ├── __init__.py ├── bilstm.py ├── bilstm_crf.py ├── config.py ├── crf.py ├── hmm.py └── util.py ├── output.txt ├── requirement.txt ├── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | models/__pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 中文命名实体识别 2 | 3 | 4 | 5 | ## 数据集 6 | 7 | 本项目尝试使用了多种不同的模型(包括HMM,CRF,Bi-LSTM,Bi-LSTM+CRF)来解决中文命名实体识别问题,数据集用的是论文ACL 2018[Chinese NER using Lattice LSTM](https://github.com/jiesutd/LatticeLSTM)中收集的简历数据,数据的格式如下,它的每一行由一个字及其对应的标注组成,标注集采用BIOES,句子之间用一个空行隔开。 8 | 9 | ``` 10 | 美 B-LOC 11 | 国 E-LOC 12 | 的 O 13 | 华 B-PER 14 | 莱 I-PER 15 | 士 E-PER 16 | 17 | 我 O 18 | 跟 O 19 | 他 O 20 | 谈 O 21 | 笑 O 22 | 风 O 23 | 生 O 24 | ``` 25 | 26 | 该数据集就位于项目目录下的`ResumeNER`文件夹里。 27 | 28 | ## 运行结果 29 | 30 | 下面是四种不同的模型以及这Ensemble这四个模型预测结果的准确率(取最好): 31 | 32 | | | HMM | CRF | BiLSTM | BiLSTM+CRF | Ensemble | 33 | | ---- | ------ | ------ | ------ | ---------- | -------- | 34 | | 召回率 | 91.22% | 95.43% | 95.32% | 95.72% | 95.65% | 35 | | 准确率 | 91.49% | 95.43% | 95.37% | 95.74% | 95.69% | 36 | | F1分数 | 91.30% | 95.42% | 95.32% | 95.70% | 95.64% | 37 | 38 | 最后一列Ensemble是将这四个模型的预测结果结合起来,使用“投票表决”的方法得出最后的预测结果。 39 | 40 | (Ensemble的三个指标均不如BiLSTM+CRF,可以认为在Ensemble过程中,是其他三个模型拖累了BiLSTM+CRF) 41 | 42 | 具体的输出可以查看`output.txt`文件。 43 | 44 | 45 | 46 | ## 快速开始 47 | 48 | 首先安装依赖项: 49 | 50 | ``` 51 | pip3 install -r requirement.txt 52 | ``` 53 | 54 | 安装完毕之后,直接使用 55 | 56 | ``` 57 | python3 main.py 58 | ``` 59 | 60 | 即可训练以及评估模型,评估模型将会打印出模型的精确率、召回率、F1分数值以及混淆矩阵,如果想要修改相关模型参数或者是训练参数,可以在`./models/config.py`文件中进行设置。 61 | 62 | 训练完毕之后,如果想要加载并评估模型,运行如下命令: 63 | 64 | ```shell 65 | python3 test.py 66 | ``` 67 | 68 | 下面是这些模型的简单介绍(github网页对数学公式的支持不太好,涉及公式的部分无法正常显示,[我的博客](https://zhuanlan.zhihu.com/p/61227299) 有对这些模型以及代码更加详细的介绍): 69 | 70 | 71 | 72 | ## 隐马尔可夫模型(Hidden Markov Model,HMM) 73 | 74 | 隐马尔可夫模型描述由一个隐藏的马尔科夫链随机生成不可观测的状态随机序列,再由各个状态生成一个观测而产生观测随机序列的过程(李航 统计学习方法)。隐马尔可夫模型由初始状态分布,状态转移概率矩阵以及观测概率矩阵所确定。 75 | 76 | 命名实体识别本质上可以看成是一种序列标注问题,在使用HMM解决命名实体识别这种序列标注问题的时候,我们所能观测到的是字组成的序列(观测序列),观测不到的是每个字对应的标注(状态序列)。 77 | 78 | **初始状态分布**就是每一个标注的初始化概率,**状态转移概率矩阵**就是由某一个标注转移到下一个标注的概率(就是若前一个词的标注为$tag_i$ ,则下一个词的标注为$tag_j$的概率为 $M_{ij}$),**观测概率矩阵**就是指在 79 | 80 | 某个标注下,生成某个词的概率。 81 | 82 | HMM模型的训练过程对应隐马尔可夫模型的学习问题(李航 统计学习方法), 83 | 84 | 实际上就是根据训练数据根据最大似然的方法估计模型的三个要素,即上文提到的初始状态分布、状态转移概率矩阵以及观测概率矩阵,模型训练完毕之后,利用模型进行解码,即对给定观测序列,求它对应的状态序列,这里就是对给定的句子,求句子中的每个字对应的标注,针对这个解码问题,我们使用的是维特比(viterbi)算法。 85 | 86 | 具体的细节可以查看 `models/hmm.py`文件。 87 | 88 | 89 | 90 | 91 | 92 | ## 条件随机场(Conditional Random Field, CRF) 93 | 94 | HMM模型中存在两个假设,一是输出观察值之间严格独立,二是状态转移过程中当前状态只与前一状态有关。也就是说,在命名实体识别的场景下,HMM认为观测到的句子中的每个字都是相互独立的,而且当前时刻的标注只与前一时刻的标注相关。但实际上,命名实体识别往往需要更多的特征,比如词性,词的上下文等等,同时当前时刻的标注应该与前一时刻以及后一时刻的标注都相关联。由于这两个假设的存在,显然HMM模型在解决命名实体识别的问题上是存在缺陷的。 95 | 96 | 条件随机场通过引入自定义的特征函数,不仅可以表达观测之间的依赖,还可表示当前观测与前后多个状态之间的复杂依赖,可以有效克服HMM模型面临的问题。 97 | 98 | 为了建立一个条件随机场,我们首先要定义一个特征函数集,该函数集内的每个特征函数都以标注序列作为输入,提取的特征作为输出。假设该函数集为: 99 | 100 | ![函数集](./imgs/func_set.png) 101 | 102 | 其中$x=(x_1, ..., x_m)$表示观测序列,$s = (s_1, ...., s_m)$表示状态序列。然后,条件随机场使用对数线性模型来计算给定观测序列下状态序列的条件概率: 103 | 104 | ![log_linear_crf](./imgs/log_linear_crf.png) 105 | 106 | 其中$s^{'}$是是所有可能的状态序列,$w$是条件随机场模型的参数,可以把它看成是每个特征函数的权重。CRF模型的训练其实就是对参数$w$的估计。假设我们有$n$个已经标注好的数据$\{(x^i, s^i)\}_{i=1}^n$, 107 | 108 | 则其对数似然函数的正则化形式如下: 109 | 110 | ![log_likehood_crf](./imgs/log_likehood_crf.png) 111 | 112 | 那么,最优参数$w^*$就是: 113 | 114 | ![w_crf](./imgs/w_crf.png) 115 | 116 | 模型训练结束之后,对给定的观测序列$x$,它对应的最优状态序列应该是: 117 | 118 | ![decode_crf](./imgs/decode_crf.png) 119 | 120 | 解码的时候与HMM类似,也可以采用维特比算法。 121 | 122 | 具体的细节可以查看 `models/crf.py`文件。 123 | 124 | 125 | 126 | 127 | 128 | ## Bi-LSTM 129 | 130 | 除了以上两种基于概率图模型的方法,LSTM也常常被用来解决序列标注问题。和HMM、CRF不同的是,LSTM是依靠神经网络超强的非线性拟合能力,在训练时将样本通过高维空间中的复杂非线性变换,学习到从样本到标注的函数,之后使用这个函数为指定的样本预测每个token的标注。下方就是使用双向LSTM(双向能够更好的捕捉序列之间的依赖关系)进行序列标注的示意图: 131 | 132 | 133 | 134 | ![biLSTM_NER](./imgs/biLSTM_NER.png) 135 | 136 | 137 | 138 | 基于双向LSTM的序列标注模型实现可以查看`models/bilstm.py`文件。 139 | 140 | 141 | 142 | ## Bi-LSTM+CRF 143 | 144 | LSTM的优点是能够通过双向的设置学习到观测序列(输入的字)之间的依赖,在训练过程中,LSTM能够根据目标(比如识别实体)自动提取观测序列的特征,但是缺点是无法学习到状态序列(输出的标注)之间的关系,要知道,在命名实体识别任务中,标注之间是有一定的关系的,比如B类标注(表示某实体的开头)后面不会再接一个B类标注,所以LSTM在解决NER这类序列标注任务时,虽然可以省去很繁杂的特征工程,但是也存在无法学习到标注上下文的缺点。 145 | 146 | 相反,CRF的优点就是能对隐含状态建模,学习状态序列的特点,但它的缺点是需要手动提取序列特征。所以一般的做法是,在LSTM后面再加一层CRF,以获得两者的优点。 147 | 148 | 具体的实现请查看`models/bilstm_crf.py` 149 | 150 | 151 | 152 | ## 代码中一些需要注意的点 153 | 154 | * HMM模型中要处理OOV(Out of vocabulary)的问题,就是测试集里面有些字是不在训练集里面的, 155 | 这个时候通过观测概率矩阵是无法查询到OOV对应的各种状态的概率的,处理这个问题可以将OOV对应的状态的概率分布设为均匀分布。 156 | * HMM的三个参数(即状态转移概率矩阵、观测概率矩阵以及初始状态概率矩阵)在使用监督学习方法进行估计的过程中,如果有些项从未出现,那么该项对应的位置就为0,而在使用维特比算法进行解码的时候,计算过程需要将这些值相乘,那么如果其中有为0的项,那么整条路径的概率也变成0了。此外,解码过程中多个小概率相乘很可能出现下溢的情况,为了解决这两个问题,我们给那些从未出现过的项赋予一个很小的数(如0.00000001),同时在进行解码的时候将模型的三个参数都映射到对数空间,这样既可以避免下溢,又可以简化乘法运算。 157 | * CRF中将训练数据以及测试数据作为模型的输入之前,都需要先用特征函数提取特征! 158 | * Bi-LSTM+CRF模型可以参考:[Neural Architectures for Named Entity Recognition](https://arxiv.org/pdf/1603.01360.pdf),可以重点看一下里面的损失函数的定义。代码里面关于损失函数的计算采用的是类似动态规划的方法,不是很好理解,这里推荐看一下以下这些博客: 159 | 160 | * [CRF Layer on the Top of BiLSTM - 5](https://createmomo.github.io/2017/11/11/CRF-Layer-on-the-Top-of-BiLSTM-5/) 161 | * [Bi-LSTM-CRF for Sequence Labeling PENG](https://zhuanlan.zhihu.com/p/27338210) 162 | * [Pytorch Bi-LSTM + CRF 代码详解](https://blog.csdn.net/cuihuijun1hao/article/details/79405740) 163 | 164 | 165 | 166 | 167 | ## TODO 168 | 169 | * BI-LSTM+CRF 比起Bi-LSTM效果并没有好很多,一种可能的解释是: 170 | - 数据集太小,不足够让模型学习到转移矩阵(后续尝试在更大的数据集上测试一下结果) 171 | * 尝试更加复杂的模型,参考论文[Chinese NER using Lattice LSTM](https://github.com/jiesutd/LatticeLSTM) 172 | * 更详细的评估结果:打印混淆矩阵,同时输出每种类别的召回率、准确率、F1指标,便于分析。 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /ckpts/bilstm.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/ckpts/bilstm.pkl -------------------------------------------------------------------------------- /ckpts/bilstm_crf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/ckpts/bilstm_crf.pkl -------------------------------------------------------------------------------- /ckpts/crf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/ckpts/crf.pkl -------------------------------------------------------------------------------- /ckpts/hmm.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/ckpts/hmm.pkl -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from codecs import open 3 | 4 | 5 | def build_corpus(split, make_vocab=True, data_dir="./ResumeNER"): 6 | """读取数据""" 7 | assert split in ['train', 'dev', 'test'] 8 | 9 | word_lists = [] 10 | tag_lists = [] 11 | with open(join(data_dir, split+".char.bmes"), 'r', encoding='utf-8') as f: 12 | word_list = [] 13 | tag_list = [] 14 | for line in f: 15 | if line != '\n': 16 | word, tag = line.strip('\n').split() 17 | word_list.append(word) 18 | tag_list.append(tag) 19 | else: 20 | word_lists.append(word_list) 21 | tag_lists.append(tag_list) 22 | word_list = [] 23 | tag_list = [] 24 | 25 | # 如果make_vocab为True,还需要返回word2id和tag2id 26 | if make_vocab: 27 | word2id = build_map(word_lists) 28 | tag2id = build_map(tag_lists) 29 | return word_lists, tag_lists, word2id, tag2id 30 | else: 31 | return word_lists, tag_lists 32 | 33 | 34 | def build_map(lists): 35 | maps = {} 36 | for list_ in lists: 37 | for e in list_: 38 | if e not in maps: 39 | maps[e] = len(maps) 40 | 41 | return maps 42 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import Counter 3 | 4 | from models.hmm import HMM 5 | from models.crf import CRFModel 6 | from models.bilstm_crf import BILSTM_Model 7 | from utils import save_model, flatten_lists 8 | from evaluating import Metrics 9 | 10 | 11 | def hmm_train_eval(train_data, test_data, word2id, tag2id, remove_O=False): 12 | """训练并评估hmm模型""" 13 | # 训练HMM模型 14 | train_word_lists, train_tag_lists = train_data 15 | test_word_lists, test_tag_lists = test_data 16 | 17 | hmm_model = HMM(len(tag2id), len(word2id)) 18 | hmm_model.train(train_word_lists, 19 | train_tag_lists, 20 | word2id, 21 | tag2id) 22 | save_model(hmm_model, "./ckpts/hmm.pkl") 23 | 24 | # 评估hmm模型 25 | pred_tag_lists = hmm_model.test(test_word_lists, 26 | word2id, 27 | tag2id) 28 | 29 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O) 30 | metrics.report_scores() 31 | metrics.report_confusion_matrix() 32 | 33 | return pred_tag_lists 34 | 35 | 36 | def crf_train_eval(train_data, test_data, remove_O=False): 37 | 38 | # 训练CRF模型 39 | train_word_lists, train_tag_lists = train_data 40 | test_word_lists, test_tag_lists = test_data 41 | 42 | crf_model = CRFModel() 43 | crf_model.train(train_word_lists, train_tag_lists) 44 | save_model(crf_model, "./ckpts/crf.pkl") 45 | 46 | pred_tag_lists = crf_model.test(test_word_lists) 47 | 48 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O) 49 | metrics.report_scores() 50 | metrics.report_confusion_matrix() 51 | 52 | return pred_tag_lists 53 | 54 | 55 | def bilstm_train_and_eval(train_data, dev_data, test_data, 56 | word2id, tag2id, crf=True, remove_O=False): 57 | train_word_lists, train_tag_lists = train_data 58 | dev_word_lists, dev_tag_lists = dev_data 59 | test_word_lists, test_tag_lists = test_data 60 | 61 | start = time.time() 62 | vocab_size = len(word2id) 63 | out_size = len(tag2id) 64 | bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf) 65 | bilstm_model.train(train_word_lists, train_tag_lists, 66 | dev_word_lists, dev_tag_lists, word2id, tag2id) 67 | 68 | model_name = "bilstm_crf" if crf else "bilstm" 69 | save_model(bilstm_model, "./ckpts/"+model_name+".pkl") 70 | 71 | print("训练完毕,共用时{}秒.".format(int(time.time()-start))) 72 | print("评估{}模型中...".format(model_name)) 73 | pred_tag_lists, test_tag_lists = bilstm_model.test( 74 | test_word_lists, test_tag_lists, word2id, tag2id) 75 | 76 | metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O) 77 | metrics.report_scores() 78 | metrics.report_confusion_matrix() 79 | 80 | return pred_tag_lists 81 | 82 | 83 | def ensemble_evaluate(results, targets, remove_O=False): 84 | """ensemble多个模型""" 85 | for i in range(len(results)): 86 | results[i] = flatten_lists(results[i]) 87 | 88 | pred_tags = [] 89 | for result in zip(*results): 90 | ensemble_tag = Counter(result).most_common(1)[0][0] 91 | pred_tags.append(ensemble_tag) 92 | 93 | targets = flatten_lists(targets) 94 | assert len(pred_tags) == len(targets) 95 | 96 | print("Ensemble 四个模型的结果如下:") 97 | metrics = Metrics(targets, pred_tags, remove_O=remove_O) 98 | metrics.report_scores() 99 | metrics.report_confusion_matrix() 100 | -------------------------------------------------------------------------------- /evaluating.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from utils import flatten_lists 4 | 5 | 6 | class Metrics(object): 7 | """用于评价模型,计算每个标签的精确率,召回率,F1分数""" 8 | 9 | def __init__(self, golden_tags, predict_tags, remove_O=False): 10 | 11 | # [[t1, t2], [t3, t4]...] --> [t1, t2, t3, t4...] 12 | self.golden_tags = flatten_lists(golden_tags) 13 | self.predict_tags = flatten_lists(predict_tags) 14 | 15 | if remove_O: # 将O标记移除,只关心实体标记 16 | self._remove_Otags() 17 | 18 | # 辅助计算的变量 19 | self.tagset = set(self.golden_tags) 20 | self.correct_tags_number = self.count_correct_tags() 21 | self.predict_tags_counter = Counter(self.predict_tags) 22 | self.golden_tags_counter = Counter(self.golden_tags) 23 | 24 | # 计算精确率 25 | self.precision_scores = self.cal_precision() 26 | 27 | # 计算召回率 28 | self.recall_scores = self.cal_recall() 29 | 30 | # 计算F1分数 31 | self.f1_scores = self.cal_f1() 32 | 33 | def cal_precision(self): 34 | 35 | precision_scores = {} 36 | for tag in self.tagset: 37 | precision_scores[tag] = self.correct_tags_number.get(tag, 0) / \ 38 | self.predict_tags_counter[tag] 39 | 40 | return precision_scores 41 | 42 | def cal_recall(self): 43 | 44 | recall_scores = {} 45 | for tag in self.tagset: 46 | recall_scores[tag] = self.correct_tags_number.get(tag, 0) / \ 47 | self.golden_tags_counter[tag] 48 | return recall_scores 49 | 50 | def cal_f1(self): 51 | f1_scores = {} 52 | for tag in self.tagset: 53 | p, r = self.precision_scores[tag], self.recall_scores[tag] 54 | f1_scores[tag] = 2*p*r / (p+r+1e-10) # 加上一个特别小的数,防止分母为0 55 | return f1_scores 56 | 57 | def report_scores(self): 58 | """将结果用表格的形式打印出来,像这个样子: 59 | 60 | precision recall f1-score support 61 | B-LOC 0.775 0.757 0.766 1084 62 | I-LOC 0.601 0.631 0.616 325 63 | B-MISC 0.698 0.499 0.582 339 64 | I-MISC 0.644 0.567 0.603 557 65 | B-ORG 0.795 0.801 0.798 1400 66 | I-ORG 0.831 0.773 0.801 1104 67 | B-PER 0.812 0.876 0.843 735 68 | I-PER 0.873 0.931 0.901 634 69 | 70 | avg/total 0.779 0.764 0.770 6178 71 | """ 72 | # 打印表头 73 | header_format = '{:>9s} {:>9} {:>9} {:>9} {:>9}' 74 | header = ['precision', 'recall', 'f1-score', 'support'] 75 | print(header_format.format('', *header)) 76 | 77 | row_format = '{:>9s} {:>9.4f} {:>9.4f} {:>9.4f} {:>9}' 78 | # 打印每个标签的 精确率、召回率、f1分数 79 | for tag in self.tagset: 80 | print(row_format.format( 81 | tag, 82 | self.precision_scores[tag], 83 | self.recall_scores[tag], 84 | self.f1_scores[tag], 85 | self.golden_tags_counter[tag] 86 | )) 87 | 88 | # 计算并打印平均值 89 | avg_metrics = self._cal_weighted_average() 90 | print(row_format.format( 91 | 'avg/total', 92 | avg_metrics['precision'], 93 | avg_metrics['recall'], 94 | avg_metrics['f1_score'], 95 | len(self.golden_tags) 96 | )) 97 | 98 | def count_correct_tags(self): 99 | """计算每种标签预测正确的个数(对应精确率、召回率计算公式上的tp),用于后面精确率以及召回率的计算""" 100 | correct_dict = {} 101 | for gold_tag, predict_tag in zip(self.golden_tags, self.predict_tags): 102 | if gold_tag == predict_tag: 103 | if gold_tag not in correct_dict: 104 | correct_dict[gold_tag] = 1 105 | else: 106 | correct_dict[gold_tag] += 1 107 | 108 | return correct_dict 109 | 110 | def _cal_weighted_average(self): 111 | 112 | weighted_average = {} 113 | total = len(self.golden_tags) 114 | 115 | # 计算weighted precisions: 116 | weighted_average['precision'] = 0. 117 | weighted_average['recall'] = 0. 118 | weighted_average['f1_score'] = 0. 119 | for tag in self.tagset: 120 | size = self.golden_tags_counter[tag] 121 | weighted_average['precision'] += self.precision_scores[tag] * size 122 | weighted_average['recall'] += self.recall_scores[tag] * size 123 | weighted_average['f1_score'] += self.f1_scores[tag] * size 124 | 125 | for metric in weighted_average.keys(): 126 | weighted_average[metric] /= total 127 | 128 | return weighted_average 129 | 130 | def _remove_Otags(self): 131 | 132 | length = len(self.golden_tags) 133 | O_tag_indices = [i for i in range(length) 134 | if self.golden_tags[i] == 'O'] 135 | 136 | self.golden_tags = [tag for i, tag in enumerate(self.golden_tags) 137 | if i not in O_tag_indices] 138 | 139 | self.predict_tags = [tag for i, tag in enumerate(self.predict_tags) 140 | if i not in O_tag_indices] 141 | print("原总标记数为{},移除了{}个O标记,占比{:.2f}%".format( 142 | length, 143 | len(O_tag_indices), 144 | len(O_tag_indices) / length * 100 145 | )) 146 | 147 | def report_confusion_matrix(self): 148 | """计算混淆矩阵""" 149 | 150 | print("\nConfusion Matrix:") 151 | tag_list = list(self.tagset) 152 | # 初始化混淆矩阵 matrix[i][j]表示第i个tag被模型预测成第j个tag的次数 153 | tags_size = len(tag_list) 154 | matrix = [] 155 | for i in range(tags_size): 156 | matrix.append([0] * tags_size) 157 | 158 | # 遍历tags列表 159 | for golden_tag, predict_tag in zip(self.golden_tags, self.predict_tags): 160 | try: 161 | row = tag_list.index(golden_tag) 162 | col = tag_list.index(predict_tag) 163 | matrix[row][col] += 1 164 | except ValueError: # 有极少数标记没有出现在golden_tags,但出现在predict_tags,跳过这些标记 165 | continue 166 | 167 | # 输出矩阵 168 | row_format_ = '{:>7} ' * (tags_size+1) 169 | print(row_format_.format("", *tag_list)) 170 | for i, row in enumerate(matrix): 171 | print(row_format_.format(tag_list[i], *row)) 172 | -------------------------------------------------------------------------------- /imgs/biLSTM_NER.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/imgs/biLSTM_NER.png -------------------------------------------------------------------------------- /imgs/decode_crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/imgs/decode_crf.png -------------------------------------------------------------------------------- /imgs/func_set.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/imgs/func_set.png -------------------------------------------------------------------------------- /imgs/log_likehood_crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/imgs/log_likehood_crf.png -------------------------------------------------------------------------------- /imgs/log_linear_crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/imgs/log_linear_crf.png -------------------------------------------------------------------------------- /imgs/w_crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/imgs/w_crf.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from data import build_corpus 3 | from utils import extend_maps, prepocess_data_for_lstmcrf 4 | from evaluate import hmm_train_eval, crf_train_eval, \ 5 | bilstm_train_and_eval, ensemble_evaluate 6 | 7 | 8 | def main(): 9 | """训练模型,评估结果""" 10 | 11 | # 读取数据 12 | print("读取数据...") 13 | train_word_lists, train_tag_lists, word2id, tag2id = \ 14 | build_corpus("train") 15 | dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False) 16 | test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False) 17 | 18 | # 训练评估hmm模型 19 | print("正在训练评估HMM模型...") 20 | hmm_pred = hmm_train_eval( 21 | (train_word_lists, train_tag_lists), 22 | (test_word_lists, test_tag_lists), 23 | word2id, 24 | tag2id 25 | ) 26 | 27 | # 训练评估CRF模型 28 | print("正在训练评估CRF模型...") 29 | crf_pred = crf_train_eval( 30 | (train_word_lists, train_tag_lists), 31 | (test_word_lists, test_tag_lists) 32 | ) 33 | 34 | # 训练评估BI-LSTM模型 35 | print("正在训练评估双向LSTM模型...") 36 | # LSTM模型训练的时候需要在word2id和tag2id加入PAD和UNK 37 | bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False) 38 | lstm_pred = bilstm_train_and_eval( 39 | (train_word_lists, train_tag_lists), 40 | (dev_word_lists, dev_tag_lists), 41 | (test_word_lists, test_tag_lists), 42 | bilstm_word2id, bilstm_tag2id, 43 | crf=False 44 | ) 45 | 46 | print("正在训练评估Bi-LSTM+CRF模型...") 47 | # 如果是加了CRF的lstm还要加入 (解码的时候需要用到) 48 | crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True) 49 | # 还需要额外的一些数据处理 50 | train_word_lists, train_tag_lists = prepocess_data_for_lstmcrf( 51 | train_word_lists, train_tag_lists 52 | ) 53 | dev_word_lists, dev_tag_lists = prepocess_data_for_lstmcrf( 54 | dev_word_lists, dev_tag_lists 55 | ) 56 | test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf( 57 | test_word_lists, test_tag_lists, test=True 58 | ) 59 | lstmcrf_pred = bilstm_train_and_eval( 60 | (train_word_lists, train_tag_lists), 61 | (dev_word_lists, dev_tag_lists), 62 | (test_word_lists, test_tag_lists), 63 | crf_word2id, crf_tag2id 64 | ) 65 | 66 | ensemble_evaluate( 67 | [hmm_pred, crf_pred, lstm_pred, lstmcrf_pred], 68 | test_tag_lists 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luopeixiang/named_entity_recognition/8073771816613d84274e7528e07026f620041558/models/__init__.py -------------------------------------------------------------------------------- /models/bilstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 4 | 5 | 6 | class BiLSTM(nn.Module): 7 | def __init__(self, vocab_size, emb_size, hidden_size, out_size): 8 | """初始化参数: 9 | vocab_size:字典的大小 10 | emb_size:词向量的维数 11 | hidden_size:隐向量的维数 12 | out_size:标注的种类 13 | """ 14 | super(BiLSTM, self).__init__() 15 | self.embedding = nn.Embedding(vocab_size, emb_size) 16 | self.bilstm = nn.LSTM(emb_size, hidden_size, 17 | batch_first=True, 18 | bidirectional=True) 19 | 20 | self.lin = nn.Linear(2*hidden_size, out_size) 21 | 22 | def forward(self, sents_tensor, lengths): 23 | emb = self.embedding(sents_tensor) # [B, L, emb_size] 24 | 25 | packed = pack_padded_sequence(emb, lengths, batch_first=True) 26 | rnn_out, _ = self.bilstm(packed) 27 | # rnn_out:[B, L, hidden_size*2] 28 | rnn_out, _ = pad_packed_sequence(rnn_out, batch_first=True) 29 | 30 | scores = self.lin(rnn_out) # [B, L, out_size] 31 | 32 | return scores 33 | 34 | def test(self, sents_tensor, lengths, _): 35 | """第三个参数不会用到,加它是为了与BiLSTM_CRF保持同样的接口""" 36 | logits = self.forward(sents_tensor, lengths) # [B, L, out_size] 37 | _, batch_tagids = torch.max(logits, dim=2) 38 | 39 | return batch_tagids 40 | -------------------------------------------------------------------------------- /models/bilstm_crf.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from .util import tensorized, sort_by_lengths, cal_loss, cal_lstm_crf_loss 9 | from .config import TrainingConfig, LSTMConfig 10 | from .bilstm import BiLSTM 11 | 12 | 13 | class BILSTM_Model(object): 14 | def __init__(self, vocab_size, out_size, crf=True): 15 | """功能:对LSTM的模型进行训练与测试 16 | 参数: 17 | vocab_size:词典大小 18 | out_size:标注种类 19 | crf选择是否添加CRF层""" 20 | self.device = torch.device( 21 | "cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | # 加载模型参数 24 | self.emb_size = LSTMConfig.emb_size 25 | self.hidden_size = LSTMConfig.hidden_size 26 | 27 | self.crf = crf 28 | # 根据是否添加crf初始化不同的模型 选择不一样的损失计算函数 29 | if not crf: 30 | self.model = BiLSTM(vocab_size, self.emb_size, 31 | self.hidden_size, out_size).to(self.device) 32 | self.cal_loss_func = cal_loss 33 | else: 34 | self.model = BiLSTM_CRF(vocab_size, self.emb_size, 35 | self.hidden_size, out_size).to(self.device) 36 | self.cal_loss_func = cal_lstm_crf_loss 37 | 38 | # 加载训练参数: 39 | self.epoches = TrainingConfig.epoches 40 | self.print_step = TrainingConfig.print_step 41 | self.lr = TrainingConfig.lr 42 | self.batch_size = TrainingConfig.batch_size 43 | 44 | # 初始化优化器 45 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 46 | 47 | # 初始化其他指标 48 | self.step = 0 49 | self._best_val_loss = 1e18 50 | self.best_model = None 51 | 52 | def train(self, word_lists, tag_lists, 53 | dev_word_lists, dev_tag_lists, 54 | word2id, tag2id): 55 | # 对数据集按照长度进行排序 56 | word_lists, tag_lists, _ = sort_by_lengths(word_lists, tag_lists) 57 | dev_word_lists, dev_tag_lists, _ = sort_by_lengths( 58 | dev_word_lists, dev_tag_lists) 59 | 60 | B = self.batch_size 61 | for e in range(1, self.epoches+1): 62 | self.step = 0 63 | losses = 0. 64 | for ind in range(0, len(word_lists), B): 65 | batch_sents = word_lists[ind:ind+B] 66 | batch_tags = tag_lists[ind:ind+B] 67 | 68 | losses += self.train_step(batch_sents, 69 | batch_tags, word2id, tag2id) 70 | 71 | if self.step % TrainingConfig.print_step == 0: 72 | total_step = (len(word_lists) // B + 1) 73 | print("Epoch {}, step/total_step: {}/{} {:.2f}% Loss:{:.4f}".format( 74 | e, self.step, total_step, 75 | 100. * self.step / total_step, 76 | losses / self.print_step 77 | )) 78 | losses = 0. 79 | 80 | # 每轮结束测试在验证集上的性能,保存最好的一个 81 | val_loss = self.validate( 82 | dev_word_lists, dev_tag_lists, word2id, tag2id) 83 | print("Epoch {}, Val Loss:{:.4f}".format(e, val_loss)) 84 | 85 | def train_step(self, batch_sents, batch_tags, word2id, tag2id): 86 | self.model.train() 87 | self.step += 1 88 | # 准备数据 89 | tensorized_sents, lengths = tensorized(batch_sents, word2id) 90 | tensorized_sents = tensorized_sents.to(self.device) 91 | targets, lengths = tensorized(batch_tags, tag2id) 92 | targets = targets.to(self.device) 93 | 94 | # forward 95 | scores = self.model(tensorized_sents, lengths) 96 | 97 | # 计算损失 更新参数 98 | self.optimizer.zero_grad() 99 | loss = self.cal_loss_func(scores, targets, tag2id).to(self.device) 100 | loss.backward() 101 | self.optimizer.step() 102 | 103 | return loss.item() 104 | 105 | def validate(self, dev_word_lists, dev_tag_lists, word2id, tag2id): 106 | self.model.eval() 107 | with torch.no_grad(): 108 | val_losses = 0. 109 | val_step = 0 110 | for ind in range(0, len(dev_word_lists), self.batch_size): 111 | val_step += 1 112 | # 准备batch数据 113 | batch_sents = dev_word_lists[ind:ind+self.batch_size] 114 | batch_tags = dev_tag_lists[ind:ind+self.batch_size] 115 | tensorized_sents, lengths = tensorized( 116 | batch_sents, word2id) 117 | tensorized_sents = tensorized_sents.to(self.device) 118 | targets, lengths = tensorized(batch_tags, tag2id) 119 | targets = targets.to(self.device) 120 | 121 | # forward 122 | scores = self.model(tensorized_sents, lengths) 123 | 124 | # 计算损失 125 | loss = self.cal_loss_func( 126 | scores, targets, tag2id).to(self.device) 127 | val_losses += loss.item() 128 | val_loss = val_losses / val_step 129 | 130 | if val_loss < self._best_val_loss: 131 | print("保存模型...") 132 | self.best_model = deepcopy(self.model) 133 | self._best_val_loss = val_loss 134 | 135 | return val_loss 136 | 137 | def test(self, word_lists, tag_lists, word2id, tag2id): 138 | """返回最佳模型在测试集上的预测结果""" 139 | # 准备数据 140 | word_lists, tag_lists, indices = sort_by_lengths(word_lists, tag_lists) 141 | tensorized_sents, lengths = tensorized(word_lists, word2id) 142 | tensorized_sents = tensorized_sents.to(self.device) 143 | 144 | self.best_model.eval() 145 | with torch.no_grad(): 146 | batch_tagids = self.best_model.test( 147 | tensorized_sents, lengths, tag2id) 148 | 149 | # 将id转化为标注 150 | pred_tag_lists = [] 151 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items()) 152 | for i, ids in enumerate(batch_tagids): 153 | tag_list = [] 154 | if self.crf: 155 | for j in range(lengths[i] - 1): # crf解码过程中,end被舍弃 156 | tag_list.append(id2tag[ids[j].item()]) 157 | else: 158 | for j in range(lengths[i]): 159 | tag_list.append(id2tag[ids[j].item()]) 160 | pred_tag_lists.append(tag_list) 161 | 162 | # indices存有根据长度排序后的索引映射的信息 163 | # 比如若indices = [1, 2, 0] 则说明原先索引为1的元素映射到的新的索引是0, 164 | # 索引为2的元素映射到新的索引是1... 165 | # 下面根据indices将pred_tag_lists和tag_lists转化为原来的顺序 166 | ind_maps = sorted(list(enumerate(indices)), key=lambda e: e[1]) 167 | indices, _ = list(zip(*ind_maps)) 168 | pred_tag_lists = [pred_tag_lists[i] for i in indices] 169 | tag_lists = [tag_lists[i] for i in indices] 170 | 171 | return pred_tag_lists, tag_lists 172 | 173 | 174 | class BiLSTM_CRF(nn.Module): 175 | def __init__(self, vocab_size, emb_size, hidden_size, out_size): 176 | """初始化参数: 177 | vocab_size:字典的大小 178 | emb_size:词向量的维数 179 | hidden_size:隐向量的维数 180 | out_size:标注的种类 181 | """ 182 | super(BiLSTM_CRF, self).__init__() 183 | self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size) 184 | 185 | # CRF实际上就是多学习一个转移矩阵 [out_size, out_size] 初始化为均匀分布 186 | self.transition = nn.Parameter( 187 | torch.ones(out_size, out_size) * 1/out_size) 188 | # self.transition.data.zero_() 189 | 190 | def forward(self, sents_tensor, lengths): 191 | # [B, L, out_size] 192 | emission = self.bilstm(sents_tensor, lengths) 193 | 194 | # 计算CRF scores, 这个scores大小为[B, L, out_size, out_size] 195 | # 也就是每个字对应对应一个 [out_size, out_size]的矩阵 196 | # 这个矩阵第i行第j列的元素的含义是:上一时刻tag为i,这一时刻tag为j的分数 197 | batch_size, max_len, out_size = emission.size() 198 | crf_scores = emission.unsqueeze( 199 | 2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0) 200 | 201 | return crf_scores 202 | 203 | def test(self, test_sents_tensor, lengths, tag2id): 204 | """使用维特比算法进行解码""" 205 | start_id = tag2id[''] 206 | end_id = tag2id[''] 207 | pad = tag2id[''] 208 | tagset_size = len(tag2id) 209 | 210 | crf_scores = self.forward(test_sents_tensor, lengths) 211 | device = crf_scores.device 212 | # B:batch_size, L:max_len, T:target set size 213 | B, L, T, _ = crf_scores.size() 214 | # viterbi[i, j, k]表示第i个句子,第j个字对应第k个标记的最大分数 215 | viterbi = torch.zeros(B, L, T).to(device) 216 | # backpointer[i, j, k]表示第i个句子,第j个字对应第k个标记时前一个标记的id,用于回溯 217 | backpointer = (torch.zeros(B, L, T).long() * end_id).to(device) 218 | lengths = torch.LongTensor(lengths).to(device) 219 | # 向前递推 220 | for step in range(L): 221 | batch_size_t = (lengths > step).sum().item() 222 | if step == 0: 223 | # 第一个字它的前一个标记只能是start_id 224 | viterbi[:batch_size_t, step, 225 | :] = crf_scores[: batch_size_t, step, start_id, :] 226 | backpointer[: batch_size_t, step, :] = start_id 227 | else: 228 | max_scores, prev_tags = torch.max( 229 | viterbi[:batch_size_t, step-1, :].unsqueeze(2) + 230 | crf_scores[:batch_size_t, step, :, :], # [B, T, T] 231 | dim=1 232 | ) 233 | viterbi[:batch_size_t, step, :] = max_scores 234 | backpointer[:batch_size_t, step, :] = prev_tags 235 | 236 | # 在回溯的时候我们只需要用到backpointer矩阵 237 | backpointer = backpointer.view(B, -1) # [B, L * T] 238 | tagids = [] # 存放结果 239 | tags_t = None 240 | for step in range(L-1, 0, -1): 241 | batch_size_t = (lengths > step).sum().item() 242 | if step == L-1: 243 | index = torch.ones(batch_size_t).long() * (step * tagset_size) 244 | index = index.to(device) 245 | index += end_id 246 | else: 247 | prev_batch_size_t = len(tags_t) 248 | 249 | new_in_batch = torch.LongTensor( 250 | [end_id] * (batch_size_t - prev_batch_size_t)).to(device) 251 | offset = torch.cat( 252 | [tags_t, new_in_batch], 253 | dim=0 254 | ) # 这个offset实际上就是前一时刻的 255 | index = torch.ones(batch_size_t).long() * (step * tagset_size) 256 | index = index.to(device) 257 | index += offset.long() 258 | 259 | try: 260 | tags_t = backpointer[:batch_size_t].gather( 261 | dim=1, index=index.unsqueeze(1).long()) 262 | except RuntimeError: 263 | import pdb 264 | pdb.set_trace() 265 | tags_t = tags_t.squeeze(1) 266 | tagids.append(tags_t.tolist()) 267 | 268 | # tagids:[L-1](L-1是因为扣去了end_token),大小的liebiao 269 | # 其中列表内的元素是该batch在该时刻的标记 270 | # 下面修正其顺序,并将维度转换为 [B, L] 271 | tagids = list(zip_longest(*reversed(tagids), fillvalue=pad)) 272 | tagids = torch.Tensor(tagids).long() 273 | 274 | # 返回解码的结果 275 | return tagids 276 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | # 设置lstm训练参数 2 | class TrainingConfig(object): 3 | batch_size = 64 4 | # 学习速率 5 | lr = 0.001 6 | epoches = 30 7 | print_step = 5 8 | 9 | 10 | class LSTMConfig(object): 11 | emb_size = 128 # 词向量的维数 12 | hidden_size = 128 # lstm隐向量的维数 13 | -------------------------------------------------------------------------------- /models/crf.py: -------------------------------------------------------------------------------- 1 | from sklearn_crfsuite import CRF 2 | 3 | from .util import sent2features 4 | 5 | 6 | class CRFModel(object): 7 | def __init__(self, 8 | algorithm='lbfgs', 9 | c1=0.1, 10 | c2=0.1, 11 | max_iterations=100, 12 | all_possible_transitions=False 13 | ): 14 | 15 | self.model = CRF(algorithm=algorithm, 16 | c1=c1, 17 | c2=c2, 18 | max_iterations=max_iterations, 19 | all_possible_transitions=all_possible_transitions) 20 | 21 | def train(self, sentences, tag_lists): 22 | features = [sent2features(s) for s in sentences] 23 | self.model.fit(features, tag_lists) 24 | 25 | def test(self, sentences): 26 | features = [sent2features(s) for s in sentences] 27 | pred_tag_lists = self.model.predict(features) 28 | return pred_tag_lists 29 | -------------------------------------------------------------------------------- /models/hmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class HMM(object): 5 | def __init__(self, N, M): 6 | """Args: 7 | N: 状态数,这里对应存在的标注的种类 8 | M: 观测数,这里对应有多少不同的字 9 | """ 10 | self.N = N 11 | self.M = M 12 | 13 | # 状态转移概率矩阵 A[i][j]表示从i状态转移到j状态的概率 14 | self.A = torch.zeros(N, N) 15 | # 观测概率矩阵, B[i][j]表示i状态下生成j观测的概率 16 | self.B = torch.zeros(N, M) 17 | # 初始状态概率 Pi[i]表示初始时刻为状态i的概率 18 | self.Pi = torch.zeros(N) 19 | 20 | def train(self, word_lists, tag_lists, word2id, tag2id): 21 | """HMM的训练,即根据训练语料对模型参数进行估计, 22 | 因为我们有观测序列以及其对应的状态序列,所以我们 23 | 可以使用极大似然估计的方法来估计隐马尔可夫模型的参数 24 | 参数: 25 | word_lists: 列表,其中每个元素由字组成的列表,如 ['担','任','科','员'] 26 | tag_lists: 列表,其中每个元素是由对应的标注组成的列表,如 ['O','O','B-TITLE', 'E-TITLE'] 27 | word2id: 将字映射为ID 28 | tag2id: 字典,将标注映射为ID 29 | """ 30 | 31 | assert len(tag_lists) == len(word_lists) 32 | 33 | # 估计转移概率矩阵 34 | for tag_list in tag_lists: 35 | seq_len = len(tag_list) 36 | for i in range(seq_len - 1): 37 | current_tagid = tag2id[tag_list[i]] 38 | next_tagid = tag2id[tag_list[i+1]] 39 | self.A[current_tagid][next_tagid] += 1 40 | # 问题:如果某元素没有出现过,该位置为0,这在后续的计算中是不允许的 41 | # 解决方法:我们将等于0的概率加上很小的数 42 | self.A[self.A == 0.] = 1e-10 43 | self.A = self.A / self.A.sum(dim=1, keepdim=True) 44 | 45 | # 估计观测概率矩阵 46 | for tag_list, word_list in zip(tag_lists, word_lists): 47 | assert len(tag_list) == len(word_list) 48 | for tag, word in zip(tag_list, word_list): 49 | tag_id = tag2id[tag] 50 | word_id = word2id[word] 51 | self.B[tag_id][word_id] += 1 52 | self.B[self.B == 0.] = 1e-10 53 | self.B = self.B / self.B.sum(dim=1, keepdim=True) 54 | 55 | # 估计初始状态概率 56 | for tag_list in tag_lists: 57 | init_tagid = tag2id[tag_list[0]] 58 | self.Pi[init_tagid] += 1 59 | self.Pi[self.Pi == 0.] = 1e-10 60 | self.Pi = self.Pi / self.Pi.sum() 61 | 62 | def test(self, word_lists, word2id, tag2id): 63 | pred_tag_lists = [] 64 | for word_list in word_lists: 65 | pred_tag_list = self.decoding(word_list, word2id, tag2id) 66 | pred_tag_lists.append(pred_tag_list) 67 | return pred_tag_lists 68 | 69 | def decoding(self, word_list, word2id, tag2id): 70 | """ 71 | 使用维特比算法对给定观测序列求状态序列, 这里就是对字组成的序列,求其对应的标注。 72 | 维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划求概率最大路径(最优路径) 73 | 这时一条路径对应着一个状态序列 74 | """ 75 | # 问题:整条链很长的情况下,十分多的小概率相乘,最后可能造成下溢 76 | # 解决办法:采用对数概率,这样源空间中的很小概率,就被映射到对数空间的大的负数 77 | # 同时相乘操作也变成简单的相加操作 78 | A = torch.log(self.A) 79 | B = torch.log(self.B) 80 | Pi = torch.log(self.Pi) 81 | 82 | # 初始化 维比特矩阵viterbi 它的维度为[状态数, 序列长度] 83 | # 其中viterbi[i, j]表示标注序列的第j个标注为i的所有单个序列(i_1, i_2, ..i_j)出现的概率最大值 84 | seq_len = len(word_list) 85 | viterbi = torch.zeros(self.N, seq_len) 86 | # backpointer是跟viterbi一样大小的矩阵 87 | # backpointer[i, j]存储的是 标注序列的第j个标注为i时,第j-1个标注的id 88 | # 等解码的时候,我们用backpointer进行回溯,以求出最优路径 89 | backpointer = torch.zeros(self.N, seq_len).long() 90 | 91 | # self.Pi[i] 表示第一个字的标记为i的概率 92 | # Bt[word_id]表示字为word_id的时候,对应各个标记的概率 93 | # self.A.t()[tag_id]表示各个状态转移到tag_id对应的概率 94 | 95 | # 所以第一步为 96 | start_wordid = word2id.get(word_list[0], None) 97 | Bt = B.t() 98 | if start_wordid is None: 99 | # 如果字不再字典里,则假设状态的概率分布是均匀的 100 | bt = torch.log(torch.ones(self.N) / self.N) 101 | else: 102 | bt = Bt[start_wordid] 103 | viterbi[:, 0] = Pi + bt 104 | backpointer[:, 0] = -1 105 | 106 | # 递推公式: 107 | # viterbi[tag_id, step] = max(viterbi[:, step-1]* self.A.t()[tag_id] * Bt[word]) 108 | # 其中word是step时刻对应的字 109 | # 由上述递推公式求后续各步 110 | for step in range(1, seq_len): 111 | wordid = word2id.get(word_list[step], None) 112 | # 处理字不在字典中的情况 113 | # bt是在t时刻字为wordid时,状态的概率分布 114 | if wordid is None: 115 | # 如果字不再字典里,则假设状态的概率分布是均匀的 116 | bt = torch.log(torch.ones(self.N) / self.N) 117 | else: 118 | bt = Bt[wordid] # 否则从观测概率矩阵中取bt 119 | for tag_id in range(len(tag2id)): 120 | max_prob, max_id = torch.max( 121 | viterbi[:, step-1] + A[:, tag_id], 122 | dim=0 123 | ) 124 | viterbi[tag_id, step] = max_prob + bt[tag_id] 125 | backpointer[tag_id, step] = max_id 126 | 127 | # 终止, t=seq_len 即 viterbi[:, seq_len]中的最大概率,就是最优路径的概率 128 | best_path_prob, best_path_pointer = torch.max( 129 | viterbi[:, seq_len-1], dim=0 130 | ) 131 | 132 | # 回溯,求最优路径 133 | best_path_pointer = best_path_pointer.item() 134 | best_path = [best_path_pointer] 135 | for back_step in range(seq_len-1, 0, -1): 136 | best_path_pointer = backpointer[best_path_pointer, back_step] 137 | best_path_pointer = best_path_pointer.item() 138 | best_path.append(best_path_pointer) 139 | 140 | # 将tag_id组成的序列转化为tag 141 | assert len(best_path) == len(word_list) 142 | id2tag = dict((id_, tag) for tag, id_ in tag2id.items()) 143 | tag_list = [id2tag[id_] for id_ in reversed(best_path)] 144 | 145 | return tag_list 146 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # ******** CRF 工具函数************* 5 | 6 | 7 | def word2features(sent, i): 8 | """抽取单个字的特征""" 9 | word = sent[i] 10 | prev_word = "" if i == 0 else sent[i-1] 11 | next_word = "" if i == (len(sent)-1) else sent[i+1] 12 | # 使用的特征: 13 | # 前一个词,当前词,后一个词, 14 | # 前一个词+当前词, 当前词+后一个词 15 | features = { 16 | 'w': word, 17 | 'w-1': prev_word, 18 | 'w+1': next_word, 19 | 'w-1:w': prev_word+word, 20 | 'w:w+1': word+next_word, 21 | 'bias': 1 22 | } 23 | return features 24 | 25 | 26 | def sent2features(sent): 27 | """抽取序列特征""" 28 | return [word2features(sent, i) for i in range(len(sent))] 29 | 30 | 31 | # ******** LSTM模型 工具函数************* 32 | 33 | def tensorized(batch, maps): 34 | PAD = maps.get('') 35 | UNK = maps.get('') 36 | 37 | max_len = len(batch[0]) 38 | batch_size = len(batch) 39 | 40 | batch_tensor = torch.ones(batch_size, max_len).long() * PAD 41 | for i, l in enumerate(batch): 42 | for j, e in enumerate(l): 43 | batch_tensor[i][j] = maps.get(e, UNK) 44 | # batch各个元素的长度 45 | lengths = [len(l) for l in batch] 46 | 47 | return batch_tensor, lengths 48 | 49 | 50 | def sort_by_lengths(word_lists, tag_lists): 51 | pairs = list(zip(word_lists, tag_lists)) 52 | indices = sorted(range(len(pairs)), 53 | key=lambda k: len(pairs[k][0]), 54 | reverse=True) 55 | pairs = [pairs[i] for i in indices] 56 | # pairs.sort(key=lambda pair: len(pair[0]), reverse=True) 57 | 58 | word_lists, tag_lists = list(zip(*pairs)) 59 | 60 | return word_lists, tag_lists, indices 61 | 62 | 63 | def cal_loss(logits, targets, tag2id): 64 | """计算损失 65 | 参数: 66 | logits: [B, L, out_size] 67 | targets: [B, L] 68 | lengths: [B] 69 | """ 70 | PAD = tag2id.get('') 71 | assert PAD is not None 72 | 73 | mask = (targets != PAD) # [B, L] 74 | targets = targets[mask] 75 | out_size = logits.size(2) 76 | logits = logits.masked_select( 77 | mask.unsqueeze(2).expand(-1, -1, out_size) 78 | ).contiguous().view(-1, out_size) 79 | 80 | assert logits.size(0) == targets.size(0) 81 | loss = F.cross_entropy(logits, targets) 82 | 83 | return loss 84 | 85 | # FOR BiLSTM-CRF 86 | 87 | 88 | def cal_lstm_crf_loss(crf_scores, targets, tag2id): 89 | """计算双向LSTM-CRF模型的损失 90 | 该损失函数的计算可以参考:https://arxiv.org/pdf/1603.01360.pdf 91 | """ 92 | pad_id = tag2id.get('') 93 | start_id = tag2id.get('') 94 | end_id = tag2id.get('') 95 | 96 | device = crf_scores.device 97 | 98 | # targets:[B, L] crf_scores:[B, L, T, T] 99 | batch_size, max_len = targets.size() 100 | target_size = len(tag2id) 101 | 102 | # mask = 1 - ((targets == pad_id) + (targets == end_id)) # [B, L] 103 | mask = (targets != pad_id) 104 | lengths = mask.sum(dim=1) 105 | targets = indexed(targets, target_size, start_id) 106 | 107 | # # 计算Golden scores方法1 108 | # import pdb 109 | # pdb.set_trace() 110 | targets = targets.masked_select(mask) # [real_L] 111 | 112 | flatten_scores = crf_scores.masked_select( 113 | mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores) 114 | ).view(-1, target_size*target_size).contiguous() 115 | 116 | golden_scores = flatten_scores.gather( 117 | dim=1, index=targets.unsqueeze(1)).sum() 118 | 119 | # 计算golden_scores方法2:利用pack_padded_sequence函数 120 | # targets[targets == end_id] = pad_id 121 | # scores_at_targets = torch.gather( 122 | # crf_scores.view(batch_size, max_len, -1), 2, targets.unsqueeze(2)).squeeze(2) 123 | # scores_at_targets, _ = pack_padded_sequence( 124 | # scores_at_targets, lengths-1, batch_first=True 125 | # ) 126 | # golden_scores = scores_at_targets.sum() 127 | 128 | # 计算all path scores 129 | # scores_upto_t[i, j]表示第i个句子的第t个词被标注为j标记的所有t时刻事前的所有子路径的分数之和 130 | scores_upto_t = torch.zeros(batch_size, target_size).to(device) 131 | for t in range(max_len): 132 | # 当前时刻 有效的batch_size(因为有些序列比较短) 133 | batch_size_t = (lengths > t).sum().item() 134 | if t == 0: 135 | scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, 136 | t, start_id, :] 137 | else: 138 | # We add scores at current timestep to scores accumulated up to previous 139 | # timestep, and log-sum-exp Remember, the cur_tag of the previous 140 | # timestep is the prev_tag of this timestep 141 | # So, broadcast prev. timestep's cur_tag scores 142 | # along cur. timestep's cur_tag dimension 143 | scores_upto_t[:batch_size_t] = torch.logsumexp( 144 | crf_scores[:batch_size_t, t, :, :] + 145 | scores_upto_t[:batch_size_t].unsqueeze(2), 146 | dim=1 147 | ) 148 | all_path_scores = scores_upto_t[:, end_id].sum() 149 | 150 | # 训练大约两个epoch loss变成负数,从数学的角度上来说,loss = -logP 151 | loss = (all_path_scores - golden_scores) / batch_size 152 | return loss 153 | 154 | 155 | def indexed(targets, tagset_size, start_id): 156 | """将targets中的数转化为在[T*T]大小序列中的索引,T是标注的种类""" 157 | batch_size, max_len = targets.size() 158 | for col in range(max_len-1, 0, -1): 159 | targets[:, col] += (targets[:, col-1] * tagset_size) 160 | targets[:, 0] += (start_id * tagset_size) 161 | return targets 162 | -------------------------------------------------------------------------------- /output.txt: -------------------------------------------------------------------------------- 1 | 读取数据... 2 | 加载并评估hmm模型... 3 | precision recall f1-score support 4 | E-EDU 0.9167 0.9821 0.9483 112 5 | B-RACE 1.0000 0.9286 0.9630 14 6 | E-TITLE 0.9514 0.9637 0.9575 772 7 | B-NAME 0.9800 0.8750 0.9245 112 8 | M-NAME 0.9459 0.8537 0.8974 82 9 | M-CONT 0.9815 1.0000 0.9907 53 10 | M-ORG 0.9002 0.9327 0.9162 4325 11 | B-CONT 0.9655 1.0000 0.9825 28 12 | B-EDU 0.9000 0.9643 0.9310 112 13 | B-LOC 0.3333 0.3333 0.3333 6 14 | B-ORG 0.8422 0.8879 0.8644 553 15 | B-TITLE 0.8811 0.8925 0.8867 772 16 | E-CONT 0.9655 1.0000 0.9825 28 17 | E-ORG 0.8262 0.8680 0.8466 553 18 | E-NAME 0.9000 0.8036 0.8491 112 19 | M-TITLE 0.9038 0.8751 0.8892 1922 20 | E-LOC 0.5000 0.5000 0.5000 6 21 | B-PRO 0.5581 0.7273 0.6316 33 22 | M-LOC 0.5833 0.3333 0.4242 21 23 | O 0.9568 0.9177 0.9369 5190 24 | M-PRO 0.4490 0.6471 0.5301 68 25 | M-EDU 0.9348 0.9609 0.9477 179 26 | E-PRO 0.6512 0.8485 0.7368 33 27 | E-RACE 1.0000 0.9286 0.9630 14 28 | avg/total 0.9149 0.9122 0.9130 15100 29 | 30 | Confusion Matrix: 31 | E-EDU B-RACE E-TITLE B-NAME M-NAME M-CONT M-ORG B-CONT B-EDU B-LOC B-ORG B-TITLE E-CONT E-ORG E-NAME M-TITLE E-LOC B-PRO M-LOC O M-PRO M-EDU E-PRO E-RACE 32 | E-EDU 110 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 33 | B-RACE 0 13 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 34 | E-TITLE 1 0 744 0 0 0 15 0 0 0 4 0 0 0 0 2 0 0 0 6 0 0 0 0 35 | B-NAME 0 0 0 98 0 0 2 0 0 0 1 0 0 0 0 0 0 0 0 8 0 0 0 0 36 | M-NAME 0 0 0 0 70 0 3 0 0 0 0 0 0 0 6 0 0 0 0 3 0 0 0 0 37 | M-CONT 0 0 0 0 0 53 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 38 | M-ORG 3 0 4 1 2 1 4034 0 3 0 38 17 1 42 2 53 3 10 5 70 25 1 7 0 39 | B-CONT 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 40 | B-EDU 0 0 0 0 0 0 1 0 108 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 41 | B-LOC 0 0 0 0 0 0 0 0 0 2 3 0 0 0 0 0 0 0 0 1 0 0 0 0 42 | B-ORG 0 0 0 1 0 0 23 1 0 3 491 6 0 0 0 0 0 0 0 28 0 0 0 0 43 | B-TITLE 0 0 1 0 0 0 23 0 2 0 6 689 0 1 0 28 0 2 0 20 0 0 0 0 44 | E-CONT 0 0 0 0 0 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 45 | E-ORG 0 0 1 0 0 0 30 0 1 0 0 9 0 480 0 18 0 1 0 10 3 0 0 0 46 | E-NAME 0 0 0 0 2 0 0 0 0 0 0 0 0 3 90 0 0 0 0 16 0 0 0 0 47 | M-TITLE 3 0 6 0 0 0 115 0 2 0 3 35 0 17 0 1682 0 1 0 44 7 4 3 0 48 | E-LOC 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 3 0 0 2 0 0 0 0 49 | B-PRO 0 0 0 0 0 0 5 0 1 0 0 0 0 0 0 0 0 24 0 0 3 0 0 0 50 | M-LOC 0 0 0 0 0 0 7 0 0 1 0 0 0 2 0 0 0 0 7 4 0 0 0 0 51 | O 2 0 26 0 0 0 204 0 1 0 37 26 0 30 2 78 0 3 0 4763 12 1 4 0 52 | M-PRO 0 0 0 0 0 0 18 0 1 0 0 0 0 3 0 0 0 1 0 0 44 1 0 0 53 | M-EDU 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 4 172 0 0 54 | E-PRO 1 0 0 0 0 0 0 0 1 0 0 0 0 2 0 0 0 0 0 0 0 1 28 0 55 | E-RACE 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 13 56 | 加载并评估crf模型... 57 | precision recall f1-score support 58 | E-EDU 0.9910 0.9821 0.9865 112 59 | B-RACE 1.0000 1.0000 1.0000 14 60 | E-TITLE 0.9857 0.9819 0.9838 772 61 | B-NAME 1.0000 0.9821 0.9910 112 62 | M-NAME 1.0000 0.9756 0.9877 82 63 | M-CONT 1.0000 1.0000 1.0000 53 64 | M-ORG 0.9523 0.9563 0.9543 4325 65 | B-CONT 1.0000 1.0000 1.0000 28 66 | B-EDU 0.9820 0.9732 0.9776 112 67 | B-LOC 1.0000 0.8333 0.9091 6 68 | B-ORG 0.9636 0.9566 0.9601 553 69 | B-TITLE 0.9376 0.9339 0.9358 772 70 | E-CONT 1.0000 1.0000 1.0000 28 71 | E-ORG 0.9199 0.9132 0.9165 553 72 | E-NAME 1.0000 0.9821 0.9910 112 73 | M-TITLE 0.9248 0.9022 0.9134 1922 74 | E-LOC 1.0000 0.8333 0.9091 6 75 | B-PRO 0.9091 0.9091 0.9091 33 76 | M-LOC 1.0000 0.8095 0.8947 21 77 | O 0.9630 0.9732 0.9681 5190 78 | M-PRO 0.8354 0.9706 0.8980 68 79 | M-EDU 0.9824 0.9330 0.9570 179 80 | E-PRO 0.9091 0.9091 0.9091 33 81 | E-RACE 1.0000 1.0000 1.0000 14 82 | avg/total 0.9543 0.9543 0.9542 15100 83 | 84 | Confusion Matrix: 85 | E-EDU B-RACE E-TITLE B-NAME M-NAME M-CONT M-ORG B-CONT B-EDU B-LOC B-ORG B-TITLE E-CONT E-ORG E-NAME M-TITLE E-LOC B-PRO M-LOC O M-PRO M-EDU E-PRO E-RACE 86 | E-EDU 110 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 87 | B-RACE 0 14 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 88 | E-TITLE 1 0 758 0 0 0 2 0 0 0 0 0 0 1 0 1 0 0 0 9 0 0 0 0 89 | B-NAME 0 0 0 110 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 90 | M-NAME 0 0 0 0 80 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 91 | M-CONT 0 0 0 0 0 53 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 92 | M-ORG 0 0 2 0 0 0 4136 0 0 0 1 11 0 12 0 65 0 1 0 91 5 0 1 0 93 | B-CONT 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 94 | B-EDU 0 0 0 0 0 0 1 0 109 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 95 | B-LOC 0 0 0 0 0 0 0 0 0 5 1 0 0 0 0 0 0 0 0 0 0 0 0 0 96 | B-ORG 0 0 0 0 0 0 1 0 0 0 529 12 0 0 0 0 0 0 0 11 0 0 0 0 97 | B-TITLE 0 0 0 0 0 0 12 0 0 0 7 721 0 0 0 22 0 1 0 9 0 0 0 0 98 | E-CONT 0 0 0 0 0 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 99 | E-ORG 0 0 1 0 0 0 14 0 0 0 0 0 0 505 0 20 0 0 0 13 0 0 0 0 100 | E-NAME 0 0 0 0 0 0 0 0 0 0 0 0 0 0 110 0 0 0 0 2 0 0 0 0 101 | M-TITLE 0 0 3 0 0 0 89 0 1 0 1 19 0 17 0 1734 0 0 0 54 2 1 1 0 102 | E-LOC 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 5 0 0 0 0 0 0 0 103 | B-PRO 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 30 0 0 1 0 0 0 104 | M-LOC 0 0 0 0 0 0 4 0 0 0 0 0 0 0 0 0 0 0 17 0 0 0 0 0 105 | O 0 0 5 0 0 0 75 0 0 0 9 6 0 11 0 33 0 0 0 5051 0 0 0 0 106 | M-PRO 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 66 0 0 0 107 | M-EDU 0 0 0 0 0 0 5 0 0 0 0 0 0 1 0 0 0 1 0 1 4 167 0 0 108 | E-PRO 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 30 0 109 | E-RACE 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 110 | 加载并评估bilstm模型... 111 | precision recall f1-score support 112 | E-EDU 0.9732 0.9732 0.9732 112 113 | B-RACE 1.0000 0.9286 0.9630 14 114 | E-TITLE 0.9754 0.9754 0.9754 772 115 | B-NAME 1.0000 0.8929 0.9434 112 116 | M-NAME 0.9186 0.9634 0.9405 82 117 | M-CONT 0.9815 1.0000 0.9907 53 118 | M-ORG 0.9631 0.9535 0.9583 4325 119 | B-CONT 1.0000 1.0000 1.0000 28 120 | B-EDU 0.9649 0.9821 0.9735 112 121 | B-LOC 1.0000 0.8333 0.9091 6 122 | B-ORG 0.9402 0.9675 0.9537 553 123 | B-TITLE 0.9457 0.9249 0.9352 772 124 | E-CONT 1.0000 1.0000 1.0000 28 125 | E-ORG 0.9194 0.9078 0.9136 553 126 | E-NAME 1.0000 0.9464 0.9725 112 127 | M-TITLE 0.9409 0.8871 0.9132 1922 128 | E-LOC 1.0000 1.0000 1.0000 6 129 | B-PRO 0.8182 0.8182 0.8182 33 130 | M-LOC 1.0000 1.0000 1.0000 21 131 | O 0.9541 0.9819 0.9678 5190 132 | M-PRO 0.7159 0.9265 0.8077 68 133 | M-EDU 0.9716 0.9553 0.9634 179 134 | E-PRO 0.8857 0.9394 0.9118 33 135 | E-RACE 1.0000 1.0000 1.0000 14 136 | avg/total 0.9537 0.9532 0.9532 15100 137 | 138 | Confusion Matrix: 139 | E-EDU B-RACE E-TITLE B-NAME M-NAME M-CONT M-ORG B-CONT B-EDU B-LOC B-ORG B-TITLE E-CONT E-ORG E-NAME M-TITLE E-LOC B-PRO M-LOC O M-PRO M-EDU E-PRO E-RACE 140 | E-EDU 109 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 141 | B-RACE 0 13 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 142 | E-TITLE 1 0 753 0 0 0 0 0 0 0 0 0 0 2 0 2 0 0 0 14 0 0 0 0 143 | B-NAME 0 0 0 100 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 0 0 0 0 144 | M-NAME 0 0 0 0 79 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 145 | M-CONT 0 0 0 0 0 53 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 146 | M-ORG 1 0 0 0 0 1 4124 0 0 0 21 11 0 16 0 46 0 2 0 94 8 0 1 0 147 | B-CONT 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 148 | B-EDU 0 0 0 0 0 0 1 0 110 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 149 | B-LOC 0 0 0 0 0 0 0 0 0 5 0 1 0 0 0 0 0 0 0 0 0 0 0 0 150 | B-ORG 0 0 0 0 0 0 3 0 1 0 535 4 0 0 0 0 0 0 0 10 0 0 0 0 151 | B-TITLE 0 0 0 0 0 0 10 0 0 0 7 714 0 1 0 24 0 1 0 15 0 0 0 0 152 | E-CONT 0 0 0 0 0 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 153 | E-ORG 0 0 1 0 0 0 14 0 0 0 0 1 0 502 0 21 0 0 0 11 3 0 0 0 154 | E-NAME 0 0 0 0 0 0 0 0 0 0 0 0 0 0 106 0 0 0 0 2 0 0 0 0 155 | M-TITLE 1 0 6 0 2 0 81 0 1 0 0 17 0 16 0 1705 0 0 0 86 3 3 1 0 156 | E-LOC 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 0 0 0 0 0 0 0 157 | B-PRO 0 0 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 27 0 0 3 0 0 0 158 | M-LOC 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21 0 0 0 0 0 159 | O 0 0 12 0 0 0 44 0 0 0 5 7 0 7 0 13 0 2 0 5096 4 0 0 0 160 | M-PRO 0 0 0 0 0 0 4 0 0 0 0 0 0 1 0 0 0 0 0 0 63 0 0 0 161 | M-EDU 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 1 4 171 0 0 162 | E-PRO 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 31 0 163 | E-RACE 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 164 | 加载并评估bilstm+crf模型... 165 | precision recall f1-score support 166 | E-EDU 0.9820 0.9732 0.9776 112 167 | B-RACE 1.0000 0.9286 0.9630 14 168 | E-TITLE 0.9921 0.9767 0.9843 772 169 | B-NAME 1.0000 0.9196 0.9581 112 170 | M-NAME 0.9753 0.9634 0.9693 82 171 | M-CONT 1.0000 0.9623 0.9808 53 172 | M-ORG 0.9525 0.9635 0.9579 4325 173 | B-CONT 1.0000 0.9643 0.9818 28 174 | B-EDU 0.9820 0.9732 0.9776 112 175 | B-LOC 1.0000 0.8333 0.9091 6 176 | B-ORG 0.9555 0.9711 0.9632 553 177 | B-TITLE 0.9420 0.9262 0.9340 772 178 | E-CONT 1.0000 0.9643 0.9818 28 179 | E-ORG 0.9234 0.9150 0.9192 553 180 | E-NAME 1.0000 0.9375 0.9677 112 181 | M-TITLE 0.9528 0.8918 0.9213 1922 182 | E-LOC 1.0000 0.8333 0.9091 6 183 | B-PRO 0.9412 0.9697 0.9552 33 184 | M-LOC 1.0000 0.8095 0.8947 21 185 | O 0.9605 0.9827 0.9714 5190 186 | M-PRO 0.8684 0.9706 0.9167 68 187 | M-EDU 0.9767 0.9385 0.9573 179 188 | E-PRO 0.9118 0.9394 0.9254 33 189 | E-RACE 1.0000 1.0000 1.0000 14 190 | avg/total 0.9574 0.9572 0.9570 15100 191 | 192 | Confusion Matrix: 193 | E-EDU B-RACE E-TITLE B-NAME M-NAME M-CONT M-ORG B-CONT B-EDU B-LOC B-ORG B-TITLE E-CONT E-ORG E-NAME M-TITLE E-LOC B-PRO M-LOC O M-PRO M-EDU E-PRO E-RACE 194 | E-EDU 109 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 195 | B-RACE 0 13 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 196 | E-TITLE 1 0 754 0 0 0 3 0 0 0 0 0 0 2 0 1 0 0 0 11 0 0 0 0 197 | B-NAME 0 0 0 103 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 0 0 0 0 198 | M-NAME 0 0 0 0 79 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 199 | M-CONT 0 0 0 0 0 51 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 200 | M-ORG 0 0 0 0 0 0 4167 0 0 0 9 14 0 16 0 37 0 0 0 77 4 0 1 0 201 | B-CONT 0 0 0 0 0 0 0 27 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 202 | B-EDU 0 0 0 0 0 0 2 0 109 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 203 | B-LOC 0 0 0 0 0 0 0 0 0 5 1 0 0 0 0 0 0 0 0 0 0 0 0 0 204 | B-ORG 0 0 0 0 0 0 1 0 0 0 537 5 0 0 0 0 0 0 0 10 0 0 0 0 205 | B-TITLE 0 0 0 0 0 0 14 0 0 0 7 715 0 2 0 20 0 1 0 13 0 0 0 0 206 | E-CONT 0 0 0 0 0 0 0 0 0 0 0 0 27 0 0 0 0 0 0 1 0 0 0 0 207 | E-ORG 0 0 0 0 0 0 20 0 0 0 0 0 0 506 0 17 0 1 0 9 0 0 0 0 208 | E-NAME 0 0 0 0 0 0 0 0 0 0 0 0 0 0 105 0 0 0 0 6 0 0 0 0 209 | M-TITLE 0 0 2 0 0 0 102 0 1 0 2 18 0 13 0 1714 0 0 0 66 2 1 1 0 210 | E-LOC 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 5 0 0 0 0 0 0 0 211 | B-PRO 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 32 0 0 0 0 0 0 212 | M-LOC 0 0 0 0 0 0 4 0 0 0 0 0 0 0 0 0 0 0 17 0 0 0 0 0 213 | O 0 0 4 0 0 0 57 0 0 0 5 7 0 7 0 10 0 0 0 5100 0 0 0 0 214 | M-PRO 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 66 0 0 0 215 | M-EDU 1 0 0 0 0 0 2 0 0 0 1 0 0 1 0 0 0 0 0 2 4 168 0 0 216 | E-PRO 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 31 0 217 | E-RACE 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 218 | Ensemble 四个模型的结果如下: 219 | precision recall f1-score support 220 | E-EDU 0.9910 0.9821 0.9865 112 221 | B-RACE 1.0000 0.9286 0.9630 14 222 | E-TITLE 0.9832 0.9832 0.9832 772 223 | B-NAME 1.0000 0.9286 0.9630 112 224 | M-NAME 0.9756 0.9756 0.9756 82 225 | M-CONT 1.0000 1.0000 1.0000 53 226 | M-ORG 0.9434 0.9667 0.9549 4325 227 | B-CONT 1.0000 1.0000 1.0000 28 228 | B-EDU 0.9735 0.9821 0.9778 112 229 | B-LOC 1.0000 0.8333 0.9091 6 230 | B-ORG 0.9747 0.9747 0.9747 553 231 | B-TITLE 0.9426 0.9365 0.9396 772 232 | E-CONT 1.0000 1.0000 1.0000 28 233 | E-ORG 0.9305 0.9204 0.9255 553 234 | E-NAME 1.0000 0.9464 0.9725 112 235 | M-TITLE 0.9499 0.8975 0.9230 1922 236 | E-LOC 1.0000 0.8333 0.9091 6 237 | B-PRO 0.9000 0.8182 0.8571 33 238 | M-LOC 1.0000 0.8095 0.8947 21 239 | O 0.9679 0.9707 0.9693 5190 240 | M-PRO 0.7586 0.9706 0.8516 68 241 | M-EDU 0.9773 0.9609 0.9690 179 242 | E-PRO 0.8857 0.9394 0.9118 33 243 | E-RACE 1.0000 1.0000 1.0000 14 244 | avg/total 0.9569 0.9565 0.9564 15100 245 | 246 | Confusion Matrix: 247 | E-EDU B-RACE E-TITLE B-NAME M-NAME M-CONT M-ORG B-CONT B-EDU B-LOC B-ORG B-TITLE E-CONT E-ORG E-NAME M-TITLE E-LOC B-PRO M-LOC O M-PRO M-EDU E-PRO E-RACE 248 | E-EDU 110 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 249 | B-RACE 0 13 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 250 | E-TITLE 1 0 759 0 0 0 2 0 0 0 0 0 0 1 0 1 0 0 0 8 0 0 0 0 251 | B-NAME 0 0 0 104 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 0 0 0 0 252 | M-NAME 0 0 0 0 80 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 253 | M-CONT 0 0 0 0 0 53 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 254 | M-ORG 0 0 1 0 0 0 4181 0 0 0 4 10 0 14 0 38 0 0 0 68 8 0 1 0 255 | B-CONT 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 | B-EDU 0 0 0 0 0 0 1 0 110 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 257 | B-LOC 0 0 0 0 0 0 0 0 0 5 1 0 0 0 0 0 0 0 0 0 0 0 0 0 258 | B-ORG 0 0 0 0 0 0 0 0 0 0 539 5 0 0 0 0 0 0 0 9 0 0 0 0 259 | B-TITLE 0 0 0 0 0 0 13 0 0 0 6 723 0 0 0 19 0 1 0 10 0 0 0 0 260 | E-CONT 0 0 0 0 0 0 0 0 0 0 0 0 28 0 0 0 0 0 0 0 0 0 0 0 261 | E-ORG 0 0 1 0 0 0 15 0 0 0 0 1 0 509 0 15 0 1 0 9 2 0 0 0 262 | E-NAME 0 0 0 0 0 0 0 0 0 0 0 0 0 0 106 0 0 0 0 5 0 0 0 0 263 | M-TITLE 0 0 3 0 0 0 106 0 1 0 0 21 0 13 0 1725 0 0 0 48 3 1 1 0 264 | E-LOC 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 5 0 0 0 0 0 0 0 265 | B-PRO 0 0 0 0 0 0 2 0 1 0 0 0 0 0 0 0 0 27 0 0 3 0 0 0 266 | M-LOC 0 0 0 0 0 0 4 0 0 0 0 0 0 0 0 0 0 0 17 0 0 0 0 0 267 | O 0 0 8 0 0 0 106 0 0 0 3 7 0 8 0 18 0 0 0 5038 1 0 1 0 268 | M-PRO 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 66 0 0 0 269 | M-EDU 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 4 172 0 0 270 | E-PRO 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 31 0 271 | E-RACE 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 272 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.2 2 | python-crfsuite==0.9.6 3 | six==1.12.0 4 | sklearn-crfsuite==0.3.6 5 | tabulate==0.8.3 6 | torch==1.0.1.post2 7 | tqdm==4.31.1 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils import load_model, extend_maps, prepocess_data_for_lstmcrf 2 | from data import build_corpus 3 | from evaluating import Metrics 4 | from evaluate import ensemble_evaluate 5 | 6 | HMM_MODEL_PATH = './ckpts/hmm.pkl' 7 | CRF_MODEL_PATH = './ckpts/crf.pkl' 8 | BiLSTM_MODEL_PATH = './ckpts/bilstm.pkl' 9 | BiLSTMCRF_MODEL_PATH = './ckpts/bilstm_crf.pkl' 10 | 11 | REMOVE_O = False # 在评估的时候是否去除O标记 12 | 13 | 14 | def main(): 15 | print("读取数据...") 16 | train_word_lists, train_tag_lists, word2id, tag2id = \ 17 | build_corpus("train") 18 | dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False) 19 | test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False) 20 | 21 | print("加载并评估hmm模型...") 22 | hmm_model = load_model(HMM_MODEL_PATH) 23 | hmm_pred = hmm_model.test(test_word_lists, 24 | word2id, 25 | tag2id) 26 | metrics = Metrics(test_tag_lists, hmm_pred, remove_O=REMOVE_O) 27 | metrics.report_scores() # 打印每个标记的精确度、召回率、f1分数 28 | metrics.report_confusion_matrix() # 打印混淆矩阵 29 | 30 | # 加载并评估CRF模型 31 | print("加载并评估crf模型...") 32 | crf_model = load_model(CRF_MODEL_PATH) 33 | crf_pred = crf_model.test(test_word_lists) 34 | metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O) 35 | metrics.report_scores() 36 | metrics.report_confusion_matrix() 37 | 38 | # bilstm模型 39 | print("加载并评估bilstm模型...") 40 | bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False) 41 | bilstm_model = load_model(BiLSTM_MODEL_PATH) 42 | bilstm_model.model.bilstm.flatten_parameters() # remove warning 43 | lstm_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists, 44 | bilstm_word2id, bilstm_tag2id) 45 | metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O) 46 | metrics.report_scores() 47 | metrics.report_confusion_matrix() 48 | 49 | print("加载并评估bilstm+crf模型...") 50 | crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True) 51 | bilstm_model = load_model(BiLSTMCRF_MODEL_PATH) 52 | bilstm_model.model.bilstm.bilstm.flatten_parameters() # remove warning 53 | test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf( 54 | test_word_lists, test_tag_lists, test=True 55 | ) 56 | lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists, 57 | crf_word2id, crf_tag2id) 58 | metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O) 59 | metrics.report_scores() 60 | metrics.report_confusion_matrix() 61 | 62 | ensemble_evaluate( 63 | [hmm_pred, crf_pred, lstm_pred, lstmcrf_pred], 64 | test_tag_lists 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def merge_maps(dict1, dict2): 5 | """用于合并两个word2id或者两个tag2id""" 6 | for key in dict2.keys(): 7 | if key not in dict1: 8 | dict1[key] = len(dict1) 9 | return dict1 10 | 11 | 12 | def save_model(model, file_name): 13 | """用于保存模型""" 14 | with open(file_name, "wb") as f: 15 | pickle.dump(model, f) 16 | 17 | 18 | def load_model(file_name): 19 | """用于加载模型""" 20 | with open(file_name, "rb") as f: 21 | model = pickle.load(f) 22 | return model 23 | 24 | 25 | # LSTM模型训练的时候需要在word2id和tag2id加入PAD和UNK 26 | # 如果是加了CRF的lstm还要加入 (解码的时候需要用到) 27 | def extend_maps(word2id, tag2id, for_crf=True): 28 | word2id[''] = len(word2id) 29 | word2id[''] = len(word2id) 30 | tag2id[''] = len(tag2id) 31 | tag2id[''] = len(tag2id) 32 | # 如果是加了CRF的bilstm 那么还要加入token 33 | if for_crf: 34 | word2id[''] = len(word2id) 35 | word2id[''] = len(word2id) 36 | tag2id[''] = len(tag2id) 37 | tag2id[''] = len(tag2id) 38 | 39 | return word2id, tag2id 40 | 41 | 42 | def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False): 43 | assert len(word_lists) == len(tag_lists) 44 | for i in range(len(word_lists)): 45 | word_lists[i].append("") 46 | if not test: # 如果是测试数据,就不需要加end token了 47 | tag_lists[i].append("") 48 | 49 | return word_lists, tag_lists 50 | 51 | 52 | def flatten_lists(lists): 53 | flatten_list = [] 54 | for l in lists: 55 | if type(l) == list: 56 | flatten_list += l 57 | else: 58 | flatten_list.append(l) 59 | return flatten_list 60 | --------------------------------------------------------------------------------