├── pic ├── pic1.png └── pic2.png ├── src ├── BILSTM_CRF.pyc ├── data_helper.pyc ├── __pycache__ │ ├── model.cpython-36.pyc │ ├── BILSTM_CRF.cpython-36.pyc │ ├── data_helper.cpython-36.pyc │ └── model_helper.cpython-36.pyc ├── data │ ├── C_dict.txt │ ├── B_dict.txt │ ├── calc_f1.py │ ├── 1.build_dict.py │ ├── 2.build_file.py │ └── dev │ │ └── c_id.txt ├── .idea │ ├── misc.xml │ ├── modules.xml │ ├── nlp_template.iml │ ├── remote-mappings.xml │ ├── deployment.xml │ ├── webServers.xml │ └── workspace.xml ├── model_helper.py ├── main.py ├── BILSTM_CRF.py └── data_helper.py ├── LICENSE ├── README.md └── calc_f1.py /pic/pic1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/pic/pic1.png -------------------------------------------------------------------------------- /pic/pic2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/pic/pic2.png -------------------------------------------------------------------------------- /src/BILSTM_CRF.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/src/BILSTM_CRF.pyc -------------------------------------------------------------------------------- /src/data_helper.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/src/data_helper.pyc -------------------------------------------------------------------------------- /src/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/src/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/BILSTM_CRF.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/src/__pycache__/BILSTM_CRF.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/data_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/src/__pycache__/data_helper.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/model_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nrgeup/chinese_semantic_role_labeling/HEAD/src/__pycache__/model_helper.cpython-36.pyc -------------------------------------------------------------------------------- /src/data/C_dict.txt: -------------------------------------------------------------------------------- 1 | _PAD 2 | O 3 | ARG1 4 | ARG0 5 | rel 6 | ARGM-ADV 7 | ARGM-TMP 8 | ARGM-LOC 9 | ARG2 10 | ARGM-MNR 11 | ARGM-PRP 12 | ARG3 13 | ARGM-CND 14 | ARGM-DIR 15 | ARGM-BNF 16 | ARGM-TPC 17 | ARGM-EXT 18 | ARGM-DIS 19 | ARG4 20 | ARGM-FRQ 21 | -------------------------------------------------------------------------------- /src/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /src/data/B_dict.txt: -------------------------------------------------------------------------------- 1 | _PAD 2 | NN 3 | VV 4 | PU 5 | NR 6 | AD 7 | P 8 | CD 9 | JJ 10 | M 11 | DEC 12 | DEG 13 | CC 14 | NT 15 | VA 16 | LC 17 | DT 18 | PN 19 | AS 20 | VC 21 | VE 22 | OD 23 | ETC 24 | MSP 25 | BA 26 | CS 27 | DEV 28 | SB 29 | LB 30 | SP 31 | DER 32 | FW 33 | NP 34 | -------------------------------------------------------------------------------- /src/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/.idea/nlp_template.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /src/.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /src/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /src/.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | 15 | -------------------------------------------------------------------------------- /src/model_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import collections 3 | 4 | 5 | 6 | class TrainModel( 7 | collections.namedtuple("TrainModel", 8 | ("graph", "model"))): 9 | pass 10 | 11 | 12 | def create_train_model( 13 | model_creator, 14 | hparams): 15 | 16 | graph = tf.Graph() 17 | with graph.as_default(), tf.container("train"): 18 | model = model_creator( 19 | hparams, 20 | tf.contrib.learn.ModeKeys.TRAIN, 21 | ) 22 | return TrainModel( 23 | graph=graph, 24 | model=model, 25 | ) 26 | 27 | 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 WangKe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于 LSTM 和 CRF 的语义标注模型 2 | 3 | ```这个repo是一个课程作业,并无论文发表,数据集是助教提供的CPB里的一部分。Free to download and free to use.``` 4 | 5 | ## 任务描述 6 | 7 | ### 论元识别 8 | 9 | 论元识别 10 | 根据中文宾州命题库(CPB),给定某个特定的命题(/rel),识别出句子中的该命题的 11 | 论元以及其左右边界。例如在下列例句中: 12 | 13 | 我们/PN/O 希望/VV/O 台湾/NR/B-ARG0 当局/NN/E-ARG0 顺应/VV/O 历史/NN/O 发展/NN/O 潮流/NN/O ,/PU/O 把握/VV/rel 时机/NN/S-ARG1 ,/PU/O 就/P/O 两/CD/O 岸/NN/O 政治/NN/O 谈判/NN/O 作出/VV/O 积极/JJ/O 回应/NN/O 和/CC/O 明智/JJ/O 选择/NN/O 。/PU/O 14 | 15 | 例句已经完成分词和词性标注(part of speech, POS)。对于每一个词块 16 | “A/B/C”,A 是词;B 是词性信息;C 是论元标记。 17 | 18 | 在上述例句中表征命题的目标动词为“把握”,该命题有两个论元“台湾当局”以 19 | 及“时机”,他们所充当的角色是 arg0 和 arg1,参评系统应能正确识别这些论元的左 20 | 右边界以及所充当的角色。如果"台湾当局"只识别出来了"台湾",是不可以算识别正确 21 | 的论元。 22 | 23 | ### 评价指标 24 | 25 | 论元识别性能采用 P/R/F 指标加以评价,具体而言: 26 | 27 | * 命题论元识别正确率(P)=系统识别正确的命题论元数/系统识别的所有命题论元数*100% 28 | * 命题论元识别召回率(R)=系统识别正确的命题论元数/标准答案中所有命题论元数*100% 29 | * 命题论元识别 F 值=2*P*R/(P+R) 30 | 31 | ## 实验方法 32 | 33 | ### 模型概览 34 | 35 | 我们使用了一个双向的 LSTM 加上 CRF 实现语义角色标注。循环神经网络(Recurrent Neural Network)是一种对序列建模的重要模型,在自然语言处理任务中有着广泛地应用。不同于前馈神经网络(Feed-forward Neural Network), RNN 能够处理输入之间前后关联的问题。LSTM 是 RNN 的一种重要变种,常用来学习长序列中蕴含的长程依赖关系。另外用双向循环网络来学到历史和未来的信息。然后用前面 LSTM 网络学习输入的特征表示,在整个网络的末端用条件随机场(Conditional Random Filed)在特征的基础上完成序列标注。示意图如下: 36 | 37 | ![模型图](./pic/pic1.png "模型图") 38 | 39 | ### 实验步骤 40 | 41 | #### a) 预处理 42 | ##### 词: 43 | 我们统计了共有 18418 个词,然后用正则表达式,将将所有的数字转换为_NUMBER、所有的人名换成_NAME、所有的年份替换成_YEAR、所有的日期替换成_DAY、所有的时间替换成_TIME。然后得到一个大小为 16314 的词典,我们选取了前 13000 个词作为词典,不在其中的词我们都替换成_UNK。 44 | 45 | ##### 词性: 46 | 我们得到大小为 32 的词性表。 47 | 48 | ##### 角色: 49 | 我们先将所有角色词前面的‘B-’、 50 | ‘S-’、 51 | ‘I-’、 52 | ‘E-’都删掉(最后我们再恢 53 | 复出这些前缀),得到一个大小为 19 的角色词表。 54 | 55 | 56 | #### b) 构造输入 57 | 输入我们由三个部分拼接而成,分别是词、词性、是不是论元。我们把词、 58 | 词性通过词表取词向量转换为实向量表示的词向量序列,然后再拼接上论元的 59 | one-hot 标记方式词向量。 60 | #### c) 特征表示 61 | 将前面的词向量序列作为双向 LSTM 模型的输入;LSTM 模型学习输入序列 62 | 的特征表示,得到新的特性表示序列; 63 | 64 | #### d) 序列标注 65 | CRF 以上一步中 LSTM 学习到的特征为输入,以标记序列为监督信号,完成 66 | 序列标注;最后用维特比算法解码,得到最终的序列。 67 | 68 | ### 实验结果 69 | 70 | 1. 实验资源 71 | * Tensorflow: 1.4 72 | * Python: 3.6 73 | * 在 2 块 TITAN X 12G 显存训练 74 | 75 | 2. 模型参数 76 | * LSTM hidden unit: 120 77 | * Word embedding dim: 100 78 | * Pos embedding dim: 19 79 | * Optimizer: Adam 80 | * Batch_size: 128 81 | 82 | 3. 实验结果 83 | 84 | 我们在验证集上的结果见下表第一行: 85 | 86 | (验证集文件:./data/best_eval_dev.txt) 87 | 88 | (预测的测试文件:见根目录下 eval_test.txt) 89 | 90 | | | Precision | Recall | F1 | 91 | | :----: | :----: | :----: | :----: | 92 | | Ours | 0.727667 | 0.736989 | 0.732299 | 93 | | w/o 替换词 | 0.702766 | 0.713769 | 0.708225 | 94 | | w/o CRF | 0.686026 | 0.644268 | 0.664492 | 95 | 96 | 从上表中看出,第二行是我们与没有做替换词的方法的比较,可以看出替 97 | 换掉人名、数字、年份、时间等这些可以帮助提高效果。另外,我们也对比了 98 | 不用 CRF,而是直接用 LSTM 的输出结果作为标注序列的方法,可以看出,使用 CRF 来标记能很大程度提高实验结果,这说明了 CRF 作为概率化结构模型, 99 | 能更好的弥补神经网络标记偏执,不能全局归一的问题。 100 | 101 | 我们训练了大约 1.4K 个 epoch 之后稳定,时间约为 4 小时,loss 图如下: 102 | 103 | ![learning cure](./pic/pic2.png) 104 | 105 | 106 | 107 | 108 | 4. 程序运行方式 109 | 110 | > cd src 111 | > 112 | > python main.py 113 | 114 | ```建议在Linux环境下打开,windows下请用除记事本外的编辑器打开``` 115 | 116 | 117 | 118 | 119 | ## 参考文献 120 | 1. Sun W, Sui Z, Wang M, et al. Chinese Semantic Role Labeling with Shallow 121 | Parsing[C]. empirical methods in natural language processing, 2009: 1475-1483. 122 | 2. Zhou J, Xu W. End-to-end learning of semantic role labeling using recurrent 123 | neural networks[C]. meeting of the association for computational linguistics, 124 | 2015: 1127-1137. 125 | 3. Huang Z, Xu W, Yu K, et al. Bidirectional LSTM-CRF Models for Sequence 126 | Tagging[J]. arXiv: Computation and Language, 2015. 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /calc_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys, os 4 | 5 | def calc_f1(pred_file, gold_file): 6 | case_true, case_recall, case_precision = 0, 0, 0 7 | golds = [gold.split() for gold in open(gold_file, 'r').read().strip().split('\n')] 8 | preds = [pred.split() for pred in open(pred_file, 'r').read().strip().split('\n')] 9 | assert len(golds) == len(preds), "length of prediction file and gold file should be the same." 10 | for gold, pred in zip(golds, preds): 11 | lastname = '' 12 | keys_gold, keys_pred = {}, {} 13 | for item in gold: 14 | word, label = item.split('/')[0], item.split('/')[-1] 15 | flag, name = label[:label.find('-')], label[label.find('-')+1:] 16 | if flag == 'O': 17 | continue 18 | if flag == 'S': 19 | if name not in keys_gold: 20 | keys_gold[name] = [word] 21 | else: 22 | keys_gold[name].append(word) 23 | else: 24 | if flag == 'B': 25 | if name not in keys_gold: 26 | keys_gold[name] = [word] 27 | else: 28 | keys_gold[name].append(word) 29 | lastname = name 30 | elif flag == 'I' or flag == 'E': 31 | assert name == lastname, "the I-/E- labels are inconsistent with B- labels in gold file." 32 | keys_gold[name][-1] += ' ' + word 33 | for item in pred: 34 | word, label = item.split('/')[0], item.split('/')[-1] 35 | flag, name = label[:label.find('-')], label[label.find('-')+1:] 36 | if flag == 'O': 37 | continue 38 | if flag == 'S': 39 | if name not in keys_pred: 40 | keys_pred[name] = [word] 41 | else: 42 | keys_pred[name].append(word) 43 | else: 44 | if flag == 'B': 45 | if name not in keys_pred: 46 | keys_pred[name] = [word] 47 | else: 48 | keys_pred[name].append(word) 49 | lastname = name 50 | elif flag == 'I' or flag == 'E': 51 | assert name == lastname, "the I-/E- labels are inconsistent with B- labels in pred file." 52 | keys_pred[name][-1] += ' ' + word 53 | 54 | for key in keys_gold: 55 | case_recall += len(keys_gold[key]) 56 | for key in keys_pred: 57 | case_precision += len(keys_pred[key]) 58 | 59 | for key in keys_pred: 60 | if key in keys_gold: 61 | for word in keys_pred[key]: 62 | if word in keys_gold[key]: 63 | case_true += 1 64 | keys_gold[key].remove(word) # avoid replicate words 65 | assert case_recall != 0, "no labels in gold files!" 66 | assert case_precision != 0, "no labels in pred files!" 67 | recall = 1.0 * case_true / case_recall 68 | precision = 1.0 * case_true / case_precision 69 | f1 = 2.0 * recall * precision / (recall + precision) 70 | result = "recall: %s precision: %s F: %s" % (str(recall), str(precision), str(f1)) 71 | return result 72 | # calc_f1('cpbtest1.txt', 'cpbtest_answer.txt') 73 | if __name__ == "__main__": 74 | if len(sys.argv[1:]) != 2: 75 | print('the function takes exactly two parameters: pred_file and gold_file') 76 | else: 77 | if not os.path.exists(sys.argv[1]): 78 | print('pred_file not exists!') 79 | elif not os.path.exists(sys.argv[2]): 80 | print('gold_file not exists!') 81 | else: 82 | print(calc_f1(sys.argv[1], sys.argv[2])) -------------------------------------------------------------------------------- /src/data/calc_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys, os 4 | 5 | def calc_f1(pred_file, gold_file): 6 | case_true, case_recall, case_precision = 0, 0, 0 7 | golds = [gold.split() for gold in open(gold_file, 'r').read().strip().split('\n')] 8 | preds = [pred.split() for pred in open(pred_file, 'r').read().strip().split('\n')] 9 | assert len(golds) == len(preds), "length of prediction file and gold file should be the same." 10 | for gold, pred in zip(golds, preds): 11 | lastname = '' 12 | keys_gold, keys_pred = {}, {} 13 | for item in gold: 14 | word, label = item.split('/')[0], item.split('/')[-1] 15 | flag, name = label[:label.find('-')], label[label.find('-')+1:] 16 | if flag == 'O': 17 | continue 18 | if flag == 'S': 19 | if name not in keys_gold: 20 | keys_gold[name] = [word] 21 | else: 22 | keys_gold[name].append(word) 23 | else: 24 | if flag == 'B': 25 | if name not in keys_gold: 26 | keys_gold[name] = [word] 27 | else: 28 | keys_gold[name].append(word) 29 | lastname = name 30 | elif flag == 'I' or flag == 'E': 31 | assert name == lastname, "the I-/E- labels are inconsistent with B- labels in gold file." 32 | keys_gold[name][-1] += ' ' + word 33 | for item in pred: 34 | word, label = item.split('/')[0], item.split('/')[-1] 35 | flag, name = label[:label.find('-')], label[label.find('-')+1:] 36 | if flag == 'O': 37 | continue 38 | if flag == 'S': 39 | if name not in keys_pred: 40 | keys_pred[name] = [word] 41 | else: 42 | keys_pred[name].append(word) 43 | else: 44 | if flag == 'B': 45 | if name not in keys_pred: 46 | keys_pred[name] = [word] 47 | else: 48 | keys_pred[name].append(word) 49 | lastname = name 50 | elif flag == 'I' or flag == 'E': 51 | assert name == lastname, "the I-/E- labels are inconsistent with B- labels in pred file." 52 | keys_pred[name][-1] += ' ' + word 53 | 54 | for key in keys_gold: 55 | case_recall += len(keys_gold[key]) 56 | for key in keys_pred: 57 | case_precision += len(keys_pred[key]) 58 | 59 | for key in keys_pred: 60 | if key in keys_gold: 61 | for word in keys_pred[key]: 62 | if word in keys_gold[key]: 63 | case_true += 1 64 | keys_gold[key].remove(word) # avoid replicate words 65 | assert case_recall != 0, "no labels in gold files!" 66 | assert case_precision != 0, "no labels in pred files!" 67 | recall = 1.0 * case_true / case_recall 68 | precision = 1.0 * case_true / case_precision 69 | f1 = 2.0 * recall * precision / (recall + precision) 70 | result = "recall: %s precision: %s F: %s" % (str(recall), str(precision), str(f1)) 71 | return result 72 | # calc_f1('cpbtest1.txt', 'cpbtest_answer.txt') 73 | if __name__ == "__main__": 74 | if len(sys.argv[1:]) != 2: 75 | print('the function takes exactly two parameters: pred_file and gold_file') 76 | else: 77 | if not os.path.exists(sys.argv[1]): 78 | print('pred_file not exists!') 79 | elif not os.path.exists(sys.argv[2]): 80 | print('gold_file not exists!') 81 | else: 82 | print(calc_f1(sys.argv[1], sys.argv[2])) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright. All Rights Reserved. 2 | # Author: Wang Ke 3 | # Contact: wangke17[AT]pku.edu.cn 4 | # Discription: 5 | # role label 6 | # 7 | # ============================= 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import data_helper 12 | import BILSTM_CRF 13 | import time 14 | import os 15 | 16 | 17 | # Hyper-parameters 18 | def create_hparams(): 19 | timestamp = str(int(time.time())) 20 | return tf.contrib.training.HParams( 21 | # file path 22 | word_dict_file="./data/A_dict.txt", 23 | pos_dict_file="./data/B_dict.txt", 24 | role_dict_file="./data/C_dict.txt", 25 | train_path="./data/train/", 26 | dev_path="./data/dev/", 27 | test_path="./data/test/", 28 | cpbtrain_file="./data/cpbtrain.txt", 29 | cpbdev_file="./data/cpbdev.txt", 30 | cpbtest_file="./data/cpbtest.txt", 31 | a_path='a.txt', 32 | b_path='b.txt', 33 | c_path='c.txt', 34 | a_id_path='a_id.txt', 35 | b_id_path='b_id.txt', 36 | c_id_path='c_id.txt', 37 | timestamp=timestamp, 38 | save_path="./runs/"+timestamp, 39 | 40 | # data params 41 | batch_size=128, 42 | seq_max_len=241, 43 | word_vocab_size=13000, 44 | pos_vocab_size=33, 45 | role_vocab_size=20, 46 | word2id={}, 47 | pos2id={}, 48 | role2id={}, 49 | id2word={}, 50 | id2pos={}, 51 | id2role={}, 52 | max_f1=0.0, 53 | 54 | # model params 55 | dropout_rate=0.5, 56 | hidden_dim=120, 57 | word_emb_dim=100, 58 | pos_emb_dim=19, 59 | learning_rate=0.002, 60 | num_layers=1, 61 | # train params 62 | num_epochs=20000, 63 | # divice 64 | gpu=1, 65 | ) 66 | 67 | 68 | def train(): 69 | # load parameters 70 | hparams = create_hparams() 71 | start_time = time.time() 72 | print("preparing train and dev data") 73 | # load dict message 74 | [hparams.word2id, hparams.pos2id, hparams.role2id, hparams.id2word, hparams.id2pos, hparams.id2role] = data_helper.load_dict(hparams=hparams) 75 | # [word, pos] ==> role 76 | Train_word, Train_pos, Train_role, Dev_word, Dev_pos, Dev_role = data_helper.get_train(hparams=hparams) 77 | 78 | print("building model...") 79 | with tf.Graph().as_default(): 80 | config = tf.ConfigProto(allow_soft_placement=True) 81 | os.environ["CUDA_VISIBLE_DEVICES"] = str(hparams.gpu) 82 | sess = tf.Session(config=config) 83 | with sess.as_default(): 84 | with tf.device("/gpu:" + str(hparams.gpu)): 85 | initializer = tf.random_uniform_initializer(-0.1, 0.1) 86 | with tf.variable_scope("model", reuse=None, initializer=initializer): 87 | model = BILSTM_CRF.bilstm_crf(hparams=hparams) 88 | print("training model...") 89 | sess.run(tf.global_variables_initializer()) 90 | model.train(sess, hparams, Train_word, Train_pos, Train_role, Dev_word, Dev_pos, Dev_role) 91 | print("final best f1 on valid dataset is: %f" % hparams.max_f1) 92 | 93 | end_time = time.time() 94 | print("time used %f (hour)" % ((end_time - start_time) / 3600)) 95 | return 96 | 97 | 98 | def eval(): 99 | hparams = create_hparams() 100 | # load dict message 101 | [hparams.word2id, hparams.pos2id, hparams.role2id, hparams.id2word, hparams.id2pos, 102 | hparams.id2role] = data_helper.load_dict(hparams=hparams) 103 | print("Evaluation model...") 104 | os.environ["CUDA_VISIBLE_DEVICES"] = str(hparams.gpu) 105 | name = "1513764280" 106 | checkpoint_dir = os.path.join('runs', name, "checkpoints") 107 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 108 | graph = tf.Graph() 109 | with graph.as_default(): 110 | config = tf.ConfigProto(allow_soft_placement=True) 111 | sess = tf.Session(config=config) 112 | with sess.as_default(): 113 | initializer = tf.random_uniform_initializer(-0.1, 0.1) 114 | with tf.variable_scope("model", reuse=None, initializer=initializer): 115 | model = BILSTM_CRF.bilstm_crf(hparams=hparams) 116 | # model.saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 117 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 118 | model.saver.restore(sess, ckpt.model_checkpoint_path) 119 | sess.run(tf.tables_initializer()) 120 | print("Load ckpt %s file success!" % checkpoint_file) 121 | type = 'test' 122 | Test_word, Test_pos, Test_role = data_helper.get_test(hparams, type) 123 | model.eval(sess, hparams, Test_word, Test_pos, Test_role, type, name) 124 | return 125 | 126 | 127 | if __name__ == '__main__': 128 | train() 129 | eval() 130 | 131 | -------------------------------------------------------------------------------- /src/data/1.build_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | train_file = "./cpbtrain.txt" 5 | dev_file = "./cpbdev.txt" 6 | test_file = "./cpbtest.txt" 7 | 8 | A_dict_file = "./A_dict.txt" 9 | B_dict_file = "./B_dict.txt" 10 | C_dict_file = "./C_dict.txt" 11 | 12 | 13 | def add_word_to_dict(_dict, _word): 14 | if _word in _dict: 15 | _dict[_word] += 1 16 | else: 17 | _dict[_word] = 1 18 | return _dict 19 | 20 | 21 | def sub_word(_a): 22 | ans_a = _a 23 | if _a.find("年") != -1: 24 | if _a.find("一九")!=-1 or _a.find("二0")!=-1 or _a.find("19")!=-1 \ 25 | or _a.find("20")!=-1 or _a.find("一八")!=-1 or _a.find("二零")!=-1\ 26 | or _a.find("18")!=-1: 27 | # print(_a.split()) 28 | return '_YEAR' 29 | if _a.find("百分之")!=-1 or _a.find("%")!=-1: 30 | return '_PERCENT' 31 | 32 | if _a.find("www·")!=-1: 33 | return '_NET' 34 | if _a[-1] == "万" or _a[-1] == "亿": 35 | if _a.find("一")!=-1 or _a.find("二")!=-1 or _a.find("三")!=-1 or _a.find("四")!=-1 or \ 36 | _a.find("五") != -1 or _a.find("六")!=-1 or _a.find("七")!=-1 or\ 37 | _a.find("八")!=-1 or _a.find("九")!=-1 or _a.find("十")!=-1 or \ 38 | _a.find("0")!=-1 or _a.find("1")!=-1 or _a.find("2")!=-1 or \ 39 | _a.find("3") != -1 or _a.find("4")!=-1 or _a.find("5")!=-1 or\ 40 | _a.find("6")!=-1 or _a.find("7")!=-1 or _a.find("8")!=-1 or\ 41 | _a.find("9")!=-1 or _a.find("两")!=-1 or _a.find("百")!=-1: 42 | return '_NUMBER' 43 | if (_a[-1]=="1" or _a[-1]=="2" or _a[-1]=="3" or _a[-1]=="4" or _a[-1]=="5" or 44 | _a[-1] == "6" or _a[-1] == "7" or _a[-1] == "8" or _a[-1] == "9" or _a[-1] == "0") and (_a[0] == "1" or _a[0] == "2" or _a[0] == "3" or _a[0] == "4" or _a[0] == "5" or 45 | _a[0] == "6" or _a[0] == "7" or _a[0] == "8" or _a[0] == "9" or _a[0] == "0"): 46 | return '_NUMBER' 47 | 48 | if _a.find("·")!=-1 and _a != "·": 49 | if not (_a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 50 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 51 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 52 | _a.find("9") != -1): 53 | return '_NAME' 54 | 55 | if _a.find("月")!=-1: 56 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 57 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 58 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1 or \ 59 | _a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 60 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 61 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 62 | _a.find("9") != -1: 63 | return '_MONTH' 64 | 65 | if _a[-1]=="分": 66 | if ((_a.find('时')!=-1)or(_a.find('点')!=-1)): 67 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 68 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 69 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1 or \ 70 | _a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 71 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 72 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 73 | _a.find("9") != -1 or _a.find("两") != -1 or _a.find("百") != -1: 74 | return '_TIME' 75 | if _a[-1] == "日": 76 | if _a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 77 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 78 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 79 | _a.find("9") != -1: 80 | return '_DAY' 81 | else: 82 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 83 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 84 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1: 85 | return '_DAY' 86 | # pass # pass # print(_a.split()) 87 | 88 | p1=re.compile('^[零一二三四五六七八九十0123456789百千万亿多]*$') 89 | number = p1.match(_a) 90 | if number and _a != '百' and _a != '万' and _a != '亿' and _a != '多': 91 | return '_NUMBER' 92 | 93 | if _a.find('点')!=-1: 94 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 95 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 96 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1: 97 | return '_NUMBER' 98 | 99 | return ans_a 100 | 101 | 102 | def add_file_to_dict(_file, _dict_a, _dict_b, _dict_c): 103 | with open(_file, 'r') as f_i: 104 | for _line in f_i: 105 | _line_items = _line.strip().split(' ') 106 | for _item in _line_items: 107 | if _item == "": 108 | continue 109 | _item_list = _item.split('/') 110 | flag = len(_item_list) 111 | _c = None 112 | if flag == 3: 113 | [_a, _b, _c] = _item_list 114 | _c = _c[_c.find('-')+1:] 115 | _dict_c = add_word_to_dict(_dict_c, _c) 116 | # print(_c) 117 | if flag == 2: 118 | [_a, _b] = _item.split('/') 119 | _a = sub_word(_a) 120 | _dict_a = add_word_to_dict(_dict_a, _a) 121 | _dict_b = add_word_to_dict(_dict_b, _b) 122 | return _dict_a, _dict_b, _dict_c 123 | 124 | 125 | def write_dict_in_file(_dict, _file, flag, max_num=None): 126 | word_list = sorted(_dict.items(), key=lambda d: d[1], reverse=True) 127 | with open(_file, 'w') as f_i: 128 | i = 0 129 | if flag == "A": 130 | f_i.write('_PAD\n') 131 | f_i.write('_UNK\n') 132 | i += 2 133 | if flag == "B": 134 | f_i.write('_PAD\n') 135 | i += 1 136 | if flag == "C": 137 | f_i.write('_PAD\n') 138 | i += 1 139 | 140 | for _item in word_list: 141 | if i < max_num: 142 | # f_i.write("%s:%d\n" % (_item[0], _item[1])) 143 | f_i.write("%s\n" % (_item[0])) 144 | i += 1 145 | else: 146 | break 147 | return 148 | 149 | 150 | def main(): 151 | A_dict = {} 152 | B_dict = {} 153 | C_dict = {} 154 | 155 | A_dict, B_dict, C_dict = add_file_to_dict(train_file, A_dict, B_dict, C_dict) 156 | A_dict, B_dict, C_dict = add_file_to_dict(dev_file, A_dict, B_dict, C_dict) 157 | A_dict, B_dict, C_dict = add_file_to_dict(test_file, A_dict, B_dict, C_dict) 158 | print("size: Word %d Pos %d Role %d"%(len(A_dict), len(B_dict), len(C_dict))) 159 | write_dict_in_file(A_dict, A_dict_file, flag="A", max_num=13000) 160 | write_dict_in_file(B_dict, B_dict_file, flag="B", max_num=10000) 161 | write_dict_in_file(C_dict, C_dict_file, flag="C", max_num=10000) 162 | 163 | 164 | main() 165 | -------------------------------------------------------------------------------- /src/data/2.build_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | train_file = "./cpbtrain.txt" 5 | dev_file = "./cpbdev.txt" 6 | test_file = "./cpbtest.txt" 7 | 8 | A_dict_file = "./A_dict.txt" 9 | B_dict_file = "./B_dict.txt" 10 | C_dict_file = "./C_dict.txt" 11 | 12 | 13 | train_path = "./train/" 14 | dev_path = "./dev/" 15 | test_path = "./test/" 16 | 17 | a_path = 'a.txt' 18 | b_path = 'b.txt' 19 | c_path = 'c.txt' 20 | 21 | a_id_path = 'a_id.txt' 22 | b_id_path = 'b_id.txt' 23 | c_id_path = 'c_id.txt' 24 | 25 | max_seq_len = 0 26 | 27 | def load_dict(_file): 28 | _dict = {} 29 | i = 0 30 | with open(_file, 'r') as f: 31 | for item in f: 32 | item = item.strip() 33 | if item != "": 34 | _dict[item] = i 35 | i += 1 36 | return _dict 37 | 38 | 39 | def sub_word(_a): 40 | ans_a = _a 41 | if _a.find("年") != -1: 42 | if _a.find("一九")!=-1 or _a.find("二0")!=-1 or _a.find("19")!=-1 \ 43 | or _a.find("20")!=-1 or _a.find("一八")!=-1 or _a.find("二零")!=-1\ 44 | or _a.find("18")!=-1: 45 | # print(_a.split()) 46 | return '_YEAR' 47 | if _a.find("百分之")!=-1 or _a.find("%")!=-1: 48 | return '_PERCENT' 49 | 50 | if _a.find("www·")!=-1: 51 | return '_NET' 52 | if _a[-1] == "万" or _a[-1] == "亿": 53 | if _a.find("一")!=-1 or _a.find("二")!=-1 or _a.find("三")!=-1 or _a.find("四")!=-1 or \ 54 | _a.find("五") != -1 or _a.find("六")!=-1 or _a.find("七")!=-1 or\ 55 | _a.find("八")!=-1 or _a.find("九")!=-1 or _a.find("十")!=-1 or \ 56 | _a.find("0")!=-1 or _a.find("1")!=-1 or _a.find("2")!=-1 or \ 57 | _a.find("3") != -1 or _a.find("4")!=-1 or _a.find("5")!=-1 or\ 58 | _a.find("6")!=-1 or _a.find("7")!=-1 or _a.find("8")!=-1 or\ 59 | _a.find("9")!=-1 or _a.find("两")!=-1 or _a.find("百")!=-1: 60 | return '_NUMBER' 61 | if (_a[-1]=="1" or _a[-1]=="2" or _a[-1]=="3" or _a[-1]=="4" or _a[-1]=="5" or 62 | _a[-1] == "6" or _a[-1] == "7" or _a[-1] == "8" or _a[-1] == "9" or _a[-1] == "0") and (_a[0] == "1" or _a[0] == "2" or _a[0] == "3" or _a[0] == "4" or _a[0] == "5" or 63 | _a[0] == "6" or _a[0] == "7" or _a[0] == "8" or _a[0] == "9" or _a[0] == "0"): 64 | return '_NUMBER' 65 | 66 | if _a.find("·")!=-1 and _a != "·": 67 | if not (_a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 68 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 69 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 70 | _a.find("9") != -1): 71 | return '_NAME' 72 | 73 | if _a.find("月")!=-1: 74 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 75 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 76 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1 or \ 77 | _a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 78 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 79 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 80 | _a.find("9") != -1: 81 | return '_MONTH' 82 | 83 | if _a[-1]=="分": 84 | if ((_a.find('时')!=-1)or(_a.find('点')!=-1)): 85 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 86 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 87 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1 or \ 88 | _a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 89 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 90 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 91 | _a.find("9") != -1 or _a.find("两") != -1 or _a.find("百") != -1: 92 | return '_TIME' 93 | if _a[-1] == "日": 94 | if _a.find("0") != -1 or _a.find("1") != -1 or _a.find("2") != -1 or \ 95 | _a.find("3") != -1 or _a.find("4") != -1 or _a.find("5") != -1 or \ 96 | _a.find("6") != -1 or _a.find("7") != -1 or _a.find("8") != -1 or \ 97 | _a.find("9") != -1: 98 | return '_DAY' 99 | else: 100 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 101 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 102 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1: 103 | return '_DAY' 104 | # pass # pass # print(_a.split()) 105 | 106 | p1=re.compile('^[零一二三四五六七八九十0123456789百千万亿多]*$') 107 | number = p1.match(_a) 108 | if number and _a != '百' and _a != '万' and _a != '亿' and _a != '多': 109 | return '_NUMBER' 110 | 111 | if _a.find('点')!=-1: 112 | if _a.find("一") != -1 or _a.find("二") != -1 or _a.find("三") != -1 or _a.find("四") != -1 or \ 113 | _a.find("五") != -1 or _a.find("六") != -1 or _a.find("七") != -1 or \ 114 | _a.find("八") != -1 or _a.find("九") != -1 or _a.find("十") != -1: 115 | return '_NUMBER' 116 | 117 | return ans_a 118 | 119 | 120 | def build_data(): 121 | # train 122 | A_lists = [] 123 | B_lists = [] 124 | C_lists = [] 125 | with open(train_file, 'r') as f_i: 126 | for _line in f_i: 127 | _line_items = _line.strip().split(' ') 128 | A_list = [] 129 | B_list = [] 130 | C_list = [] 131 | for _item in _line_items: 132 | if _item == "": 133 | continue 134 | [_a, _b, _c] = _item.split('/') 135 | 136 | _c = _c[_c.find('-')+1:] 137 | 138 | A_list.append(_a) 139 | B_list.append(_b) 140 | C_list.append(_c) 141 | global max_seq_len 142 | max_seq_len = max(max_seq_len, len(A_list)) 143 | 144 | A_lists.append(' '.join(A_list) + '\n') 145 | B_lists.append(' '.join(B_list) + '\n') 146 | C_lists.append(' '.join(C_list) + '\n') 147 | 148 | 149 | path = train_path 150 | with open(path + a_path, 'w') as f: 151 | f.writelines(A_lists) 152 | with open(path + b_path, 'w') as f: 153 | f.writelines(B_lists) 154 | with open(path + c_path, 'w') as f: 155 | f.writelines(C_lists) 156 | 157 | # dev 158 | A_lists = [] 159 | B_lists = [] 160 | C_lists = [] 161 | with open(dev_file, 'r') as f_i: 162 | for _line in f_i: 163 | _line_items = _line.strip().split(' ') 164 | A_list = [] 165 | B_list = [] 166 | C_list = [] 167 | for _item in _line_items: 168 | if _item == "": 169 | continue 170 | [_a, _b, _c] = _item.split('/') 171 | 172 | _c = _c[_c.find('-') + 1:] 173 | 174 | A_list.append(_a) 175 | B_list.append(_b) 176 | C_list.append(_c) 177 | 178 | A_lists.append(' '.join(A_list) + '\n') 179 | B_lists.append(' '.join(B_list) + '\n') 180 | C_lists.append(' '.join(C_list) + '\n') 181 | 182 | path = dev_path 183 | with open(path + a_path, 'w') as f: 184 | f.writelines(A_lists) 185 | with open(path + b_path, 'w') as f: 186 | f.writelines(B_lists) 187 | with open(path + c_path, 'w') as f: 188 | f.writelines(C_lists) 189 | 190 | 191 | # test 192 | A_lists = [] 193 | B_lists = [] 194 | with open(test_file, 'r') as f_i: 195 | for _line in f_i: 196 | _line_items = _line.strip().split(' ') 197 | A_list = [] 198 | B_list = [] 199 | for _item in _line_items: 200 | if _item == "": 201 | continue 202 | _item_list = _item.split('/') 203 | _a = _item_list[0] 204 | _b = _item_list[1] 205 | 206 | A_list.append(_a) 207 | B_list.append(_b) 208 | 209 | A_lists.append(' '.join(A_list) + '\n') 210 | B_lists.append(' '.join(B_list) + '\n') 211 | 212 | path = test_path 213 | with open(path + a_path, 'w') as f: 214 | f.writelines(A_lists) 215 | with open(path + b_path, 'w') as f: 216 | f.writelines(B_lists) 217 | 218 | 219 | def build_id_data(A, B, C): 220 | # train 221 | A_lists = [] 222 | B_lists = [] 223 | C_lists = [] 224 | with open(train_file, 'r') as f_i: 225 | for _line in f_i: 226 | _line_items = _line.strip().split(' ') 227 | A_list = [] 228 | B_list = [] 229 | C_list = [] 230 | for _item in _line_items: 231 | if _item == "": 232 | continue 233 | [_a, _b, _c] = _item.split('/') 234 | _a = sub_word(_a) 235 | if _a not in A: 236 | _a = '_UNK' 237 | _c = _c[_c.find('-')+1:] 238 | 239 | A_list.append(str(A[_a])) 240 | B_list.append(str(B[_b])) 241 | C_list.append(str(C[_c])) 242 | 243 | A_lists.append(' '.join(A_list) + '\n') 244 | B_lists.append(' '.join(B_list) + '\n') 245 | C_lists.append(' '.join(C_list) + '\n') 246 | 247 | path = train_path 248 | with open(path + a_id_path, 'w') as f: 249 | f.writelines(A_lists) 250 | with open(path + b_id_path, 'w') as f: 251 | f.writelines(B_lists) 252 | with open(path + c_id_path, 'w') as f: 253 | f.writelines(C_lists) 254 | 255 | # dev 256 | A_lists = [] 257 | B_lists = [] 258 | C_lists = [] 259 | with open(dev_file, 'r') as f_i: 260 | for _line in f_i: 261 | _line_items = _line.strip().split(' ') 262 | A_list = [] 263 | B_list = [] 264 | C_list = [] 265 | for _item in _line_items: 266 | if _item == "": 267 | continue 268 | [_a, _b, _c] = _item.split('/') 269 | _a = sub_word(_a) 270 | if _a not in A: 271 | _a = '_UNK' 272 | _c = _c[_c.find('-') + 1:] 273 | 274 | A_list.append(str(A[_a])) 275 | B_list.append(str(B[_b])) 276 | C_list.append(str(C[_c])) 277 | 278 | A_lists.append(' '.join(A_list) + '\n') 279 | B_lists.append(' '.join(B_list) + '\n') 280 | C_lists.append(' '.join(C_list) + '\n') 281 | 282 | path = dev_path 283 | with open(path + a_id_path, 'w') as f: 284 | f.writelines(A_lists) 285 | with open(path + b_id_path, 'w') as f: 286 | f.writelines(B_lists) 287 | with open(path + c_id_path, 'w') as f: 288 | f.writelines(C_lists) 289 | 290 | 291 | # test 292 | A_lists = [] 293 | B_lists = [] 294 | C_lists = [] 295 | with open(test_file, 'r') as f_i: 296 | for _line in f_i: 297 | _line_items = _line.strip().split(' ') 298 | A_list = [] 299 | B_list = [] 300 | C_list = [] 301 | for _item in _line_items: 302 | if _item == "": 303 | continue 304 | _item_list = _item.split('/') 305 | _a = _item_list[0] 306 | _b = _item_list[1] 307 | _c = 0 308 | _a = sub_word(_a) 309 | if _a not in A: 310 | _a = '_UNK' 311 | if len(_item_list) == 3: 312 | _c = C[_item_list[2]] 313 | A_list.append(str(A[_a])) 314 | B_list.append(str(B[_b])) 315 | C_list.append(str(_c)) 316 | 317 | A_lists.append(' '.join(A_list) + '\n') 318 | B_lists.append(' '.join(B_list) + '\n') 319 | C_lists.append(' '.join(C_list) + '\n') 320 | 321 | path = test_path 322 | with open(path + a_id_path, 'w') as f: 323 | f.writelines(A_lists) 324 | with open(path + b_id_path, 'w') as f: 325 | f.writelines(B_lists) 326 | with open(path + c_id_path, 'w') as f: 327 | f.writelines(C_lists) 328 | 329 | 330 | def main(): 331 | A_dict = load_dict(A_dict_file) 332 | B_dict = load_dict(B_dict_file) 333 | C_dict = load_dict(C_dict_file) 334 | 335 | build_data() 336 | 337 | build_id_data(A_dict, B_dict, C_dict) 338 | 339 | print("max seq len = %d" % max_seq_len) 340 | 341 | main() 342 | -------------------------------------------------------------------------------- /src/BILSTM_CRF.py: -------------------------------------------------------------------------------- 1 | import math 2 | import data_helper 3 | import numpy as np 4 | import tensorflow as tf 5 | import time 6 | import shutil 7 | import os 8 | 9 | 10 | class bilstm_crf(object): 11 | 12 | def __init__(self, hparams, is_training=True): 13 | # Parameter 14 | self.num_layers = hparams.num_layers 15 | self.learning_rate = hparams.learning_rate 16 | self.hidden_dim = hparams.hidden_dim # hidden_dim=100 17 | self.word_emb_dim = hparams.word_emb_dim # word_emb_dim=90 18 | self.pos_emb_dim = hparams.pos_emb_dim # pos_emb_dim=10 19 | self.dropout_rate = hparams.dropout_rate # 0.5 20 | self.word_vocab_size = hparams.word_vocab_size # 10000, 21 | self.pos_vocab_size = hparams.pos_vocab_size # 33, 22 | self.num_classes = hparams.role_vocab_size # 20, 23 | 24 | # placeholder of word\pos\role 25 | self.inputs_word = tf.placeholder(tf.int32, shape=[None, None], name="inputs_word") 26 | self.inputs_pos = tf.placeholder(tf.int32, shape=[None, None], name="inputs_pos") 27 | self.predicts_role = tf.placeholder(tf.int32, shape=[None, None], name="predicts_role") 28 | self.sequence_lengths = tf.placeholder(tf.int32, shape=[None], name="sequence_lengths") 29 | self.rel_vector = tf.placeholder(tf.float32, shape=[None, None, 1], name="rel_vector") 30 | 31 | with tf.variable_scope("input-embedding"): 32 | self.word_embedding = tf.get_variable("emb-word", [self.word_vocab_size, self.word_emb_dim]) 33 | self.pos_embedding = tf.get_variable("emb-pos", [self.word_vocab_size, self.pos_emb_dim]) 34 | self.inputs_emb_word = tf.nn.embedding_lookup(self.word_embedding, self.inputs_word) 35 | self.inputs_emb_pos = tf.nn.embedding_lookup(self.pos_embedding, self.inputs_pos) 36 | 37 | with tf.variable_scope("concat"): 38 | self.inputs_emb = tf.concat([self.inputs_emb_word, self.inputs_emb_pos, self.rel_vector], axis=2) 39 | 40 | with tf.variable_scope("bi-lstm"): 41 | # lstm cell 42 | lstm_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_dim) 43 | lstm_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_dim) 44 | 45 | # dropout 46 | if is_training: 47 | lstm_cell_fw = tf.nn.rnn_cell.DropoutWrapper(lstm_cell_fw, output_keep_prob=(1 - self.dropout_rate)) 48 | lstm_cell_bw = tf.nn.rnn_cell.DropoutWrapper(lstm_cell_bw, output_keep_prob=(1 - self.dropout_rate)) 49 | 50 | lstm_cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_fw] * self.num_layers) 51 | lstm_cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_bw] * self.num_layers) 52 | 53 | # forward and backward 54 | (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn( 55 | cell_fw=lstm_cell_fw, 56 | cell_bw=lstm_cell_bw, 57 | inputs=self.inputs_emb, 58 | sequence_length=self.sequence_lengths, 59 | dtype=tf.float32, 60 | ) 61 | self.output = tf.concat([output_fw, output_bw], axis=-1) 62 | 63 | # project 64 | with tf.variable_scope("project"): 65 | W = tf.get_variable("W", shape=[self.hidden_dim * 2, self.num_classes], 66 | dtype=tf.float32) 67 | b = tf.get_variable("b", shape=[self.num_classes], dtype=tf.float32, 68 | initializer=tf.zeros_initializer()) 69 | nsteps = tf.shape(self.output)[1] 70 | output = tf.reshape(self.output, [-1, 2 * self.hidden_dim]) 71 | pred = tf.matmul(output, W) + b 72 | self.logits = tf.reshape(pred, [-1, nsteps, self.num_classes]) 73 | 74 | log_likelihood, trans_params = tf.contrib.crf.crf_log_likelihood( 75 | self.logits, self.predicts_role, self.sequence_lengths) 76 | self.trans_params = trans_params # need to evaluate it for decoding 77 | self.loss = tf.reduce_mean(-log_likelihood) 78 | 79 | # for tensorboard 80 | self.train_summary = tf.summary.scalar("loss", self.loss) 81 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 82 | 83 | self.saver = tf.train.Saver(tf.global_variables()) 84 | 85 | def train(self, sess, hparams, Train_word, Train_pos, Train_role, Dev_word, Dev_pos, Dev_role): 86 | Test_word, Test_pos, Test_role = data_helper.get_test(hparams, type='test') 87 | checkpoint_dir = hparams.save_path + "/checkpoints" 88 | checkpoint_prefix = checkpoint_dir + "/model" 89 | if not os.path.exists(checkpoint_dir): 90 | os.makedirs(checkpoint_dir) 91 | 92 | merged = tf.summary.merge_all() 93 | summary_writer_train = tf.summary.FileWriter(hparams.save_path + '/train_loss', sess.graph) 94 | 95 | num_iterations = int(math.ceil(1.0 * len(Train_word) / hparams.batch_size)) 96 | 97 | cnt = 0 98 | for epoch in range(hparams.num_epochs): 99 | print("current epoch: %d" % (epoch)) 100 | 101 | for iteration in range(num_iterations): 102 | # train 103 | X_word_train_batch, X_pos_train_batch, y_role_train_batch = data_helper.next_batch(Train_word, Train_pos, Train_role, 104 | start_index=iteration * hparams.batch_size, 105 | batch_size=hparams.batch_size) 106 | X_rel_train_batch = self.get_one_hot_rel(y_role_train_batch, hparams.role2id['rel']) 107 | 108 | X_train_sequence_lengths = data_helper.get_length_by_vec(X_word_train_batch) 109 | _, loss_train, logits, train_summary = \ 110 | sess.run([ 111 | self.optimizer, 112 | self.loss, 113 | self.logits, 114 | self.train_summary 115 | ], 116 | feed_dict={ 117 | self.inputs_word: X_word_train_batch, 118 | self.inputs_pos: X_pos_train_batch, 119 | self.rel_vector: X_rel_train_batch, 120 | self.sequence_lengths: X_train_sequence_lengths, 121 | self.predicts_role: y_role_train_batch, 122 | }) 123 | 124 | if iteration % 10 == 0: 125 | cnt += 1 126 | feed_dict = { 127 | self.inputs_word: X_word_train_batch, 128 | self.inputs_pos: X_pos_train_batch, 129 | self.rel_vector: X_rel_train_batch, 130 | self.sequence_lengths: X_train_sequence_lengths, 131 | # self.predicts_role: y_role_train_batch, 132 | } 133 | predicts_train = self.predict(sess, feed_dict, X_train_sequence_lengths) 134 | precision_train, recall_train, f1_train = self.evaluate(X_train_sequence_lengths, X_word_train_batch, X_pos_train_batch, y_role_train_batch, predicts_train, hparams.id2word, hparams.id2pos, hparams.id2role) 135 | summary_writer_train.add_summary(train_summary, cnt) 136 | print("iteration: %3d, train loss: %5f, train precision: %.5f, train recall: %.5f, train f1: %.5f" % (iteration, loss_train, precision_train, recall_train, f1_train)) 137 | 138 | # validation 139 | if iteration % 100 == 0 and f1_train > 0.6: 140 | self.eval(sess, hparams, Dev_word, Dev_pos, Dev_role, eval_type='dev', name=hparams.timestamp) 141 | precision_dev, recall_dev, f1_dev = data_helper.calc_f1(hparams.cpbdev_file, hparams.save_path + "/eval_dev.txt") 142 | print( 143 | "iteration: %3d, valid precision: %.5f, valid recall: %.5f, valid f1: %.5f" % ( 144 | iteration, precision_dev, recall_dev, f1_dev)) 145 | 146 | if f1_dev >= hparams.max_f1: 147 | hparams.max_f1 = f1_dev 148 | save_name = self.saver.save(sess, checkpoint_prefix, global_step=cnt) 149 | shutil.copyfile(hparams.save_path + "/eval_dev.txt", hparams.save_path + "/best_eval_dev.txt") 150 | 151 | self.eval(sess, hparams, Test_word, Test_pos, Test_role, eval_type='test', name=hparams.timestamp) 152 | 153 | str_out = "saved the best model with f1: %.5f save path:%s" % (hparams.max_f1, save_name) 154 | print(str_out) 155 | data_helper.log(str_out, hparams.save_path) 156 | 157 | def predict(self, sess, fd, sequence_lengths): 158 | # get tag scores and transition params of CRF 159 | viterbi_sequences = [] 160 | logits, trans_params = sess.run( 161 | [self.logits, self.trans_params], feed_dict=fd 162 | ) 163 | 164 | # iterate over the sentences because no batching in vitervi_decode 165 | for logit, sequence_length in zip(logits, sequence_lengths): 166 | logit = logit[:sequence_length] # keep only the valid steps 167 | viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode( 168 | logit, trans_params) 169 | viterbi_sequences += [viterbi_seq] 170 | return viterbi_sequences 171 | 172 | def evaluate(self, lengths, X_word, X_pos, y_true, y_pred, id2word, id2pos, id2role): 173 | 174 | case_true, case_recall, case_precision = 0, 0, 0 175 | 176 | x_word_id = data_helper.unpadding(X_word) 177 | x_pos_id = data_helper.unpadding(X_pos) 178 | y_true_id = data_helper.unpadding(y_true) 179 | y_pred_id = y_pred 180 | 181 | for i in range(len(lengths)): 182 | 183 | x_word = [id2word[val] for val in x_word_id[i]] 184 | x_pos = [id2pos[val] for val in x_pos_id[i]] 185 | y = [id2role[val] for val in y_true_id[i]] 186 | y_hat = [id2role[val] for val in y_pred_id[i]] 187 | 188 | true_labels = data_helper.extract_entity(x_word, y) 189 | pred_labels = data_helper.extract_entity(x_word, y_hat) 190 | 191 | for key in true_labels: 192 | case_recall += len(true_labels[key]) 193 | for key in pred_labels: 194 | case_precision += len(pred_labels[key]) 195 | 196 | for key in pred_labels: 197 | if key in true_labels: 198 | for word in pred_labels[key]: 199 | if word in true_labels[key]: 200 | case_true += 1 201 | true_labels[key].remove(word) # avoid replicate words 202 | recall = -1.0 203 | precision = -1.0 204 | f1 = -1.0 205 | if case_recall != 0: 206 | recall = 1.0 * case_true / case_recall 207 | if case_precision != 0: 208 | precision = 1.0 * case_true / case_precision 209 | if recall > 0 and precision > 0: 210 | f1 = 2.0 * recall * precision / (recall + precision) 211 | return precision, recall, f1 212 | 213 | def reconstruct(self, lens, roles, id2role): 214 | ans_seq = [] 215 | for i in range(lens): 216 | role_list = [id2role[val] for val in roles[i]] 217 | role_list = data_helper.recover_role(role_list) 218 | ans_seq.append(role_list) 219 | return ans_seq 220 | 221 | def get_one_hot_rel(self, vec, ref_id): 222 | ans = np.zeros(shape=[len(vec), len(vec[0]), 1], dtype=float) 223 | for i in range(len(vec)): 224 | j = np.where(vec[i] == ref_id) 225 | ans[i][j] = [1.0] 226 | return ans 227 | 228 | def eval(self, sess, hparams, Test_word, Test_pos, Test_role, eval_type, name): 229 | num_iterations = int(math.ceil(1.0 * len(Test_word) / hparams.batch_size)) 230 | outputs_role = [] 231 | for iteration in range(num_iterations): 232 | X_word_test_batch, X_pos_test_batch, y_role_test_batch, full_size = data_helper.next_test_batch(Test_word, Test_pos, 233 | Test_role, 234 | start_index=iteration * hparams.batch_size, 235 | batch_size=hparams.batch_size) 236 | X_rel_test_batch = self.get_one_hot_rel(y_role_test_batch, hparams.role2id['rel']) 237 | X_test_sequence_lengths = data_helper.get_length_by_vec(X_word_test_batch) 238 | 239 | feed_dict = { 240 | self.inputs_word: X_word_test_batch, 241 | self.inputs_pos: X_pos_test_batch, 242 | self.rel_vector: X_rel_test_batch, 243 | self.sequence_lengths: X_test_sequence_lengths, 244 | } 245 | predicts_dev = self.predict(sess, feed_dict, X_test_sequence_lengths) 246 | 247 | outputs_role += self.reconstruct(full_size, predicts_dev, hparams.id2role) 248 | 249 | if eval_type == 'dev': 250 | eval_file = hparams.cpbdev_file 251 | if eval_type == 'test': 252 | eval_file = hparams.cpbtest_file 253 | 254 | outputs = data_helper.recover_eval(eval_file, outputs_role) 255 | 256 | save_path = "./runs/%s/eval_%s.txt" % (name, eval_type) 257 | with open(save_path, 'w') as f: 258 | f.writelines(outputs) 259 | f.write('\n') # for consistance with cpttest.txt 260 | print("eval success!, size: %d save at %s" % (len(outputs), save_path)) 261 | return 262 | 263 | -------------------------------------------------------------------------------- /src/data_helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import csv 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | def build_dict(dict_file): 9 | line_id = 0 10 | token2id = {} 11 | id2token = {} 12 | with open(dict_file) as infile: 13 | for row in infile: 14 | token = row.strip() 15 | if token == "": 16 | break 17 | token_id = line_id 18 | token2id[token] = token_id 19 | id2token[token_id] = token 20 | line_id += 1 21 | return token2id, id2token 22 | 23 | 24 | def load_dict(hparams): 25 | # word dict 26 | word2id, id2word = build_dict(hparams.word_dict_file) 27 | pos2id, id2pos = build_dict(hparams.pos_dict_file) 28 | role2id, id2role = build_dict(hparams.role_dict_file) 29 | return [word2id, pos2id, role2id, id2word, id2pos, id2role] 30 | 31 | 32 | def get_test(hparams, type): 33 | Test_word = [] 34 | Test_pos = [] 35 | Test_role = [] 36 | if type == 'test': 37 | path = hparams.test_path 38 | if type == 'dev': 39 | path = hparams.dev_path 40 | with open(path + hparams.a_id_path) as f_in: # word 41 | for row_item in f_in: 42 | _list = row_item.strip().split(' ') 43 | uu = [int(tmp) for tmp in _list if tmp != ''] 44 | if len(uu) == 0: 45 | continue 46 | Test_word.append(uu) 47 | with open(path + hparams.b_id_path) as f_in: # pos 48 | for row_item in f_in: 49 | _list = row_item.strip().split(' ') 50 | uu = [int(tmp) for tmp in _list if tmp != ''] 51 | if len(uu) == 0: 52 | continue 53 | Test_pos.append(uu) 54 | with open(path + hparams.c_id_path) as f_in: # role 55 | for row_item in f_in: 56 | _list = row_item.strip().split(' ') 57 | uu = [int(tmp) for tmp in _list if tmp != ''] 58 | if len(uu) == 0: 59 | continue 60 | Test_role.append(uu) 61 | assert len(Test_word) == len(Test_pos) 62 | assert len(Test_pos) == len(Test_role) 63 | print("Load %s size: %d" % (type, len(Test_word))) 64 | Test_word = np.array(padding(Test_word, seq_max_len=hparams.seq_max_len)) 65 | Test_pos = np.array(padding(Test_pos, seq_max_len=hparams.seq_max_len)) 66 | Test_role = np.array(padding(Test_role, seq_max_len=hparams.seq_max_len)) 67 | return Test_word, Test_pos, Test_role 68 | 69 | 70 | def get_train(hparams): 71 | # load train file 72 | Train_word = [] 73 | Train_pos = [] 74 | Train_role = [] 75 | path = hparams.train_path 76 | with open(path + hparams.a_id_path) as f_in: # word 77 | for row_item in f_in: 78 | _list = row_item.strip().split(' ') 79 | uu = [int(tmp) for tmp in _list if tmp != ''] 80 | if len(uu) == 0: 81 | continue 82 | Train_word.append(uu) 83 | with open(path + hparams.b_id_path) as f_in: # pos 84 | for row_item in f_in: 85 | _list = row_item.strip().split(' ') 86 | uu = [int(tmp) for tmp in _list if tmp != ''] 87 | if len(uu) == 0: 88 | continue 89 | Train_pos.append(uu) 90 | with open(path + hparams.c_id_path) as f_in: # role 91 | for row_item in f_in: 92 | _list = row_item.strip().split(' ') 93 | uu = [int(tmp) for tmp in _list if tmp != ''] 94 | if len(uu) == 0: 95 | continue 96 | Train_role.append(uu) 97 | 98 | # load dev file 99 | Dev_word = [] 100 | Dev_pos = [] 101 | Dev_role = [] 102 | path = hparams.dev_path 103 | with open(path + hparams.a_id_path) as f_in: # word 104 | for row_item in f_in: 105 | _list = row_item.strip().split(' ') 106 | uu = [int(tmp) for tmp in _list if tmp != ''] 107 | if len(uu) == 0: 108 | continue 109 | Dev_word.append(uu) 110 | with open(path + hparams.b_id_path) as f_in: # pos 111 | for row_item in f_in: 112 | _list = row_item.strip().split(' ') 113 | uu = [int(tmp) for tmp in _list if tmp != ''] 114 | if len(uu) == 0: 115 | continue 116 | Dev_pos.append(uu) 117 | with open(path + hparams.c_id_path) as f_in: # role 118 | for row_item in f_in: 119 | _list = row_item.strip().split(' ') 120 | uu = [int(tmp) for tmp in _list if tmp != ''] 121 | if len(uu) == 0: 122 | continue 123 | Dev_role.append(uu) 124 | 125 | print("train size: %d, validation size: %d" % (len(Train_word), len(Dev_word))) 126 | 127 | # padding 128 | Train_word = np.array(padding(Train_word, seq_max_len=hparams.seq_max_len)) 129 | Train_pos = np.array(padding(Train_pos, seq_max_len=hparams.seq_max_len)) 130 | Train_role = np.array(padding(Train_role, seq_max_len=hparams.seq_max_len)) 131 | 132 | Dev_word = np.array(padding(Dev_word, seq_max_len=hparams.seq_max_len)) 133 | Dev_pos = np.array(padding(Dev_pos, seq_max_len=hparams.seq_max_len)) 134 | Dev_role = np.array(padding(Dev_role, seq_max_len=hparams.seq_max_len)) 135 | 136 | return Train_word, Train_pos, Train_role, Dev_word, Dev_pos, Dev_role 137 | 138 | 139 | def padding(sample, seq_max_len): 140 | """use '0' to padding the sentence""" 141 | for i in range(len(sample)): 142 | if len(sample[i]) < seq_max_len: 143 | sample[i] += [0 for _ in range(seq_max_len - len(sample[i]))] 144 | return sample 145 | 146 | 147 | def unpadding(sample): 148 | """delete '0' from padding sentence""" 149 | sample_new = [] 150 | for item in sample: 151 | _list = [] 152 | _list_tmp = [] 153 | for ii in item: 154 | _list_tmp.append(ii) 155 | if ii != 0: 156 | _list = _list + _list_tmp 157 | _list_tmp = [] 158 | sample_new.append(_list) 159 | return sample_new 160 | 161 | 162 | def next_test_batch(X_word, X_pos, y_role, start_index, batch_size=128): 163 | full_size = batch_size 164 | last_index = start_index + batch_size 165 | X_word_batch = list(X_word[start_index:min(last_index, len(X_word))]) 166 | X_pos_batch = list(X_pos[start_index:min(last_index, len(X_pos))]) 167 | y_role_batch = list(y_role[start_index:min(last_index, len(y_role))]) 168 | if last_index > len(X_word): 169 | full_size = len(X_word) - start_index 170 | left_size = last_index - (len(X_word)) 171 | for i in range(left_size): 172 | index = np.random.randint(len(X_word)) 173 | X_word_batch.append(X_word[index]) 174 | X_pos_batch.append(X_pos[index]) 175 | y_role_batch.append(y_role[index]) 176 | X_word_batch = np.array(X_word_batch) 177 | X_pos_batch = np.array(X_pos_batch) 178 | y_role_batch = np.array(y_role_batch) 179 | return X_word_batch, X_pos_batch, y_role_batch, full_size 180 | 181 | 182 | 183 | def next_batch(X_word, X_pos, y_role, start_index, batch_size=128): 184 | last_index = start_index + batch_size 185 | X_word_batch = list(X_word[start_index:min(last_index, len(X_word))]) 186 | X_pos_batch = list(X_pos[start_index:min(last_index, len(X_pos))]) 187 | y_role_batch = list(y_role[start_index:min(last_index, len(y_role))]) 188 | if last_index > len(X_word): 189 | left_size = last_index - (len(X_word)) 190 | for i in range(left_size): 191 | index = np.random.randint(len(X_word)) 192 | X_word_batch.append(X_word[index]) 193 | X_pos_batch.append(X_pos[index]) 194 | y_role_batch.append(y_role[index]) 195 | X_word_batch = np.array(X_word_batch) 196 | X_pos_batch = np.array(X_pos_batch) 197 | y_role_batch = np.array(y_role_batch) 198 | return X_word_batch, X_pos_batch, y_role_batch 199 | 200 | 201 | def extract_entity(seqs, labels): 202 | entitys = {} 203 | for id, item in enumerate(labels): 204 | if item == 'O' or item == '_PAD' or item == 'rel': 205 | continue 206 | if item in entitys: 207 | entitys[item].append(seqs[id]) 208 | else: 209 | entitys[item] = [seqs[id]] 210 | return entitys 211 | 212 | 213 | def get_length_by_vec(seq_x): 214 | seq_len = [] 215 | for ii in seq_x: 216 | _len = len([jj for jj in ii if jj != 0]) 217 | seq_len.append(_len) 218 | # print(ii) 219 | assert _len != 0 220 | seq_len = np.array(seq_len) 221 | assert len(seq_len) == len(seq_x) 222 | return seq_len 223 | 224 | 225 | def next_random_batch(Dev_word, Dev_pos, Dev_role, batch_size): 226 | x_word_batch = [] 227 | x_pos_batch = [] 228 | y_role_batch = [] 229 | for i in range(batch_size): 230 | index = np.random.randint(len(Dev_word)) 231 | if len(Dev_word[index]) == 0: 232 | continue 233 | x_word_batch.append(Dev_word[index]) 234 | x_pos_batch.append(Dev_pos[index]) 235 | y_role_batch.append(Dev_role[index]) 236 | x_word_batch = np.array(x_word_batch) 237 | x_pos_batch = np.array(x_pos_batch) 238 | y_role_batch = np.array(y_role_batch) 239 | return x_word_batch, x_pos_batch, y_role_batch 240 | 241 | 242 | log_mode = 'w' 243 | 244 | 245 | def log(str, out_path): 246 | global log_mode 247 | log_path = out_path + "/log.txt" 248 | with open(log_path, log_mode) as f: 249 | f.write(str + '\n') 250 | if log_mode != 'a': 251 | log_mode = 'a' 252 | return 253 | 254 | 255 | def recover_role(role_list): 256 | ans_list = role_list 257 | good_list = ['O', 'rel', '_PAD'] 258 | last_item = None 259 | for i in range(len(role_list)): 260 | item = ans_list[i] 261 | next_item = None 262 | if i != len(role_list) - 1: 263 | next_item = ans_list[i+1] 264 | if item == '_PAD': 265 | print("Error, echo _PAD, %s" % str(role_list)) 266 | if item not in good_list: 267 | if item != last_item and item != next_item: 268 | ans_list[i] = 'S-' + ans_list[i] 269 | if item != last_item and item == next_item: 270 | ans_list[i] = 'B-' + ans_list[i] 271 | if item == last_item and item == next_item: 272 | ans_list[i] = 'I-' + ans_list[i] 273 | if item == last_item and item != next_item: 274 | ans_list[i] = 'E-' + ans_list[i] 275 | last_item = item 276 | return ans_list 277 | 278 | 279 | def recover_eval(test_file, outputs_role): 280 | with open(test_file, 'r') as f: 281 | outputs_lines = f.readlines() 282 | 283 | # assert len(outputs_lines) == len(outputs_role) 284 | outputs = [] 285 | for i in range(len(outputs_lines)): 286 | item_lists = [] 287 | _line = outputs_lines[i].strip() 288 | if _line == "": 289 | break 290 | _line_items = _line.split(' ') 291 | 292 | for j in range(len(_line_items)): 293 | _item = _line_items[j] 294 | 295 | _item_list = _item.split('/') 296 | _a = _item_list[0] 297 | _b = _item_list[1] 298 | _c = outputs_role[i][j] 299 | if len(_item_list) == 3 and _item_list[2] == 'rel': 300 | _c = _item_list[2] 301 | item_lists.append('/'.join([_a, _b, _c])) 302 | outputs.append(' '.join(item_lists) + '\n') 303 | return outputs 304 | 305 | 306 | def calc_f1(pred_file, gold_file): 307 | case_true, case_recall, case_precision = 0, 0, 0 308 | golds = [gold.split() for gold in open(gold_file, 'r').read().strip().split('\n')] 309 | preds = [pred.split() for pred in open(pred_file, 'r').read().strip().split('\n')] 310 | assert len(golds) == len(preds), "length of prediction file and gold file should be the same." 311 | for gold, pred in zip(golds, preds): 312 | lastname = '' 313 | keys_gold, keys_pred = {}, {} 314 | for item in gold: 315 | word, label = item.split('/')[0], item.split('/')[-1] 316 | flag, name = label[:label.find('-')], label[label.find('-') + 1:] 317 | if flag == 'O': 318 | continue 319 | if flag == 'S': 320 | if name not in keys_gold: 321 | keys_gold[name] = [word] 322 | else: 323 | keys_gold[name].append(word) 324 | else: 325 | if flag == 'B': 326 | if name not in keys_gold: 327 | keys_gold[name] = [word] 328 | else: 329 | keys_gold[name].append(word) 330 | lastname = name 331 | elif flag == 'I' or flag == 'E': 332 | # assert name == lastname, "the I-/E- labels are inconsistent with B- labels in gold file. %s" % str(gold) 333 | keys_gold[name][-1] += ' ' + word 334 | for item in pred: 335 | word, label = item.split('/')[0], item.split('/')[-1] 336 | flag, name = label[:label.find('-')], label[label.find('-') + 1:] 337 | if flag == 'O': 338 | continue 339 | if flag == 'S': 340 | if name not in keys_pred: 341 | keys_pred[name] = [word] 342 | else: 343 | keys_pred[name].append(word) 344 | else: 345 | if flag == 'B': 346 | if name not in keys_pred: 347 | keys_pred[name] = [word] 348 | else: 349 | keys_pred[name].append(word) 350 | lastname = name 351 | elif flag == 'I' or flag == 'E': 352 | # assert name == lastname, "the I-/E- labels are inconsistent with B- labels in pred file. %s" % str(pred) 353 | keys_pred[name][-1] += ' ' + word 354 | 355 | for key in keys_gold: 356 | case_recall += len(keys_gold[key]) 357 | for key in keys_pred: 358 | case_precision += len(keys_pred[key]) 359 | 360 | for key in keys_pred: 361 | if key in keys_gold: 362 | for word in keys_pred[key]: 363 | if word in keys_gold[key]: 364 | case_true += 1 365 | keys_gold[key].remove(word) # avoid replicate words 366 | assert case_recall != 0, "no labels in gold files!" 367 | assert case_precision != 0, "no labels in pred files!" 368 | recall = 1.0 * case_true / case_recall 369 | precision = 1.0 * case_true / case_precision 370 | f1 = 2.0 * recall * precision / (recall + precision) 371 | return recall, precision, f1 372 | 373 | -------------------------------------------------------------------------------- /src/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | get_batch 70 | targets_weight 71 | tf.contrib.crf.viterbi_decode 72 | max 73 | log 74 | log( 75 | inputs_word 76 | restore 77 | Graph 78 | 79 | 80 | 81 | 82 | 105 | 106 | 107 | 108 | 109 | true 110 | DEFINITION_ORDER 111 | 112 | 113 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 |