├── .gitignore ├── README.md ├── pretrain.py ├── run_pretrain.cmd ├── run_test_keras_model.cmd ├── run_train_keras_model.cmd ├── test_keras_model.py ├── test_keras_model.sh ├── train_keras_model.py ├── train_keras_model.sh ├── utils ├── __init__.py ├── viterbi.py └── viterbi.pyc └── word2vec_model ├── .gitignore ├── run.cmd └── train_word2vec_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 说明: 2 | 利用rnn实现中文分词算法 3 | 源码参考:http://www.jianshu.com/p/7e233ef57cb6 4 | 数据集下载地址:http://pan.baidu.com/s/1jIyNT7w 5 | 6 | 训练步骤: 7 | 1 用现有的语料库(已经切分好)训练出word2vec的model 8 | 2 预训练处理语料库得到训练输入和测试输入 9 | 3 构建rnn并进行训练,在训练的同时测试准确率 10 | 4 根据训练好的model得到可能的序列组合,并利用viterbi算法选择出其中可能性最大的一个序列 -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | python pretrain.py input_file cws_info_filePath cws_data_filePath 5 | ''' 6 | 7 | import json 8 | import h5py 9 | import string 10 | import codecs 11 | import sys 12 | import time 13 | 14 | corpus_tags = ['S', 'B', 'M', 'E'] 15 | retain_unknown = 'retain-unknown' 16 | retain_padding = 'retain-padding' 17 | 18 | def saveCwsInfo(path, cwsInfo): 19 | '''保存分词训练数据字典和概率''' 20 | print('save cws info to %s'%path) 21 | fd = open(path, 'w') 22 | (initProb, tranProb), (vocab, indexVocab) = cwsInfo 23 | j = json.dumps((initProb, tranProb)) 24 | fd.write(j + '\n') 25 | for char in vocab: 26 | fd.write(char.encode('utf-8') + '\t' + str(vocab[char]) + '\n') 27 | fd.close() 28 | 29 | def loadCwsInfo(path): 30 | '''载入分词训练数据字典和概率''' 31 | print('load cws info from %s'%path) 32 | fd = open(path, 'r') 33 | line = fd.readline() 34 | j = json.loads(line.strip()) 35 | initProb, tranProb = j[0], j[1] 36 | lines = fd.readlines() 37 | fd.close() 38 | vocab = {} 39 | indexVocab = [0 for i in range(len(lines))] 40 | for line in lines: 41 | rst = line.strip().split('\t') 42 | if len(rst) < 2: continue 43 | char, index = rst[0].decode('utf-8'), int(rst[1]) 44 | vocab[char] = index 45 | indexVocab[index] = char 46 | return (initProb, tranProb), (vocab, indexVocab) 47 | 48 | def saveCwsData(path, cwsData): 49 | '''保存分词训练输入样本''' 50 | print('save cws data to %s'%path) 51 | #采用hdf5保存大矩阵效率最高 52 | fd = h5py.File(path,'w') 53 | (X, y) = cwsData 54 | fd.create_dataset('X', data = X) 55 | fd.create_dataset('y', data = y) 56 | fd.close() 57 | 58 | def loadCwsData(path): 59 | '''载入分词训练输入样本''' 60 | print('load cws data from %s'%path) 61 | fd = h5py.File(path,'r') 62 | X = fd['X'][:] 63 | y = fd['y'][:] 64 | fd.close() 65 | return (X, y) 66 | 67 | def sent2vec2(sent, vocab, ctxWindows = 5): 68 | 69 | charVec = [] 70 | for char in sent: 71 | if char in vocab: 72 | charVec.append(vocab[char]) 73 | else: 74 | charVec.append(vocab[retain_unknown]) 75 | #首尾padding 76 | num = len(charVec) 77 | pad = int((ctxWindows - 1)/2) 78 | for i in range(pad): 79 | charVec.insert(0, vocab[retain_padding] ) 80 | charVec.append(vocab[retain_padding] ) 81 | X = [] 82 | for i in range(num): 83 | X.append(charVec[i:i + ctxWindows]) 84 | return X 85 | 86 | def sent2vec(sent, vocab, ctxWindows = 5): 87 | chars = [] 88 | for char in sent: 89 | chars.append(char) 90 | return sent2vec2(chars, vocab, ctxWindows = ctxWindows) 91 | 92 | def doc2vec(fname, vocab): 93 | '''文档转向量''' 94 | 95 | #一次性读入文件,注意内存 96 | fd = codecs.open(fname, 'r', 'utf-8') 97 | lines = fd.readlines() 98 | fd.close() 99 | 100 | #样本集 101 | X = [] 102 | y = [] 103 | 104 | #标注统计信息 105 | tagSize = len(corpus_tags) 106 | tagCnt = [0 for i in range(tagSize)] 107 | tagTranCnt = [[0 for i in range(tagSize)] for j in range(tagSize)] 108 | 109 | #遍历行 110 | for line in lines: 111 | #按空格分割 112 | words = line.strip('\n').split() 113 | #每行的分词信息 114 | chars = [] 115 | tags = [] 116 | #遍历词 117 | for word in words: 118 | #包含两个字及以上的词 119 | if len(word) > 1: 120 | #词的首字 121 | chars.append(word[0]) 122 | tags.append(corpus_tags.index('B')) 123 | #词中间的字 124 | for char in word[1:(len(word) - 1)]: 125 | chars.append(char) 126 | tags.append(corpus_tags.index('M')) 127 | #词的尾字 128 | chars.append(word[-1]) 129 | tags.append(corpus_tags.index('E')) 130 | #单字词 131 | else: 132 | chars.append(word) 133 | tags.append(corpus_tags.index('S')) 134 | 135 | #字向量表示 136 | lineVecX = sent2vec2(chars, vocab, ctxWindows = 7) 137 | 138 | #统计标注信息 139 | lineVecY = [] 140 | lastTag = -1 141 | for tag in tags: 142 | #向量 143 | lineVecY.append(tag) 144 | #lineVecY.append(corpus_tags[tag]) 145 | #统计tag频次 146 | tagCnt[tag] += 1 147 | #统计tag转移频次 148 | if lastTag != -1: 149 | tagTranCnt[lastTag][tag] += 1 150 | #暂存上一次的tag 151 | lastTag = tag 152 | 153 | X.extend(lineVecX) 154 | y.extend(lineVecY) 155 | 156 | #字总频次 157 | charCnt = sum(tagCnt) 158 | #转移总频次 159 | tranCnt = sum([sum(tag) for tag in tagTranCnt]) 160 | #tag初始概率 161 | initProb = [] 162 | for i in range(tagSize): 163 | initProb.append(tagCnt[i]/float(charCnt)) 164 | #tag转移概率 165 | tranProb = [] 166 | for i in range(tagSize): 167 | p = [] 168 | for j in range(tagSize): 169 | p.append(tagTranCnt[i][j]/float(tranCnt)) 170 | tranProb.append(p) 171 | 172 | return X, y, initProb, tranProb 173 | 174 | def genVocab(fname, delimiters = [' ', '\n']): 175 | 176 | #一次性读入文件,注意内存 177 | fd = codecs.open(fname, 'r', 'utf-8') 178 | data = fd.read() 179 | fd.close() 180 | 181 | vocab = {} 182 | indexVocab = [] 183 | #遍历 184 | index = 0 185 | for char in data: 186 | #如果为分隔符则无需加入字典 187 | if char not in delimiters and char not in vocab: 188 | vocab[char] = index 189 | indexVocab.append(char) 190 | index += 1 191 | 192 | #加入未登陆新词和填充词 193 | vocab[retain_unknown] = len(vocab) 194 | vocab[retain_padding] = len(vocab) 195 | indexVocab.append(retain_unknown) 196 | indexVocab.append(retain_padding) 197 | #返回字典与索引 198 | return vocab, indexVocab 199 | 200 | def load(fname): 201 | print 'train from file', fname 202 | delims = [' ', '\n'] 203 | vocab, indexVocab = genVocab(fname) 204 | X, y, initProb, tranProb = doc2vec(fname, vocab) 205 | print len(X), len(y), len(vocab), len(indexVocab) 206 | print initProb 207 | print tranProb 208 | return (X, y), (initProb, tranProb), (vocab, indexVocab) 209 | 210 | if __name__ == '__main__': 211 | start_time = time.time() 212 | 213 | if len(sys.argv) < 4: 214 | print globals()['__doc__'] % locals() 215 | sys.exit(1) 216 | input_file, cws_info_filePath, cws_data_filePath = sys.argv[1:4] 217 | 218 | (X, y), (initProb, tranProb), (vocab, indexVocab) = load(input_file) 219 | saveCwsInfo(cws_info_filePath, ((initProb, tranProb), (vocab, indexVocab))) 220 | saveCwsData(cws_data_filePath, (X, y)) 221 | 222 | end_time = time.time() 223 | print("used time : %d s" % (end_time - start_time)) -------------------------------------------------------------------------------- /run_pretrain.cmd: -------------------------------------------------------------------------------- 1 | python ./pretrain.py ./dataset/msr_training.utf8 ./cws.info ./cws.data -------------------------------------------------------------------------------- /run_test_keras_model.cmd: -------------------------------------------------------------------------------- 1 | python ./test_keras_model.py ./cws_keras_model ./keras_model_weights ./dataset/msr_test.utf8 ./dataset/msr_test.utf8.cws -------------------------------------------------------------------------------- /run_train_keras_model.cmd: -------------------------------------------------------------------------------- 1 | python ./train_keras_model.py ./cws.info ./cws.data ./cws_keras_model ./keras_model_weights -------------------------------------------------------------------------------- /test_keras_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | python test_keras_model.py cws_info_file keras_model_file keras_model_weights_file test_data_file output_file 5 | ''' 6 | 7 | import numpy as np 8 | import json 9 | import h5py 10 | import codecs 11 | import time 12 | import sys 13 | 14 | import pretrain as cws 15 | import viterbi 16 | 17 | from sklearn import model_selection 18 | 19 | from keras.preprocessing import sequence 20 | from keras.optimizers import SGD, RMSprop, Adagrad 21 | from keras.utils import np_utils 22 | from keras.models import Sequential,Graph, model_from_json 23 | from keras.layers.core import Dense, Dropout, Activation, TimeDistributedDense 24 | from keras.layers.embeddings import Embedding 25 | from keras.layers.recurrent import LSTM, GRU, SimpleRNN 26 | 27 | from gensim.models import Word2Vec 28 | 29 | def loadModel(modelPath, weightPath): 30 | 31 | fd = open(modelPath, 'r') 32 | j = fd.read() 33 | fd.close() 34 | 35 | model = model_from_json(j) 36 | 37 | model.load_weights(weightPath) 38 | 39 | return model 40 | 41 | 42 | # 根据输入得到标注推断 43 | def cwsSent(sent, model, cwsInfo): 44 | (initProb, tranProb), (vocab, indexVocab) = cwsInfo 45 | vec = cws.sent2vec(sent, vocab, ctxWindows = 7) 46 | vec = np.array(vec) 47 | probs = model.predict_proba(vec) 48 | #classes = model.predict_classes(vec) 49 | 50 | prob, path = viterbi.viterbi(vec, cws.corpus_tags, initProb, tranProb, probs.transpose()) 51 | 52 | ss = '' 53 | for i, t in enumerate(path): 54 | ss += '%s/%s '%(sent[i], cws.corpus_tags[t]) 55 | ss = '' 56 | word = '' 57 | for i, t in enumerate(path): 58 | if cws.corpus_tags[t] == 'S': 59 | ss += sent[i] + ' ' 60 | word = '' 61 | elif cws.corpus_tags[t] == 'B': 62 | word += sent[i] 63 | elif cws.corpus_tags[t] == 'E': 64 | word += sent[i] 65 | ss += word + ' ' 66 | word = '' 67 | elif cws.corpus_tags[t] == 'M': 68 | word += sent[i] 69 | 70 | return ss 71 | 72 | def cwsFile(fname, dstname, model, cwsInfo): 73 | fd = codecs.open(fname, 'r', 'utf-8') 74 | lines = fd.readlines() 75 | fd.close() 76 | 77 | fd = open(dstname, 'w') 78 | for line in lines: 79 | rst = cwsSent(line.strip(), model, cwsInfo) 80 | fd.write(rst.encode('utf-8') + '\n') 81 | fd.close() 82 | 83 | if __name__ == '__main__': 84 | if len(sys.argv) < 6: 85 | print(globals()['__doc__'] % locals()) 86 | sys.exit(1) 87 | cws_info_file, keras_model_file, keras_model_weights_file, test_data_file, output_file = sys.argv[1:6] 88 | 89 | cwsInfo = cws.loadCwsInfo(cws_info_file) 90 | print('Loading model...') 91 | start_time = time.time() 92 | model = loadModel(keras_model_file, keras_model_weights_file) 93 | print("Loading used time : ", time.time() - start_time) 94 | print('Done!') 95 | print('-------------start predict----------------') 96 | #s = u'为寂寞的夜空画上一个月亮' 97 | #print cwsSent(s, model, cwsInfo) 98 | cwsFile(test_data_file, output_file, model, cwsInfo) 99 | -------------------------------------------------------------------------------- /test_keras_model.sh: -------------------------------------------------------------------------------- 1 | /home/escenter11/gym/anaconda/bin/python ./test_keras_model.py ./cws.info ./cws_keras_model ./keras_model_weights ./dataset/msr_test.utf8 ./dataset/msr_test.utf8.cws 2 | -------------------------------------------------------------------------------- /train_keras_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | python train_keras_model.py cws_info_filePath cws_data_filePath output_keras_model_file output_keras_model_weights_file 5 | ''' 6 | 7 | import numpy as np 8 | import json 9 | import h5py 10 | import codecs 11 | import time 12 | import sys 13 | 14 | import pretrain as cws 15 | from uitls import viterbi 16 | 17 | from sklearn import model_selection 18 | 19 | from keras.preprocessing import sequence 20 | from keras.optimizers import SGD, RMSprop, Adagrad 21 | from keras.utils import np_utils 22 | from keras.models import Sequential,Graph, model_from_json 23 | from keras.layers.core import Dense, Dropout, Activation, TimeDistributedDense 24 | from keras.layers.embeddings import Embedding 25 | from keras.layers.recurrent import LSTM, GRU, SimpleRNN 26 | 27 | from gensim.models import Word2Vec 28 | 29 | def train(cwsInfo, cwsData, modelPath, weightPath): 30 | 31 | (initProb, tranProb), (vocab, indexVocab) = cwsInfo 32 | (X, y) = cwsData 33 | 34 | train_X, test_X, train_y, test_y = model_selection.train_test_split(X, y , train_size=0.9, random_state=1) 35 | 36 | train_X = np.array(train_X) 37 | train_y = np.array(train_y) 38 | test_X = np.array(test_X) 39 | test_y = np.array(test_y) 40 | 41 | outputDims = len(cws.corpus_tags) 42 | Y_train = np_utils.to_categorical(train_y, outputDims) 43 | Y_test = np_utils.to_categorical(test_y, outputDims) 44 | batchSize = 128 45 | vocabSize = len(vocab) + 1 46 | wordDims = 100 47 | maxlen = 7 48 | hiddenDims = 100 49 | 50 | w2vModel = Word2Vec.load('./word2vec_model/msr_training_word2vec.model') 51 | embeddingDim = w2vModel.vector_size 52 | embeddingUnknown = [0 for i in range(embeddingDim)] 53 | embeddingWeights = np.zeros((vocabSize + 1, embeddingDim)) 54 | for word, index in vocab.items(): 55 | if word in w2vModel: 56 | e = w2vModel[word] 57 | else: 58 | e = embeddingUnknown 59 | embeddingWeights[index, :] = e 60 | 61 | #LSTM 62 | model = Sequential() 63 | model.add(Embedding(output_dim = embeddingDim, input_dim = vocabSize + 1, 64 | input_length = maxlen, mask_zero = True, weights = [embeddingWeights])) 65 | model.add(LSTM(output_dim = hiddenDims, return_sequences = True)) 66 | model.add(LSTM(output_dim = hiddenDims, return_sequences = False)) 67 | model.add(Dropout(0.5)) 68 | model.add(Dense(outputDims)) 69 | model.add(Activation('softmax')) 70 | model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics=["accuracy"]) 71 | 72 | result = model.fit(train_X, Y_train, batch_size = batchSize, 73 | nb_epoch = 20, validation_data = (test_X,Y_test)) 74 | 75 | j = model.to_json() 76 | fd = open(modelPath, 'w') 77 | fd.write(j) 78 | fd.close() 79 | 80 | model.save_weights(weightPath) 81 | 82 | return model 83 | 84 | if __name__ == '__main__': 85 | if len(sys.argv) < 5: 86 | print globals()['__doc__'] % locals() 87 | sys.exit(1) 88 | cws_info_filePath, cws_data_filePath, output_keras_model_file, output_keras_model_weights_file = sys.argv[1:5] 89 | 90 | print 'Loading vocab...' 91 | start_time = time.time() 92 | cwsInfo = cws.loadCwsInfo(cws_info_filePath) 93 | cwsData = cws.loadCwsData(cws_data_filePath) 94 | print("Loading used time : ", time.time() - start_time) 95 | print 'Done!' 96 | 97 | print 'Training model...' 98 | start_time = time.time() 99 | model = train(cwsInfo, cwsData, output_keras_model_file, output_keras_model_weights_file) 100 | print("Training used time : ", time.time() - start_time) 101 | print 'Done!' -------------------------------------------------------------------------------- /train_keras_model.sh: -------------------------------------------------------------------------------- 1 | /home/escenter11/gym/anaconda/bin/python ./train_keras_model.py ./cws.info ./cws.data ./cws_keras_model ./keras_model_weights 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clayandgithub/rnn_cws/32f32cec87444bba4e0e3858c5972d0fac837f0a/utils/__init__.py -------------------------------------------------------------------------------- /utils/viterbi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | #2016年 01月 28日 星期四 17:14:03 CST by Demobin 5 | 6 | def _print(hiddenstates, V): 7 | s = " " + " ".join(("%7d" % i) for i in range(len(V))) + "\n" 8 | for i, state in enumerate(hiddenstates): 9 | s += "%.5s: " % state 10 | s += " ".join("%.7s" % ("%f" % v[i]) for v in V) 11 | s += "\n" 12 | print(s) 13 | 14 | #标准viterbi算法,参数为观察状态、隐藏状态、概率三元组(初始概率、转移概率、观察概率) 15 | def viterbi(obs, states, start_p, trans_p, emit_p): 16 | 17 | lenObs = len(obs) 18 | lenStates = len(states) 19 | 20 | V = [[0.0 for col in range(lenStates)] for row in range(lenObs)] 21 | path = [[0 for col in range(lenObs)] for row in range(lenStates)] 22 | 23 | #t = 0时刻 24 | for y in range(lenStates): 25 | #V[0][y] = start_p[y] * emit_p[y][obs[0]] 26 | V[0][y] = start_p[y] * emit_p[y][0] 27 | path[y][0] = y 28 | 29 | #t > 1时 30 | for t in range(1, lenObs): 31 | newpath = [[0.0 for col in range(lenObs)] for row in range(lenStates)] 32 | 33 | for y in range(lenStates): 34 | prob = -1 35 | state = 0 36 | for y0 in range(lenStates): 37 | #nprob = V[t - 1][y0] * trans_p[y0][y] * emit_p[y][obs[t]] 38 | nprob = V[t - 1][y0] * trans_p[y0][y] * emit_p[y][t] 39 | if nprob > prob: 40 | prob = nprob 41 | state = y0 42 | #记录最大概率 43 | V[t][y] = prob 44 | #记录路径 45 | newpath[y][:t] = path[state][:t] 46 | newpath[y][t] = y 47 | 48 | path = newpath 49 | 50 | prob = -1 51 | state = 0 52 | for y in range(lenStates): 53 | if V[lenObs - 1][y] > prob: 54 | prob = V[lenObs - 1][y] 55 | state = y 56 | 57 | #_print(states, V) 58 | return prob, path[state] 59 | 60 | def example(): 61 | #隐藏状态 62 | hiddenstates = ('Healthy', 'Fever') 63 | #观察状态 64 | observations = ('normal', 'cold', 'dizzy') 65 | 66 | #初始概率 67 | ''' 68 | Healthy': 0.6, 'Fever': 0.4 69 | ''' 70 | start_p = [0.6, 0.4] 71 | #转移概率 72 | ''' 73 | Healthy' : {'Healthy': 0.7, 'Fever': 0.3}, 74 | Fever' : {'Healthy': 0.4, 'Fever': 0.6} 75 | ''' 76 | trans_p = [[0.7, 0.3], [0.4, 0.6]] 77 | #发射概率/输出概率/观察概率 78 | ''' 79 | Healthy' : {'normal': 0.5, 'cold': 0.4, 'dizzy': 0.1}, 80 | Fever' : {'normal': 0.1, 'cold': 0.3, 'dizzy': 0.6} 81 | ''' 82 | emit_p = [[0.5, 0.4, 0.1], [0.1, 0.3, 0.6]] 83 | 84 | return viterbi(observations, 85 | hiddenstates, 86 | start_p, 87 | trans_p, 88 | emit_p) 89 | 90 | if __name__ == '__main__': 91 | print(example()) -------------------------------------------------------------------------------- /utils/viterbi.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clayandgithub/rnn_cws/32f32cec87444bba4e0e3858c5972d0fac837f0a/utils/viterbi.pyc -------------------------------------------------------------------------------- /word2vec_model/.gitignore: -------------------------------------------------------------------------------- 1 | dataset -------------------------------------------------------------------------------- /word2vec_model/run.cmd: -------------------------------------------------------------------------------- 1 | python ./train_word2vec_model.py ./dataset/msr_training.utf8 ./msr_training_word2vec.model ./msr_training_word2vec.vector -------------------------------------------------------------------------------- /word2vec_model/train_word2vec_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | python train_word2vec_model.py input_file output_model_file output_vector_file 5 | ''' 6 | 7 | # import modules & set up logging 8 | import os 9 | import sys 10 | import logging 11 | import multiprocessing 12 | import time 13 | import json 14 | 15 | from gensim.models import Word2Vec 16 | from gensim.models.word2vec import LineSentence 17 | 18 | def output_vocab(vocab): 19 | for k, v in vocab.items(): 20 | print(k) 21 | 22 | if __name__ == '__main__': 23 | start_time = time.time() 24 | 25 | program = os.path.basename(sys.argv[0]) 26 | logger = logging.getLogger(program) 27 | 28 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 29 | logger.info("running %s" % ' '.join(sys.argv)) 30 | 31 | # check and process input arguments 32 | if len(sys.argv) < 4: 33 | print globals()['__doc__'] % locals() 34 | sys.exit(1) 35 | input_file, output_model_file, output_vector_file = sys.argv[1:4] 36 | 37 | model = Word2Vec(LineSentence(input_file), size=128, window=5, min_count=5, 38 | workers=multiprocessing.cpu_count()) 39 | 40 | # trim unneeded model memory = use(much) less RAM 41 | #model.init_sims(replace=True) 42 | model.save(output_model_file) 43 | model.save_word2vec_format(output_vector_file, binary=False) 44 | 45 | end_time = time.time() 46 | print("used time : %d s" % (end_time - start_time)) --------------------------------------------------------------------------------