├── constant.py ├── seq2seq.py ├── base.py ├── data ├── rel_names ├── rel_pairs ├── rel.py └── rel ├── config.py ├── utils.py ├── mmtnn.py ├── transition_slstm.py ├── test_DNNBase.py ├── ReadMe.md ├── .gitignore ├── pipeline.py ├── init.py ├── word2vec.py ├── transform_data_w2v.py ├── crf.py ├── prepare_data_msr_ner.py ├── export_emr.py ├── preprocess_data.py ├── evaluate.py ├── prepare_data_semeval.py ├── prepare_data.py ├── dnn_base.py ├── re_cnn.py ├── dnn.py └── prepare_data_emr.py /constant.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /seq2seq.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /base.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | class Base: 3 | def __init__(self): 4 | pass -------------------------------------------------------------------------------- /data/rel_names: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/supercoderhawk/DeepNLP/HEAD/data/rel_names -------------------------------------------------------------------------------- /data/rel_pairs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/supercoderhawk/DeepNLP/HEAD/data/rel_pairs -------------------------------------------------------------------------------- /data/rel.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | import re 3 | import pickle 4 | def generate_rel(): 5 | relation_pairs = {} 6 | relation_names = {} 7 | with open('rel',encoding='utf-8') as file: 8 | for line in file.readlines(): 9 | content = line.strip() 10 | if len(content)>0: 11 | sections = re.sub(r'[ ]+', ' ', content).split(' ') 12 | rel_name = sections[0] 13 | arg1 = sections[1][:-1].split(':')[1] 14 | arg2 = sections[2].split(':')[1] 15 | if relation_pairs.get(arg1) == None: 16 | relation_pairs[arg1] = [arg2] 17 | else: 18 | relation_pairs[arg1].append(arg2) 19 | relation_names[arg1+':'+arg2] = rel_name 20 | with open('rel_pairs','wb') as pairs_file: 21 | pickle.dump(relation_pairs,pairs_file) 22 | with open('rel_names','wb') as names_file: 23 | pickle.dump(relation_names,names_file) 24 | 25 | generate_rel() -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from enum import Enum 3 | 4 | 5 | class CorpusType(Enum): 6 | Train = 1 7 | Test = 2 8 | 9 | 10 | class TrainMode(Enum): 11 | Sentence = 1 12 | Batch = 2 13 | 14 | 15 | class BaseConfig: 16 | def __init__(self, learning_rate, vocab_size, embed_size, hidden_units): 17 | self.learning_rate = learning_rate 18 | self.vocab_size = vocab_size 19 | self.embed_size = embed_size 20 | self.hidden_units = hidden_units 21 | 22 | 23 | class DNNConfig(BaseConfig): 24 | def __init__(self, learning_rate, skip_left, skip_right, vocab_size, embed_size, hidden_units): 25 | BaseConfig.__init__(learning_rate, vocab_size, embed_size, hidden_units) 26 | self.skip_window_left = skip_left 27 | self.skip_window_right = skip_right 28 | self.window_size = self.skip_window_left + self.skip_window_right + 1 29 | self.concat_embed_size = self.embed_size * self.window_size 30 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def strQ2B(ustring): 6 | '''全角转半角''' 7 | rstring = '' 8 | for uchar in ustring: 9 | inside_code = ord(uchar) 10 | if inside_code == 12288: # 全角空格直接转换 11 | inside_code = 32 12 | elif (inside_code >= 65281 and inside_code <= 65374): # 全角字符(除空格)根据关系转化 13 | inside_code -= 65248 14 | rstring += chr(inside_code) 15 | return rstring 16 | 17 | 18 | def plot_lengths(lengths): 19 | lengths = sorted(lengths) 20 | pre_i = lengths[0] 21 | count = [] 22 | x = [] 23 | j = 0 24 | for i in lengths: 25 | if pre_i == i: 26 | j += 1 27 | else: 28 | count.append(j) 29 | x.append(pre_i) 30 | j = 0 31 | pre_i = i 32 | 33 | # print(len(list(filter(lambda l: l > 300, lengths)))) 34 | print('count size: ' + str(len(lengths))) 35 | print('max length: ' + str(lengths[-1])) 36 | x = range(len(count)) 37 | plt.plot(x, count) 38 | plt.ylabel('长度') 39 | plt.show() 40 | -------------------------------------------------------------------------------- /mmtnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | from dnn_base import DNNBase 6 | from preprocess_data import PreprocessData 7 | from config import TrainMode 8 | 9 | class MMTNN(DNNBase): 10 | def __init__(self): 11 | DNNBase.__init__(self) 12 | self.dtype = tf.float32 13 | self.vocab_size = 4500 14 | self.embed_size = 50 15 | self.concat_embed_size = self.window_size * self.embed_size 16 | self.learning_rate = 0.2 17 | self.lam = 0.0002 18 | pre = PreprocessData('pku',TrainMode.Batch) 19 | self.dictionary = pre.dictionary 20 | 21 | self.embeddings = self.weight_variable([self.vocab_size, self.embed_size]) 22 | self.input = tf.placeholder(tf.int32,[None,2]) 23 | 24 | 25 | def weight_variable(self, shape): 26 | initial = tf.truncated_normal(shape, stddev=1.0/math.sqrt(shape[-1]), dtype=self.dtype) 27 | return tf.Variable(initial) 28 | 29 | def train(self): 30 | pass 31 | 32 | def seg(self): 33 | pass -------------------------------------------------------------------------------- /transition_slstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | class TransitionSLSTM: 6 | def __init__(self): 7 | # 参数初始化 8 | self.dtype = tf.float32 9 | self.alpha = 0.2 10 | self.embed_size = 100 11 | self.hidden_unit = 50 12 | self.action_count = 5 13 | # 数据初始化 14 | 15 | # 构建模型 16 | self.sess = tf.Session() 17 | # placeholder 18 | self.stack = tf.placeholder(self.dtype, [self.embed_size, None]) 19 | self.buffer = tf.placeholder(self.dtype, [self.embed_size, None]) 20 | self.history_action = tf.placeholder(self.dtype, [self.embed_size, None]) 21 | self.allowed_action = tf.placeholder(self.dtype, [self.embed_size, None]) 22 | # 变量 23 | self.action = tf.Variable(tf.random_uniform([self.action_count, self.hidden_unit], -1, 1, dtype=self.dtype)) 24 | self.action_bias = tf.Variable(tf.random_uniform([self.action_count, self.hidden_unit], -1, 1, dtype=self.dtype)) 25 | 26 | def train_exe(self): 27 | pass 28 | 29 | def train_sentence(self, sentence, label): 30 | pass 31 | -------------------------------------------------------------------------------- /test_DNNBase.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import numpy as np 3 | from dnn_base import DNNBase 4 | 5 | # -*- coding: UTF-8 -*- 6 | class TestDNNBase(TestCase): 7 | def setUp(self): 8 | self.dnn_base = DNNBase() 9 | def test_viterbi(self): 10 | score = np.arange(10, 170, 10).reshape(4, 4).T 11 | A = np.array([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0]]) 12 | init_A = np.array([1, 1, 0, 0]) 13 | labels = np.array([3, 3, 3, 3]) 14 | current_path = self.dnn_base.viterbi(score, A, init_A) 15 | print(current_path) 16 | 17 | def test_viterbi_new(self): 18 | score = np.arange(10, 170, 10).reshape(4, 4).T 19 | A = np.array([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0]]) 20 | init_A = np.array([1, 1, 0, 0]) 21 | labels = np.array([3,3,3,3]) 22 | current_path = self.dnn_base.viterbi_new(score, A, init_A,labels) 23 | #print(current_path) 24 | #correct_path = np.array([1, 3, 1, 3]) 25 | #correct_score = np.array([21, 102, 203, 364]) 26 | #self.assertTrue(np.all(current_path == correct_path)) 27 | #self.assertTrue(np.all(current_score == correct_score)) 28 | 29 | def test_generate_transition_update(self): 30 | pass 31 | 32 | def test_generate_transition_update_index(self): 33 | pass 34 | 35 | def test_sentence2index(self): 36 | pass 37 | 38 | def test_index2seq(self): 39 | pass 40 | 41 | def test_tags2words(self): 42 | pass 43 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # 深度学习工具库 2 | 3 | 本项目基于`tensorflow`,实现一些论文提出的基于深度学习的分词、命名实体识别和实体关系抽取模型。 4 | 5 | 本项目是在[DNN_CWS](https://github.com/supercoderhawk/DNN_CWS)的基础上进行开发。增加了实体关系抽取。 6 | 7 | **本项目目前有重构的计划** 8 | 9 | **本项目迁移至[DeepLearning_NLP](https://github.com/supercoderhawk/DeepLearning_NLP),故本项目暂时停止维护** 10 | 11 | ## 项目功能 12 | 13 | * 中文分词 14 | * 命名实体识别 15 | * 实体关系抽取 16 | 17 | ## 依赖 18 | 1. python >= 3.5 19 | 2. tensorflow>=1.2.0 20 | 3. matplotlib>=1.5.3 21 | 22 | ## 语料库 23 | 24 | 文件夹`corpus`下: 25 | 26 | 1. pku_training.utf8、pku_test.utf8: sighan 2005 bakeoff 北大分词库 27 | 2. msr_training.utf8、msr_test.utf8: sighan 2005 bakeoff 微软亚洲研究院分词库 28 | 3. msr_ner_training.utf8: sighan 2006 bakeoff 微软亚洲研究院命名实体识别语料库 29 | 4. semeval_relation.utf8: International Workshop on Semantic Evaluation (SemEval) 30 | 2010 task 8 关系抽取数据集 31 | 32 | ## 参考论文 33 | 34 | ### 中文分词 && 命名实体识别 35 | * [deep learning for chinese word segmentation and pos tagging](www.aclweb.org/anthology/D13-1061) (已完全实现,文件[`dnn.py`](https://github.com/supercoderhawk/DeepNLP/blob/master/dnn.py)) 36 | * [Long Short-Term Memory Neural Networks for Chinese Word Segmentation](http://www.emnlp2015.org/proceedings/EMNLP/pdf/EMNLP141.pdf) (完全实现,需要调参,文件[`dnn.py`](https://github.com/supercoderhawk/DeepNLP/blob/master/dnn.py)) 37 | * [Max-Margin Tensor Neural Network for Chinese Word Segmentation](www.aclweb.org/anthology/P14-1028) (正在实现,文件[`mmtnn.py`](https://github.com/supercoderhawk/DeepNLP/blob/master/mmtnn.py)) 38 | 39 | ## 实体关系抽取 40 | * [relation extraction: perspective from convolutional neural networks](http://aclweb.org/anthology/W15-1506) (已完全实现,文件[`re_cnn`](https://github.com/supercoderhawk/DeepNLP/blob/master/re_cnn.py)) 41 | ## TodoList 42 | 43 | - [ ] 支持`pip` 44 | 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | ### Example user template template 98 | ### Example user template 99 | 100 | # IntelliJ project files 101 | .idea 102 | *.iml 103 | out 104 | gen 105 | corpus/* 106 | !corpus/msr_training.utf8 107 | !corpus/msr_test.utf8 108 | !corpus/pku_training.utf8 109 | !corpus/pku_test.utf8 110 | !corpus/msr_ner_training.utf8 111 | !corpus/semeval_relation.utf8 112 | tmp/ -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | from dnn import DNN 4 | from config import TrainMode 5 | from re_cnn import RECNN 6 | from evaluate import estimate_ner 7 | 8 | 9 | def get_cws(content, model_name): 10 | dnn = DNN('mlp', mode=TrainMode.Sentence, task='ner') 11 | ner = dnn.seg(content, model_path=model_name, ner=True, trans=True)[1] 12 | return ner 13 | 14 | 15 | def get_ner(content, model_name): 16 | if model_name.startswith('tmp/mlp'): 17 | dnn = DNN('mlp', mode=TrainMode.Sentence, task='ner', is_seg=True) 18 | else: 19 | dnn = DNN('lstm', task='ner', is_seg=True) 20 | ner = dnn.seg(content, model_path=model_name, ner=True, trans=True) 21 | return ner[1] 22 | 23 | 24 | def get_relation(): 25 | re = RECNN(2) 26 | re.evaluate('cnn_emr_model3.ckpt') 27 | re.evaluate('cnn_emr_model3.ckpt') 28 | 29 | 30 | def evaluate_ner(model_name): 31 | base_folder = 'corpus/emr_ner_test_' 32 | labels = np.load(base_folder + 'labels.npy') 33 | characters = np.load(base_folder + 'characters.npy') 34 | corr_count = 0 35 | prec_count = 0 36 | recall_count = 0 37 | for ch, l in zip(characters, labels): 38 | c_count, p_count, r_count = estimate_ner(get_ner(ch, model_name), l) 39 | corr_count += c_count 40 | prec_count += p_count 41 | recall_count += r_count 42 | print(corr_count, prec_count, recall_count) 43 | prec = corr_count / prec_count 44 | recall = corr_count / recall_count 45 | f1 = 2 * prec * recall / (prec + recall) 46 | with open('corpus/ner.txt', 'a', encoding='utf8') as f: 47 | f.write(model_name + '\t{:.2f}\t{:.2f}\t{:.2f}\n'.format(prec * 100, recall * 100, f1 * 100)) 48 | print('precision:', prec) 49 | print('recall:', recall) 50 | print('F1 score:', f1) 51 | 52 | 53 | def evaluate_re(): 54 | window_size = [(2,), (3,), (4,), (2, 3), (3, 4), (2, 3, 4)] 55 | for w in window_size: 56 | print('window size:', w) 57 | re_two = RECNN(2, window_size=w, train=False) 58 | # re_multi = RECNN(29, window_size=w, train=False) 59 | name = 'cnn_emr_model100_{0}.ckpt'.format('_'.join(map(str, w))) 60 | re_two.evaluate(name) 61 | # re_multi.evaluate(name) 62 | 63 | 64 | if __name__ == '__main__': 65 | # 实体识别 66 | # print('mlp') 67 | # evaluate_ner('tmp/mlp/mlp-ner-model20.ckpt') 68 | # print('mlp+embed') 69 | # evaluate_ner('tmp/mlp/mlp-ner-embed-model50.ckpt') 70 | # print('lstm') 71 | # evaluate_ner('tmp/lstm/lstm-ner-model50.ckpt') 72 | # print('lstm+embed') 73 | # evaluate_ner('tmp/lstm/lstm-ner-embed-model50.ckpt') 74 | # 关系抽取 75 | evaluate_re() 76 | # re_two = RECNN(2, window_size=(4,), train=False) 77 | # re_two.evaluate('cnn_emr_model60_4.ckpt') 78 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | from shutil import copyfile 4 | import re 5 | 6 | 7 | def extract_cws_file(): 8 | source_folder = 'corpus/cws_unannotated/' 9 | dest_folder = 'corpus/emr/cws/' 10 | judge_folder = 'corpus/emr/' 11 | cws_ext_name = '.cws' 12 | files = set() 13 | for i in os.listdir(judge_folder): 14 | if os.path.isfile(judge_folder + i): 15 | files.add(os.path.splitext(i)[0]) 16 | for i in files: 17 | copyfile(source_folder + i + cws_ext_name, dest_folder + i + cws_ext_name) 18 | 19 | 20 | def transfer_cws_file(): 21 | source_dir = 'corpus/emr/cws/' 22 | dest_folder = 'corpus/emr/' 23 | cws_files = set() 24 | cws_ext_name = '.cws' 25 | origin_files = set() 26 | for i in os.listdir(source_dir): 27 | cws_files.add(os.path.splitext(i)[0]) 28 | 29 | for i in os.listdir(dest_folder): 30 | if os.path.isfile(dest_folder + i): 31 | origin_files.add(os.path.splitext(i)[0]) 32 | print(len(cws_files) == len(origin_files)) 33 | 34 | for i in cws_files: 35 | copyfile(source_dir + i + cws_ext_name, dest_folder + i + cws_ext_name) 36 | 37 | 38 | def merge_cws_file(): 39 | content = '' 40 | dest_folder = 'corpus/' 41 | judge_folder = 'corpus/emr/' 42 | cws_ext_name = '.cws' 43 | files = set() 44 | 45 | for i in os.listdir(judge_folder): 46 | if os.path.isfile(judge_folder + i): 47 | files.add(os.path.splitext(i)[0]) 48 | 49 | for i in files: 50 | with open(judge_folder + i + cws_ext_name, 'r', encoding='utf8') as f: 51 | content += f.read().replace('\n', '') + '\n' 52 | with open(dest_folder + 'emr_training.utf8', 'w', encoding='utf8') as f: 53 | f.write(content) 54 | 55 | 56 | def merge_emr(): 57 | base_folder = 'corpus/admission-annotation/' 58 | ext_name = '.txt' 59 | files = set() 60 | sentences = [] 61 | for i in os.listdir(base_folder): 62 | files.add(i[:i.index('.')]) 63 | for i in files: 64 | with open(base_folder + i + ext_name, encoding='utf8') as f: 65 | content = f.read().replace('\n', '') 66 | index = [m.start() for m in re.finditer('。', content)] 67 | l = len(content) 68 | sentence = '' 69 | for beg, end in zip([-1] + index[:-1], index): 70 | if end - beg <= 1: 71 | continue 72 | if end != l - 1: 73 | if content[end + 1] != '”': 74 | sentences.append(sentence + content[beg + 1:end + 1]) 75 | sentence = '' 76 | else: 77 | sentence = content[beg + 1:end + 1] 78 | else: 79 | sentences.append(sentence + content[beg + 1:end + 1]) 80 | l = [len(l) for l in sentences] 81 | print(max(l), min(l)) 82 | with open('corpus/emr.txt', 'w', encoding='utf8') as f: 83 | f.write('\n'.join(sentences)) 84 | 85 | 86 | if __name__ == '__main__': 87 | # extract_cws_file() 88 | # transfer_cws_file() 89 | # merge_cws_file() 90 | merge_emr() 91 | -------------------------------------------------------------------------------- /word2vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | from transform_data_w2v import TransformDataW2V 6 | 7 | 8 | class Word2Vec: 9 | def __init__(self, output, batch_size=128, skip_window=1, embed_size=100,dict_path='corpus/emr_ner_dict.utf8', 10 | num_sampled=64, steps=25000): 11 | self.output = output 12 | self.batch_size = batch_size 13 | self.skip_window = skip_window 14 | self.embed_size = embed_size 15 | self.num_sampled = num_sampled 16 | self.steps = steps 17 | self.tran = TransformDataW2V(self.batch_size, self.skip_window,dict_path=dict_path) 18 | self.vocab_size = len(self.tran.dictionary) 19 | self.embeddings = tf.Variable(tf.random_uniform([self.vocab_size, self.embed_size], -1.0, 1.0)) 20 | 21 | def train(self): 22 | train_inputs = tf.placeholder(tf.int32, shape=[self.batch_size]) 23 | train_labels = tf.placeholder(tf.int32, shape=[self.batch_size, 1]) 24 | 25 | embed = tf.nn.embedding_lookup(self.embeddings, train_inputs) 26 | 27 | nce_weights = tf.Variable( 28 | tf.truncated_normal([self.vocab_size, self.embed_size], 29 | stddev=1.0 / math.sqrt(self.embed_size))) 30 | nce_biases = tf.Variable(tf.zeros([self.vocab_size])) 31 | 32 | loss = tf.reduce_mean( 33 | tf.nn.nce_loss(weights=nce_weights, biases=nce_biases, labels=train_labels, inputs=embed, 34 | num_sampled=self.num_sampled, num_classes=self.vocab_size)) 35 | optimizer = tf.train.GradientDescentOptimizer(0.2).minimize(loss) 36 | 37 | with tf.Session() as sess: 38 | tf.global_variables_initializer().run() 39 | 40 | aver_loss = 0 41 | for step in range(self.steps): 42 | batch_inputs, batch_labels = self.tran.generate_batch() 43 | feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels} 44 | _, loss_val = sess.run([optimizer, loss], feed_dict=feed_dict) 45 | aver_loss += loss_val 46 | 47 | if step % 2000 == 0: 48 | if step > 0: 49 | aver_loss /= 2000 50 | # The average loss is an estimate of the loss over the last 2000 batches. 51 | print("Average loss at step ", step, ": ", aver_loss) 52 | aver_loss = 0 53 | np.save(self.output, self.embeddings.eval()) 54 | 55 | def test(self): 56 | valid_dataset = [3021] 57 | norm = tf.sqrt(tf.reduce_sum(tf.square(self.embeddings), 1, keep_dims=True)) 58 | normalized_embeddings = self.embeddings / norm 59 | valid_embeddings = tf.nn.embedding_lookup( 60 | normalized_embeddings, valid_dataset) 61 | similarity = tf.abs(tf.matmul( 62 | valid_embeddings, normalized_embeddings, transpose_b=True)) 63 | print(similarity.eval()) 64 | pair = zip(range(self.vocab_size), similarity.eval()[0]) 65 | spair = sorted(pair, key=lambda x: x[1]) 66 | print(spair[0:10]) 67 | 68 | 69 | if __name__ == '__main__': 70 | w2v = Word2Vec('corpus/embed/embeddings', embed_size=100) 71 | w2v.train() 72 | -------------------------------------------------------------------------------- /transform_data_w2v.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import collections 4 | import random 5 | 6 | 7 | class TransformDataW2V(object): 8 | def __init__(self, batch_size, skip_window, source_file='corpus/emr.txt', dict_path='corpus/emr_embed_dict.utf8'): 9 | self.dict_path = dict_path 10 | self.batch_size = batch_size 11 | self.skip_window = skip_window 12 | self.source_file = source_file 13 | self.dictionary = self.read_dictionary() 14 | self.span = 2 * self.skip_window + 1 15 | self.indices = self.get_indices('emr') 16 | self.input, self.output = self.generate_collections() 17 | self.size = len(self.input) 18 | self.start = 0 19 | 20 | def read_dictionary(self): 21 | dict_file = open(self.dict_path, 'r', encoding='utf-8') 22 | dict_content = dict_file.read().splitlines() 23 | dictionary = {} 24 | dict_arr = map(lambda item: item.split(' '), dict_content) 25 | for _, dict_item in enumerate(dict_arr): 26 | dictionary[dict_item[0]] = int(dict_item[1]) 27 | dict_file.close() 28 | 29 | return dictionary 30 | 31 | def build_dictionary(self, source_file): 32 | dictionary = {'UNK': 0} 33 | with open(source_file, 'r', encoding='utf8') as f: 34 | characters = set(''.join(f.readlines())) 35 | for i, ch in enumerate(characters): 36 | dictionary[ch] = i + 1 37 | with open(self.dict_path, 'w', encoding='utf8') as dict_file: 38 | for character in dictionary: 39 | dict_file.write(character + ' ' + str(dictionary[character]) + '\n') 40 | return dictionary 41 | 42 | def get_indices(self, name): 43 | lines = [] 44 | with open('corpus/' + name + '.txt', 'r', encoding='utf8') as file: 45 | sentences = file.readlines() 46 | for sentence in sentences: 47 | if sentence: 48 | lines.append(self.sentence2index(sentence)) 49 | 50 | return lines 51 | 52 | def sentence2index(self, sentence): 53 | index = [] 54 | for ch in sentence: 55 | if ch in self.dictionary: 56 | index.append(self.dictionary[ch]) 57 | else: 58 | index.append(self.dictionary['UNK']) 59 | return index 60 | 61 | def generate_collections(self): 62 | input = [] 63 | output = [] 64 | for index in self.indices: 65 | target = index[self.skip_window:-self.skip_window] 66 | input += [i for l in zip(*[target] * self.skip_window * 2) for i in l] 67 | target_index = range(self.skip_window, len(target) + self.skip_window) 68 | 69 | def shuffle(i): 70 | return random.sample(index[i - self.skip_window:i] + index[i + 1:i + self.skip_window + 1],2 * self.skip_window) 71 | 72 | output += [j for i in map(shuffle, target_index) for j in i] 73 | if len(input) != len(output): 74 | print(len(input) - len(output)) 75 | 76 | return input, output 77 | 78 | def generate_batch(self): 79 | if self.start + self.batch_size > self.size: 80 | input_batch = self.input[self.start:] + self.input[:self.batch_size + self.start - self.size] 81 | output_batch = self.output[self.start:] + self.output[:self.batch_size + self.start - self.size] 82 | else: 83 | input_batch = self.input[self.start:self.start + self.batch_size] 84 | output_batch = self.output[self.start:self.start + self.batch_size] 85 | 86 | self.start += self.batch_size 87 | self.start %= self.size 88 | 89 | return np.array(input_batch, dtype=np.int32), np.expand_dims(np.array(output_batch, dtype=np.int32), 1) 90 | -------------------------------------------------------------------------------- /crf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | 4 | 5 | def estimate_ner(current_labels, correct_labels): 6 | corr_dict = {} 7 | curr_dict = {} 8 | corr_start = -2 9 | curr_start = -2 10 | 11 | # print('curr',current_labels) 12 | # print('corr', correct_labels) 13 | for label_index, (curr_label, corr_label) in enumerate(zip(current_labels, correct_labels)): 14 | if corr_label == 1: 15 | corr_start = label_index 16 | if corr_start == label_index - 1: 17 | corr_dict[corr_start] = 1 18 | elif label_index > 0 and corr_label == 2 and correct_labels[label_index - 1] != 2: 19 | corr_dict[corr_start] = label_index - corr_start 20 | 21 | if curr_label == 1: 22 | curr_start = label_index 23 | if curr_start == label_index - 1: 24 | curr_dict[curr_start] = 1 25 | elif label_index > 0 and curr_label == 2 and current_labels[label_index - 1] != 2: 26 | curr_dict[curr_start] = label_index - curr_start 27 | 28 | corr_count = 0 29 | prec_length = len(curr_dict) 30 | recall_length = len(corr_dict) 31 | for curr_start in curr_dict: 32 | if curr_start in corr_dict and curr_dict[curr_start] == corr_dict[curr_start]: 33 | corr_count += 1 34 | 35 | return corr_count, prec_length,recall_length 36 | 37 | def prepare_for_crfpp(folder, output_name): 38 | content = [] 39 | filenames = set() 40 | for _, _, names in os.walk(folder): 41 | for filename in names: 42 | name, _ = os.path.splitext(filename) 43 | if name not in filenames: 44 | filenames.add(name) 45 | for filename in filenames: 46 | path = folder + filename 47 | with open(path + '.txt', encoding='utf-8') as src_file: 48 | raw_text = src_file.read().replace('\n', '\r\n') 49 | labels = len(raw_text) * ['O'] 50 | with open(path + '.ann', encoding='utf-8') as ann_file: 51 | ann_items = ann_file.read().splitlines() 52 | for item in ann_items: 53 | sections = item.split('\t') 54 | if sections[0].startswith('T'): 55 | pos = sections[1].split(' ') 56 | start, end = int(pos[1]), int(pos[2]) 57 | labels[start] = 'B' 58 | if end - start - 1 > 0: 59 | labels[start + 1:end] = ['I'] * (end - start - 1) 60 | for ch, l in zip(raw_text, labels): 61 | if ch == '\r': 62 | continue 63 | if ch == '。': 64 | content.append(ch + '\t' + l + '\n') 65 | else: 66 | content.append(ch + '\t' + l) 67 | with open(output_name, mode='w', encoding='utf-8') as o: 68 | o.write('\n'.join(content)) 69 | 70 | 71 | def evaluate_ner(path): 72 | with open(path, encoding='utf-8') as f: 73 | entries = map(lambda l: l.split('\t'), [l for l in f.read().splitlines() if l]) 74 | res = list(zip(*entries)) 75 | label_map = {'O': 0, 'B': 1, 'I': 2} 76 | correct = list(map(lambda l: label_map[l], res[1])) 77 | current = list(map(lambda l: label_map[l], res[2])) 78 | corr, p_count, r_count = estimate_ner(current, correct) 79 | p = corr / p_count 80 | r = corr / r_count 81 | f1 = 2 * p * r / (p + r) 82 | print('precision:', p) 83 | print('recall:', r) 84 | print('f1', f1) 85 | 86 | if __name__ == '__main__': 87 | # train_folder = 'corpus/emr_paper/train/' 88 | # test_folder = 'corpus/emr_paper/test/' 89 | # prepare_for_crfpp(test_folder,'corpus/test.data') 90 | # prepare_for_crfpp(train_folder, 'corpus/train.data') 91 | # evaluate_ner('D:\Learning\master_project\clinicalText\CRF++-0.58\\res.data') 92 | evaluate_ner('D:\Learning\master_project\clinicalText\CRF++-0.58\\res_slim.data') 93 | -------------------------------------------------------------------------------- /data/rel: -------------------------------------------------------------------------------- 1 | PartOf Arg1:Sign, Arg2:Part 2 | QualityValue Arg1:Sign, Arg2:Quality 3 | QuantityValue Arg1:Sign, Arg2:Quantity 4 | UnitOf Arg1:Sign, Arg2:Unit 5 | PropertyOf Arg1:Sign, Arg2:Property 6 | DegreeOf Arg1:Sign, Arg2:Degree 7 | DateOf Arg1:Sign, Arg2:Date 8 | TimeOf Arg1:Sign, Arg2:Time 9 | StartTime Arg1:Sign, Arg2:Time 10 | EndTime Arg1:Sign, Arg2:Time 11 | Moment Arg1:Sign, Arg2:Time 12 | LocationOf Arg1:Sign, Arg2:Location 13 | FamilyOf Arg1:Sign, Arg2:Family 14 | ModifierOf Arg1:Sign, Arg2:Modifier 15 | 16 | PartOf Arg1:Symptom, Arg2:Part 17 | QualityValue Arg1:Symptom, Arg2:Quality 18 | QuantityValue Arg1:Symptom, Arg2:Quantity 19 | UnitOf Arg1:Symptom, Arg2:Unit 20 | PropertyOf Arg1:Symptom, Arg2:Property 21 | DegreeOf Arg1:Symptom, Arg2:Degree 22 | ResultOf Arg1:Symptom, Arg2:Result 23 | DateOf Arg1:Symptom, Arg2:Date 24 | TimeOf Arg1:Symptom, Arg2:Time 25 | StartTime Arg1:Symptom, Arg2:Time 26 | EndTime Arg1:Symptom, Arg2:Time 27 | Moment Arg1:Symptom, Arg2:Time 28 | FamilyOf Arg1:Symptom, Arg2:Family 29 | LocationOf Arg1:Symptom, Arg2:Location 30 | ModifierOf Arg1:Symptom, Arg2:Modifier 31 | 32 | PartOf Arg1:Disease, Arg2:Part 33 | DiseaseTypeOf Arg1:Disease, Arg2:DiseaseType 34 | PropertyOf Arg1:Disease, Arg2:Property 35 | DegreeOf Arg1:Disease, Arg2:Degree 36 | TimeOf Arg1:Disease, Arg2:Time 37 | Moment Arg1:Disease, Arg2:Time 38 | StartTime Arg1:Disease, Arg2:Time 39 | EndTime Arg1:Disease, Arg2:Time 40 | QualityValue Arg1:Disease, Arg2:Quality 41 | FamilyOf Arg1:Disease, Arg2:Family 42 | ModifierOf Arg1:Disease, Arg2:Modifier 43 | 44 | QualityValue Arg1:Examination, Arg2:Quality 45 | QuantityValue Arg1:Examination, Arg2:Quantity 46 | UnitOf Arg1:Examination, Arg2:Unit 47 | PartOf Arg1:Examination, Arg2:Part 48 | ResultOf Arg1:Examination, Arg2:Result 49 | Moment Arg1:Examination, Arg2:Time 50 | TimeOf Arg1:Examination, Arg2:Time 51 | DateOf Arg1:Examination, Arg2:Date 52 | LocationOf Arg1:Examination, Arg2:Location 53 | ModifierOf Arg1:Examination, Arg2:Modifier 54 | 55 | PartOf Arg1:Medicine, Arg2:Part 56 | SpecOf Arg1:Medicine, Arg2:Spec 57 | Moment Arg1:Medicine, Arg2:Time 58 | TimeOf Arg1:Medicine, Arg2:Time 59 | StartTime Arg1:Medicine, Arg2:Time 60 | EndTime Arg1:Medicine, Arg2:Time 61 | QualityValue Arg1:Medicine, Arg2:Quality 62 | UsageOf Arg1:Medicine, Arg2:Usage 63 | DoseOf Arg1:Medicine, Arg2:Dose 64 | ModifierOf Arg1:Medicine, Arg2:Modifier 65 | 66 | PartOf Arg1:Treatment, Arg2:Part 67 | QualityValue Arg1:Treatment, Arg2:Quality 68 | PropertyOf Arg1:Treatment, Arg2:Property 69 | Moment Arg1:Treatment, Arg2:Time 70 | TimeOf Arg1:Treatment, Arg2:Time 71 | StartTime Arg1:Treatment, Arg2:Time 72 | EndTime Arg1:Treatment, Arg2:Time 73 | LocationOf Arg1:Treatment, Arg2:Location 74 | ResultOf Arg1:Treatment, Arg2:Result 75 | ModifierOf Arg1:Treatment, Arg2:Modifier 76 | 77 | 78 | UseMedicine Arg1:Treatment, Arg2:Medicine 79 | AlongWith Arg1:Treatment, Arg2:Treatment 80 | 81 | AlongWith Arg1:Disease, Arg2:Disease 82 | LeadTo Arg1:Disease, Arg2:Symptom 83 | LeadTo Arg1:Disease, Arg2:Sign 84 | Adopt Arg1:Disease, Arg2:Treatment 85 | Take Arg1:Disease, Arg2:Medicine 86 | 87 | Find Arg1:Examination, Arg2:Sign 88 | Confirm Arg1:Examination, Arg2:Disease 89 | 90 | AlongWith Arg1:Symptom, Arg2:Symptom 91 | 92 | AlongWith Arg1:Sign, Arg2:Sign 93 | Complement Arg1:Sign, Arg2:Sign 94 | 95 | Limit Aeg1:Part, Arg2:Part -------------------------------------------------------------------------------- /prepare_data_msr_ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | from utils import strQ2B 4 | from collections import Counter 5 | 6 | 7 | class PrepareDataMSRNer: 8 | def __init__(self): 9 | self.labels_dict = {'O': 0, 'B': 1, 'I': 2} 10 | self.labels_count = len(self.labels_dict) 11 | self.ext_dict_path = ['corpus/msr_dict.utf8', 'corpus/pku_dict.utf8', 'corpus/emr_dict.utf8'] 12 | self.dict_path = 'corpus/msr_ner_dict.utf8' 13 | self.corpus_path = 'corpus/msr_ner_training.utf8' 14 | self.words, self.labels = self.read_content() 15 | self.dictionary, self.reverse_dictionary = self.build_dictionary() 16 | self.characters, self.character_labels = self.build_dataset() 17 | np.save('corpus/msr_ner_training_characters', self.characters) 18 | np.save('corpus/msr_ner_training_labels', self.character_labels) 19 | 20 | def read_content(self): 21 | words = [] 22 | labels = [] 23 | with open(self.corpus_path, 'r', encoding='utf8') as corpus_file: 24 | sentences = corpus_file.read().splitlines() 25 | for sentence in sentences: 26 | word = [] 27 | label = [] 28 | sections = sentence.strip().split(' ') 29 | for section in sections: 30 | pair = section.split('/') 31 | word.append(strQ2B(pair[0])) 32 | label.append(pair[1]) 33 | words.append(word) 34 | labels.append(label) 35 | return words, labels 36 | 37 | def build_dictionary(self): 38 | dictionary = {} 39 | characters = [] 40 | for dict_path in self.ext_dict_path: 41 | d = self.read_dictionary(dict_path) 42 | characters.extend(d.keys()) 43 | content = '' 44 | for line in self.words: 45 | for word in line: 46 | content += word 47 | characters.extend(list(Counter(content))) 48 | characters = list( 49 | filter(lambda ch: ch != 'UNK' and ch != 'STRT' and ch != 'END' and ch != 'BATCH_PAD', set(characters))) 50 | dictionary['BATCH_PAD'] = 0 51 | dictionary['UNK'] = 1 52 | dictionary['STRT'] = 2 53 | dictionary['END'] = 3 54 | for index, character in enumerate(characters, 4): 55 | dictionary[character] = index 56 | 57 | with open(self.dict_path, 'w', encoding='utf8') as dict_file: 58 | for character in dictionary: 59 | dict_file.write(character + ' ' + str(dictionary[character]) + '\n') 60 | return dictionary, dict(zip(dictionary.values(), dictionary.keys())) 61 | 62 | @staticmethod 63 | def read_dictionary(dict_path): 64 | dict_file = open(dict_path, 'r', encoding='utf-8') 65 | dict_content = dict_file.read().splitlines() 66 | dictionary = {} 67 | dict_arr = map(lambda item: item.split(' '), dict_content) 68 | for _, dict_item in enumerate(dict_arr): 69 | dictionary[dict_item[0]] = int(dict_item[1]) 70 | dict_file.close() 71 | return dictionary 72 | 73 | def build_dataset(self): 74 | seg_punctuation = ['。', '?', '!'] 75 | characters = [] 76 | labels = [] 77 | for line_word, line_label in zip(self.words, self.labels): 78 | line_characters = [] 79 | line_labels = [] 80 | for word, label in zip(line_word, line_label): 81 | for ch in word: 82 | if ch in self.dictionary: 83 | line_characters.append(self.dictionary[ch]) 84 | else: 85 | line_characters.append(1) 86 | if label == 'o': 87 | line_labels.extend([self.labels_dict['O']] * len(word)) 88 | else: 89 | line_labels.append(self.labels_dict['B']) 90 | line_labels.extend([self.labels_dict['I']] * (len(word) - 1)) 91 | if word in seg_punctuation: 92 | characters.append(np.array(line_characters, np.int32)) 93 | labels.append(np.array(line_labels, np.int32)) 94 | line_characters = [] 95 | line_labels = [] 96 | if len(line_characters) != 0: 97 | characters.append(np.array(line_characters, np.int32)) 98 | labels.append(np.array(line_labels, np.int32)) 99 | return np.array(characters), np.array(labels) 100 | 101 | 102 | if __name__ == '__main__': 103 | PrepareDataMSRNer() 104 | -------------------------------------------------------------------------------- /export_emr.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | import re 3 | import os 4 | 5 | def read_single_file(ann_file,raw_file): 6 | with open(raw_file, encoding='utf-8') as r: 7 | sentence = r.read() 8 | rn_indices = [m.start() for m in re.finditer('\n',sentence)] 9 | spans_diff = {} 10 | if len(rn_indices): 11 | spans = zip([-1]+rn_indices,rn_indices+[len(sentence)+len(rn_indices)]) 12 | for i,(before,curr) in enumerate(spans): 13 | spans_diff[(before+2,curr)] = i*2 14 | raw_sentence = sentence 15 | sentence = sentence.replace('\n','') 16 | 17 | #periods = [m.start() for m in re.finditer('。', sentence)] 18 | periods = [] 19 | sentence_len = len(sentence) 20 | last = 0 21 | sentences = [] 22 | for i,ch in enumerate(sentence): 23 | if ch =='。': 24 | if i 1: 104 | labels[start+1:end] = ['I']*(end-start-1) 105 | sentence_dict[sentence]['label'] = labels 106 | return sentence_dict 107 | 108 | 109 | 110 | def read_emr(directory,dest_file): 111 | files = set() 112 | for f in os.listdir(directory): 113 | files.add(os.path.splitext(os.path.split(f)[1])[0]) 114 | sentences = [] 115 | for f in files: 116 | sentences.extend(read_single_file(directory+f+'.ann',directory+f+'.txt').values()) 117 | text = '' 118 | with open(dest_file, 'w',encoding='utf-8') as f: 119 | for sentence in sentences: 120 | text += '\n'.join([' '.join(l) for l in zip(sentence['text'],sentence['label'])]) 121 | text += '\n\n' 122 | 123 | f.write(text) 124 | 125 | if __name__ == '__main__': 126 | read_emr('corpus/emr_paper/train/','corpus/emr_paper/emr_training.conll') 127 | read_emr('corpus/emr_paper/test/', 'corpus/emr_paper/emr_test.conll') -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import os 4 | from config import CorpusType, TrainMode 5 | 6 | 7 | class PreprocessData: 8 | def __init__(self, corpus, mode, type=CorpusType.Train,force_generate=True): 9 | self.skip_window_left = 0 10 | self.skip_window_right = 1 11 | self.window_size = self.skip_window_left + self.skip_window_right + 1 12 | self.dict_path = 'corpus/' + corpus + '_dict.utf8' 13 | if type == CorpusType.Train: 14 | self.input_base = 'corpus/' + corpus + '_training' 15 | elif type == CorpusType.Test: 16 | self.input_base = 'corpus/' + corpus + '_test' 17 | if mode == TrainMode.Sentence: 18 | self.characters = np.load(self.input_base + '_characters.npy') 19 | self.labels = np.load(self.input_base + '_labels.npy') 20 | self.lengths = np.load(self.input_base + '_lengths.npy') 21 | self.character_batches, self.label_batches = self.generate_sentences() 22 | elif mode == TrainMode.Batch: 23 | self.characters = np.load(self.input_base + '_character_batches.npy') 24 | self.labels = np.load(self.input_base + '_label_batches.npy') 25 | self.lengths = np.load(self.input_base + '_lengths_batches.npy') 26 | self.output_base = 'corpus/dnn/' + corpus + '_training' 27 | self.ouput_suffix = '_' + str(self.skip_window_left) + '_' + str(self.skip_window_right) 28 | if os.path.exists(self.output_base + '_character_batches' + self.ouput_suffix + '.npy') and not force_generate: 29 | self.character_batches = np.load(self.output_base + '_character_batches' + self.ouput_suffix + '.npy') 30 | self.label_batches = np.load(self.output_base + '_label_batches' + self.ouput_suffix + '.npy') 31 | else: 32 | self.character_batches, self.label_batches = self.generate_batches() 33 | np.save(self.output_base + '_character_batches' + self.ouput_suffix, self.character_batches) 34 | np.save(self.output_base + '_label_batches' + self.ouput_suffix, self.label_batches) 35 | else: 36 | print('模式错误') 37 | exit(1) 38 | 39 | self.dictionary = self.read_dictionary() 40 | 41 | def read_dictionary(self): 42 | dict_file = open(self.dict_path, 'r', encoding='utf-8') 43 | dict_content = dict_file.read().splitlines() 44 | dictionary = {} 45 | dict_arr = map(lambda item: item.split(' '), dict_content) 46 | for _, dict_item in enumerate(dict_arr): 47 | dictionary[dict_item[0]] = int(dict_item[1]) 48 | dict_file.close() 49 | 50 | return dictionary 51 | 52 | def generate_sentences(self): 53 | characters_batch = [] 54 | labels_batch = [] 55 | for i, sentence_words in enumerate(self.characters): 56 | if len(sentence_words) < max(self.skip_window_left, self.skip_window_right): 57 | continue 58 | extend_words = [2] * self.skip_window_left 59 | extend_words.extend(sentence_words) 60 | extend_words.extend([3] * self.skip_window_right) 61 | word_batch = list( 62 | map(lambda item: extend_words[item[0] - self.skip_window_left:item[0] + self.skip_window_right + 1], 63 | enumerate(extend_words[self.skip_window_left:-self.skip_window_right], self.skip_window_left))) 64 | characters_batch.append(np.array(word_batch, dtype=np.int32)) 65 | labels_batch.append(np.array(self.labels[i], dtype=np.int32)) 66 | #print(characters_batch) 67 | return np.array(characters_batch), np.array(labels_batch) 68 | 69 | def generate_batches(self): 70 | character_batches = [] 71 | label_batches = [] 72 | for batch_index, batch in enumerate(self.characters): 73 | character_batch = [] 74 | label_batch = [] 75 | for sentence_index, sentence in enumerate(batch): 76 | extend_words = [2] * self.skip_window_left 77 | extend_words.extend(sentence) 78 | extend_words.extend([3] * self.skip_window_right) 79 | if self.skip_window_right != 0: 80 | word_batch = list( 81 | map(lambda item: extend_words[item[0] - self.skip_window_left:item[0] + self.skip_window_right + 1], 82 | enumerate(extend_words[self.skip_window_left:-self.skip_window_right], self.skip_window_left))) 83 | else: 84 | word_batch = list( 85 | map(lambda item: extend_words[item[0] - self.skip_window_left:item[0] + self.skip_window_right + 1], 86 | enumerate(extend_words[self.skip_window_left:], self.skip_window_left))) 87 | character_batch.append(word_batch) 88 | label_batch.append(self.labels[batch_index][sentence_index]) 89 | character_batches.append(character_batch) 90 | label_batches.append(label_batch) 91 | 92 | return np.array(character_batches, dtype=np.int32), np.array(label_batches, dtype=np.int32) 93 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | from dnn import DNN 4 | from prepare_data import PrepareData 5 | from config import CorpusType, TrainMode 6 | from re_cnn import RECNN 7 | 8 | 9 | def evaluate_mlp(): 10 | cws = DNN('mlp', mode=TrainMode.Sentence) 11 | model = 'tmp/mlp-model20.ckpt' 12 | # print(cws.seg('小明来自南京师范大学', model, debug=True)) 13 | # print(cws.seg('小明是上海理工大学的学生', model)) 14 | # print(cws.seg('迈向充满希望的新世纪', model)) 15 | # print(cws.seg('我爱北京天安门', model)) 16 | # print(cws.seg('在中国致公党第十一次全国代表大会隆重召开之际,中国共产党中央委员会谨向大会表示热烈的祝贺,向致公党的同志们',model)) 17 | print(cws.seg('多饮多尿多食', model)) 18 | print(cws.seg('无明显小便泡沫增多,伴有夜尿3次。', model)) 19 | print(cws.seg('无明显小便泡沫增多,伴有夜尿3次。', model, ner=True)) 20 | print(cws.seg('无明显双脚疼痛,无间歇性后跛行,无明显足部红肿破溃', model)) 21 | # evaluate_model(cws, model) 22 | 23 | 24 | def evaluate_mlp_ner(): 25 | cws = DNN('mlp', mode=TrainMode.Sentence, is_seg=True, task='ner') 26 | model = 'tmp/mlp/mlp-ner-model1.ckpt' 27 | # print(cws.seg('在中国致公党第十一次全国代表大会隆重召开之际,中国共产党中央委员会谨向大会表示热烈的祝贺,向致公党的同志们', model,ner=True)) 28 | print(cws.seg('多饮多尿多食', model, ner=True)) 29 | print(cws.seg('无明显小便泡沫增多,伴有夜尿3次。', model, ner=True)) 30 | print(cws.seg('无明显双脚疼痛,无间歇性后跛行,无明显足部红肿破溃', model, ner=True, debug=False)) 31 | 32 | 33 | def evaluate_lstm(): 34 | cws = DNN('lstm', is_seg=True) 35 | model = 'tmp/lstm-model100.ckpt' 36 | print(cws.seg('小明来自南京师范大学', model, debug=True)) 37 | print(cws.seg('小明是上海理工大学的学生', model)) 38 | print(cws.seg('迈向充满希望的新世纪', model)) 39 | print(cws.seg('我爱北京天安门', model)) 40 | print(cws.seg('多饮多尿多食', model)) 41 | print(cws.seg('无明显小便泡沫增多,伴有夜尿3次。无明显双脚疼痛,无间歇性后跛行,无明显足部红肿破溃', model)) 42 | # evaluate_model(cws, model) 43 | 44 | 45 | def evaludate_RECNN(): 46 | reCNN = RECNN() 47 | reCNN.test() 48 | 49 | 50 | def evaluate_model(cws, model): 51 | pre = PrepareData(4000, 'pku', dict_path='corpus/pku_dict.utf8', type=CorpusType.Test) 52 | sentences = pre.raw_lines 53 | labels = pre.labels_index 54 | corr_count = 0 55 | re_count = 0 56 | total_count = 0 57 | 58 | for _, (sentence, label) in enumerate(zip(sentences, labels)): 59 | _, tag = cws.seg(sentence, model) 60 | cor_count, prec_count, recall_count = estimate_cws(tag, np.array(label)) 61 | corr_count += cor_count 62 | re_count += recall_count 63 | total_count += prec_count 64 | prec = corr_count / total_count 65 | recall = corr_count / re_count 66 | 67 | print(prec) 68 | print(recall) 69 | print(2 * prec * recall / (prec + recall)) 70 | 71 | 72 | def estimate_cws(current_labels, correct_labels): 73 | cor_dict = {} 74 | curt_dict = {} 75 | curt_start = 0 76 | cor_start = 0 77 | for label_index, (curt_label, cor_label) in enumerate(zip(current_labels, correct_labels)): 78 | if cor_label == 0: 79 | cor_dict[label_index] = label_index + 1 80 | elif cor_label == 1: 81 | cor_start = label_index 82 | elif cor_label == 3: 83 | cor_dict[cor_start] = label_index + 1 84 | 85 | if curt_label == 0: 86 | curt_dict[label_index] = label_index + 1 87 | elif curt_label == 1: 88 | curt_start = label_index 89 | elif curt_label == 3: 90 | curt_dict[curt_start] = label_index + 1 91 | 92 | cor_count = 0 93 | recall_length = len(curt_dict) 94 | prec_length = len(cor_dict) 95 | for curt_start in curt_dict.keys(): 96 | if curt_start in cor_dict and curt_dict[curt_start] == cor_dict[curt_start]: 97 | cor_count += 1 98 | 99 | return cor_count, prec_length, recall_length 100 | 101 | 102 | def estimate_ner(current_labels, correct_labels): 103 | corr_dict = {} 104 | curr_dict = {} 105 | corr_start = -2 106 | curr_start = -2 107 | 108 | # print('curr',current_labels) 109 | # print('corr', correct_labels) 110 | for label_index, (curr_label, corr_label) in enumerate(zip(current_labels, correct_labels)): 111 | if corr_label == 1: 112 | corr_start = label_index 113 | if corr_start == label_index - 1: 114 | corr_dict[corr_start] = 1 115 | elif label_index > 0 and corr_label == 2 and correct_labels[label_index - 1] != 2: 116 | corr_dict[corr_start] = label_index - corr_start 117 | 118 | if curr_label == 1: 119 | curr_start = label_index 120 | if curr_start == label_index - 1: 121 | curr_dict[curr_start] = 1 122 | elif label_index > 0 and curr_label == 2 and current_labels[label_index - 1] != 2: 123 | curr_dict[curr_start] = label_index - curr_start 124 | 125 | corr_count = 0 126 | prec_length = len(curr_dict) 127 | recall_length = len(corr_dict) 128 | for curr_start in curr_dict: 129 | if curr_start in corr_dict and curr_dict[curr_start] == corr_dict[curr_start]: 130 | corr_count += 1 131 | 132 | return corr_count, prec_length,recall_length 133 | 134 | if __name__ == '__main__': 135 | # evaluate_mlp() 136 | evaluate_mlp_ner() 137 | # evaluate_lstm() 138 | # evaludate_RECNN() 139 | -------------------------------------------------------------------------------- /prepare_data_semeval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import re 3 | import numpy as np 4 | import pickle 5 | from functools import reduce 6 | 7 | class PrepareDataSemeval: 8 | def __init__(self, batch_length=95, batch_size=50): 9 | self.base_path = 'corpus/semeval_relation' 10 | self.path = self.base_path + '.utf8' 11 | self.batch_length = batch_length 12 | self.batch_size = batch_size 13 | self.relation_categories, self.relations = self.read_content() 14 | self.dictionary = self.build_dictionary() 15 | self.batches = self.build_batches() 16 | with open('corpus/semeval_relation_batches.rel', 'wb') as f: 17 | pickle.dump(self.batches, f) 18 | print(len(self.relation_categories)) 19 | 20 | def read_content(self): 21 | with open(self.path, 'r', encoding='utf8') as file: 22 | contents = file.read().split('\n\n') 23 | relation_categories = {'Other': 0} 24 | relation_count = {'Other': 0} 25 | length = [] 26 | relations = [] 27 | for content in contents: 28 | sections = content.splitlines() 29 | if len(sections) == 0: 30 | break 31 | 32 | idx = re.search(r'^[0-9]+', sections[0]) 33 | if idx: 34 | idx = idx.group(0) 35 | else: 36 | idx = '1' 37 | 38 | sentence = sections[0][len(idx) + 2:-1] 39 | if idx == '1': 40 | sentence = sentence[1:] 41 | 42 | reduce_lengths = [4, 9, 13] 43 | e1_starttag_pos = sentence.find('') 44 | e1_endtag_pos = sentence.find('') 45 | e2_starttag_pos = sentence.find('') 46 | e2_endtag_pos = sentence.find('') 47 | e1_start = e1_starttag_pos 48 | e1_end = e1_endtag_pos - reduce_lengths[0] 49 | e2_start = e2_starttag_pos - reduce_lengths[1] 50 | e2_end = e2_endtag_pos - reduce_lengths[2] 51 | raw_sentence = sentence.replace('', '').replace('', '')\ 52 | .replace('', '').replace('','').strip() 53 | e1 = raw_sentence[e1_start:e1_end] 54 | e2 = raw_sentence[e2_start:e2_end] 55 | raw_words = list(filter(lambda w:len(w)>0 and w != ' ',raw_sentence.split(' '))) # 带标点的单词 56 | raw_words_index = [0] 57 | for raw_word in raw_words[:-1]: 58 | raw_words_index.append(raw_words_index[-1] + len(raw_word) + 1) 59 | 60 | words = [] # 分离标点后的单词 61 | words_index = [] 62 | for raw_word_index, raw_word in enumerate(raw_words): 63 | if len(raw_word) > 1: 64 | if len(raw_word) >2 and not raw_word[-1].isalnum() and not raw_word[0].isalnum(): 65 | words.append(raw_word[0]) 66 | words.append(raw_word[1:-1]) 67 | words.append(raw_word[-1]) 68 | words_index.append(raw_words_index[raw_word_index]) 69 | words_index.append(raw_words_index[raw_word_index]+1) 70 | words_index.append(raw_words_index[raw_word_index] + len(raw_word) - 1) 71 | elif not raw_word[-1].isalnum(): 72 | words.append(raw_word[:-1]) 73 | words.append(raw_word[-1]) 74 | words_index.append(raw_words_index[raw_word_index]) 75 | words_index.append(raw_words_index[raw_word_index] + len(raw_word) - 1) 76 | elif not raw_word[0].isalnum(): 77 | words.append(raw_word[0]) 78 | words.append(raw_word[1:]) 79 | words_index.append(raw_words_index[raw_word_index]) 80 | words_index.append(raw_words_index[raw_word_index] + 1) 81 | else: 82 | words.append(raw_word) 83 | words_index.append(raw_words_index[raw_word_index]) 84 | length.append(len(words)) 85 | 86 | e1_index = words_index.index(e1_start) 87 | e2_index = words_index.index(e2_start) 88 | for word in words: 89 | if len(word) == 0: 90 | print('fuck') 91 | if sections[1] != 'Other': 92 | relation = re.search(r'([a-zA-Z-]*)\(', sections[1]).groups()[0] 93 | if relation not in relation_categories: 94 | relation_categories[relation] = len(relation_categories) 95 | relation_count[relation] = 0 96 | primary, secondary = re.search(r'\((\S+),(\S+)\)', sections[1]).groups() 97 | if primary == 'e2' and secondary == 'e1': 98 | e1_index = words_index.index(e2_start) 99 | e2_index = words_index.index(e1_start) 100 | else: 101 | relation = 'Other' 102 | relation_count[relation] += 1 103 | relations.append({'id': idx, 'words': words, 'primary': e1_index, 'secondary': e2_index, 104 | 'type': relation_categories[relation]}) 105 | 106 | print(relation_count) 107 | print(relation_categories) 108 | all_count = reduce(lambda a,b:a+b,relation_count.values()) 109 | return relation_categories, relations 110 | 111 | def build_dictionary(self): 112 | words_set = set() 113 | for relation in self.relations: 114 | words_set = words_set.union(set(relation['words'])) 115 | dictionary = {'BATCH_PAD': 0, 'UNK': 1} 116 | with open('corpus/semeval_dict.utf8', 'w', encoding='utf8') as file: 117 | for word in words_set: 118 | dictionary[word] = len(dictionary) 119 | file.write(word+' '+str(dictionary[word])+'\n') 120 | file.write('BATCH_PAD'+' '+str(dictionary['BATCH_PAD'])+'\n') 121 | file.write('UNK' + ' ' + str(dictionary['UNK']) + '\n') 122 | return dictionary 123 | 124 | def build_batches(self): 125 | batches = [] 126 | sentence = [] 127 | primary = [] 128 | secondary = [] 129 | label = [] 130 | base_index = range(self.batch_length, 2 * self.batch_length) 131 | for relation_index, relation in enumerate(self.relations,1): 132 | words = list(map(lambda w: self.dictionary[w], relation['words'])) 133 | words += [self.dictionary['BATCH_PAD']] * (self.batch_length - len(words)) 134 | sentence.append(words) 135 | base_index = range(self.batch_length,len(relation['words'])+self.batch_length) 136 | 137 | p = list(map(lambda i: i - relation['primary'], base_index)) 138 | s = list(map(lambda i: i - relation['secondary'], base_index)) 139 | 140 | p.extend([self.dictionary['BATCH_PAD']]*(self.batch_length-len(relation['words']))) 141 | s.extend([self.dictionary['BATCH_PAD']] * (self.batch_length - len(relation['words']))) 142 | 143 | primary.append(p) 144 | secondary.append(s) 145 | 146 | relation_arr = [0]*len(self.relation_categories) 147 | relation_arr[relation['type']] = 1 148 | label.append(relation_arr) 149 | 150 | if relation_index % self.batch_size == 0: 151 | batches.append({'sentence': np.array(sentence, np.int32), 'primary': np.array(primary, np.int32), 152 | 'secondary': np.array(secondary, np.int32), 'label': np.array(label, np.float32)}) 153 | sentence.clear() 154 | primary.clear() 155 | secondary.clear() 156 | label.clear() 157 | 158 | return batches 159 | 160 | if __name__ == '__main__': 161 | sem = PrepareDataSemeval() 162 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import re 4 | import os 5 | import collections 6 | from utils import plot_lengths 7 | from config import CorpusType, TrainMode 8 | 9 | 10 | class PrepareData: 11 | def __init__(self, vocab_size, corpus, batch_length=224, batch_size=50, dict_path=None, mode=TrainMode.Batch, 12 | type=CorpusType.Train): 13 | self.vocab_size = vocab_size 14 | self.dict_path = dict_path 15 | self.batch_length = batch_length 16 | self.batch_size = batch_size 17 | self.SPLIT_CHAR = ' ' # 分隔符:双空格 18 | # 字符数量, 19 | # 其中'BATCH_PAD'表示构建batch时不足时补的字符,'UNK'表示词汇表外的字符, 20 | # 'STAT'表示句子首字符之前的字符,'END'表示句子尾字符后面的字符,这两个字符用于生成字的上下文 21 | self.count = [['BATCH_PAD', 0], ['UNK', 0], ['STRT', 0], ['END', 0]] 22 | self.init_count = len(self.count) 23 | if type == CorpusType.Train: 24 | self.input_file = 'corpus/' + corpus + '_' + 'training.utf8' 25 | self.output_base = 'corpus/' + corpus + '_' + 'training_' 26 | elif type == CorpusType.Test: 27 | self.input_file = 'corpus/' + corpus + '_' + 'test.utf8' 28 | self.output_base = 'corpus/' + corpus + '_' + 'test_' 29 | self.lines = self.read_lines() 30 | 31 | if self.dict_path == None: 32 | self.dictionary, self.reverse_dictionary = self.build_dictionary('corpus/' + corpus + '_dict.utf8') 33 | else: 34 | self.dictionary, self.reverse_dictionary = self.read_dictionary() 35 | if self.dictionary == None: 36 | print('vocabulary size larger than dictionary size') 37 | exit(1) 38 | 39 | if type == CorpusType.Train: 40 | self.characters_index, self.labels_index = self.build_dataset() 41 | if mode == TrainMode.Sentence: 42 | self.characters_index, self.labels_index, self.lengths = self.build_batch(trunc=False) 43 | np.save(self.output_base + 'characters', self.characters_index) 44 | np.save(self.output_base + 'labels', self.labels_index) 45 | np.save(self.output_base + 'lengths', self.lengths) 46 | elif mode == TrainMode.Batch: 47 | self.character_batches, self.label_batches, self.lengths, self.sentences, self.sentence_labels, self.sentence_lengths = self.build_batch() 48 | np.save(self.output_base + 'character_batches', self.character_batches) 49 | np.save(self.output_base + 'label_batches', self.label_batches) 50 | np.save(self.output_base + 'lengths_batches', self.lengths) 51 | elif type == CorpusType.Test: 52 | self.raw_lines = list(map(lambda s: s.replace(self.SPLIT_CHAR, ''), self.lines)) 53 | if os.path.exists('corpus/' + corpus + '_test_labels.npy'): 54 | self.labels_index = np.load('corpus/' + corpus + '_test_labels.npy') 55 | else: 56 | _, self.labels_index = self.build_dataset() 57 | np.save('corpus/' + corpus + '_test_labels', self.labels_index) 58 | 59 | # plot_lengths(self.sentence_lengths) 60 | 61 | def read_lines(self): 62 | file = open(self.input_file, 'r', encoding='utf-8') 63 | content = file.read() 64 | # sentences = re.sub('[ ]+', self.SPLIT_CHAR, strQ2B(content)).splitlines() # 将词分隔符统一为双空格 65 | sentences = re.sub('[ ]+', self.SPLIT_CHAR, content).splitlines() # 将词分隔符统一为双空格 66 | sentences = list(map(lambda s: s.strip(), filter(None, sentences))) # 去除空行,去首尾空格 67 | file.close() 68 | return sentences 69 | 70 | def build_dictionary(self, output=None): 71 | dictionary = {} 72 | words = ''.join(self.lines).replace(' ', '') 73 | vocab_count = len(collections.Counter(words)) 74 | print('characters count'+str(vocab_count)) 75 | if vocab_count + self.init_count < self.vocab_size: 76 | return None 77 | self.count.extend(collections.Counter(words).most_common(self.vocab_size - self.init_count)) 78 | 79 | for word, _ in self.count: 80 | dictionary[word] = len(dictionary) 81 | if output != None: 82 | with open(output, 'w', encoding='utf8') as file: 83 | for ch, index in dictionary.items(): 84 | file.write(ch + ' ' + str(index) + '\n') 85 | reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 86 | return dictionary, reverse_dictionary 87 | 88 | def read_dictionary(self): 89 | dict_file = open(self.dict_path, 'r', encoding='utf-8') 90 | dict_content = dict_file.read().splitlines() 91 | dictionary = {} 92 | dict_arr = map(lambda item: item.split(' '), dict_content) 93 | for _, dict_item in enumerate(dict_arr): 94 | dictionary[dict_item[0]] = int(dict_item[1]) 95 | dict_file.close() 96 | # if len(dictionary) < self.vocab_size: 97 | # return None 98 | reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 99 | # else: 100 | # reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 101 | # for i in range(self.vocab_size, len(dictionary)): 102 | # dictionary.pop(reverse_dictionary[i]) 103 | return dictionary, reverse_dictionary 104 | 105 | def build_dataset(self): 106 | sentence_index = [] 107 | labels_index = [] 108 | for sentence in self.lines: 109 | sentence_label = [] 110 | word_index = [] 111 | words = sentence.strip().split(self.SPLIT_CHAR) 112 | for word in words: 113 | l = len(word) 114 | if l == 0: 115 | continue 116 | elif l == 1: 117 | sentence_label.append(0) 118 | else: 119 | sentence_label.append(1) 120 | sentence_label.extend([2] * (l - 2)) 121 | sentence_label.append(3) 122 | for ch in word: 123 | index = self.dictionary.get(ch) 124 | if index is not None: 125 | word_index.append(index) 126 | else: 127 | word_index.append(0) 128 | sentence_index.append(word_index) 129 | labels_index.append(sentence_label) 130 | return np.array(sentence_index), np.array(labels_index) 131 | 132 | def build_batch(self, trunc = True): 133 | sentence_batches = [] 134 | label_batches = [] 135 | sentence_lengths = [] 136 | lengths = [] 137 | sentences = [] 138 | labels = [] 139 | unknown = 4 140 | seg_ch = [self.dictionary['。'], self.dictionary['!'], self.dictionary['?']] 141 | no_seg_ch = [self.dictionary['”']] 142 | characters_index = self.characters_index.tolist() 143 | labels_index = self.labels_index.tolist() 144 | line_lengths = list(map(lambda chs: len(chs), characters_index)) 145 | 146 | def is_seg(item): 147 | return item[1] in seg_ch and (item[0] < item[2] - 1 and characters[item[0] + 1] not in no_seg_ch) 148 | 149 | for characters, label, length in zip(characters_index, labels_index, line_lengths): 150 | if length <= 1: 151 | continue 152 | seg_indices = [0] + [i[0] + 1 for i in filter(is_seg, zip(range(length), characters, [length] * length))] 153 | for pre_seg_index, cur_seg_index in zip(seg_indices[:-1], seg_indices[1:]): 154 | sentence = characters[pre_seg_index:cur_seg_index] 155 | sentence_labels = label[pre_seg_index:cur_seg_index] 156 | sentences.append(sentence) 157 | labels.append(sentence_labels) 158 | sentence_length = len(sentence) 159 | sentence_lengths.append(sentence_length) 160 | if sentence_length <= self.batch_length: 161 | pad_length = self.batch_length - sentence_length 162 | sentence_batches.append(sentence + [self.dictionary['BATCH_PAD']] * pad_length) 163 | label_batches.append(sentence_labels + [unknown] * pad_length) 164 | lengths.append(sentence_length) 165 | else: 166 | if sentence_labels[sentence_length - 1] != 0 and 1 in sentence_labels[:sentence_length:-1]: 167 | last_index = sentence_labels[:sentence_length:-1].index(1) 168 | pad_length = self.batch_length - last_index 169 | sentence_batches.append(sentence[:last_index] + [self.dictionary['BATCH_PAD']] * pad_length) 170 | label_batches.append(sentence_labels[:last_index] + [unknown] * pad_length) 171 | lengths.append(last_index) 172 | else: 173 | sentence_batches.append(sentence[:self.batch_length]) 174 | label_batches.append(sentence_labels[:self.batch_length]) 175 | lengths.append(self.batch_length) 176 | 177 | if trunc: 178 | extra_count = len(sentence_batches) % self.batch_size 179 | sentence_batches = np.array(sentence_batches[:-extra_count], dtype=np.int32).reshape( 180 | [-1, self.batch_size, self.batch_length]) 181 | label_batches = np.array(label_batches[:-extra_count], dtype=np.int32).reshape( 182 | [-1, self.batch_size, self.batch_length]) 183 | lengths = np.array(lengths[:-extra_count], dtype=np.int32).reshape([-1, self.batch_size]) 184 | return sentence_batches, label_batches, lengths, sentences, labels, sentence_lengths 185 | else: 186 | return np.array(sentence_batches, dtype=np.int32),np.array(label_batches, dtype=np.int32), np.array(lengths,dtype=np.int32) 187 | 188 | 189 | if __name__ == '__main__': 190 | # PrepareData(4600, 'pku', mode=TrainMode.Batch) 191 | # PrepareData(4000, 'pku', type=CorpusType.Test, dict_path='corpus/pku_dict.utf8') 192 | # PrepareData(5000, 'msr', mode=TrainMode.Batch) 193 | PrepareData(None, 'emr',mode=TrainMode.Sentence) 194 | -------------------------------------------------------------------------------- /dnn_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | from base import Base 4 | from collections import OrderedDict 5 | 6 | 7 | class DNNBase(Base): 8 | def __init__(self): 9 | Base.__init__(self) 10 | self.tags_count = 3 11 | self.dictionary = None 12 | self.skip_window_left = 0 13 | self.skip_window_right = 1 14 | self.window_size = self.skip_window_left + self.skip_window_right + 1 15 | self.hinge_discount = 0.2 16 | self.reverse_categories, self.category_reverse_dict, self.zh_categories = self.init_categories() 17 | 18 | def init_categories(self): 19 | categories = {'Sign': 'SN', 'Symptom': 'SYM', 'Part': 'PT', 'Property': 'PTY', 'Degree': 'DEG', 20 | 'Quality': 'QLY', 'Quantity': 'QNY', 'Unit': 'UNT', 'Time': 'T', 'Date': 'DT', 'Result': 'RES', 21 | 'Disease': 'DIS', 'DiseaseType': 'DIT', 'Examination': 'EXN', 'Location': 'LOC', 22 | 'Medicine': 'MED', 'Spec': 'SPEC', 'Usage': 'USG', 'Dose': 'DSE', 'Treatment': 'TRT', 23 | 'Family': 'FAM', 'Modifier': 'MOF'} 24 | zh_categories = {'Sign': '体征', 'Symptom': '症状', 'Part': '部位', 'Property': '性质', 'Degree': '程度', 25 | 'Quality': '定性值', 'Quantity': '定量值', 'Unit': '单位', 'Time': '时间', 'Date': '日期', 'Result': '结果', 26 | 'Disease': '疾病', 'DiseaseType': '疾病分型分歧', 'Examination': '检查', 'Location': '机构', 27 | 'Medicine': '药物', 'Spec': '规格', 'Usage': '用法', 'Dose': '用量', 'Treatment': '治疗', 28 | 'Family': '家族成员', 'Modifier': '其他修饰词'} 29 | category_labels_dict = OrderedDict({'O': 0}) 30 | category_index = 1 31 | for category in categories: 32 | category_labels_dict[categories[category] + '_B'] = category_index 33 | category_index += 1 34 | category_labels_dict[categories[category] + '_O'] = category_index 35 | category_index += 1 36 | category_labels_dict['P'] = category_index 37 | return OrderedDict(zip(categories.values(), categories.keys())), OrderedDict( 38 | zip(category_labels_dict.values(), category_labels_dict.keys())), zh_categories 39 | 40 | def viterbi(self, emission, A, init_A, return_score=False, is_constraint=False, labels=None, size=4): 41 | """ 42 | 维特比算法的实现,所有输入和返回参数均为numpy数组对象 43 | :param emission: 发射概率矩阵,对应于本模型中的分数矩阵,4*length 44 | :param A: 转移概率矩阵,4*4 45 | :param init_A: 初始转移概率矩阵,4 46 | :param return_score: 是否返回最优路径的分值,默认为False 47 | :return: 最优路径,若return_score为True,返回最优路径及其对应分值 48 | """ 49 | 50 | constraint = [[0, 1], [2, 3], [2, 3], [0, 1]] 51 | length = emission.shape[1] 52 | path = np.ones([self.tags_count, length], dtype=np.int32) * -1 53 | corr_path = np.zeros([length], dtype=np.int32) 54 | path_score = np.ones([self.tags_count, length], dtype=np.float64) * (np.finfo('f').min / 2) 55 | path_score[:, 0] = init_A + emission[:, 0] 56 | 57 | if labels is not None: 58 | for i in range(size): 59 | if i != labels[0]: 60 | path_score[i, 0] += self.hinge_discount 61 | 62 | for pos in range(1, length): 63 | for t in range(self.tags_count): 64 | for prev in range(self.tags_count): 65 | if is_constraint: 66 | if t not in constraint[prev]: 67 | continue 68 | temp = path_score[prev][pos - 1] + A[prev][t] + emission[t][pos] 69 | if labels is not None: 70 | if t != labels[pos]: 71 | temp += self.hinge_discount 72 | if temp >= path_score[t][pos]: 73 | path[t][pos] = prev 74 | path_score[t][pos] = temp 75 | 76 | max_index = np.argmax(path_score[:, -1]) 77 | corr_path[length - 1] = max_index 78 | for i in range(length - 1, 0, -1): 79 | max_index = path[max_index][i] 80 | corr_path[i - 1] = max_index 81 | if return_score: 82 | return corr_path, path_score[max_index, -1] 83 | else: 84 | return corr_path 85 | 86 | def viterbi_new(self, emission, transition, transition_init, labels=None): 87 | constraint = [[0, 1], [2, 3], [2, 3], [0, 1]] 88 | length = emission.shape[1] 89 | path = np.ones([self.tags_count, length + 1], dtype=np.int32) * -1 90 | corr_path = np.zeros([length], dtype=np.int32) 91 | path_score = np.ones([self.tags_count, length + 1], dtype=np.float64) * (np.finfo('f').min / 2) 92 | # path_score[:, 0] = transition_init + emission[:, 0] 93 | path_score[0, 0] = 0 94 | 95 | for pos in range(1, length + 1): 96 | for path_index in range(self.tags_count): 97 | for curr_label in constraint[path_index]: 98 | tmp = path_score[path_index, pos - 1] + emission[curr_label, pos - 1] + transition[path_index, curr_label] 99 | if labels is not None: 100 | if curr_label != labels[pos - 1]: 101 | tmp += self.hinge_discount 102 | if tmp > path_score[curr_label, pos]: 103 | path_score[curr_label, pos] = tmp 104 | path[curr_label, pos] = path_index 105 | 106 | # print(path) 107 | # print(path_score) 108 | max_index = np.argmax(path_score[:, -1]) 109 | corr_path[length - 1] = max_index 110 | for i in range(length - 1, 0, -1): 111 | max_index = path[max_index][i + 1] 112 | corr_path[i - 1] = max_index 113 | return corr_path 114 | 115 | def generate_transition_update(self, correct_tags, current_tags): 116 | if correct_tags.shape != current_tags.shape: 117 | print('序列长度不同') 118 | return None 119 | 120 | A_update = np.zeros([self.tags_count, self.tags_count], dtype=np.float32) 121 | init_A_update = np.zeros([self.tags_count], dtype=np.float32) 122 | before_corr = correct_tags[0] 123 | before_curr = current_tags[0] 124 | update_init = False 125 | 126 | if before_corr != before_curr: 127 | init_A_update[before_corr] += 1 128 | init_A_update[before_curr] -= 1 129 | update_init = True 130 | 131 | for _, (corr_tag, curr_tag) in enumerate(zip(correct_tags[1:], current_tags[1:])): 132 | if corr_tag != curr_tag or before_corr != before_curr: 133 | A_update[before_corr, corr_tag] += 1 134 | A_update[before_curr, curr_tag] -= 1 135 | before_corr = corr_tag 136 | before_curr = curr_tag 137 | 138 | return A_update, init_A_update, update_init 139 | 140 | def generate_transition_update_index(self, correct_labels, current_labels): 141 | if correct_labels.shape != current_labels.shape: 142 | print('序列长度不同') 143 | return None 144 | 145 | before_corr = correct_labels[0] 146 | before_curr = current_labels[0] 147 | update_init = False 148 | 149 | trans_init_pos = None 150 | trans_init_neg = None 151 | trans_pos = [] 152 | trans_neg = [] 153 | 154 | if before_corr != before_curr: 155 | trans_init_pos = [before_corr] 156 | trans_init_neg = [before_curr] 157 | update_init = True 158 | 159 | for _, (corr_label, curr_label) in enumerate(zip(correct_labels[1:], current_labels[1:])): 160 | if corr_label != curr_label or before_corr != before_curr: 161 | trans_pos.append([before_corr, corr_label]) 162 | trans_neg.append([before_curr, curr_label]) 163 | before_corr = corr_label 164 | before_curr = curr_label 165 | 166 | return trans_pos, trans_neg, trans_init_pos, trans_init_neg, update_init 167 | 168 | def sentence2index(self, sentence): 169 | index = [] 170 | for word in sentence: 171 | if word not in self.dictionary: 172 | index.append(1) 173 | else: 174 | index.append(self.dictionary[word]) 175 | 176 | return index 177 | 178 | def index2seq(self, indices): 179 | ext_indices = [2] * self.skip_window_left 180 | ext_indices.extend(indices + [3] * self.skip_window_right) 181 | seq = [] 182 | for index in range(self.skip_window_left, len(ext_indices) - self.skip_window_right): 183 | seq.append(ext_indices[index - self.skip_window_left: index + self.skip_window_right + 1]) 184 | 185 | return seq 186 | 187 | def tags2words(self, sentence, tags): 188 | words = [] 189 | word = '' 190 | for tag_index, tag in enumerate(tags): 191 | if tag == 0: 192 | words.append(sentence[tag_index]) 193 | elif tag == 1: 194 | word = sentence[tag_index] 195 | elif tag == 2: 196 | word += sentence[tag_index] 197 | else: 198 | words.append(word + sentence[tag_index]) 199 | word = '' 200 | # 处理最后一个标记为I的情况 201 | if word != '': 202 | words.append(word) 203 | 204 | return words 205 | 206 | def tags2entities(self, sentence, tags, return_start=True): 207 | entities = [] 208 | entity_starts = [] 209 | entity = '' 210 | 211 | for tag_index, tag in enumerate(tags): 212 | if tag == 0: 213 | continue 214 | elif tag == 1: 215 | if entity: 216 | entities.append(entity) 217 | entity = sentence[tag_index] 218 | entity_starts.append(tag_index) 219 | else: 220 | entity += sentence[tag_index] 221 | if entity != '': 222 | entities.append(entity) 223 | if return_start: 224 | return entities,entity_starts 225 | else: 226 | return entities 227 | 228 | def tags2category_entities(self, sentence, tags): 229 | entities = [] 230 | entity = '' 231 | category = '' 232 | for tag_index, tag in enumerate(tags): 233 | type = self.category_reverse_dict[tag] 234 | if tag == 0: 235 | continue 236 | elif type[-1] == 'B': 237 | entities.append(entity + '/' + category) 238 | entity = sentence[tag_index] 239 | category = self.zh_categories[self.reverse_categories[type[:-2]]] 240 | else: 241 | entity += sentence[tag_index] 242 | if entity != '': 243 | entities.append(entity + '/' + category) 244 | return entities 245 | -------------------------------------------------------------------------------- /re_cnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import tensorflow as tf 4 | import pickle 5 | 6 | 7 | class RECNN(): 8 | def __init__(self, relation_count=2, window_size=(3,), batch_size=50, batch_length=85,train=True): 9 | tf.reset_default_graph() 10 | self.dtype = tf.float32 11 | self.window_size = window_size 12 | self.filter_size = 150 13 | self.relation_count = relation_count 14 | self.batch_length = batch_length 15 | self.batch_size = batch_size 16 | self.learning_rate = 0.01 17 | self.dropout_rate = 0.5 18 | self.lam = 0.0001 19 | self.character_embed_size = 300 20 | self.position_embed_size = 50 21 | self.dict_path = 'corpus/emr_words_dict.utf8' 22 | self.dictionary = self.read_dictionary() 23 | self.words_size = len(self.dictionary) 24 | self.is_train = train 25 | if relation_count == 2: 26 | self.batch_path = 'corpus/emr_all_relation_batches.rel' 27 | self.output_folder = 'tmp/re_two/' 28 | self.test_batch_path = 'corpus/emr_test_all_relations.rel' 29 | elif relation_count == 29: 30 | self.batch_path = 'corpus/emr_relation_batches.rel' 31 | self.output_folder = 'tmp/re_multi/' 32 | self.test_batch_path = 'corpus/emr_test_relations.rel' 33 | else: 34 | raise Exception('relation count error') 35 | 36 | self.concat_embed_size = self.character_embed_size + 2 * self.position_embed_size 37 | self.input_characters = tf.placeholder(tf.int32, [None, self.batch_length]) 38 | self.input_position = tf.placeholder(tf.int32, [None, self.batch_length]) 39 | self.input = tf.placeholder(self.dtype, [None, self.batch_length, self.concat_embed_size, 1]) 40 | self.input_relation = tf.placeholder(self.dtype, [None, self.relation_count]) 41 | self.position_embedding = self.weight_variable([2 * self.batch_length, self.position_embed_size]) 42 | self.character_embedding = self.weight_variable([self.words_size, self.character_embed_size]) 43 | self.conv_kernel = self.get_conv_kernel() 44 | self.bias = [self.weight_variable([self.filter_size])] * len(self.window_size) 45 | self.full_connected_weight = self.weight_variable([self.filter_size*len(self.window_size), self.relation_count]) 46 | self.full_connected_bias = self.weight_variable([self.relation_count]) 47 | self.position_lookup = tf.nn.embedding_lookup(self.position_embedding, self.input_position) 48 | self.character_lookup = tf.nn.embedding_lookup(self.character_embedding, self.input_characters) 49 | self.character_embed_holder = tf.placeholder(self.dtype, 50 | [None, self.batch_length, self.character_embed_size]) 51 | self.primary_embed_holder = tf.placeholder(self.dtype, 52 | [None, self.batch_length, self.position_embed_size]) 53 | self.secondary_embed_holder = tf.placeholder(self.dtype, 54 | [None, self.batch_length, self.position_embed_size]) 55 | self.emebd_concat = tf.expand_dims( 56 | tf.concat([self.character_embed_holder, self.primary_embed_holder, self.secondary_embed_holder], 2), 3) 57 | if train: 58 | self.hidden_layer = tf.layers.dropout(self.get_hidden(), self.dropout_rate) 59 | else: 60 | self.hidden_layer = tf.expand_dims(tf.layers.dropout(self.get_hidden(), self.dropout_rate),0) 61 | self.output_no_softmax = tf.matmul(self.hidden_layer, self.full_connected_weight) + self.full_connected_bias 62 | self.output = tf.nn.softmax(tf.matmul(self.hidden_layer, self.full_connected_weight) + self.full_connected_bias) 63 | self.params = [self.position_embedding, self.character_embedding, self.full_connected_weight, 64 | self.full_connected_bias] + self.conv_kernel + self.bias 65 | self.regularization = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam), 66 | self.params) 67 | self.loss = tf.reduce_sum(tf.square(self.output - self.input_relation)) / self.batch_size + self.regularization 68 | self.cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_relation, 69 | logits=self.output_no_softmax) + self.regularization 70 | self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) 71 | # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate) 72 | self.train_model = self.optimizer.minimize(self.loss) 73 | self.train_cross_entropy_model = self.optimizer.minimize(self.cross_entropy) 74 | self.saver = tf.train.Saver(max_to_keep=100) 75 | 76 | def weight_variable(self, shape): 77 | initial = tf.truncated_normal(shape, stddev=0.1, dtype=self.dtype) 78 | return tf.Variable(initial) 79 | 80 | def get_conv_kernel(self): 81 | conv_kernel = [] 82 | for w in self.window_size: 83 | conv_kernel.append(self.weight_variable([w, self.concat_embed_size, 1, self.filter_size])) 84 | return conv_kernel 85 | 86 | def get_max_pooling(self, x): 87 | max_pooling = [] 88 | for w in self.window_size: 89 | max_pooling.append(self.max_pooling(x, w)) 90 | return max_pooling 91 | 92 | def get_hidden(self): 93 | h = None 94 | for w, conv, bias in zip(self.window_size, self.conv_kernel, self.bias): 95 | if h is None: 96 | h = tf.squeeze(self.max_pooling(tf.nn.relu(self.conv(conv) + bias), w)) 97 | else: 98 | hh = tf.squeeze(self.max_pooling(tf.nn.relu(self.conv(conv) + bias), w)) 99 | if self.is_train: 100 | h = tf.concat([h, hh], 1) 101 | else: 102 | h = tf.concat([h,hh], 0) 103 | return h 104 | 105 | def conv(self, conv_kernel): 106 | return tf.nn.conv2d(self.input, conv_kernel, strides=[1, 1, 1, 1], padding='VALID') 107 | 108 | def max_pooling(self, x, window_size): 109 | return tf.nn.max_pool(x, ksize=[1, self.batch_length - window_size + 1, 1, 1], 110 | strides=[1, 1, 1, 1], padding='VALID') 111 | 112 | def train(self): 113 | batches = self.load_batches(self.batch_path) 114 | with tf.Session() as sess: 115 | tf.global_variables_initializer().run() 116 | sess.graph.finalize() 117 | epochs = 100 118 | for i in range(1, epochs + 1): 119 | print('epoch:' + str(i)) 120 | for batch in batches: 121 | character_embeds, primary_embeds = sess.run([self.character_lookup, self.position_lookup], 122 | feed_dict={self.input_characters: batch['sentence'], 123 | self.input_position: batch['primary']}) 124 | secondary_embeds = sess.run(self.position_lookup, feed_dict={self.input_position: batch['secondary']}) 125 | input = sess.run(self.emebd_concat, feed_dict={self.character_embed_holder: character_embeds, 126 | self.primary_embed_holder: primary_embeds, 127 | self.secondary_embed_holder: secondary_embeds}) 128 | # sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']}) 129 | sess.run(self.train_cross_entropy_model, feed_dict={self.input: input, self.input_relation: batch['label']}) 130 | if i % 50 == 0: 131 | model_name = 'cnn_emr_model{0}_{1}.ckpt'.format(i, '_'.join(map(str, self.window_size))) 132 | self.saver.save(sess, self.output_folder + model_name) 133 | 134 | def load_batches(self, path): 135 | with open(path, 'rb') as f: 136 | batches = pickle.load(f) 137 | return batches 138 | 139 | def read_dictionary(self): 140 | dict_file = open(self.dict_path, 'r', encoding='utf-8') 141 | dict_content = dict_file.read().splitlines() 142 | dictionary = {} 143 | dict_arr = map(lambda item: item.split(' '), dict_content) 144 | for _, dict_item in enumerate(dict_arr): 145 | dictionary[dict_item[0]] = int(dict_item[1]) 146 | dict_file.close() 147 | return dictionary 148 | 149 | def predict(self, sentences, primary_indies, secondary_indices): 150 | with tf.Session() as sess: 151 | self.saver.restore(sess, self.output_folder + 'cnn_emr_model3.ckpt') 152 | character_embeds, primary_embeds = sess.run([self.character_lookup, self.position_lookup], 153 | feed_dict={self.input_characters: sentences, 154 | self.input_position: primary_indies}) 155 | secondary_embeds = sess.run(self.position_lookup, feed_dict={self.input_position: secondary_indices}) 156 | input = sess.run(self.emebd_concat, feed_dict={self.character_embed_holder: character_embeds, 157 | self.primary_embed_holder: primary_embeds, 158 | self.secondary_embed_holder: secondary_embeds}) 159 | output = sess.run(self.output, feed_dict={self.input: input}) 160 | return np.argmax(output, 1) 161 | 162 | def evaluate(self, model_file): 163 | #tf.reset_default_graph() 164 | with tf.Session() as sess: 165 | #tf.global_variables_initializer().run() 166 | 167 | self.saver.restore(sess=sess, save_path=self.output_folder + model_file) 168 | items = self.load_batches(self.test_batch_path) 169 | corr_count = [0] * self.relation_count 170 | prec_count = [0] * self.relation_count 171 | recall_count = [0] * self.relation_count 172 | 173 | for item in items: 174 | character_embeds, primary_embeds = sess.run([self.character_lookup, self.position_lookup], 175 | feed_dict={self.input_characters: item['sentence'], 176 | self.input_position: item['primary']}) 177 | secondary_embeds = sess.run(self.position_lookup, feed_dict={self.input_position: item['secondary']}) 178 | input = sess.run(self.emebd_concat, feed_dict={self.character_embed_holder: character_embeds, 179 | self.primary_embed_holder: primary_embeds, 180 | self.secondary_embed_holder: secondary_embeds}) 181 | # print(input) 182 | output = np.squeeze(sess.run(self.output, feed_dict={self.input: input})) 183 | target = np.argmax(item['label']) 184 | current = np.argmax(output) 185 | if target == current: 186 | corr_count[target] += 1 187 | prec_count[current] += 1 188 | recall_count[target] += 1 189 | 190 | precs = [c / p for c, p in zip(corr_count, prec_count) if p != 0 and c != 0] 191 | recalls = [c / r for c, r in zip(corr_count, recall_count) if r!= 0 and c != 0] 192 | print(corr_count) 193 | print(recall_count) 194 | print(corr_count) 195 | print(precs) 196 | print(recalls) 197 | prec = sum(precs) / len(precs) 198 | recall = sum(recalls) / len(recalls) 199 | f1 = 2*prec*recall/(prec+recall) 200 | print('precision:', prec) 201 | print('recall:', recall) 202 | print('f1',f1) 203 | 204 | def train_two(): 205 | re_2 = RECNN(window_size=(2,)) 206 | re_2.train() 207 | re_3 = RECNN(window_size=(3,)) 208 | re_3.train() 209 | re_4 = RECNN(window_size=(4,)) 210 | re_4.train() 211 | re_2_3 = RECNN(window_size=(2, 3)) 212 | re_2_3.train() 213 | re_3_4 = RECNN(window_size=(3, 4)) 214 | re_3_4.train() 215 | re_2_3_4 = RECNN(window_size=(2, 3, 4)) 216 | re_2_3_4.train() 217 | 218 | def train_multi(): 219 | re_2 = RECNN(window_size=(2,),relation_count=29) 220 | re_2.train() 221 | re_3 = RECNN(window_size=(3,),relation_count=29) 222 | re_3.train() 223 | re_4 = RECNN(window_size=(4,),relation_count=29) 224 | re_4.train() 225 | re_2_3 = RECNN(window_size=(2, 3),relation_count=29) 226 | re_2_3.train() 227 | re_3_4 = RECNN(window_size=(3, 4),relation_count=29) 228 | re_3_4.train() 229 | re_2_3_4 = RECNN(window_size=(2, 3, 4),relation_count=29) 230 | re_2_3_4.train() 231 | 232 | if __name__ == '__main__': 233 | train_two() 234 | train_multi() 235 | -------------------------------------------------------------------------------- /dnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import math 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | from dnn_base import DNNBase 7 | from preprocess_data import PreprocessData 8 | from config import TrainMode 9 | 10 | 11 | class DNN(DNNBase): 12 | def __init__(self, type = 'mlp', batch_size = 10, batch_length = 224, 13 | mode = TrainMode.Batch, task = 'cws', is_seg = False, 14 | is_embed = False): 15 | tf.reset_default_graph() 16 | DNNBase.__init__(self) 17 | # 参数初始化 18 | self.dtype = tf.float32 19 | self.skip_window_left = 0 20 | self.skip_window_right = 1 21 | self.window_size = self.skip_window_left + self.skip_window_right + 1 22 | # self.vocab_size = 4000 23 | self.embed_size = 100 24 | self.hidden_units = 150 25 | if task == 'cws': 26 | self.tags = [0, 1, 2, 3] 27 | elif task == 'ner': 28 | self.tags = [0, 1, 2] 29 | else: 30 | raise Exception('task name error') 31 | self.is_embed = is_embed 32 | self.tags_count = len(self.tags) 33 | self.concat_embed_size = self.window_size * self.embed_size 34 | self.learning_rate = 0.01 35 | self.lam = 0.0001 36 | self.batch_length = batch_length 37 | self.batch_size = batch_size 38 | self.mode = mode 39 | self.type = type 40 | self.is_seg = is_seg 41 | self.dropout_rate = 0.2 42 | # 数据初始化 43 | pre = PreprocessData('emr_ner', self.mode, force_generate = True) 44 | self.character_batches = pre.character_batches 45 | self.label_batches = pre.label_batches 46 | self.lengths = pre.lengths 47 | self.dictionary = pre.dictionary 48 | self.vocab_size = len(self.dictionary) 49 | # 模型定义和初始化 50 | self.sess = tf.Session() 51 | 52 | initializer = tf.contrib.layers.xavier_initializer(dtype = self.dtype) 53 | if not self.is_embed: 54 | self.embeddings = tf.Variable( 55 | tf.truncated_normal([self.vocab_size, self.embed_size], 56 | stddev = 1.0 / math.sqrt(self.embed_size), 57 | dtype = self.dtype), name = 'embeddings') 58 | # self.embeddings = tf.get_variable('embeddings', [self.vocab_size, self.embed_size], dtype=self.dtype, 59 | # initializer=initializer) 60 | else: 61 | self.embeddings = tf.Variable(np.load('corpus/embed/embeddings.npy'), 62 | dtype = self.dtype, name = 'embeddings') 63 | self.input = tf.placeholder(tf.int32, shape = [None, self.window_size]) 64 | self.label_index_correct = tf.placeholder(tf.int32, shape = [None, 2]) 65 | self.label_index_current = tf.placeholder(tf.int32, shape = [None, 2]) 66 | # self.w = tf.Variable( 67 | # tf.truncated_normal([self.tags_count, self.hidden_units], stddev=1.0 / math.sqrt(self.concat_embed_size), 68 | # dtype=self.dtype), name='w') 69 | self.w = tf.get_variable('w', [self.tags_count, self.hidden_units], 70 | dtype = self.dtype, initializer = initializer) 71 | # self.transition = tf.Variable(tf.random_uniform([self.tags_count, self.tags_count], -0.2, 0.2, dtype=self.dtype)) 72 | # self.transition_init = tf.Variable(tf.random_uniform([self.tags_count], -0.2, 0.2, dtype=self.dtype)) 73 | self.transition = tf.get_variable('transition', 74 | [self.tags_count, self.tags_count], 75 | dtype = self.dtype, 76 | initializer = initializer) 77 | self.transition_init = tf.get_variable('transition_init', [self.tags_count], 78 | dtype = self.dtype, 79 | initializer = initializer) 80 | self.transition_holder = tf.placeholder(self.dtype, 81 | shape = self.transition.get_shape()) 82 | self.transition_init_holder = tf.placeholder(self.dtype, 83 | shape = self.transition_init.get_shape()) 84 | # self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) 85 | self.optimizer = tf.train.AdagradOptimizer(0.02) 86 | # self.optimizer = tf.train.MomentumOptimizer(0.01,0.9) 87 | # self.optimizer = tf.train.AdamOptimizer(0.0001)#,beta1=0.1,beta2=0.001) 88 | self.update_transition = self.transition.assign( 89 | tf.add((1 - self.learning_rate * self.lam) * self.transition, 90 | self.learning_rate * self.transition_holder)) 91 | self.update_transition_init = self.transition_init.assign( 92 | tf.add((1 - self.learning_rate * self.lam) * self.transition_init, 93 | self.learning_rate * self.transition_init_holder)) 94 | self.look_up = tf.reshape( 95 | tf.nn.embedding_lookup(self.embeddings, self.input), 96 | [-1, self.concat_embed_size]) 97 | self.params = [self.w, self.embeddings] 98 | if type == 'mlp': 99 | self.b = tf.Variable(tf.zeros([self.tags_count, 1], dtype = self.dtype), 100 | name = 'b') 101 | self.params.append(self.b) 102 | self.input_embeds = tf.transpose( 103 | tf.reshape(tf.nn.embedding_lookup(self.embeddings, self.input), 104 | [-1, self.concat_embed_size])) 105 | self.hidden_w = tf.Variable( 106 | tf.random_uniform([self.hidden_units, self.concat_embed_size], 107 | 4.0 / math.sqrt(self.concat_embed_size), 108 | 4 / math.sqrt(self.concat_embed_size), 109 | dtype = self.dtype), name = 'hidden_w') 110 | self.hidden_b = tf.Variable( 111 | tf.zeros([self.hidden_units, 1], dtype = self.dtype), name = 'hidden_b') 112 | self.word_scores = tf.matmul(self.w, 113 | tf.sigmoid(tf.matmul(self.hidden_w, 114 | self.input_embeds) + self.hidden_b)) + self.b 115 | self.params += [self.hidden_w, self.hidden_b] 116 | self.loss = tf.reduce_sum( 117 | tf.gather_nd(self.word_scores, self.label_index_current) - 118 | tf.gather_nd(self.word_scores, 119 | self.label_index_correct)) + tf.contrib.layers.apply_regularization( 120 | tf.contrib.layers.l2_regularizer(self.lam), self.params) 121 | elif type == 'lstm': 122 | self.lstm = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units) 123 | self.b = tf.Variable( 124 | tf.zeros([self.tags_count, 1, 1], dtype = self.dtype), name = 'b') 125 | self.params.append(self.b) 126 | if self.mode == TrainMode.Batch: 127 | if not self.is_seg: 128 | self.input = tf.placeholder(tf.int32, shape = [self.batch_size, 129 | self.batch_length, 130 | self.window_size]) 131 | self.input_embeds = tf.reshape( 132 | tf.nn.embedding_lookup(self.embeddings, self.input), 133 | [self.batch_size, self.batch_length, self.concat_embed_size]) 134 | self.input_embeds = tf.layers.dropout(self.input_embeds, 135 | self.dropout_rate) 136 | self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, 137 | self.input_embeds, 138 | dtype = self.dtype) 139 | self.params += [v for v in tf.global_variables() if 140 | v.name.startswith('rnn')] 141 | self.word_scores = tf.tensordot(self.w, 142 | tf.transpose(self.lstm_output), 143 | [[1], [0]]) + self.b 144 | self.label_index_correct = tf.placeholder(tf.int32, shape = [None, 3]) 145 | self.label_index_current = tf.placeholder(tf.int32, shape = [None, 3]) 146 | self.transition_correct_holder = tf.placeholder(tf.int32, [None, 2]) 147 | self.transition_current_holder = tf.placeholder(tf.int32, [None, 2]) 148 | self.transition_init_correct_holder = tf.placeholder(tf.int32, 149 | [None, 1]) 150 | self.transition_init_current_holder = tf.placeholder(tf.int32, 151 | [None, 1]) 152 | self.loss_scores = tf.reduce_sum( 153 | tf.gather_nd(self.word_scores, self.label_index_current) - 154 | tf.gather_nd(self.word_scores, 155 | self.label_index_correct)) + tf.reduce_sum( 156 | tf.gather_nd(self.transition, 157 | self.transition_current_holder) - tf.gather_nd( 158 | self.transition, 159 | self.transition_correct_holder)) 160 | self.loss_scores_with_init = self.loss_scores + tf.reduce_sum( 161 | tf.gather_nd(self.transition_init, 162 | self.transition_init_current_holder) - tf.gather_nd( 163 | self.transition_init, 164 | self.transition_init_correct_holder)) 165 | self.regularization = tf.contrib.layers.apply_regularization( 166 | tf.contrib.layers.l2_regularizer(self.lam), 167 | self.params + [self.transition]) 168 | self.regularization_with_init = tf.contrib.layers.apply_regularization( 169 | tf.contrib.layers.l2_regularizer(self.lam), 170 | self.params + [self.transition, self.transition_init]) 171 | self.loss = self.loss_scores / self.batch_size + self.regularization 172 | self.loss_with_init = self.loss_scores_with_init / self.batch_size + self.regularization_with_init 173 | else: 174 | self.input_embeds = tf.reshape( 175 | tf.nn.embedding_lookup(self.embeddings, self.input), 176 | [1, -1, self.concat_embed_size]) 177 | self.lstm_output, self.lstm_out_state = tf.nn.dynamic_rnn(self.lstm, 178 | self.input_embeds, 179 | dtype = self.dtype) 180 | self.word_scores = tf.matmul(self.w, tf.transpose( 181 | self.lstm_output[-1, :, :])) + self.b[:, :, -1] 182 | 183 | if self.is_seg == False: 184 | gvs = self.optimizer.compute_gradients(self.loss) 185 | cliped_grad = [ 186 | (tf.clip_by_norm(grad, 5) if grad is not None else grad, var) for 187 | grad, var in gvs] 188 | self.train = self.optimizer.apply_gradients( 189 | cliped_grad) # self.optimizer.minimize(self.loss) 190 | if self.is_seg == False and self.type == 'lstm': 191 | gvs2 = self.optimizer.compute_gradients(self.loss_with_init) 192 | cliped_grad2 = [ 193 | (tf.clip_by_norm(grad2, 5) if grad2 is not None else grad2, var2) for 194 | grad2, var2 in gvs2] 195 | self.train_with_init = self.optimizer.apply_gradients(cliped_grad2) 196 | # self.train_with_init = self.optimizer.minimize(self.loss_with_init) 197 | self.saver = tf.train.Saver(max_to_keep = 100) 198 | # self.saver.restore(self.sess, 'tmp/lstm-bbbmodel6.ckpt') 199 | self.sentence_index = 0 200 | 201 | def train_exe(self): 202 | tf.global_variables_initializer().run(session = self.sess) 203 | self.sess.graph.finalize() 204 | epochs = 50 205 | last_time = time.time() 206 | if self.mode == TrainMode.Sentence: 207 | for i in range(epochs): 208 | print('epoch:%d' % i) 209 | for sentence_index, (sentence, labels, length) in enumerate( 210 | zip(self.character_batches, self.label_batches, self.lengths)): 211 | # self.train_sentence(sentence[:length], labels[:length]) 212 | self.train_sentence(sentence, labels) 213 | self.sentence_index = sentence_index 214 | if sentence_index > 0 and sentence_index % 8000 == 0: 215 | print(sentence_index) 216 | print(time.time() - last_time) 217 | last_time = time.time() 218 | if (i + 1) % 10 == 0: 219 | if self.type == 'mlp': 220 | if self.is_embed: 221 | self.saver.save(self.sess, 222 | 'tmp/mlp/mlp-ner-embed-model{0}.ckpt'.format( 223 | i + 1)) 224 | else: 225 | self.saver.save(self.sess, 226 | 'tmp/mlp/mlp-ner-model{0}.ckpt'.format(i + 1)) 227 | elif self.type == 'lstm': 228 | if self.is_embed: 229 | self.saver.save(self.sess, 230 | 'tmp/lstm/lstm-ner-embed-model{0}.ckpt'.format( 231 | i + 1)) 232 | else: 233 | self.saver.save(self.sess, 234 | 'tmp/lstm/lstm-ner-model{0}.ckpt'.format(i + 1)) 235 | elif self.mode == TrainMode.Batch: 236 | for i in range(epochs): 237 | self.step = i 238 | print('epoch:%d' % i) 239 | for batch_index, (character_batch, label_batch, lengths) in enumerate( 240 | zip(self.character_batches, self.label_batches, self.lengths)): 241 | self.train_batch(character_batch, label_batch, lengths) 242 | if batch_index > 0 and batch_index % 100 == 0: 243 | print(batch_index) 244 | print(time.time() - last_time) 245 | last_time = time.time() 246 | if (i + 1) % 10 == 0: 247 | if self.is_embed: 248 | self.saver.save(self.sess, 249 | 'tmp/lstm/lstm-ner-embed-model{0}.ckpt'.format( 250 | i + 1)) 251 | else: 252 | self.saver.save(self.sess, 253 | 'tmp/lstm/lstm-ner-model{0}.ckpt'.format(i + 1)) 254 | 255 | def train_sentence(self, sentence, labels): 256 | scores = self.sess.run(self.word_scores, feed_dict = {self.input: sentence}) 257 | current_labels = self.viterbi(scores, 258 | self.transition.eval(session = self.sess), 259 | self.transition_init.eval( 260 | session = self.sess), labels = labels, 261 | size = 3) 262 | diff_tags = np.subtract(labels, current_labels) 263 | update_index = np.where(diff_tags != 0)[0] 264 | update_length = len(update_index) 265 | 266 | if update_length == 0: 267 | return 268 | 269 | update_labels_pos = np.stack([labels[update_index], update_index], 270 | axis = -1) 271 | update_labels_neg = np.stack([current_labels[update_index], update_index], 272 | axis = -1) 273 | feed_dict = {self.input: sentence, 274 | self.label_index_current: update_labels_neg, 275 | self.label_index_correct: update_labels_pos} 276 | self.sess.run(self.train, feed_dict) 277 | 278 | # 更新转移矩阵 279 | transition_update, transition_init_update, update_init = self.generate_transition_update( 280 | labels, current_labels) 281 | self.sess.run(self.update_transition, 282 | feed_dict = {self.transition_holder: transition_update}) 283 | if update_init: 284 | self.sess.run(self.update_transition_init, feed_dict = { 285 | self.transition_init_holder: transition_init_update}) 286 | 287 | def train_batch(self, sentence_batches, label_batches, lengths): 288 | scores = self.sess.run(self.word_scores, 289 | feed_dict = {self.input: sentence_batches}) 290 | transition = self.transition.eval(session = self.sess) 291 | transition_init = self.transition_init.eval(session = self.sess) 292 | update_labels_pos = None 293 | update_labels_neg = None 294 | current_labels = [] 295 | trans_pos_indices = [] 296 | trans_neg_indices = [] 297 | trans_init_pos_indices = [] 298 | trans_init_neg_indices = [] 299 | for i in range(self.batch_size): 300 | current_label = self.viterbi(scores[:, :lengths[i], i], transition, 301 | transition_init) 302 | # current_label = self.viterbi(scores[:, :lengths[i], i], transition, transition_init, is_constraint=True, 303 | # labels=label_batches[i, :lengths[i]]) 304 | # current_label = self.viterbi_new(scores[:, :lengths[i], i], transition, transition_init, 305 | # label_batches[i, :lengths[i]]) 306 | current_labels.append(current_label) 307 | diff_tag = np.subtract(label_batches[i, :lengths[i]], current_label) 308 | update_index = np.where(diff_tag != 0)[0] 309 | update_length = len(update_index) 310 | if update_length == 0: 311 | continue 312 | update_label_pos = np.stack([label_batches[i, update_index], update_index, 313 | i * np.ones([update_length])], axis = -1) 314 | update_label_neg = np.stack([current_label[update_index], update_index, 315 | i * np.ones([update_length])], axis = -1) 316 | if update_labels_pos is not None: 317 | np.concatenate((update_labels_pos, update_label_pos)) 318 | np.concatenate((update_labels_neg, update_label_neg)) 319 | else: 320 | update_labels_pos = update_label_pos 321 | update_labels_neg = update_label_neg 322 | trans_pos_index, trans_neg_index, trans_init_pos, trans_init_neg, update_init = self.generate_transition_update_index( 323 | label_batches[i, :lengths[i]], current_labels[i]) 324 | trans_pos_indices.extend(trans_pos_index) 325 | trans_neg_indices.extend(trans_neg_index) 326 | if update_init: 327 | trans_init_pos_indices.append(trans_init_pos) 328 | trans_init_neg_indices.append(trans_init_neg) 329 | 330 | if update_labels_pos is not None and update_labels_neg is not None: 331 | feed_dict = {self.input: sentence_batches, 332 | self.label_index_current: update_labels_neg, 333 | self.label_index_correct: update_labels_pos, 334 | self.transition_current_holder: trans_neg_indices, 335 | self.transition_correct_holder: trans_pos_indices} 336 | # self.sess.run(self.train, feed_dict) 337 | 338 | if len(trans_init_pos_indices) == 0: 339 | self.sess.run(self.train, feed_dict) 340 | else: 341 | feed_dict[self.transition_init_correct_holder] = trans_init_pos_indices 342 | feed_dict[self.transition_init_current_holder] = trans_init_neg_indices 343 | self.sess.run(self.train_with_init, feed_dict) 344 | 345 | def seg(self, sentence, model_path = 'tmp/mlp-model0.ckpt', debug = False, 346 | ner = False, trans = False): 347 | tf.global_variables_initializer().run(session = self.sess) 348 | self.saver.restore(self.sess, model_path) 349 | if not trans: 350 | s = self.sentence2index(sentence) 351 | else: 352 | if isinstance(sentence, np.ndarray): 353 | s = sentence.tolist() 354 | else: 355 | s = sentence 356 | seq = self.index2seq(s) 357 | 358 | sentence_scores = self.sess.run(self.word_scores, 359 | feed_dict = {self.input: seq}) 360 | transition_init = self.transition_init.eval(session = self.sess) 361 | transition = self.transition.eval(session = self.sess) 362 | if debug: 363 | print(transition) 364 | embeds = self.sess.run(self.look_up, feed_dict = {self.input: seq}) 365 | print(sentence_scores) 366 | if self.type == 'lstm': 367 | output = self.sess.run(self.lstm_output, feed_dict = {self.input: seq}) 368 | print(output[-1, :, 10]) 369 | print(self.transition_init.eval(session = self.sess)) 370 | current_labels = self.viterbi(sentence_scores, transition, transition_init) 371 | if not ner: 372 | return self.tags2words(sentence, current_labels), current_labels 373 | else: 374 | # return self.tags2entities(sentence, current_labels), current_labels 375 | return None, current_labels 376 | # return self.tags2category_entities(sentence, current_labels), current_labels 377 | 378 | 379 | if __name__ == '__main__': 380 | mlp = DNN('mlp', mode = TrainMode.Sentence, task = 'ner') 381 | mlp.train_exe() 382 | mlp_embed = DNN('mlp', mode = TrainMode.Sentence, task = 'ner', 383 | is_embed = True) 384 | mlp_embed.train_exe() 385 | lstm = DNN('lstm', task = 'ner') 386 | lstm.train_exe() 387 | lstm_embed = DNN('lstm', task = 'ner', is_embed = True) 388 | lstm_embed.train_exe() 389 | -------------------------------------------------------------------------------- /prepare_data_emr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import numpy as np 4 | from collections import OrderedDict 5 | import pickle 6 | from itertools import chain 7 | from utils import plot_lengths 8 | from evaluate import estimate_ner 9 | 10 | 11 | class PrepareDataNer(): 12 | def __init__(self, entity_batch_length=224, relation_batch_length=85, entity_batch_size=10, relation_batch_size=50): 13 | self.entity_tags = {'O': 0, 'B': 1, 'I': 2, 'P': 3} 14 | self.reversed_tags = dict(zip(self.entity_tags.values(),self.entity_tags.keys())) 15 | self.entity_categories = {'Sign': 'SN', 'Symptom': 'SYM', 'Part': 'PT', 'Property': 'PTY', 'Degree': 'DEG', 16 | 'Quality': 'QLY', 'Quantity': 'QNY', 'Unit': 'UNT', 'Time': 'T', 'Date': 'DT', 17 | 'Result': 'RES', 18 | 'Disease': 'DIS', 'DiseaseType': 'DIT', 'Examination': 'EXN', 'Location': 'LOC', 19 | 'Medicine': 'MED', 'Spec': 'SPEC', 'Usage': 'USG', 'Dose': 'DSE', 'Treatment': 'TRT', 20 | 'Family': 'FAM', 21 | 'Modifier': 'MOF'} 22 | self.entity_category_labels = OrderedDict({'O': 0}) 23 | entity_category_index = 1 24 | for category in self.entity_categories: 25 | self.entity_category_labels[self.entity_categories[category] + '_B'] = entity_category_index 26 | entity_category_index += 1 27 | self.entity_category_labels[self.entity_categories[category] + '_O'] = entity_category_index 28 | entity_category_index += 1 29 | self.entity_category_labels['P'] = entity_category_index 30 | self.entity_labels_count = len(self.entity_tags) 31 | self.relation_categories = {'PartOf': '部位', 'PropertyOf': '性质', 'DegreeOf': '程度', 'QualityValue': '定性值', 32 | 'QuantityValue': '定量值', 'UnitOf': '单位', 'TimeOf': '持续时间', 'StartTime': '开始时间', 33 | 'EndTime': '结束时间', 'Moment': '时间点', 'DateOf': '日期', 'ResultOf': '结果', 34 | 'LocationOf': '地点', 'DiseaseTypeOf': '疾病分型分期', 'SpecOf': '规格', 'UsageOf': '用法', 35 | 'DoseOf': '用量', 'FamilyOf': '家族成员', 'ModifierOf': '其他修饰词', 'UseMedicine': '用药', 36 | 'LeadTo': '导致', 'Find': '发现', 'Confirm': '证实', 'Adopt': '采取', 'Take': '用药', 37 | 'Limit': '限定', 'AlongWith': '伴随', 'Complement': '补足'} 38 | with open('data/rel_pairs', 'rb') as pairs_file: 39 | self.relation_constraint = pickle.load(pairs_file) 40 | self.relation_category_labels = {'NoRelation': 0} 41 | relation_category_index = 1 42 | for relation_category in self.relation_categories: 43 | self.relation_category_labels[relation_category] = relation_category_index 44 | relation_category_index += 1 45 | self.relation_category_label_count = len(self.relation_category_labels) 46 | self.relation_labels = {'Y': 1, 'N': 0} 47 | self.relation_label_count = len(self.relation_labels) 48 | self.base_folder = 'corpus/emr_paper/train/' 49 | self.test_base_folder = 'corpus/emr_paper/test/' 50 | self.filenames = [] 51 | self.test_filenames = [] 52 | self.ext_dict_path = ['corpus/msr_dict.utf8', 'corpus/pku_dict.utf8'] 53 | self.dict_path = 'corpus/emr_ner_dict.utf8' 54 | self.words_dict_path = 'corpus/emr_words_dict.utf8' 55 | self.entity_batch_length = entity_batch_length 56 | self.relation_batch_length = relation_batch_length 57 | self.entity_batch_size = entity_batch_size 58 | self.relation_batch_size = relation_batch_size 59 | for _, _, filenames in os.walk(self.base_folder): 60 | for filename in filenames: 61 | filename, _ = os.path.splitext(filename) 62 | if filename not in self.filenames: 63 | self.filenames.append(filename) 64 | for _, _, filenames in os.walk(self.test_base_folder): 65 | for filename in filenames: 66 | filename, _ = os.path.splitext(filename) 67 | if filename not in self.test_filenames: 68 | self.test_filenames.append(filename) 69 | self.words = set() 70 | self.content = '' 71 | e_categories = ['Sign', 'Part', 'Quantity'] 72 | r_categories = ['PartOf', 'QuantityValue'] 73 | self.annotations = self.read_annotation(self.base_folder, self.filenames, e_categories, r_categories) 74 | self.test_annotations = self.read_annotation(self.test_base_folder, self.test_filenames, e_categories, r_categories) 75 | self.dictionary, self.reverse_dictionary = self.build_dictionary() 76 | self.words_dictionary = self.build_words_dictionary() 77 | # 二分类 78 | _, _, self.all_relations, _ = self.build_dataset(self.filenames, self.annotations, is_entity_category=False) 79 | # 多分类 80 | self.characters, self.entity_labels, self.relations, _ = self.build_dataset(self.filenames, self.annotations, 81 | is_entity_category=False, 82 | is_negative_relation=False, 83 | is_relation_category=True) 84 | 85 | self.test_characters, self.test_entity_labels, _, self.test_all_relations = self.build_dataset( 86 | self.test_filenames, 87 | self.test_annotations, 88 | is_entity_category=False) 89 | _, _, self.test_relations, _ = self.build_dataset(self.test_filenames, self.test_annotations, 90 | is_entity_category=False, 91 | is_negative_relation=False, 92 | is_relation_category=True) 93 | # self.plot_words_sentences() 94 | self.export_coll(self.characters,self.entity_labels,'corpus/emr_training.conll') 95 | self.export_coll(self.test_characters, self.test_entity_labels, 'corpus/emr_test.conll') 96 | exit(1) 97 | np.save('corpus/emr_ner_training_characters', self.characters) 98 | np.save('corpus/emr_ner_training_labels', self.entity_labels) 99 | np.save('corpus/emr_ner_test_characters', self.test_characters) 100 | np.save('corpus/emr_ner_test_labels', self.test_entity_labels) 101 | with open('corpus/emr_training_relations.rel', 'wb') as f: 102 | pickle.dump(self.relations, f) 103 | 104 | extra_count = len(self.characters) % self.entity_batch_size 105 | lengths = np.array(list(map(lambda item: len(item), self.characters[:-extra_count])), np.int32).reshape( 106 | [-1, self.entity_batch_size]) 107 | np.save('corpus/emr_ner_training_lengths_batches', lengths) 108 | self.character_batches, self.label_batches = self.build_entity_batch() 109 | np.save('corpus/emr_ner_training_character_batches', self.character_batches) 110 | np.save('corpus/emr_ner_training_label_batches', self.label_batches) 111 | self.train_relation_batches = self.build_relation_batch(self.relations, self.relation_batch_size) 112 | self.all_relation_batches = self.build_relation_batch(self.all_relations, self.relation_batch_size) 113 | self.test_all_relation_batches = self.build_relation_batch(self.test_all_relations, 1) 114 | self.test_relation_batches = self.build_relation_batch(self.test_relations, 1) 115 | with open('corpus/emr_relation_batches.rel', 'wb') as f: 116 | pickle.dump(self.train_relation_batches, f) 117 | with open('corpus/emr_all_relation_batches.rel', 'wb') as f: 118 | pickle.dump(self.all_relation_batches, f) 119 | with open('corpus/emr_test_relations.rel', 'wb') as f: 120 | pickle.dump(self.test_relation_batches, f) 121 | with open('corpus/emr_test_all_relations.rel', 'wb') as f: 122 | pickle.dump(self.test_all_relation_batches, f) 123 | 124 | def export_coll(self,characters,labels,src_file): 125 | text = '' 126 | for character,label in zip(characters,labels): 127 | chs = [self.reverse_dictionary[c] for c in character] 128 | lbs = [self.reversed_tags[l] for l in label] 129 | text += '\n'.join([' '.join(l) for l in zip(chs,lbs)]) 130 | text += '\n\n' 131 | 132 | with open(src_file, 'w',encoding='utf-8') as f: 133 | f.write(text) 134 | 135 | def read_annotation(self, base_folder, filenames, e_categories, r_categories): 136 | annotation = {} 137 | for filename in filenames: 138 | with open(base_folder + filename + '.txt', encoding='utf8') as raw_file: 139 | raw_text = raw_file.read().replace('\n', '\r\n') 140 | self.content += raw_text 141 | with open(base_folder + filename + '.ann', encoding='utf8') as annotation_file: 142 | results = annotation_file.read().replace('\t', ' ').splitlines() 143 | annotation_results = {'entity': {}, 'relations': [], 'entity_start': {}, 'cws': {}} 144 | 145 | for result in results: 146 | sections = result.split(' ') 147 | if sections[0][0] == 'T': 148 | if sections[1] in e_categories: 149 | entity = {'id': sections[0], 'category': sections[1], 'start': int(sections[2]), 'end': int(sections[3]), 150 | 'content': sections[4]} 151 | annotation_results['entity_start'][int(sections[2])] = {'id': sections[0]} 152 | annotation_results['entity'][sections[0]] = entity 153 | elif sections[0][0] == 'R': 154 | if sections[1] in r_categories: 155 | relation = {'id': sections[0], 'category': sections[1], 'primary': sections[2].split(':')[-1], 156 | 'secondary': sections[3].split(':')[-1]} 157 | annotation_results['relations'].append(relation) 158 | with open(base_folder + filename + '.cws', encoding='utf8') as cws_file: 159 | words = cws_file.read().strip().split(' ') 160 | lengths = [0] 161 | 162 | for i, w in enumerate(words): 163 | lengths.append(lengths[-1] + len(w)) 164 | words[i] = words[i].replace('\n', '') 165 | self.words.add(words[i]) 166 | 167 | # 验证 168 | for e in annotation_results['entity'].values(): 169 | s = e['start'] 170 | end = e['end'] 171 | if s in lengths and end in lengths: 172 | if lengths.index(end) - lengths.index(s) != 1: 173 | print(filename) 174 | print(e) 175 | 176 | annotation_results['cws']['words'] = words 177 | annotation_results['cws']['words_index'] = lengths 178 | annotation[filename] = {'raw': raw_text, 'annotation': annotation_results} 179 | print('datasets summary:') 180 | print('entities count', len(annotation_results['entity'].values()), ' relation count', 181 | len(annotation_results['relations'])) 182 | return annotation 183 | 184 | def build_dictionary(self): 185 | dictionary = {} 186 | characters = [] 187 | for dict_path in self.ext_dict_path: 188 | d = self.read_dictionary(dict_path) 189 | characters.extend(d.keys()) 190 | 191 | # print(len(list(content)) / 1024) 192 | characters.extend(list(self.content.replace('\r\n', ''))) 193 | characters = list( 194 | filter(lambda ch: ch != 'UNK' and ch != 'STRT' and ch != 'END' and ch != 'BATCH_PAD', set(characters))) 195 | dictionary['BATCH_PAD'] = 0 196 | dictionary['UNK'] = 1 197 | dictionary['STRT'] = 2 198 | dictionary['END'] = 3 199 | for index, character in enumerate(characters, 3): 200 | dictionary[character] = index 201 | 202 | with open(self.dict_path, 'w', encoding='utf8') as dict_file: 203 | for character in dictionary: 204 | dict_file.write(character + ' ' + str(dictionary[character]) + '\n') 205 | return dictionary, dict(zip(dictionary.values(), dictionary.keys())) 206 | 207 | def build_words_dictionary(self): 208 | words = set() 209 | words_dictionary = {'BATCH_PAD': 0, 'UNK': 1} 210 | 211 | with open(self.words_dict_path, 'w', encoding='utf8') as dict_path: 212 | dict_path.write('BATCH_PAD 0\n') 213 | dict_path.write('UNK 1\n') 214 | for w in self.words: 215 | if len(w) > 0: 216 | words_dictionary[w] = len(words_dictionary) 217 | dict_path.write(w + ' ' + str(words_dictionary[w]) + '\n') 218 | 219 | return words_dictionary 220 | 221 | @staticmethod 222 | def read_dictionary(dict_path): 223 | dict_file = open(dict_path, 'r', encoding='utf-8') 224 | dict_content = dict_file.read().splitlines() 225 | dictionary = {} 226 | dict_arr = map(lambda item: item.split(' '), dict_content) 227 | for _, dict_item in enumerate(dict_arr): 228 | dictionary[dict_item[0]] = int(dict_item[1]) 229 | dict_file.close() 230 | return dictionary 231 | 232 | def build_dataset(self, filenames, ann, is_entity_category=False, is_relation_category=False, 233 | is_negative_relation=True): 234 | rn = ['\r', '\n'] 235 | seg = [self.dictionary['。']] 236 | seg_in_sentence = [self.dictionary[',']] 237 | word_seg = [self.words_dictionary['。']] 238 | word_seg_in_sentence = [self.words_dictionary[',']] 239 | characters_index = [] 240 | entity_labels = [] 241 | all_relations = {} 242 | relations = {} 243 | pos = 0 244 | neg = 0 245 | all_neg = 0 246 | max_len = 0 247 | 248 | for filename in filenames: 249 | raw_text = ann[filename]['raw'] 250 | annotations = ann[filename]['annotation'] 251 | cws_list = annotations['cws']['words'] 252 | cws_list_index = annotations['cws']['words_index'] 253 | entity_start = annotations['entity_start'] 254 | all_entities = annotations['entity'] 255 | character_index = [] 256 | word_index = [] 257 | entity_label = [self.entity_tags['O']] * len(raw_text) 258 | rn_index = [] 259 | relation = {} 260 | primary_entity = [] 261 | seg_index = [0] # 分隔符的字索引 262 | word_seg_index = [0] # 分隔符的词索引 263 | 264 | for index, character in enumerate(raw_text): 265 | if character in rn: 266 | rn_index.append(index) 267 | elif character not in self.dictionary: 268 | character_index.append(1) 269 | else: 270 | character_index.append(self.dictionary[character]) 271 | 272 | for index, word in enumerate(cws_list): 273 | if word not in self.words_dictionary: 274 | word_index.append(1) 275 | else: 276 | word_index.append(self.words_dictionary[word]) 277 | 278 | for entity_annotation in annotations['entity'].values(): 279 | start = entity_annotation['start'] 280 | end = entity_annotation['end'] 281 | content = entity_annotation['content'] 282 | type = entity_annotation['category'] 283 | if is_entity_category: 284 | entity_label[start] = self.entity_category_labels[self.entity_categories[type] + '_B'] 285 | if len(content) > 1: 286 | entity_label[start + 1:end] = [self.entity_category_labels[self.entity_categories[type] + '_O']] * ( 287 | end - start - 1) 288 | else: 289 | entity_label[start] = self.entity_tags['B'] 290 | if len(content) > 1: 291 | entity_label[start + 1:end] = [self.entity_tags['I']] * (end - start - 1) 292 | 293 | for relation_annotation in annotations['relations']: 294 | id = relation_annotation['id'] 295 | type = relation_annotation['category'] 296 | primary = relation_annotation['primary'] 297 | secondary = relation_annotation['secondary'] 298 | relation[primary] = (secondary, type, id) 299 | primary_entity.append(primary) 300 | 301 | # 处理回车 302 | if len(rn_index) != 0: 303 | entity_label = [l[1] for l in filter(lambda ch_item: ch_item[0] not in rn_index, enumerate(entity_label))] 304 | 305 | # 分割 306 | doc_length = len(character_index) 307 | for index, ch_index in enumerate(character_index): 308 | if ch_index in seg: 309 | if index != doc_length - 1 and self.dictionary['”'] != character_index[index + 1] : 310 | seg_index.append(index + 1) 311 | if seg_index[-1] != doc_length: 312 | seg_index.append(doc_length) 313 | 314 | words_length = len(word_index) 315 | for i, w in enumerate(word_index): 316 | if w in word_seg: 317 | if i != words_length - 1 and self.words_dictionary['”'] != word_index[i + 1]: 318 | word_seg_index.append(i) 319 | if word_seg_index[-1] != words_length: 320 | word_seg_index.append(words_length) 321 | 322 | # 检验 323 | if len(seg_index) != len(word_seg_index): 324 | print(filename) 325 | print(len(seg_index) - len(word_seg_index)) 326 | 327 | for sentence_index, (cur_index, latter_index, cur_word_index, latter_word_index) in enumerate( 328 | zip(seg_index[:-1], seg_index[1:], 329 | word_seg_index[:-1], word_seg_index[1:])): 330 | sentence_id = filename + '-' + str(sentence_index) 331 | # 寻找最长句子 332 | if max_len < latter_word_index - cur_word_index: 333 | max_len = latter_word_index - cur_word_index 334 | 335 | # 以句号分隔的句子中每个字的索引 336 | characters_index.append(np.array(character_index[cur_index:latter_index], dtype=np.int32)) 337 | # 每个字对应的实体标签 338 | entity_labels.append(np.array(entity_label[cur_index:latter_index], dtype=np.int32)) 339 | 340 | # 处理关系 341 | entity_dict = {} # 每个句子中所有实体字典,键为实体id,值为实体在句子中的索引 342 | positive_relations = [] # 训练用关系的'hash',primary_id < secondary_id 343 | current_relations = [] # 已添加的关系`hash`,防止无序关系添加两次 344 | current_all_relations = [] 345 | sentence_word_index = [] # 句子中每个词在词典中的索引 346 | all_positive_relations = [] # 未处理的关系的hash 347 | 348 | for ii, i in enumerate(cws_list_index[cur_word_index:latter_word_index]): 349 | sentence_word_index.append(self.words_dictionary[cws_list[cws_list_index.index(i)]]) 350 | if entity_start.get(i) != None: 351 | entity_dict[entity_start[i]['id']] = ii 352 | 353 | arr = np.arange(0, latter_word_index - cur_word_index) + self.relation_batch_length - 1 # 位置索引baseline 354 | for primary_id in [e for e in entity_dict if e in primary_entity]: 355 | secondary_id = relation[primary_id][0] 356 | type = relation[primary_id][1] 357 | if is_relation_category: 358 | relation_label = [0] * self.relation_category_label_count 359 | relation_label[self.relation_category_labels[type]] = 1 360 | else: 361 | relation_label = [0, 1] 362 | 363 | primary = entity_dict[primary_id] 364 | if entity_dict.get(secondary_id) is not None: 365 | secondary = entity_dict[secondary_id] 366 | # 无向 367 | if primary_id > secondary_id: 368 | positive_relations.append(secondary_id + ':' + primary_id) 369 | else: 370 | positive_relations.append(primary_id + ':' + secondary_id) 371 | all_positive_relations.append(primary_id + ':' + secondary_id) 372 | relation_item = {'sentence': np.array(word_index[cur_word_index:latter_word_index], dtype=np.int32), 373 | 'primary': arr - primary, 'secondary': arr - secondary, 374 | 'label': relation_label} 375 | # train_relations.append(relation_item) 376 | if relations.get(sentence_id) is None: 377 | relations[sentence_id] = [relation_item] 378 | else: 379 | relations[sentence_id].append(relation_item) 380 | if all_relations.get(sentence_id) is None: 381 | all_relations[sentence_id] = [relation_item] 382 | else: 383 | all_relations[sentence_id].append(relation_item) 384 | 385 | pos += len(positive_relations) 386 | entities = list(entity_dict.keys()) 387 | # 添加非关系,可认为是负采样 388 | distance = 8 389 | if is_negative_relation: 390 | for entity_i, entity in enumerate(entities): 391 | secondaries = [] 392 | all_secondaries = [] 393 | for s in entities[:entity_i] + entities[entity_i + 1:]: 394 | secondary_constraint = self.relation_constraint.get(all_entities[entity]['category']) 395 | if secondary_constraint is None or all_entities[s]['category'] not in secondary_constraint: 396 | continue 397 | 398 | if entity < s: 399 | first, second = entity, s 400 | else: 401 | first, second = s, entity 402 | 403 | first_index, second_index = entity_dict[entity], entity_dict[s] 404 | if first_index > second_index: 405 | first_index, second_index = second_index, first_index 406 | for i in sentence_word_index[first_index:second_index + 1]: 407 | if i in word_seg_index: 408 | second_index = i - 1 409 | 410 | rel_hash = first + ':' + second 411 | if rel_hash not in positive_relations and rel_hash not in current_relations: 412 | if abs(entity_dict[first] - entity_dict[second]) < distance: 413 | secondaries.append(s) 414 | current_relations.append(rel_hash) 415 | if rel_hash not in positive_relations and rel_hash not in current_all_relations: 416 | all_secondaries.append(s) 417 | current_all_relations.append(rel_hash) 418 | # all_secondaries = [s for s in entities[:entity_i] + entities[entity_i + 1:] 419 | # if entity + ':' + s not in all_positive_relations] 420 | primary_start = entity_dict[entity] 421 | neg += len(secondaries) 422 | all_neg += len(all_secondaries) 423 | 424 | for s in secondaries: 425 | if is_relation_category: 426 | relation_label = [0] * self.relation_category_label_count 427 | relation_label[self.relation_category_labels['NoRelation']] = 1 428 | else: 429 | relation_label = [1, 0] 430 | relation_item = {'sentence': np.array(word_index[cur_word_index:latter_word_index], dtype=np.int32), 431 | 'primary': arr - primary_start, 'secondary': arr - entity_dict[s], 432 | 'label': relation_label} 433 | # train_relations.append(relation_item) 434 | if relations.get(sentence_id) is None: 435 | relations[sentence_id] = [relation_item] 436 | else: 437 | relations[sentence_id].append(relation_item) 438 | for s in all_secondaries: 439 | if is_relation_category: 440 | relation_label = [0] * self.relation_category_label_count 441 | relation_label[self.relation_category_labels['NoRelation']] = 1 442 | else: 443 | relation_label = [1, 0] 444 | relation_item = {'sentence': np.array(word_index[cur_word_index:latter_word_index], dtype=np.int32), 445 | 'primary': arr - primary_start, 'secondary': arr - entity_dict[s], 446 | 'label': relation_label} 447 | if all_relations.get(sentence_id) is None: 448 | all_relations[sentence_id] = [relation_item] 449 | else: 450 | all_relations[sentence_id].append(relation_item) 451 | 452 | print(neg / (pos + neg)) 453 | # print(all_neg / (pos + all_neg)) 454 | train_relations = [r for rs in relations.values() for r in rs] 455 | all_relations = [r for rs in all_relations.values() for r in rs] 456 | for i, chs in enumerate(characters_index): 457 | sentence = '' 458 | for ch in chs: 459 | sentence += self.reverse_dictionary[ch] 460 | return np.array(characters_index), np.array(entity_labels), train_relations, all_relations 461 | 462 | def plot_words_sentences(self): 463 | lengths = list(map(lambda r: len(r['sentence']), self.relations)) 464 | lengths.sort() 465 | plot_lengths(lengths) 466 | 467 | def build_entity_batch(self, category=False): 468 | characters = [] 469 | labels = [] 470 | for line_characters, line_labels in zip(self.characters, self.entity_labels): 471 | length = len(line_characters) 472 | if length >= self.entity_batch_length: 473 | characters.append(line_characters[:self.entity_batch_length]) 474 | labels.append(line_labels[:self.entity_batch_length]) 475 | else: 476 | characters.append( 477 | line_characters.tolist() + [self.dictionary['BATCH_PAD']] * (self.entity_batch_length - length)) 478 | if category: 479 | labels.append(line_labels.tolist() + [self.entity_category_labels['P']] * (self.entity_batch_length - length)) 480 | else: 481 | labels.append(line_labels.tolist() + [self.entity_tags['P']] * (self.entity_batch_length - length)) 482 | extra_count = len(characters) % self.entity_batch_size 483 | characters = np.array(characters[:-extra_count], np.int32).reshape( 484 | [-1, self.entity_batch_size, self.entity_batch_length]) 485 | labels = np.array(labels[:-extra_count], np.int32).reshape([-1, self.entity_batch_size, self.entity_batch_length]) 486 | return characters, labels 487 | 488 | def build_relation_batch(self, relations, batch_size): 489 | relation_batches = [] 490 | sentence_batch = [] 491 | primary_batch = [] 492 | secondary_batch = [] 493 | label_batch = [] 494 | index = 0 495 | for relation in relations: 496 | sentence = relation['sentence'].tolist() 497 | if len(sentence) > self.relation_batch_length: 498 | sentence = sentence[:self.relation_batch_length] 499 | else: 500 | sentence.extend([self.dictionary['BATCH_PAD']] * (self.relation_batch_length - len(sentence))) 501 | primary = relation['primary'].tolist() 502 | if len(primary) > self.relation_batch_length: 503 | primary = primary[:self.relation_batch_length] 504 | else: 505 | primary.extend(range(primary[-1] + 1, primary[-1] + 1 + self.relation_batch_length - len(primary))) 506 | secondary = relation['secondary'].tolist() 507 | if len(secondary) > self.relation_batch_length: 508 | secondary = secondary[:self.relation_batch_length] 509 | else: 510 | secondary.extend(range(secondary[-1] + 1, secondary[-1] + 1 + self.relation_batch_length - len(secondary))) 511 | sentence_batch.append(sentence) 512 | primary_batch.append(primary) 513 | secondary_batch.append(secondary) 514 | label_batch.append(relation['label']) 515 | index += 1 516 | if batch_size != 1: 517 | if index > 0 and index % self.relation_batch_size == 0: 518 | batch = {'sentence': np.array(sentence_batch, np.int32), 'primary': np.array(primary_batch, np.int32), 519 | 'secondary': np.array(secondary_batch, np.int32), 'label': np.array(label_batch, np.float32)} 520 | relation_batches.append(batch) 521 | sentence_batch.clear() 522 | primary_batch.clear() 523 | secondary_batch.clear() 524 | label_batch.clear() 525 | index = 0 526 | else: 527 | batch = {'sentence': np.array(sentence_batch, np.int32), 'primary': np.array(primary_batch, np.int32), 528 | 'secondary': np.array(secondary_batch, np.int32), 'label': np.array(label_batch, np.float32)} 529 | relation_batches.append(batch) 530 | sentence_batch.clear() 531 | primary_batch.clear() 532 | secondary_batch.clear() 533 | label_batch.clear() 534 | return relation_batches 535 | 536 | 537 | def prepare_for_crfpp(folder, output_name): 538 | content = [] 539 | filenames = set() 540 | for _, _, names in os.walk(folder): 541 | for filename in names: 542 | name, _ = os.path.splitext(filename) 543 | if name not in filenames: 544 | filenames.add(name) 545 | for filename in filenames: 546 | path = folder + filename 547 | with open(path + '.txt', encoding='utf-8') as src_file: 548 | raw_text = src_file.read().replace('\n', '\r\n') 549 | labels = len(raw_text) * ['O'] 550 | with open(path + '.ann', encoding='utf-8') as ann_file: 551 | ann_items = ann_file.read().splitlines() 552 | for item in ann_items: 553 | sections = item.split('\t') 554 | if sections[0].startswith('T'): 555 | pos = sections[1].split(' ') 556 | start, end = int(pos[1]), int(pos[2]) 557 | labels[start] = 'B' 558 | if end - start - 1 > 0: 559 | labels[start + 1:end] = ['I'] * (end - start - 1) 560 | for ch, l in zip(raw_text, labels): 561 | if ch == '\r': 562 | continue 563 | if ch == '。': 564 | content.append(ch + '\t' + l + '\n') 565 | else: 566 | content.append(ch + '\t' + l) 567 | with open(output_name, mode='w', encoding='utf-8') as o: 568 | o.write('\n'.join(content)) 569 | 570 | 571 | def evaluate_ner(path): 572 | with open(path, encoding='utf-8') as f: 573 | entries = map(lambda l: l.split('\t'), [l for l in f.read().splitlines() if l]) 574 | res = list(zip(*entries)) 575 | label_map = {'O': 0, 'B': 1, 'I': 2} 576 | correct = list(map(lambda l: label_map[l], res[1])) 577 | current = list(map(lambda l: label_map[l], res[2])) 578 | corr, p_count, r_count = estimate_ner(current, correct) 579 | p = corr / p_count 580 | r = corr / r_count 581 | f1 = 2 * p * r / (p + r) 582 | print('precision:', p) 583 | print('recall:', r) 584 | print('f1', f1) 585 | 586 | 587 | if __name__ == '__main__': 588 | # PrepareDataNer() 589 | # train_folder = 'corpus/emr_paper/train/' 590 | # test_folder = 'corpus/emr_paper/test/' 591 | # prepare_for_crfpp(test_folder,'corpus/test.data') 592 | # prepare_for_crfpp(train_folder, 'corpus/train.data') 593 | evaluate_ner('D:\Learning\master_project\clinicalText\CRF++-0.58\\res.data') 594 | evaluate_ner('D:\Learning\master_project\clinicalText\CRF++-0.58\\res_slim.data') 595 | --------------------------------------------------------------------------------