├── .gitignore ├── Configure.py ├── README.md ├── __init__.py ├── build_data.py ├── char_representation.png ├── config.cfg ├── data ├── dev.txt ├── test.txt ├── train.txt └── words.txt ├── evaluate.py ├── makefile ├── model.png ├── model ├── Dataset.py ├── Vocab_Util.py ├── __init__.py ├── base_model.py ├── log_util.py └── ner_model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/glove* 2 | data/words.txt/ 3 | data/chars.txt 4 | data/tags.txt 5 | results/* 6 | *.pyc -------------------------------------------------------------------------------- /Configure.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | #配置文件 3 | 4 | class Configure(object): 5 | """用来保存配置和超参数的文件""" 6 | 7 | ''' 8 | 读取配置文件 9 | ''' 10 | def __init__(self, conf_file = 'config.cfg'): 11 | self.confs = {} 12 | for line in open(conf_file): 13 | line = line.strip().split() 14 | if len(line) < 3: 15 | continue 16 | key, value, type = line 17 | self.confs[key] = eval(type +"('" + value + "')" ) 18 | 19 | 20 | def __setitem__(self, key, value): 21 | self.confs[key] = value 22 | 23 | def __getitem__(self, key): 24 | return self.confs[key] 25 | 26 | if __name__ == '__main__': 27 | configure = Configure('./config.cfg') 28 | print configure['train_data'] 29 | print configure['test_data'] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SequenceTagging 2 | 自然语言处理中的序列标注实现,包括 3 | * glove词向量处理 4 | * vocab生成 5 | * Bi-LSTM + CRF 模型 6 | * early stop 7 | * learning rate decay 8 | 9 | 模型采用了Bi-LSTM + CRF,请见![](model.png) 10 | 其中采用了字符和Bi-LSTM encoding组成了单词的向量表示,如下图所示 11 | ![](char_representation.png) 12 | 13 | 运行方式: 14 | 1. make glove 15 | 下载glove 向量文件,并且进行解压 16 | 2. make data 17 | 处理train/dev/test数据,构建vocab和trimmed_vector文件 18 | 3. make train 19 | 构建模型,并进行训练,这里采用了学习率指数递减和early stop机制 20 | 4. make evaluate 21 | 用训练得到的模型对test数据集进行测试 22 | 23 | 配置文件请见 config.cfg 24 | 25 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xixy/SequenceTagging/2433b4f735d49aa6b2159f5ceeea341bc97a1cfe/__init__.py -------------------------------------------------------------------------------- /build_data.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from model.Dataset import * 3 | 4 | from Configure import * 5 | from model.Vocab_Util import * 6 | 7 | def main(): 8 | ''' 9 | 完成数据的预处理 10 | ''' 11 | configure = Configure('./config.cfg') 12 | 13 | processing_word = get_processing_word(lowercase=True) 14 | 15 | # 构造dataset 16 | train_dataset = Dataset(configure['train_data'], processing_word = processing_word) 17 | dev_dataset = Dataset(configure['train_data'], processing_word = processing_word) 18 | test_dataset = Dataset(configure['train_data'], processing_word = processing_word) 19 | 20 | # 构造word和tag的vocab 21 | vocab_util = Vocab_Util() 22 | vocab_words, vocab_tags = vocab_util.get_vocabs_from_datasets([train_dataset, dev_dataset, test_dataset]) 23 | # 构造词向量中的词 24 | vocab_glove = vocab_util.get_vocabs_from_glove(configure['glove_file']) 25 | 26 | # 取交集,同时出现在词向量词典和数据集中的词 27 | vocab_words = vocab_words & vocab_glove 28 | # 加入UNK和数字NUM 29 | vocab_words.add(UNK) 30 | vocab_words.add(NUM) 31 | 32 | # 保存单词和tag的vocab文件 33 | vocab_util.write_vocab(vocab_words, configure['word_vocab_file']) 34 | vocab_util.write_vocab(vocab_tags, configure['tag_vocab_file']) 35 | 36 | # 获取Trim Glove Vectors,并存储 37 | vocab = vocab_util.load_vocab(configure['word_vocab_file']) 38 | vocab_util.export_trimmed_glove_vectors(vocab, configure['glove_file'], 39 | configure['trimmed_file'], configure['word_embedding_dim']) 40 | 41 | # 构造char vocab, 并且进行存储 42 | train_dataset = Dataset(configure['train_data']) 43 | vocab_chars = vocab_util.get_char_vocab_from_datasets(train_dataset) 44 | vocab_util.write_vocab(vocab_chars, configure['char_vocab_file']) 45 | 46 | 47 | 48 | 49 | 50 | # 保存vocab 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /char_representation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xixy/SequenceTagging/2433b4f735d49aa6b2159f5ceeea341bc97a1cfe/char_representation.png -------------------------------------------------------------------------------- /config.cfg: -------------------------------------------------------------------------------- 1 | train_data ./data/train.txt str 2 | test_data ./data/test.txt str 3 | dev_data ./data/dev.txt str 4 | word_embedding_dim 300 int 5 | char_embedding_dim 100 int 6 | glove_file ./data/glove.6B/glove.6B.300d.txt str 7 | trimmed_file ./data/glove.6B.300d.trimmed.npz str 8 | word_vocab_file ./data/words.txt str 9 | tag_vocab_file ./data/tags.txt str 10 | char_vocab_file ./data/chars.txt str 11 | 12 | dir_model ./results/test/model.weights/ str 13 | path_log ./results/test/log.txt str 14 | dir_output ./results/test str 15 | 16 | training_embeddings False bool 17 | epochs 15 int 18 | dropout 0.5 float 19 | batch_size 2 int 20 | optimizer adam str 21 | learning_rate 0.01 float 22 | learning_rate_decay 0.9 float 23 | num_of_epoch_no_imprv 3 int 24 | clip -1 int 25 | hidden_size_char 100 int 26 | hidden_size_lstm 300 int 27 | use_crf True bool 28 | use_chars True bool 29 | -------------------------------------------------------------------------------- /data/dev.txt: -------------------------------------------------------------------------------- 1 | Jean B-PER 2 | Pierre I-PER 3 | lives O 4 | in O 5 | New B-LOC 6 | York I-LOC 7 | . O 8 | 9 | The O 10 | European B-ORG 11 | Union I-ORG 12 | is O 13 | a O 14 | political O 15 | and O 16 | economic O 17 | union O 18 | 19 | A O 20 | French B-MISC 21 | American I-MISC 22 | actor O 23 | won O 24 | an O 25 | oscar O 26 | 27 | Jean B-PER 28 | Pierre I-PER 29 | lives O 30 | in O 31 | New B-LOC 32 | York I-LOC 33 | . O 34 | 35 | The O 36 | European B-ORG 37 | Union I-ORG 38 | is O 39 | a O 40 | political O 41 | and O 42 | economic O 43 | union O 44 | 45 | A O 46 | French B-MISC 47 | American I-MISC 48 | actor O 49 | won O 50 | an O 51 | oscar O 52 | 53 | Jean B-PER 54 | Pierre I-PER 55 | lives O 56 | in O 57 | New B-LOC 58 | York I-LOC 59 | . O 60 | 61 | The O 62 | European B-ORG 63 | Union I-ORG 64 | is O 65 | a O 66 | political O 67 | and O 68 | economic O 69 | union O 70 | 71 | A O 72 | French B-MISC 73 | American I-MISC 74 | actor O 75 | won O 76 | an O 77 | oscar O 78 | 79 | Jean B-PER 80 | Pierre I-PER 81 | lives O 82 | in O 83 | New B-LOC 84 | York I-LOC 85 | . O 86 | 87 | The O 88 | European B-ORG 89 | Union I-ORG 90 | is O 91 | a O 92 | political O 93 | and O 94 | economic O 95 | union O 96 | 97 | A O 98 | French B-MISC 99 | American I-MISC 100 | actor O 101 | won O 102 | an O 103 | oscar O 104 | 105 | -------------------------------------------------------------------------------- /data/test.txt: -------------------------------------------------------------------------------- 1 | Jean B-PER 2 | Pierre I-PER 3 | lives O 4 | in O 5 | New B-LOC 6 | York I-LOC 7 | . O 8 | 9 | The O 10 | European B-ORG 11 | Union I-ORG 12 | is O 13 | a O 14 | political O 15 | and O 16 | economic O 17 | union O 18 | 19 | A O 20 | French B-MISC 21 | American I-MISC 22 | actor O 23 | won O 24 | an O 25 | oscar O 26 | 27 | Jean B-PER 28 | Pierre I-PER 29 | lives O 30 | in O 31 | New B-LOC 32 | York I-LOC 33 | . O 34 | 35 | The O 36 | European B-ORG 37 | Union I-ORG 38 | is O 39 | a O 40 | political O 41 | and O 42 | economic O 43 | union O 44 | 45 | A O 46 | French B-MISC 47 | American I-MISC 48 | actor O 49 | won O 50 | an O 51 | oscar O 52 | 53 | Jean B-PER 54 | Pierre I-PER 55 | lives O 56 | in O 57 | New B-LOC 58 | York I-LOC 59 | . O 60 | 61 | The O 62 | European B-ORG 63 | Union I-ORG 64 | is O 65 | a O 66 | political O 67 | and O 68 | economic O 69 | union O 70 | 71 | A O 72 | French B-MISC 73 | American I-MISC 74 | actor O 75 | won O 76 | an O 77 | oscar O 78 | 79 | Jean B-PER 80 | Pierre I-PER 81 | lives O 82 | in O 83 | New B-LOC 84 | York I-LOC 85 | . O 86 | 87 | The O 88 | European B-ORG 89 | Union I-ORG 90 | is O 91 | a O 92 | political O 93 | and O 94 | economic O 95 | union O 96 | 97 | A O 98 | French B-MISC 99 | American I-MISC 100 | actor O 101 | won O 102 | an O 103 | oscar O 104 | 105 | -------------------------------------------------------------------------------- /data/train.txt: -------------------------------------------------------------------------------- 1 | Jean B-PER 2 | Pierre I-PER 3 | lives O 4 | in O 5 | New B-LOC 6 | York I-LOC 7 | . O 8 | 9 | The O 10 | European B-ORG 11 | Union I-ORG 12 | is O 13 | a O 14 | political O 15 | and O 16 | economic O 17 | union O 18 | 19 | A O 20 | French B-MISC 21 | American I-MISC 22 | actor O 23 | won O 24 | an O 25 | oscar O 26 | 27 | Jean B-PER 28 | Pierre I-PER 29 | lives O 30 | in O 31 | New B-LOC 32 | York I-LOC 33 | . O 34 | 35 | The O 36 | European B-ORG 37 | Union I-ORG 38 | is O 39 | a O 40 | political O 41 | and O 42 | economic O 43 | union O 44 | 45 | A O 46 | French B-MISC 47 | American I-MISC 48 | actor O 49 | won O 50 | an O 51 | oscar O 52 | 53 | Jean B-PER 54 | Pierre I-PER 55 | lives O 56 | in O 57 | New B-LOC 58 | York I-LOC 59 | . O 60 | 61 | The O 62 | European B-ORG 63 | Union I-ORG 64 | is O 65 | a O 66 | political O 67 | and O 68 | economic O 69 | union O 70 | 71 | A O 72 | French B-MISC 73 | American I-MISC 74 | actor O 75 | won O 76 | an O 77 | oscar O 78 | 79 | Jean B-PER 80 | Pierre I-PER 81 | lives O 82 | in O 83 | New B-LOC 84 | York I-LOC 85 | . O 86 | 87 | The O 88 | European B-ORG 89 | Union I-ORG 90 | is O 91 | a O 92 | political O 93 | and O 94 | economic O 95 | union O 96 | 97 | A O 98 | French B-MISC 99 | American I-MISC 100 | actor O 101 | won O 102 | an O 103 | oscar O 104 | 105 | -------------------------------------------------------------------------------- /data/words.txt: -------------------------------------------------------------------------------- 1 | and 2 | is 3 | $UNK$ 4 | an 5 | economic 6 | in 7 | $NUM$ 8 | union 9 | political 10 | actor 11 | . 12 | won 13 | new 14 | european 15 | oscar 16 | french 17 | pierre 18 | lives 19 | york 20 | the 21 | a 22 | american 23 | jean -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | from Configure import * 4 | from model.Dataset import * 5 | from model.Vocab_Util import * 6 | from model.log_util import * 7 | from model.ner_model import * 8 | 9 | def main(): 10 | ''' 11 | 构造测试集,并用restore之前的模型来进行evaluate 12 | ''' 13 | # 读取配置文件 14 | config = Configure() 15 | # 设置logger 16 | logger = get_logger(config['path_log']) 17 | # 读取词典 18 | vocab_util = Vocab_Util() 19 | # dict[word] = idx 20 | vocab_words = vocab_util.load_vocab(config['word_vocab_file']) 21 | # dict[char] = idx 22 | vocab_chars = vocab_util.load_vocab(config['char_vocab_file']) 23 | # dict[tag] = idx 24 | vocab_tags = vocab_util.load_vocab(config['tag_vocab_file']) 25 | # 将词典封装给模型 26 | vocabs = [vocab_words, vocab_chars, vocab_tags] 27 | 28 | embeddings = vocab_util.get_trimmed_glove_vectors(config['trimmed_file']) 29 | 30 | # 对数据进行处理 31 | processing_word = get_processing_word(vocab_words = vocab_words, vocab_chars = vocab_chars, 32 | lowercase = True, chars = config['use_chars'], allow_unk = True) 33 | processing_tag = get_processing_word(vocab_words = vocab_tags, lowercase = False, allow_unk = False) 34 | 35 | # 得到训练数据 36 | test_dataset = Dataset(filename = config['test_data'], 37 | max_iter = None, processing_word = processing_word, processing_tag = processing_tag) 38 | 39 | # 构造模型进行训练 40 | model = ner_model(config,logger,vocabs,embeddings) 41 | model.build() 42 | model.restore_session() 43 | model.evaluate(test_dataset) 44 | 45 | if __name__ == '__main__': 46 | main() -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | glove: 2 | wget -P ./data/ "http://nlp.stanford.edu/data/glove.6B.zip" 3 | unzip ./data/glove.6B.zip -d data/glove.6B/ 4 | rm ./data/glove.6B.zip 5 | data: 6 | python build_data.py 7 | train: 8 | python train.py 9 | evaluate: 10 | python evaluate.py -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xixy/SequenceTagging/2433b4f735d49aa6b2159f5ceeea341bc97a1cfe/model.png -------------------------------------------------------------------------------- /model/Dataset.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | class Dataset(object): 4 | """用来处理数据的类""" 5 | def __init__(self, filename, max_iter = None, processing_word = None, processing_tag = None): 6 | ''' 7 | Args: 8 | filename: 该Dataset对应的数据文件路径 ./data/train.txt 9 | max_iter: max number of sentences to yield from this dataset 10 | processing_word: 对词进行处理的函数 11 | processing_tag: 对tag进行处理的函数 12 | ''' 13 | self.filename = filename 14 | self.max_iter = max_iter 15 | self.length = None # 数据集中的句子数量 16 | self.processing_word = processing_word 17 | self.processing_tag = processing_tag 18 | 19 | def __iter__(self): 20 | ''' 21 | 支持iteration的遍历操作 22 | ''' 23 | count = 0 24 | with open(self.filename) as f: 25 | words, tags = [], [] 26 | for line in f: 27 | line = line.strip() 28 | if (len(line) == 0 or line.startswith("-DOCSTART-")): 29 | # 如果是新的一行,表示句子结束 30 | if len(words) != 0: 31 | count += 1 32 | # 如果超过了 33 | if self.max_iter is not None and count > self.max_iter: 34 | break 35 | yield words, tags 36 | words, tags = [], [] 37 | 38 | else: 39 | elems = line.split(' ') 40 | word, tag = elems[0], elems[1] 41 | # 进行预处理 42 | if self.processing_word is not None: 43 | word = self.processing_word(word) 44 | if self.processing_tag is not None: 45 | tag = self.processing_tag(tag) 46 | 47 | words += [word] 48 | tags += [tag] 49 | 50 | def __len__(self): 51 | ''' 52 | 支持len操作来查看句子数量 53 | ''' 54 | # 如果没初始化,然后再进行统计 55 | if self.length is None: 56 | self.length = 0 57 | for _ in self: 58 | self.length += 1 59 | return self.length 60 | 61 | def get_minibatch(self, batch_size): 62 | ''' 63 | 从数据集中获取minibatch 64 | ''' 65 | x_batch, y_batch = [], [] 66 | for (x, y) in self: 67 | if len(x_batch) == batch_size: 68 | yield x_batch, y_batch 69 | x_batch, y_batch = [], [] 70 | 71 | # 如果是([char_ids], word_id)的形式 72 | if type(x[0]) == tuple: 73 | x = zip(*x) 74 | 75 | x_batch += [x] 76 | y_batch += [y] 77 | if len(x_batch) != 0: 78 | yield x_batch, y_batch 79 | 80 | 81 | 82 | def get_processing_word(vocab_words = None, vocab_chars = None, 83 | lowercase = False, chars = False, allow_unk = True): 84 | """Return lambda function that transform a word (string) into list, 85 | or tuple of (list, id) of int corresponding to the ids of the word and 86 | its corresponding characters. 87 | 88 | Args: 89 | vocab: dict[word] = idx 90 | vocab_chars: dict[char] = idx 91 | lowercase: 是否将单词进行小写化 92 | chars: 是否要返回单词id列表 93 | allow_unk: 是否将不在词典中的单词作为UNK 94 | 95 | Returns: 96 | f("cat") = ([12, 4, 32], 12345) 97 | = (list of char ids, word id) 98 | 99 | """ 100 | def f(word): 101 | 102 | # 首先转化为chars list 103 | if vocab_chars is not None and chars: 104 | char_ids = [] 105 | for char in word: 106 | # 如果字母不在的话,就忽略,基本上很少这种情况 107 | if char in vocab_chars: 108 | char_ids.append(vocab_chars[char]) 109 | # 对单词进行处理 110 | if lowercase: 111 | word = word.lower() 112 | if word.isdigit(): 113 | word = NUM 114 | 115 | # 得到单词的id 116 | if vocab_words is not None: 117 | if word in vocab_words: 118 | word = vocab_words[word] 119 | else: 120 | if allow_unk: 121 | word = vocab_words[UNK] 122 | else: 123 | raise Exception('出现了不存在词典中的词,请检查是否正确' + word) 124 | # 返回char ids, word 125 | if vocab_chars is not None and chars == True: 126 | return char_ids, word 127 | else: 128 | return word 129 | return f 130 | 131 | if __name__ == '__main__': 132 | processing_word = get_processing_word(lowercase=True) 133 | dataset = Dataset('../data/test.txt', processing_word = processing_word) 134 | 135 | for data in dataset: 136 | print data 137 | print len(dataset) 138 | for x_batch , y_batch in dataset.get_minibatch(5): 139 | print x_batch 140 | print y_batch 141 | # print data 142 | 143 | -------------------------------------------------------------------------------- /model/Vocab_Util.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import numpy as np 4 | import os 5 | from Dataset import * 6 | 7 | 8 | # shared global variables to be imported from model also 9 | UNK = "$UNK$" 10 | NUM = "$NUM$" 11 | NONE = "O" 12 | 13 | class Vocab_Util(object): 14 | """用来做词典操作""" 15 | def get_vocabs_from_datasets(self, datasets): 16 | ''' 17 | 从Dataset中获取词典 18 | Args: 19 | datasets:[dataset]多个Dataset对象的集合 20 | Returns: 21 | 该dataset集合中的所有单词 22 | ''' 23 | vocab_words = set() 24 | vocab_tags = set() 25 | for dataset in datasets: 26 | for words, tags in dataset: 27 | # 进行更新添加 28 | vocab_words.update(words) 29 | vocab_tags.update(tags) 30 | return vocab_words, vocab_tags 31 | 32 | def get_vocabs_from_glove(self, filename): 33 | ''' 34 | 从glove向量中获取vocab 35 | Args: 36 | filename: glove vector file路径 37 | Return: 38 | set() of words 39 | ''' 40 | vocab = set() 41 | with open(filename) as f: 42 | for line in f: 43 | word = line.strip().split(' ')[0] 44 | vocab.add(word) 45 | return vocab 46 | 47 | def write_vocab(self, vocab, filename): 48 | ''' 49 | 将vocab写入到文件中,格式为一行一个词 50 | Args: 51 | vocab:iterable that yields word 52 | filename: 保存文件路径 53 | ''' 54 | with open(filename, 'w') as f: 55 | for i, word in enumerate(vocab): 56 | if i != len(vocab) - 1: 57 | f.write(word + '\n') 58 | else: 59 | f.write(word) 60 | 61 | def load_vocab(self, filename): 62 | ''' 63 | 从vocab文件中加载vocab 64 | Args: 65 | filename: 保存文件路径 66 | Return: 67 | vocab: dict[word] = index 68 | ''' 69 | vocab = dict() 70 | idx = 0 71 | with open(filename) as f: 72 | for line in f: 73 | word = line.strip() 74 | vocab[word] = idx 75 | idx += 1 76 | return vocab 77 | 78 | def export_trimmed_glove_vectors(self, vocab, glove_filename, trimmed_filename, dim): 79 | ''' 80 | 将vocab中的词从glove vector中取出来,并存储到trimed_filename文件中 81 | Args: 82 | vocab: dictionary vocab[word] = index 83 | glove_filename: glove vector文件路径 84 | trimmed_filename: 存储np matrix的路径 85 | dim: embedding的维度 86 | ''' 87 | embeddings = np.zeros([len(vocab), dim]) 88 | with open(glove_filename) as f: 89 | for line in f: 90 | line = line.strip().split(' ') 91 | word = line[0] 92 | embedding = [float(x) for x in line[1:]] 93 | if word in vocab: 94 | word_idx = vocab[word] 95 | embeddings[word_idx] = np.asarray(embedding) 96 | 97 | np.savez_compressed(trimmed_filename, embeddings = embeddings) 98 | 99 | def get_trimmed_glove_vectors(self, trimmed_filename): 100 | ''' 101 | 加载trimmed glove vectors 102 | Args: 103 | trimmed_filename: 存储np matrix的路径 104 | Returns: 105 | matrix of embeddings (instance of np array) 106 | ''' 107 | with np.load(trimmed_filename) as data: 108 | return data['embeddings'] 109 | 110 | 111 | def get_char_vocab_from_datasets(self, datasets): 112 | ''' 113 | 从Dataset中获取字母vocab 114 | Args: 115 | datasets:[dataset]多个Dataset对象的集合 116 | Returns: 117 | 该dataset集合中的所有字母 118 | ''' 119 | vocab = set() 120 | for words, tags in datasets: 121 | for word in words: 122 | vocab.update(word) 123 | return vocab 124 | 125 | 126 | if __name__ == '__main__': 127 | processing_word = get_processing_word(lowercase=True) 128 | dataset = Dataset('../data/test.txt', processing_word = processing_word) 129 | vocab = Vocab_Util() 130 | print vocab.get_vocabs([dataset]) 131 | 132 | 133 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xixy/SequenceTagging/2433b4f735d49aa6b2159f5ceeea341bc97a1cfe/model/__init__.py -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import tensorflow as tf 4 | 5 | class base_model(object): 6 | """docstring for BaseModel""" 7 | def __init__(self, config, logger): 8 | self.config = config 9 | self.sess = None 10 | self.saver = None 11 | self.logger = logger 12 | 13 | def reinitialize_weights(self, scope_name): 14 | ''' 15 | 对某一scope的变量进行重新初始化 16 | ''' 17 | variables = tf.contrib.framework.get_variables(scope_name) 18 | init = tf.variables_initializer(variables) 19 | self.sess.run(init) 20 | 21 | def add_train_op(self, lr_method, lr, loss, clip = -1): 22 | ''' 23 | 定义self.train_op来进行训练操作 24 | Args: 25 | lr_method: adam/adagrad/sgd/rmsprop 26 | lr: learning rate 27 | loss: 损失函数 28 | clip: 梯度的clipping值,如果<0,那么就不clipping 29 | ''' 30 | _lr_m = lr_method.lower() 31 | with tf.variable_scope("train_step"): 32 | if _lr_m == 'adam': 33 | optimizer = tf.train.AdamOptimizer(lr) 34 | elif _lr_m == 'adagrad': 35 | optimizer = tf.train.AdagradOptimizer(lr) 36 | elif _lr_m == 'sgd': 37 | optimizer = tf.train.GradientDescentOptimizer(lr) 38 | elif _lr_m == 'rmsprop': 39 | optimizer = tf.train.RMSPropOptimizer(lr) 40 | else: 41 | raise NotImplementedError("Unknown method {}".format(_lr_m)) 42 | 43 | # 进行clip 44 | if clip > 0: 45 | grads, vs = zip(*optimizer.compute_gradients(loss)) 46 | grads, gnorm = tf.clip_by_global_norm(grads, clip) 47 | self.train_op = optimizer.apply_gradients(zip(grads, vs)) 48 | else: 49 | self.train_op = optimizer.minimize(loss) 50 | 51 | def initialize_session(self): 52 | ''' 53 | 定义self.sess并且初始化变量和self.saver 54 | ''' 55 | self.logger.info("Initializing tf session") 56 | self.sess = tf.Session() 57 | self.sess.run(tf.global_variables_initializer()) 58 | self.saver = tf.train.Saver() 59 | 60 | def restore_session(self, dir_model = None): 61 | ''' 62 | reload weights into self.session 63 | Args: 64 | dir_model: 模型路径 65 | ''' 66 | if dir_model is None: 67 | # 从config中读取模型路径 68 | dir_model = self.config['dir_model'] 69 | 70 | self.logger.info("Reloading the latest trained model...") 71 | self.saver.restore(self.sess, dir_model) 72 | 73 | def save_session(self): 74 | ''' 75 | 存储session 76 | ''' 77 | if not os.path.exists(self.config['dir_model']): 78 | os.makedirs(self.config['dir_model']) 79 | self.saver.save(self.sess, self.config['dir_model']) 80 | 81 | def close_session(self): 82 | ''' 83 | close Session 84 | ''' 85 | self.sessin.close() 86 | 87 | def add_summary(self): 88 | ''' 89 | 为tensorboard定义变量, 输出文件为dir_output 90 | 91 | ''' 92 | self.merged = tf.summary.merge_all() 93 | self.file_writer = tf.summary.FileWriter(self.config["dir_output"], 94 | self.sess.graph) 95 | 96 | def train(self, train, dev): 97 | ''' 98 | 进行训练,采用了early stopping和学习率指数递减 99 | Args: 100 | train: dataset yields tuple of (sentence, tags) 101 | dev: dataset 102 | ''' 103 | best_score = 0 104 | num_of_epoch_no_imprv = 0 # for early stopping,用来记录几个epoch没有提高了 105 | self.add_summary() # tensorboard 106 | for epoch in range(self.config['epochs']): 107 | self.logger.info('Epoch {:} out of {:}'.format(epoch + 1, self.config['epochs'])) 108 | 109 | # 运行一个epoch的训练工作,并返回在dev数据集上的测试f1 110 | score = self.run_epoch(train, dev, epoch) 111 | # 进行learning rate decay 112 | self.config['learning_rate'] *= self.config['learning_rate_decay'] 113 | 114 | # 进行early stopping并且保存最好的参数 115 | # 如果效果更好了 116 | if score >= best_score: 117 | # 清零 118 | num_of_epoch_no_imprv = 0 119 | # 记录当前的参数 120 | self.save_session() 121 | # 更新best score 122 | best_score = score 123 | self.logger.info("- new best score! ") 124 | # 如果效果没有更好 125 | else: 126 | num_of_epoch_no_imprv += 1 127 | # 如果已经好多轮没有效果更好了,就需呀stop 128 | if num_of_epoch_no_imprv >= self.config['num_of_epoch_no_imprv']: 129 | self.logger.info("- early stopping {} epochs without "\ 130 | "improvement".format(nepoch_no_imprv)) 131 | break 132 | 133 | 134 | def evaluate(self, test): 135 | ''' 136 | 在测试集上对模型进行测试 137 | Args: 138 | test: datset from test.txt 139 | ''' 140 | self.logger.info("Testing model over test set") 141 | # 跑测试 142 | metrics = self.run_evaluate(test) 143 | msg = " - ".join(["{} {:04.2f}".format(k, v) for k, v in metrics.items()]) 144 | self.logger.info(msg) 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /model/log_util.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import time 3 | import sys 4 | import logging 5 | 6 | def get_logger(filename): 7 | ''' 8 | 得到一个logger instance,写入filename中 9 | Args: 10 | filename: path to log.txt 11 | Returns: 12 | logger 13 | ''' 14 | logger = logging.getLogger('logger') 15 | logger.setLevel(logging.DEBUG) 16 | logging.basicConfig(format='%(message)s', level = logging.DEBUG) 17 | handler = logging.FileHandler(filename) 18 | handler.setLevel(logging.DEBUG) 19 | handler.setFormatter(logging.Formatter( 20 | '%(asctime)s:%(levelname)s: %(message)s')) 21 | logging.getLogger().addHandler(handler) 22 | return logger -------------------------------------------------------------------------------- /model/ner_model.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from base_model import * 3 | import numpy as np 4 | class ner_model(base_model): 5 | """NER model""" 6 | def __init__(self, config, logger, vocabs, embeddings = None): 7 | super(ner_model, self).__init__(config, logger) 8 | self.embeddings = embeddings 9 | self.vocabs = vocabs 10 | self.vocab_words = vocabs[0] 11 | self.vocab_chars = vocabs[1] 12 | self.vocab_tags = vocabs[2] 13 | 14 | def build(self): 15 | ''' 16 | 构建可计算图 17 | ''' 18 | # 添加placeholders 19 | self.add_placeholders() 20 | # 添加向量化操作,得到每个词的向量表示 21 | self.add_word_embeddings_op() 22 | # 计算logits 23 | self.add_logits_op() 24 | # 计算概率 25 | self.add_pred_op() 26 | # 计算损失 27 | self.add_loss_op() 28 | 29 | # 定义训练操作 30 | self.add_train_op(self.config['optimizer'], self.lr, self.loss, 31 | self.config['clip']) 32 | # 创建session和logger 33 | self.initialize_session() 34 | 35 | 36 | 37 | 38 | def add_placeholders(self): 39 | ''' 40 | 定义placeholder 41 | ''' 42 | # 表示batch中每个句子的词的id表示 43 | # shape = (batch size, max length of sentence in batch) 44 | self.word_ids = tf.placeholder(tf.int32, shape = [None, None], name = "word_ids") 45 | 46 | # 表示batch中每个句子的长度 47 | # shape = (batch size,) 48 | self.sequence_lengths = tf.placeholder(tf.int32, shape = [None], name = "sequence_lengths") 49 | 50 | # 表示batch中每个句子的每个词的字母id表示 51 | # shape = (batch size, max length of sentence in batch, max length of word) 52 | self.char_ids = tf.placeholder(tf.int32, shape = [None, None, None], 53 | name = "char_ids") 54 | 55 | # 表示batch中每个句子的每个词的长度 56 | # shape = (batch size, max length of sentence in batch) 57 | self.word_lengths = tf.placeholder(tf.int32, shape = [None, None], name = "word_lengths") 58 | 59 | # 表示batch中每个句子的每个word的label 60 | # shape = (batch size, max length of sentence in batch) 61 | self.labels = tf.placeholder(tf.int32, shape = [None, None], name = "labels") 62 | 63 | # dropout 64 | self.dropout = tf.placeholder(tf.float32, shape = [], name = "dropout") 65 | # 学习率 66 | self.lr = tf.placeholder(tf.float32, shape = [], name = "lr") 67 | 68 | 69 | def add_word_embeddings_op(self): 70 | ''' 71 | 添加embedding操作,包括词向量和字向量 72 | 如果self.embeddings不是None,那么词向量就采用pre-trained vectors,否则自行训练 73 | 字向量是自行训练的 74 | ''' 75 | with tf.variable_scope("words"): 76 | # 如果词向量是None 77 | if self.embeddings is None: 78 | self.logger.info("WARNING: randomly initializing word vectors") 79 | _word_embeddings = tf.get_variable( 80 | name = '_word_embeddings', 81 | dtype = tf.float32, 82 | shape = [len(self.vocab_words), self.config['word_embedding_dim']] 83 | ) 84 | else: 85 | # 加载已有的词向量 86 | _word_embeddings = tf.Variable( 87 | self.embeddings, 88 | name = '_word_embeddings', 89 | dtype = tf.float32, 90 | trainable = self.config['training_embeddings'] 91 | ) 92 | # lookup来获取word_ids对应的embeddings 93 | # shape = (batch size, max_length_sentence, dim) 94 | word_embeddings = tf.nn.embedding_lookup( 95 | _word_embeddings, 96 | self.word_ids, 97 | name = 'word_embeddings' 98 | ) 99 | with tf.variable_scope('chars'): 100 | if self.config['use_chars']: 101 | _char_embeddings = tf.get_variable( 102 | name = '_char_embeddings', 103 | dtype = tf.float32, 104 | shape = [len(self.vocab_chars), self.config['char_embedding_dim']] 105 | ) 106 | # shape = (batch, max_length_sentence, max_length_word, dim of char embeddings) 107 | char_embeddings = tf.nn.embedding_lookup( 108 | _char_embeddings, 109 | self.char_ids, 110 | name = 'char_embeddings' 111 | ) 112 | 113 | # 2. put the time dimension on axis=1 for dynamic_rnn 114 | s = tf.shape(char_embeddings) 115 | # shape = (batch * max_length_sentence, max_length_word, dim of char embeddings) 116 | char_embeddings = tf.reshape(char_embeddings, 117 | shape = [s[0]*s[1], s[-2], self.config['char_embedding_dim']] 118 | ) 119 | 120 | # 表示batch中每个句子的每个词的长度 121 | # shape = (batch size * max_length_sentence,) 122 | word_lengths = tf.reshape(self.word_lengths, shape=[s[0]*s[1]]) 123 | 124 | # 3 bi-lstm on chars 125 | cell_fw = tf.contrib.rnn.LSTMCell(self.config['hidden_size_char'], 126 | state_is_tuple = True 127 | ) 128 | cell_bw = tf.contrib.rnn.LSTMCell(self.config['hidden_size_char'], 129 | state_is_tuple = True 130 | ) 131 | _output = tf.nn.bidirectional_dynamic_rnn( 132 | cell_bw, # 前向RNN 133 | cell_bw, # 后向RNN 134 | char_embeddings, # 输入序列 135 | sequence_length = word_lengths, # 序列长度 136 | dtype = tf.float32 137 | ) 138 | 139 | # 取出final_state h,而不是c 140 | # shape = (batch * max_length_sentence, hidden_size_char) 141 | _, ((_, output_fw), (_, output_bw)) = _output 142 | # 双向输出进行合并 143 | 144 | # shape = (batch * max_length_sentence, 2 * hidden_size_char) 145 | output = tf.concat([output_fw, output_bw], axis=-1) 146 | 147 | 148 | # reshpae到 shape = (batch size, max_length_sentence, 2 * char hidden size) 149 | output = tf.reshape( 150 | output, 151 | shape = [s[0], s[1], 2 * self.config['hidden_size_char']] 152 | ) 153 | 154 | # 合并到word_embedding上 155 | # shape = (batch size, max_length_sentence, 2 * char hidden size + word vector dim) 156 | word_embeddings = tf.concat([word_embeddings, output], axis = -1) 157 | 158 | self.word_embeddings = tf.nn.dropout(word_embeddings, self.dropout) 159 | 160 | 161 | def add_logits_op(self): 162 | ''' 163 | 定义self.logits,句子中的每个词都对应一个得分向量,维度是tags的维度 164 | ''' 165 | # 首先对句子进行LSTM 166 | with tf.variable_scope('bi-lstm'): 167 | cell_fw = tf.contrib.rnn.LSTMCell( 168 | self.config['hidden_size_lstm'] 169 | ) 170 | cell_bw = tf.contrib.rnn.LSTMCell( 171 | self.config['hidden_size_lstm'] 172 | ) 173 | output = tf.nn.bidirectional_dynamic_rnn( 174 | cell_fw, 175 | cell_bw, 176 | self.word_embeddings, 177 | sequence_length = self.sequence_lengths, 178 | dtype = tf.float32 179 | ) 180 | # 取出output 181 | # shape = (batch size, max_length_sentence, hidden_size_lstm) 182 | (output_fw, output_bw), _ = output 183 | # shape = (batch size, max_length_sentence, 2 * hidden_size_lstm) 184 | output = tf.concat([output_fw, output_bw], axis = -1) 185 | # 进行dropout 186 | # shape = (batch size, max_length_sentence, 2 * hidden_size_lstm) 187 | output = tf.nn.dropout(output, self.dropout) 188 | 189 | # 然后用全联接网络计算概率 190 | with tf.variable_scope('proj'): 191 | W = tf.get_variable( 192 | name = 'w', 193 | dtype = tf.float32, 194 | shape = [2 * self.config['hidden_size_lstm'], len(self.vocab_tags)] 195 | ) 196 | b = tf.get_variable( 197 | name = 'b', 198 | dtype = tf.float32, 199 | shape = [len(self.vocab_tags)], 200 | initializer = tf.zeros_initializer() 201 | ) 202 | # 取出max_length_sentence 203 | nsteps = tf.shape(output)[1] 204 | 205 | # shape = (batch size * max_length_sentence, 2 * hidden_size_lstm) 206 | output = tf.reshape(output, [-1, 2*self.config['hidden_size_lstm']]) 207 | 208 | # shape = (batch size * max_length_sentence, vocab_tags_nums) 209 | pred = tf.matmul(output, W) + b 210 | 211 | # shape = (batch size, max_length_sentence, vocab_tags_nums) 212 | self.logits = tf.reshape(pred, [-1, nsteps, len(self.vocab_tags)]) 213 | 214 | def add_pred_op(self): 215 | ''' 216 | 计算prediction,如果使用crf的话,需要 217 | ''' 218 | # 取出概率最大的维度的idx 219 | # shape = (batch size, max_length_sentence) 220 | if not self.config['use_crf']: 221 | self.labels_pred = tf.cast(tf.argmax(self.logits, axis=-1), tf.int32) 222 | 223 | def add_loss_op(self): 224 | ''' 225 | 计算损失 226 | ''' 227 | # 如果使用crf部分 228 | if self.config['use_crf']: 229 | log_likelihood, trans_params = tf.contrib.crf.crf_log_likelihood( 230 | self.logits, 231 | self.labels, 232 | self.sequence_lengths 233 | ) 234 | self.trans_params = trans_params 235 | self.loss = tf.reduce_mean(-log_likelihood) 236 | # 如果不计算crf部分 237 | else: 238 | # shape = (batch size, max_length_sentence) 239 | losses = tf.nn.sparse_softmax_cross_entrophy_with_logits( 240 | logits = self.logits, 241 | labels = self.labels 242 | ) 243 | mask = tf.sequence_mask(self.sequence_lengths) 244 | losses = tf.boolean_mask(losses, mask) 245 | self.loss = tf.reduce_mean(losses) 246 | # for tensorboard 247 | tf.summary.scalar('loss', self.loss) 248 | 249 | 250 | def run_epoch(self, train, dev, epoch): 251 | ''' 252 | 运行一个epoch,包括在训练集上训练、dev集上测试,一个epoch中对train集合的所有数据进行训练 253 | ''' 254 | batch_size = self.config['batch_size'] 255 | nbatches = (len(train) + batch_size - 1) // batch_size 256 | 257 | 258 | for i, (words, labels) in enumerate(train.get_minibatch(batch_size)): 259 | 260 | # 构造feed_dict,主要是包括: 261 | # 1. word_ids, word_length 262 | # 2. char_ids, char_length 263 | # 3. learning rate 264 | # 4. dropout keep prob 265 | fd, _ = self.get_feed_dict(words, labels, 266 | self.config['learning_rate'], self.config['dropout']) 267 | 268 | # 执行计算 269 | _, train_loss, summary = self.sess.run( 270 | [self.train_op, self.loss, self.merged], 271 | feed_dict = fd 272 | ) 273 | 274 | # tensorboard 275 | if i % 10 == 0: 276 | self.file_writer.add_summary(summary, epoch * nbatches + i) 277 | 278 | # 在dev集上面进行测试 279 | metrics = self.run_evaluate(dev) 280 | 281 | msg = " - ".join(["{} {:04.2f}".format(k, v) 282 | for k, v in metrics.items()]) 283 | self.logger.info(msg) 284 | 285 | # 返回f1 286 | return metrics["f1"] 287 | 288 | def pad_sequences(self, sequences, pad_token, nlevels = 1): 289 | ''' 290 | 对sequence进行填充 291 | Args: 292 | sequences: a generator of list or tuple 293 | pad_token: the token to pad with 294 | nlevels: padding的深度,如果是1,则表示对词进行填充,如果是2表示对字进行填充 295 | Return: 296 | a list of list where each sublist has same length 297 | ''' 298 | # print '--------pad-sequences--------' 299 | if nlevels == 1: 300 | # 找到sequences中的句子最大长度 301 | max_length_sentence = max(map(lambda x: len(x), sequences)) 302 | # 然后直接进行padding 303 | sequences_padded, sequences_length = self._pad_sequences(sequences, 304 | pad_token, max_length_sentence) 305 | 306 | if nlevels == 2: 307 | # 找到sequence中所有句子的所有单词中字母数最大的单词 308 | max_length_word = max([max(map(lambda x: len(x), seq)) for seq in sequences]) 309 | # print max_length_word 310 | sequences_padded, sequences_length = [], [] 311 | for seq in sequences: 312 | # print seq 313 | # 将每个句子的每个词都进行填充 314 | sp, sl = self._pad_sequences(seq, pad_token, max_length_word) 315 | # print sp, sl 316 | # 每个句子的字母的表示 317 | sequences_padded += [sp] 318 | # 每个句子的字母的长度 319 | sequences_length += [sl] 320 | # 然后对句子进行填充 321 | # batch中最大长度的句子 322 | max_length_sentence = max(map(lambda x : len(x), sequences)) 323 | # 填充的时候用[0,0,0,0,0]用字母向量进行填充 324 | sequences_padded, _ = self._pad_sequences(sequences_padded, 325 | [pad_token] * max_length_word, max_length_sentence) 326 | # 得到句子的每个单词的字母的长度 (batch, max_length_sentence, letter_length) 327 | sequences_length, _ = self._pad_sequences(sequences_length, 0, max_length_sentence) 328 | 329 | 330 | 331 | return sequences_padded, sequences_length 332 | 333 | 334 | 335 | def _pad_sequences(self, sequences, pad_token, max_length): 336 | ''' 337 | 对sequences进行填充 338 | Args: 339 | pad_token: the token to pad with 340 | ''' 341 | sequences_padded, sequences_lengths = [], [] 342 | for sequence in sequences: 343 | sequence = list(sequence) 344 | # 获取句子长度 345 | sequences_lengths += [min(len(sequence), max_length)] 346 | # 进行填充 347 | sequence = sequence[:max_length] + [pad_token] * max(max_length - len(sequence), 0) 348 | sequences_padded += [sequence] 349 | return sequences_padded, sequences_lengths 350 | 351 | 352 | 353 | def get_feed_dict(self, words, labels = None, lr = None, dropout = None): 354 | ''' 355 | Args: 356 | words: list of sentences. A sentences is a list of ids of 357 | a list of words. A word is a list of ids 358 | labels: list of ids 359 | lr: learning rate 360 | dropout: keep prob 361 | ''' 362 | if self.config['use_chars']: 363 | # char_ids (sent1, sent2, ..., sentn) 364 | # senti ([char1, char2, char3], [char1, char2, char3], ..., [cahr1, char2, char3]) 365 | # word_ids ((word1, word2, word3),(word1, word2, word3), ... ) 366 | char_ids, word_ids = zip(*words) 367 | word_ids, sequence_lengths = self.pad_sequences(word_ids, 0) 368 | # print word_ids 369 | # print sequence_lengths 370 | # print char_ids 371 | char_ids, word_lengths = self.pad_sequences(char_ids, pad_token=0, nlevels = 2) 372 | else: 373 | word_ids, sequence_lengths = self.pad_sequences(words, 0) 374 | 375 | feed = { 376 | self.word_ids: word_ids, # 句子的词的id表示(batch, max_length_sentence) 377 | self.sequence_lengths: sequence_lengths # 句子的词的长度(batch, ) 378 | } 379 | 380 | if self.config['use_chars']: 381 | feed[self.char_ids] = char_ids # 句子的词的字母的id表示(batch, max_length_sentence, max_length_word) 382 | feed[self.word_lengths] = word_lengths # 句子的字母的长度(batch, max_length_sentence,) 383 | 384 | if labels is not None: 385 | labels, _ = self.pad_sequences(labels, 0) 386 | feed[self.labels] = labels 387 | 388 | if lr is not None: 389 | feed[self.lr] = lr 390 | if dropout is not None: 391 | feed[self.dropout] = dropout 392 | 393 | return feed, sequence_lengths 394 | 395 | def run_evaluate(self, test): 396 | ''' 397 | 在测试集上运行,并且统计结果,包括precision/recall/accuracy/f1 398 | Args: 399 | test: 一个dataset instance 400 | Returns: 401 | metrics: dict metrics['acc'] = 98.4 402 | ''' 403 | accs = [] 404 | correct_preds, total_correct, total_preds = 0., 0., 0. 405 | for words, labels in test.get_minibatch(self.config['batch_size']): 406 | # predict_batch 407 | # shape = (batch size, max_length_sentence) 408 | # shape = (batch size,) 409 | labels_pred, sequence_lengths = self.predict_batch(words) 410 | 411 | for lab, lab_pred, length in zip(labels, labels_pred, sequence_lengths): 412 | # 取每一句话,长度为length 413 | # 正确label 414 | lab = lab[:length] 415 | # 预测得到的label 416 | lab_pred = lab_pred[:length] 417 | 418 | # 预测正确的个数 419 | accs += [a == b for (a, b) in zip(lab, lab_pred)] 420 | 421 | lab_chunks = set(self.get_chunks(lab, self.vocab_tags)) 422 | 423 | lab_pred_chunks = set(self.get_chunks(lab_pred, self.vocab_tags)) 424 | 425 | # 记录正确的chunk数量 426 | correct_preds += len(lab_chunks & lab_pred_chunks) 427 | # 预测出的chunk数量 428 | total_preds += len(lab_pred_chunks) 429 | # 正确的chunk数量 430 | total_correct += len(lab_chunks) 431 | 432 | # 计算precision,预测出的chunk中有多少是正确的 433 | p = correct_preds / total_preds if correct_preds > 0 else 0 434 | # 计算recall,预测正确的chunk占了所有chunk的数量 435 | r = correct_preds / total_correct if correct_preds > 0 else 0 436 | # 计算f1 437 | f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0 438 | # 计算accuracy,用预测对的词的比例来进行表示 439 | acc = np.mean(accs) 440 | 441 | # 返回结果 442 | return { 443 | 'acc': 100 * acc, 444 | 'f1': 100 * f1 445 | } 446 | 447 | 448 | 449 | def predict_batch(self, words): 450 | ''' 451 | 对一个batch进行预测,并返回预测结果 452 | Args: 453 | words: list of sentences 454 | Returns: 455 | labels_pred: list of labels for each sentence 456 | ''' 457 | fd, sequence_lengths = self.get_feed_dict(words, dropout = 1.0) 458 | if self.config['use_crf']: 459 | viterbi_sequences = [] 460 | logits, trans_params = self.sess.run( 461 | [self.logits, self.trans_params], 462 | feed_dict = fd 463 | ) 464 | for logit, sequence_length in zip(logits, sequence_lengths): 465 | logits = logit[:sequence_length] 466 | viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode( 467 | logit, trans_params) 468 | viterbi_sequences += [viterbi_seq] 469 | 470 | return viterbi_sequences, sequence_lengths 471 | 472 | else: 473 | # shape = (batch size, max_length_sentence) 474 | labels_pred = self.sess.run(self.labels_pred, feed_dict = fd) 475 | 476 | return labels_pred, sequence_lengths 477 | 478 | 479 | def get_chunks(self, seq, tags): 480 | ''' 481 | 给定一个序列的tags,将其中的entity和位置取出来 482 | Args: 483 | seq: [4,4,0,0,1,2...] 一个句子的label 484 | tags: dict['I-LOC'] = 2 485 | Returns: 486 | list of (chunk_type, chunk_start, chunk_end) 487 | 488 | Examples: 489 | seq = [4, 5, 0, 3] 490 | tags = { 491 | 'B-PER' : 4, 492 | 'I-PER' : 5, 493 | 'B-LOC' : 3 494 | 'O' : 0 495 | } 496 | 497 | Returns: 498 | chunks = [ 499 | ('PER', 0, 2), 500 | ('LOC', 3, 4) 501 | ] 502 | ''' 503 | idx_to_tag = {idx : tag for tag, idx in tags.items()} 504 | chunks = [] 505 | 506 | # 表示当前的chunk的起点和类型 507 | chunk_start, chunk_type = None, None 508 | # print seq 509 | 510 | for i, tag_idx in enumerate(seq): 511 | # 如果不是entity的一部分 512 | if tag_idx == tags['O']: 513 | # 如果chunk_type不是None,那么就是一个entity的结束 514 | if chunk_type != None: 515 | chunk = (chunk_type, chunk_start, i) 516 | chunks.append(chunk) 517 | chunk_start, chunk_type = None, None 518 | # 如果chunk_type是None,那么就不需要处理 519 | else: 520 | pass 521 | # 如果是BI 522 | else: 523 | tag = idx_to_tag[tag_idx] 524 | # 如果是B 525 | if tag[0] == 'B': 526 | # 如果前面有entity,那么这个entity就完成了 527 | if chunk_type != None: 528 | chunk = (chunk_type, chunk_start, i) 529 | chunks.append(chunk) 530 | chunk_start, chunk_type = None, None 531 | 532 | # 记录开始 533 | chunk_start = i 534 | chunk_type = tag[2:] 535 | 536 | # 如果是I 537 | else: 538 | if chunk_type != None: 539 | # 如果chunk_type发生了变化,例如(B-PER, I-PER, B-LOC),那么就需要将(B-PER, I-PER)归类为chunk 540 | if chunk_type != tag[2:]: 541 | chunk = (chunk_type, chunk_start, i) 542 | chunks.append(chunk) 543 | chunk_start, chunk_type = None, None 544 | 545 | # 处理可能存在的最后一个未结尾的chunk 546 | if chunk_type != None: 547 | chunk = (chunk_type, chunk_start, i + 1) 548 | chunks.append(chunk) 549 | return chunks 550 | 551 | if __name__ == '__main__': 552 | model = ner_model(None, None, [None,None,None], None) 553 | seq = [4, 4, 5, 0, 3, 5] 554 | tags = { 555 | 'B-PER' : 4, 556 | 'I-PER' : 5, 557 | 'B-LOC' : 3, 558 | 'O' : 0 559 | } 560 | print model.get_chunks(seq, tags) 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | from Configure import * 4 | from model.Dataset import * 5 | from model.Vocab_Util import * 6 | from model.log_util import * 7 | from model.ner_model import * 8 | 9 | def main(): 10 | # 读取配置文件 11 | config = Configure() 12 | # 设置logger 13 | logger = get_logger(config['path_log']) 14 | 15 | # 读取词典 16 | vocab_util = Vocab_Util() 17 | # dict[word] = idx 18 | vocab_words = vocab_util.load_vocab(config['word_vocab_file']) 19 | # dict[char] = idx 20 | vocab_chars = vocab_util.load_vocab(config['char_vocab_file']) 21 | # dict[tag] = idx 22 | vocab_tags = vocab_util.load_vocab(config['tag_vocab_file']) 23 | # 将词典封装给模型 24 | vocabs = [vocab_words, vocab_chars, vocab_tags] 25 | 26 | embeddings = vocab_util.get_trimmed_glove_vectors(config['trimmed_file']) 27 | 28 | # 对数据进行处理 29 | processing_word = get_processing_word(vocab_words = vocab_words, vocab_chars = vocab_chars, 30 | lowercase = True, chars = config['use_chars'], allow_unk = True) 31 | processing_tag = get_processing_word(vocab_words = vocab_tags, lowercase = False, allow_unk = False) 32 | 33 | # 得到训练数据 34 | train_dataset = Dataset(filename = config['train_data'], 35 | max_iter = None, processing_word = processing_word, processing_tag = processing_tag) 36 | # 得到dev数据 37 | dev_dataset = Dataset(filename = config['dev_data'], 38 | max_iter = None, processing_word = processing_word, processing_tag = processing_tag) 39 | 40 | # for data in train_dataset: 41 | # print data 42 | for x_batch , y_batch in train_dataset.get_minibatch(4): 43 | print x_batch 44 | print y_batch 45 | 46 | # 构造模型进行训练 47 | model = ner_model(config,logger,vocabs,embeddings) 48 | # 构建模型图 49 | model.build() 50 | # 训练 51 | model.train(train_dataset, dev_dataset) 52 | 53 | if __name__ == '__main__': 54 | main() --------------------------------------------------------------------------------