├── .gitignore ├── README.md ├── configuration.py ├── data └── build_pku_msr_input.py ├── inference.py ├── lstm_based_cws_model.py ├── ops ├── input_ops.py └── vocab.py ├── process_chr_embedding.py ├── train.py └── word_count /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | .venv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | #model and data folder 94 | save_model/ 95 | data/output_dir 96 | output/ 97 | 98 | #notebook 99 | *.ipynb 100 | 101 | #pickle 102 | *.pkl 103 | 104 | #ini file 105 | *.ini 106 | 107 | *.TFRecord 108 | data/ 109 | ckpt/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow中文分词模型 2 | 3 | 注: 如果对准确度比较高的要求, 请使用 https://github.com/JayYip/bert-multitask-learning 4 | 5 | 部分代码参考 [TensorFlow Model Zoo](https://github.com/tensorflow/models) 6 | 7 | 运行环境: 8 | 9 | - Python 3.5 / Python 2.7 10 | - Tensorflow r1.4 11 | - Windows / Ubuntu 16.04 12 | - hanziconv 0.3.2 13 | - numpy 14 | 15 | ## 训练模型 16 | 17 | ### 1. 建立训练数据 18 | 进入到data目录下,执行以下命令 19 | 20 | ``` 21 | DATA_OUTPUT="output_dir" 22 | 23 | python build_pku_msr_input.py \ 24 | --num_threads=4 \ 25 | --output_dir=${DATA_OUTPUT} 26 | ``` 27 | 28 | ### 2. 字符嵌入 29 | 30 | #### 2.1 预训练好的字嵌入 31 | 1. 将`configuration.py`中的`ModelConfig`的`self.random_embedding`设置为`False` 32 | 2. 从[Polygot](https://sites.google.com/site/rmyeid/projects/polyglot)下载中文字嵌入数据集至项目目录,运行项目目录下`process_chr_embedding.py`。 33 | 34 | ``` 35 | EMBEDDING_DIR=... 36 | VOCAB_DIR=... 37 | 38 | python process_chr_embedding.py \ 39 | --chr_embedding_dir=${EMBEDDING_DIR} 40 | --vocab_dir=${VOCAB_DIR} 41 | ``` 42 | 43 | #### 2.2 随机初始化字嵌入 44 | 45 | 将`configuration.py`中的`ModelConfig`的`self.random_embedding`设置为`True` 46 | 47 | ### 3. 训练模型 48 | 49 | 根据需要修改configuration.py里面的模型及训练参数,开始训练模型。 50 | 以下参数如不提供将会使用默认值。 51 | 52 | ``` 53 | TRAIN_INPUT="data\${DATA_OUTPUT}" 54 | MODEL="save_model" 55 | 56 | python train.py \ 57 | --input_file_dir=${TRAIN_INPUT} \ 58 | --train_dir=${MODEL} \ 59 | --log_every_n_steps=10 60 | 61 | ``` 62 | 63 | ## 使用训练好的模型进行分词 64 | 65 | 编码须为utf8,检测的后缀为'txt','csv', 'utf8'。 66 | 67 | ``` 68 | INF_INPUT=... 69 | INF_OUTPUT=... 70 | 71 | python inference.py \ 72 | --input_file_dir=${INF_INPUT} \ 73 | --train_dir=${MODEL} \ 74 | --vocab_dir=${VOCAB_DIR} \ 75 | --out_dir=${INF_OUTPUT} 76 | ``` 77 | 78 | ## 如何根据自己需要修改算法 79 | 80 | 本模型使用的是单向LSTM+CRF,但是提供了算法修改的可能性。在```lstm_based_cws_model.py```文件中的 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 04Mar2017 5 | """Set the configuration of model and training parameters""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | 12 | class ModelConfig(object): 13 | """docstring for ModelConfig""" 14 | 15 | def __init__(self): 16 | 17 | #Set the feature name of context and tags 18 | self.context_feature_name = 'content_id' 19 | self.tag_feature_name = 'tag_id' 20 | self.length_name = 'length' 21 | 22 | #Number of thread for prefetching SequenceExample 23 | #self.num_input_reader_thread = 2 24 | #Number of preprocessing threads 25 | self.num_preprocess_thread = 2 26 | 27 | #Batch size 28 | self.batch_size = 32 29 | 30 | #LSTM input and output dimensions 31 | self.embedding_size = 64 32 | self.num_lstm_units = 128 33 | 34 | #Fully connected layer output dimensions 35 | self.num_tag = 5 36 | 37 | #Dropout 38 | self.lstm_dropout_keep_prob = 0.35 39 | #Margin loss discount 40 | self.margin_loss_discount = 0.2 41 | #Regularization 42 | self.regularization = 0.0001 43 | 44 | self.seq_max_len = 60 45 | 46 | 47 | class TrainingConfig(object): 48 | """docstring for TrainingConfig""" 49 | 50 | def __init__(self): 51 | 52 | self.num_examples_per_epoch = 500000 53 | 54 | #Optimizer for training 55 | self.optimizer = 'Adam' 56 | 57 | #Learning rate 58 | self.initial_learning_rate = 0.01 59 | #If decay factor <= 0 then not decay 60 | self.learning_rate_decay_factor = 0.5 61 | self.num_epochs_per_decay = 3 62 | 63 | #Gradient clipping 64 | self.clip_gradients = 1.0 65 | 66 | #Max checkpoints to keep 67 | self.max_checkpoints_to_keep = 2 68 | 69 | #Set training step 70 | self.training_step = 3000000 71 | 72 | self.embedding_random = False 73 | -------------------------------------------------------------------------------- /data/build_pku_msr_input.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 20Feb2017 5 | """ 6 | Download the PKU-MSR datasets or Chinese Wiki dataset and convert 7 | it to TFRecords. 8 | 9 | PKU-MSR Download Address: http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip 10 | Chinese Wiki Address: PENDING 11 | 12 | Each file is a TFRecord 13 | 14 | Output: 15 | download_dir/train-00000-of-00xxx 16 | ... 17 | download_dir/train-00127-of-00xxx 18 | 19 | Processing Description: 20 | 21 | 22 | Files Description: 23 | 24 | 25 | """ 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import sys 32 | if sys.version_info >= (3, 0): 33 | import urllib.request 34 | else: 35 | from six.moves import urllib 36 | import zipfile 37 | import zlib 38 | import os 39 | from collections import Counter 40 | import numpy as np 41 | import threading 42 | from datetime import datetime 43 | import pickle 44 | from hanziconv.hanziconv import HanziConv 45 | from multiprocessing import Process 46 | from itertools import chain, islice 47 | from builtins import open 48 | 49 | import tensorflow as tf 50 | 51 | tf.flags.DEFINE_string("data_source", "pku-msr", 52 | "Specify the data source: pku-msr or wiki-chn") 53 | tf.flags.DEFINE_string("download_dir", "download_dir", "Output data directory.") 54 | tf.flags.DEFINE_string("word_counts_output_file", "word_count", 55 | "Word Count output dir") 56 | tf.flags.DEFINE_integer("train_shards", 128, 57 | "Number of shards in training TFRecord files.") 58 | tf.flags.DEFINE_integer("num_threads", 8, 59 | "Number of threads to preprocess the images.") 60 | tf.flags.DEFINE_integer("window_size", 5, "The window size of skip-gram model") 61 | tf.flags.DEFINE_integer("seq_max_len", 30, "Max length of seqence") 62 | FLAGS = tf.flags.FLAGS 63 | 64 | 65 | class Vocabulary(object): 66 | """Simple vocabulary wrapper.""" 67 | 68 | def __init__(self, vocab, id_vocab, unk_id, unk_word=''): 69 | """Initializes the vocabulary. 70 | 71 | Args: 72 | vocab: A dictionary of word to word_id. 73 | unk_id: Id of the special 'unknown' word. 74 | """ 75 | self._vocab = vocab 76 | self._id_vocab = id_vocab 77 | self._unk_id = unk_id 78 | self._vocab[unk_word] = len(self._vocab) 79 | self._id_vocab[len(self._vocab)] = unk_word 80 | 81 | def word_to_id(self, word): 82 | """Returns the integer id of a word string.""" 83 | if word in self._vocab: 84 | return self._vocab[word] 85 | else: 86 | return self._unk_id 87 | 88 | def id_to_word(self, word_id): 89 | """Returns the word string of an integer word id.""" 90 | if word_id >= len(self._vocab): 91 | return self._id_vocab[self.unk_id] 92 | else: 93 | return self._id_vocab[word_id] 94 | 95 | 96 | def tag_to_id(t): 97 | 98 | if t == 's': 99 | return 1 100 | 101 | elif t == 'b': 102 | return 2 103 | 104 | elif t == 'm': 105 | return 3 106 | 107 | elif t == 'e': 108 | return 4 109 | 110 | 111 | #Line processing functions 112 | 113 | 114 | def split_list(alist, wanted_parts=1): 115 | length = len(alist) 116 | return [ 117 | alist[i * length // wanted_parts:(i + 1) * length // wanted_parts] 118 | for i in range(wanted_parts) 119 | ] 120 | 121 | 122 | def process_line_msr_pku(l): 123 | decoded_line = l.decode('utf8').strip().split(' ') 124 | return [w.strip('\r\n') for w in decoded_line] 125 | 126 | 127 | def process_line_as_training(l): 128 | if sys.version_info >= (3, 0): 129 | decoded_line = HanziConv.toSimplified( 130 | l.decode('utf8')).strip().split('\u3000') 131 | else: 132 | decoded_line = HanziConv.toSimplified( 133 | l.decode('utf8')).strip().split(u'\u3000') 134 | return [w.strip('\r\n') for w in decoded_line] 135 | 136 | 137 | def process_line_cityu(l): 138 | decoded_line = HanziConv.toSimplified(l.decode('utf8')).strip().split(' ') 139 | return [w.strip('\r\n') for w in decoded_line] 140 | 141 | 142 | def get_process_fn(filename): 143 | 144 | if 'msr' in filename or 'pk' in filename: 145 | return process_line_msr_pku 146 | 147 | elif 'as' in filename: 148 | return process_line_as_training 149 | 150 | elif 'cityu' in filename: 151 | return process_line_cityu 152 | 153 | 154 | def _is_valid_data_source(data_source): 155 | return data_source in ['pku-msr', 'wiki-chn'] 156 | 157 | 158 | # Convert feature functions 159 | def _int64_feature(value): 160 | """Wrapper for inserting an int64 Feature into a SequenceExample proto.""" 161 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 162 | 163 | 164 | def _bytes_feature(value): 165 | """Wrapper for inserting a bytes Feature into a SequenceExample proto.""" 166 | return tf.train.Feature(bytes_list=tf.train.BytesList( 167 | value=[value.encode('utf8')])) 168 | 169 | 170 | def _int64_feature_list(values): 171 | """Wrapper for inserting an int64 FeatureList into a SequenceExample proto.""" 172 | return tf.train.FeatureList(feature=[_int64_feature(v) for v in values]) 173 | 174 | 175 | def _bytes_feature_list(values): 176 | """Wrapper for inserting a bytes FeatureList into a SequenceExample proto.""" 177 | return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values]) 178 | 179 | 180 | def download_extract(data_source, download='Y'): 181 | """ 182 | Download files from web and extract 183 | """ 184 | if data_source == 'pku-msr': 185 | 186 | if download == 'Y': 187 | file_name = 'icwb2-data.zip' 188 | if sys.version_info >= (3, 0): 189 | urllib.request.urlretrieve( 190 | 'http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip', 191 | os.path.join(FLAGS.download_dir, file_name)) 192 | else: 193 | urllib.request.urlopen('http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip', 194 | os.path.join(FLAGS.download_dir, file_name)) 195 | 196 | zip_ref = zipfile.ZipFile( 197 | os.path.join(FLAGS.download_dir, file_name), 'r') 198 | zip_ref.extractall(FLAGS.download_dir) 199 | zip_ref.close() 200 | 201 | elif data_source == 'wiki-chn': 202 | 203 | #Implement in the future... 204 | #If there's future... 205 | pass 206 | 207 | else: 208 | assert _is_valid_num_shards(FLAGS.data_source), ( 209 | "Please make sure the data source is either 'pku-msr' or 'wiki-chn'" 210 | ) 211 | 212 | 213 | def _create_vocab(path_list): 214 | """ 215 | Create vocab objects 216 | """ 217 | 218 | counter = Counter() 219 | row_count = 0 220 | 221 | for file_path in path_list: 222 | print("Processing" + file_path) 223 | with open(file_path, 'rb') as f: 224 | for l in f: 225 | counter.update(HanziConv.toSimplified(l.decode('utf8'))) 226 | row_count = row_count + 1 227 | 228 | print("Total char:", len(counter)) 229 | 230 | # Filter uncommon words and sort by descending count. 231 | word_counts = [x for x in counter.items() if x is not ' '] 232 | word_counts.sort(key=lambda x: x[1], reverse=True) 233 | print("Words in vocabulary:", len(word_counts)) 234 | 235 | # Write out the word counts file. 236 | with open(FLAGS.word_counts_output_file, "wb") as f: 237 | 238 | #line = str("\n".join(["%s %d" % (w, c) for w, c in word_counts])) 239 | line = ["%s %d" % (w, c) for w, c in word_counts] 240 | line = "\n".join(w for w in line).encode('utf8') 241 | 242 | f.write(line) 243 | print("Wrote vocabulary file:", FLAGS.word_counts_output_file) 244 | 245 | # Create the vocabulary dictionary. 246 | reverse_vocab = [x[0] for x in word_counts] 247 | unk_id = len(reverse_vocab) 248 | vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 249 | id_vocab_dict = dict([(y, x) for (y, x) in enumerate(reverse_vocab)]) 250 | vocab = Vocabulary(vocab_dict, id_vocab_dict, unk_id) 251 | 252 | return vocab 253 | 254 | 255 | def _to_sequence_example(decoded_str, pos_tag_str, vocab): 256 | 257 | #Transfor word to word_id 258 | content_id = [vocab.word_to_id(c) for c in decoded_str] 259 | content_id = content_id[:FLAGS.seq_max_len] 260 | tag_id = [tag_to_id(t) for t in pos_tag_str] 261 | tag_id = tag_id[:FLAGS.seq_max_len] 262 | length = min(FLAGS.seq_max_len, len(content_id)) 263 | 264 | feature_lists = tf.train.FeatureLists(feature_list={ 265 | "content_id": 266 | _int64_feature_list(content_id), 267 | "tag_id": 268 | _int64_feature_list(tag_id) 269 | }) 270 | 271 | context = tf.train.Features(feature={"length": _int64_feature(length)}) 272 | 273 | sequence_example = tf.train.SequenceExample( 274 | feature_lists=feature_lists, context=context) 275 | 276 | return sequence_example 277 | 278 | 279 | def _process_text_files(thread_index, name, path_list, vocab, num_shards): 280 | 281 | #Create possible tags for fast lookup 282 | possible_tags = [] 283 | for i in range(1, 300): 284 | if i == 1: 285 | possible_tags.append('s') 286 | else: 287 | possible_tags.append('b' + 'm' * (i - 2) + 'e') 288 | 289 | for s in range(len(path_list)): 290 | filename = path_list[s] 291 | #Create file names for shards 292 | output_filename = "%s-%s" % (name, 293 | os.path.split(filename)[-1]) 294 | output_file = os.path.join(output_filename + '.TFRecord') 295 | 296 | #Init writer 297 | writer = tf.python_io.TFRecordWriter(output_file) 298 | 299 | #Get the input file name 300 | 301 | counter = 0 302 | 303 | #Init left and right queue 304 | 305 | sequence_example = None 306 | with open(filename, 'rb') as f: 307 | 308 | process_fn = get_process_fn(os.path.split(filename)[-1]) 309 | 310 | for l in f: 311 | pos_tag = [] 312 | final_line = [] 313 | 314 | decoded_line = process_fn(l) 315 | 316 | 317 | for w in decoded_line: 318 | if w and len(w) <= 299: 319 | final_line.append(w) 320 | pos_tag.append(possible_tags[len(w) - 1]) 321 | 322 | decode_str = ''.join(final_line) 323 | 324 | pos_tag_str = ''.join(pos_tag) 325 | 326 | if len(pos_tag_str) != len(decode_str): 327 | continue 328 | print('Skip one row. ' + pos_tag_str + ';' + decode_str) 329 | 330 | if len(decode_str) > 0: 331 | sequence_example = _to_sequence_example( 332 | decode_str, pos_tag_str, vocab) 333 | writer.write(sequence_example.SerializeToString()) 334 | counter += 1 335 | 336 | if not counter % 5000: 337 | print("%s [thread %d]: Processed %d in thread batch." % 338 | (datetime.now(), thread_index, counter)) 339 | sys.stdout.flush() 340 | 341 | writer.close() 342 | print("%s [thread %d]: Finished writing to %s" % 343 | (datetime.now(), thread_index, output_file)) 344 | sys.stdout.flush() 345 | counter = 0 346 | 347 | 348 | def _process_dataset(name, path_list, vocab): 349 | """ 350 | """ 351 | 352 | #Set number of threads 353 | num_threads = FLAGS.num_threads 354 | num_shards = len(path_list) 355 | 356 | #Decide 357 | spacing = np.linspace(0, len(path_list), num_threads + 1).astype(np.int) 358 | ranges = [] 359 | threads = [] 360 | for i in range(len(spacing) - 1): 361 | ranges.append([spacing[i], spacing[i + 1]]) 362 | 363 | # Create a mechanism for monitoring when all threads are finished. 364 | coord = tf.train.Coordinator() 365 | 366 | #Assign path_list based on thread to avoid error 367 | path_list_list = split_list(path_list, wanted_parts=num_threads) 368 | print(path_list_list) 369 | 370 | #Launch thread for batch processing 371 | print("Launching %d threads" % (num_threads)) 372 | for thread_index in range(num_threads): 373 | args = (thread_index, name, path_list_list[thread_index], vocab, 374 | num_shards) 375 | t = Process(target=_process_text_files, args=args) 376 | t.start() 377 | threads.append(t) 378 | 379 | # Wait for all the threads to terminate. 380 | coord.join(threads) 381 | print("%s: Finished processing all %d text files in data set '%s'." % 382 | (datetime.now(), len(path_list), name)) 383 | 384 | 385 | def get_path(data_dir='.', suffix='utf8', mode='train'): 386 | 387 | path_list = [] 388 | for dirpath, dirnames, filenames in os.walk(data_dir): 389 | for filename in filenames: 390 | fullpath = os.path.join(dirpath, filename) 391 | if fullpath.endswith(suffix) and mode in fullpath: 392 | path_list.append(fullpath) 393 | 394 | return path_list 395 | 396 | def split_files(path_list, num_rows = 50000): 397 | tmp_dir = os.path.join(FLAGS.download_dir, 'tmp') 398 | if not os.path.exists(tmp_dir): 399 | os.mkdir(tmp_dir) 400 | return_path_list = [] 401 | 402 | def chunks(iterable, n): 403 | "chunks(ABCDE,2) => AB CD E" 404 | iterable = iter(iterable) 405 | while True: 406 | # store one line in memory, 407 | # chain it to an iterator on the rest of the chunk 408 | yield chain([next(iterable)], islice(iterable, n-1)) 409 | 410 | for path in path_list: 411 | with open(path, encoding = 'utf8') as bigfile: 412 | for i, lines in enumerate(chunks(bigfile, num_rows)): 413 | file_split = '{}.{}'.format(os.path.split(path)[-1], i) 414 | write_file = os.path.join(tmp_dir, file_split) 415 | return_path_list.append(write_file) 416 | with open(write_file, 'w', encoding = 'utf8') as f: 417 | f.writelines(lines) 418 | 419 | return return_path_list 420 | 421 | 422 | 423 | 424 | def main(unused_argv): 425 | 426 | try: 427 | os.makedirs(FLAGS.download_dir) 428 | except (OSError, IOError) as err: 429 | # Windows may complain if the folders already exist 430 | pass 431 | 432 | download_extract(FLAGS.data_source, 'N') 433 | 434 | path_list = get_path(data_dir=os.path.join(FLAGS.download_dir, 'icwb2-data', 435 | 'training')) 436 | 437 | 438 | 439 | vocab = _create_vocab(path_list) 440 | pickle.dump(vocab, open('vocab.pkl', 'wb')) 441 | 442 | path_list = split_files(path_list) 443 | 444 | trimmed_path_list = [] 445 | for filename in path_list: 446 | output_filename = "%s-%s" % ('train', 447 | os.path.split(filename)[-1]) 448 | output_file = os.path.join(output_filename + '.TFRecord') 449 | if os.path.isfile(output_file): 450 | pass 451 | else: 452 | trimmed_path_list.append(filename) 453 | 454 | path_list = trimmed_path_list 455 | 456 | _process_dataset('train', path_list, vocab) 457 | 458 | 459 | 460 | if __name__ == '__main__': 461 | tf.app.run() 462 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 22Mar2017 5 | """Inference""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 13 | 14 | import tensorflow as tf 15 | import pickle 16 | from hanziconv.hanziconv import HanziConv 17 | import numpy as np 18 | 19 | from ops import input_ops 20 | from ops.vocab import Vocabulary 21 | import configuration 22 | from lstm_based_cws_model import LSTMCWS 23 | 24 | tf.flags.DEFINE_string("input_file_dir", "data/download_dir/icwb2-data/gold/", 25 | "Path of input files.") 26 | tf.flags.DEFINE_string("vocab_dir", "data/vocab.pkl", 27 | "Path of vocabulary file.") 28 | tf.flags.DEFINE_string("train_dir", "save_model", 29 | "Directory for saving and loading model checkpoints.") 30 | tf.flags.DEFINE_string("out_dir", 'output', 31 | "Frequency at which loss and global step are logged.") 32 | 33 | FLAGS = tf.app.flags.FLAGS 34 | 35 | 36 | def _create_restore_fn(checkpoint_path, saver): 37 | """Creates a function that restores a model from checkpoint. 38 | 39 | Args: 40 | checkpoint_path: Checkpoint file or a directory containing a checkpoint 41 | file. 42 | saver: Saver for restoring variables from the checkpoint file. 43 | 44 | Returns: 45 | restore_fn: A function such that restore_fn(sess) loads model variables 46 | from the checkpoint file. 47 | 48 | Raises: 49 | ValueError: If checkpoint_path does not refer to a checkpoint file or a 50 | directory containing a checkpoint file. 51 | """ 52 | if tf.gfile.IsDirectory(checkpoint_path): 53 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) 54 | if not checkpoint_path: 55 | raise ValueError( 56 | "No checkpoint file found in: %s" % checkpoint_path) 57 | 58 | def _restore_fn(sess): 59 | tf.logging.info("Loading model from checkpoint: %s", checkpoint_path) 60 | saver.restore(sess, checkpoint_path) 61 | tf.logging.info("Successfully loaded checkpoint: %s", 62 | os.path.basename(checkpoint_path)) 63 | 64 | return _restore_fn 65 | 66 | 67 | def insert_space(char, tag): 68 | if tag == 1 or tag == 4: 69 | return char + ' ' 70 | else: 71 | return char 72 | 73 | 74 | def get_final_output(line, predict_tag): 75 | return ''.join( 76 | [insert_space(char, tag) for char, tag in zip(line, predict_tag)]) 77 | 78 | 79 | def append_to_file(output_buffer, filename): 80 | #filename = os.path.join(FLAGS.out_dir, 'out_' + os.path.split(filename)[-1]) 81 | 82 | if os.path.exists(filename): 83 | append_write = 'ab' # append if already exists 84 | else: 85 | append_write = 'wb' # make a new file if not 86 | 87 | with open(filename, append_write) as file: 88 | for item in output_buffer: 89 | file.write(item.encode('utf8') + b'\n') 90 | 91 | 92 | def tag_to_id(t): 93 | if t == 's': 94 | return 1 95 | 96 | elif t == 'b': 97 | return 2 98 | 99 | elif t == 'm': 100 | return 3 101 | 102 | elif t == 'e': 103 | return 4 104 | 105 | 106 | def seq_acc(seq1, seq2): 107 | correct = 0 108 | 109 | for seq_ind, char in enumerate(seq1): 110 | if char == seq2[seq_ind]: 111 | correct += 1 112 | 113 | return correct 114 | 115 | 116 | def main(unused_argv): 117 | 118 | #Preprocess before building graph 119 | #Read vocab file 120 | with open(FLAGS.vocab_dir, 'rb') as f: 121 | p = pickle.load(f) 122 | 123 | if not tf.gfile.IsDirectory(FLAGS.out_dir): 124 | tf.logging.info('Create Output dir as %s', FLAGS.out_dir) 125 | tf.gfile.MakeDirs(FLAGS.out_dir) 126 | 127 | filename_list = [] 128 | for dirpath, dirnames, filenames in os.walk(FLAGS.input_file_dir): 129 | for filename in filenames: 130 | fullpath = os.path.join(dirpath, filename) 131 | if fullpath.split('.')[-1] in ['utf8'] and 'test' in fullpath: 132 | filename_list.append(fullpath) 133 | 134 | checkpoint_path = FLAGS.train_dir 135 | 136 | model_config = configuration.ModelConfig() 137 | 138 | #Create possible tags for fast lookup 139 | possible_tags = [] 140 | for i in range(1, 300): 141 | if i == 1: 142 | possible_tags.append('s') 143 | else: 144 | possible_tags.append('b' + 'm' * (i - 2) + 'e') 145 | 146 | #Build graph for inference 147 | g = tf.Graph() 148 | with g.as_default(): 149 | 150 | input_seq_feed = tf.placeholder(name='input_seq_feed', dtype=tf.int64) 151 | seq_length = tf.placeholder(name='seq_length', dtype=tf.int64) 152 | 153 | #Build model 154 | model = LSTMCWS(model_config, 'inference') 155 | print('Building model...') 156 | model.build() 157 | 158 | with tf.Session(graph=g) as sess: 159 | 160 | #Restore ckpt 161 | saver = tf.train.Saver() 162 | restore_fn = _create_restore_fn(checkpoint_path, saver) 163 | restore_fn(sess) 164 | 165 | for filename in filename_list: 166 | output_buffer = [] 167 | num_correct = 0 168 | num_total = 0 169 | proc_fn = input_ops.get_process_fn(filename) 170 | with open(filename, 'rb') as f: 171 | # set out name and remove old output 172 | out_filename = os.path.join(FLAGS.out_dir, 'out_' + os.path.split(filename)[-1]) 173 | if os.path.exists(out_filename): 174 | os.remove(out_filename) 175 | 176 | for line in f: 177 | l = proc_fn(line) 178 | input_seqs_list = [p.word_to_id(x) for x in ''.join(l)] 179 | 180 | #get seqence label 181 | #str_input_seqs_list = [str(x) for x in input_seqs_list] 182 | input_label = [] 183 | for w in l: 184 | if len(w) > 0 and len(w) <= 299: 185 | input_label.append(possible_tags[len(w) - 1]) 186 | elif len(w) == 0: 187 | pass 188 | else: 189 | input_label.append('s') 190 | 191 | str_input_label = ''.join(input_label) 192 | input_label = [tag_to_id(x) for x in str_input_label] 193 | 194 | #get input sequence, seq length 195 | input_seqs_list = [x for x in input_seqs_list if x != 1] 196 | seq_len = min( 197 | len(input_seqs_list), model_config.seq_max_len) 198 | # pad to same shape 199 | for _ in range(model_config.seq_max_len): 200 | input_seqs_list.append(0) 201 | input_seqs_list = input_seqs_list[:model_config.seq_max_len] 202 | 203 | #get seqence length 204 | input_label = input_label[:model_config.seq_max_len] 205 | 206 | if seq_len <= 1: 207 | predict_tag = [0] 208 | output_buffer.append(get_final_output(l, predict_tag)) 209 | 210 | else: 211 | predict_tag = sess.run( 212 | model.predict_tag, 213 | feed_dict={ 214 | input_seq_feed: input_seqs_list, 215 | seq_length: seq_len 216 | }) 217 | 218 | predict_tag = predict_tag[0][:seq_len] 219 | 220 | if len(predict_tag) != len(input_label): 221 | print('predict not right') 222 | print('predict len %d' % len(predict_tag )) 223 | print('label len %d' % len(input_label)) 224 | print('text len %d' % len(input_seqs_list)) 225 | print(seq_len) 226 | raise ValueError 227 | 228 | output_buffer.append(get_final_output(l, predict_tag)) 229 | 230 | input_label = np.array(input_label) 231 | num_correct += np.sum(input_label == predict_tag) 232 | #num_correct += seq_acc(input_label, predict_tag) 233 | num_total += len(input_label) 234 | 235 | if len(output_buffer) >= 1000: 236 | append_to_file(output_buffer, out_filename) 237 | output_buffer = [] 238 | 239 | if output_buffer: 240 | append_to_file(output_buffer, out_filename) 241 | 242 | print('%s Acc: %f' % (filename, num_correct / num_total)) 243 | print('%s Correct: %d' % (filename, num_correct)) 244 | print('%s Total: %d' % (filename, num_total)) 245 | 246 | 247 | if __name__ == '__main__': 248 | tf.app.run() 249 | -------------------------------------------------------------------------------- /lstm_based_cws_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 04Mar2017 5 | """Chinese words segmentation model based on aclweb.org/anthology/D15-1141""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | 13 | import os 14 | from ops import input_ops 15 | 16 | 17 | class LSTMCWS(object): 18 | """docstring for LSTMCWS""" 19 | 20 | def __init__(self, config, mode): 21 | """ 22 | Init mode. 23 | 24 | Args: 25 | Config: configuration object 26 | mode: 'train', 'test' or 'inference' 27 | """ 28 | 29 | self.config = config 30 | self.mode = mode 31 | 32 | #Set up initializer 33 | self.initializer = tf.contrib.layers.xavier_initializer( 34 | uniform=True, seed=None, dtype=tf.float32) 35 | 36 | #Set up sequence embeddings with the shape of [batch_size, padded_length, embedding_size] 37 | self.seq_embedding = None 38 | 39 | #Set up batch losses for tracking performance with the length of batch_size * padded_length 40 | self.batch_losses = None 41 | 42 | #Set up global step tensor 43 | self.global_step = None 44 | 45 | def is_training(self): 46 | return self.mode == 'train' 47 | 48 | def is_test(self): 49 | return self.mode == 'test' 50 | 51 | def is_inf(self): 52 | return self.mode == 'inference' 53 | 54 | def build_inputs(self): 55 | """ 56 | Input prefetching, preprocessing and batching for trianing 57 | 58 | For inference mode, input seqs and input mask needs to be provided. 59 | 60 | Returns: 61 | self.input_seqs: A tensor of Input sequence to seq_lstm with the shape of [batch_size, padding_size] 62 | self.tag_seqs: A tensor of output sequence to seq_lstm with the shape of [batch_size, padding_size] 63 | self.tag_input_seq: A tensor of input sequence to tag inference model with the shape of [batch_size, padding_size -1] 64 | self.tag_output_seq: A tensor of input sequence to tag inference model with the shape of [batch_size, padding_size -1] 65 | """ 66 | 67 | if not self.is_inf(): 68 | 69 | with tf.variable_scope('train_eval_input'): 70 | #Get all TFRecord path into a list 71 | data_files = [] 72 | file_pattern = os.path.join(self.config.input_file_dir, '*.TFRecord') 73 | data_files.extend(tf.gfile.Glob(file_pattern)) 74 | 75 | 76 | data_files = [ 77 | x for x in data_files 78 | if os.path.split(x)[-1].startswith(self.mode) 79 | ] 80 | 81 | if not data_files: 82 | tf.logging.fatal("Found no input files matching %s", 83 | file_pattern) 84 | else: 85 | tf.logging.info( 86 | "Prefetching values from %d files matching %s", 87 | len(data_files), file_pattern) 88 | 89 | def _parse_wrapper(l): 90 | return input_ops.parse_example_queue(l, self.config) 91 | 92 | dataset = tf.data.TFRecordDataset(data_files).map( 93 | _parse_wrapper) 94 | if self.is_training(): 95 | dataset = dataset.shuffle( 96 | buffer_size=256).repeat(10000).shuffle(buffer_size=256) 97 | 98 | dataset = dataset.padded_batch(batch_size=self.config.batch_size, 99 | padded_shapes = (tf.TensorShape([self.config.seq_max_len]), 100 | tf.TensorShape([self.config.seq_max_len]), 101 | tf.TensorShape([]))).filter( 102 | lambda x, y, z: tf.equal(tf.shape(x)[0], self.config.batch_size) ) 103 | 104 | iterator = dataset.make_one_shot_iterator() 105 | 106 | input_seqs, tag_seqs, sequence_length = iterator.get_next() 107 | 108 | else: 109 | with tf.variable_scope('inf_input'): 110 | #Inference 111 | input_seq_feed = tf.get_default_graph().get_tensor_by_name( 112 | "input_seq_feed:0") 113 | sequence_length = tf.get_default_graph().get_tensor_by_name( 114 | "seq_length:0") 115 | input_seqs = tf.expand_dims(input_seq_feed, 0) 116 | sequence_length = tf.expand_dims(sequence_length, 0) 117 | 118 | tag_seqs = None 119 | 120 | self.input_seqs = input_seqs 121 | self.tag_seqs = tag_seqs 122 | self.sequence_length = sequence_length 123 | 124 | def build_chr_embedding(self): 125 | """ 126 | Build Chinese character embedding 127 | 128 | Returns: 129 | self.seq_embedding: A tensor with the shape of [batch_size, padding_size, embedding_size] 130 | self.tag_embedding: A tensor with the shape of [batch_size, padding_size, num_tag] 131 | """ 132 | with tf.variable_scope( 133 | 'seq_embedding', reuse=True) as seq_embedding_scope: 134 | #chr_embedding = tf.Variable(self.embedding_tensor, name="chr_embedding") 135 | if self.is_training(): 136 | chr_embedding = tf.get_variable( 137 | name="chr_embedding", validate_shape=False, trainable=False) 138 | else: 139 | chr_embedding = tf.Variable( 140 | tf.zeros([10]), validate_shape=False, name="chr_embedding") 141 | 142 | seq_embedding = tf.nn.embedding_lookup(chr_embedding, 143 | self.input_seqs) 144 | if self.is_training(): 145 | tag_embedding = tf.one_hot(self.tag_seqs, self.config.num_tag) 146 | else: 147 | tag_embedding = None 148 | 149 | self.seq_embedding = seq_embedding 150 | self.tag_embedding = tag_embedding 151 | 152 | def build_lstm_model(self): 153 | """ 154 | Build model. 155 | 156 | Returns: 157 | self.logit: A tensor containing the probability of prediction with the shape of [batch_size, padding_size, num_tag] 158 | """ 159 | 160 | #Setup LSTM Cell 161 | fw_lstm_cell = tf.contrib.rnn.BasicLSTMCell( 162 | num_units=self.config.num_lstm_units, state_is_tuple=True) 163 | bw_lstm_cell = tf.contrib.rnn.BasicLSTMCell( 164 | num_units=self.config.num_lstm_units, state_is_tuple=True) 165 | 166 | #Dropout when training 167 | if self.is_training(): 168 | fw_lstm_cell = tf.contrib.rnn.DropoutWrapper( 169 | fw_lstm_cell, 170 | input_keep_prob=self.config.lstm_dropout_keep_prob, 171 | output_keep_prob=self.config.lstm_dropout_keep_prob) 172 | bw_lstm_cell = tf.contrib.rnn.DropoutWrapper( 173 | bw_lstm_cell, 174 | input_keep_prob=self.config.lstm_dropout_keep_prob, 175 | output_keep_prob=self.config.lstm_dropout_keep_prob) 176 | 177 | self.seq_embedding.set_shape([None, None, self.config.embedding_size]) 178 | 179 | with tf.variable_scope('seq_lstm') as lstm_scope: 180 | 181 | #Run LSTM with sequence_length timesteps 182 | bi_output, _ = tf.nn.bidirectional_dynamic_rnn( 183 | cell_fw=fw_lstm_cell, 184 | cell_bw=bw_lstm_cell, 185 | inputs=self.seq_embedding, 186 | sequence_length=self.sequence_length, 187 | dtype=tf.float32, 188 | scope=lstm_scope) 189 | fw_out, bw_out = bi_output 190 | lstm_output = tf.concat([fw_out, bw_out], 2) 191 | 192 | self.lstm_output = lstm_output 193 | 194 | def build_sentence_score_loss(self): 195 | """ 196 | Use CRF log likelihood to get sentence score and loss 197 | """ 198 | #Fully connected layer to get logit 199 | with tf.variable_scope('logit') as logit_scope: 200 | logit = tf.contrib.layers.fully_connected( 201 | inputs=self.lstm_output, 202 | num_outputs=self.config.num_tag, 203 | activation_fn=None, 204 | weights_initializer=self.initializer, 205 | scope=logit_scope) 206 | self.logit = logit 207 | 208 | if self.is_inf(): 209 | with tf.variable_scope('tag_inf') as tag_scope: 210 | transition_param = tf.get_variable( 211 | 'transitions', 212 | shape=[self.config.num_tag, self.config.num_tag]) 213 | 214 | self.predict_tag, _ = tf.contrib.crf.crf_decode( 215 | logit, transition_param, self.sequence_length) 216 | 217 | else: 218 | with tf.variable_scope('tag_inf') as tag_scope: 219 | sentence_likelihood, transition_param = tf.contrib.crf.crf_log_likelihood( 220 | inputs=logit, 221 | tag_indices=tf.to_int32(self.tag_seqs), 222 | sequence_lengths=self.sequence_length) 223 | 224 | self.predict_tag, _ = tf.contrib.crf.crf_decode( 225 | logit, transition_param, self.sequence_length) 226 | 227 | with tf.variable_scope('loss'): 228 | batch_loss = tf.reduce_mean(-sentence_likelihood) 229 | 230 | # if self.is_inf(): 231 | # prob = tf.nn.softmax(logit) 232 | # self.predict_tag = tf.squeeze(tf.nn.top_k(prob, k=1)[1]) 233 | 234 | # else: 235 | # with tf.variable_scope('loss'): 236 | # prob = tf.nn.softmax(logit) 237 | # self.predict_tag = tf.squeeze(tf.nn.top_k(prob, k=1)[1]) 238 | 239 | # batch_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit, labels=self.tag_seqs) 240 | # batch_loss = tf.reduce_mean(batch_loss) 241 | 242 | #Add to total loss 243 | tf.losses.add_loss(batch_loss) 244 | 245 | #Get total loss 246 | total_loss = tf.losses.get_total_loss() 247 | 248 | tf.summary.scalar('batch_loss', batch_loss) 249 | tf.summary.scalar('total_loss', total_loss) 250 | 251 | with tf.variable_scope('accuracy'): 252 | 253 | seq_len = tf.cast( 254 | tf.reduce_sum(self.sequence_length), tf.float32) 255 | padded_len = tf.cast( 256 | tf.reduce_sum( 257 | self.config.batch_size * self.config.seq_max_len), 258 | tf.float32) 259 | 260 | # Calculate acc 261 | correct = tf.cast( 262 | tf.equal(self.predict_tag, tf.cast(self.tag_seqs, 263 | tf.int32)), tf.float32) 264 | correct = tf.reduce_sum(correct) - padded_len + seq_len 265 | 266 | self.accuracy = correct / seq_len 267 | 268 | if self.is_test(): 269 | 270 | tf.summary.scalar('eval_accuracy', self.accuracy) 271 | else: 272 | tf.summary.scalar('average_len', 273 | tf.reduce_mean(self.sequence_length)) 274 | tf.summary.scalar('train_accuracy', self.accuracy) 275 | 276 | #Output loss 277 | self.batch_loss = batch_loss 278 | self.total_loss = total_loss 279 | 280 | def setup_global_step(self): 281 | """Sets up the global step Tensor.""" 282 | global_step = tf.Variable( 283 | initial_value=0, 284 | name="global_step", 285 | trainable=False, 286 | collections=[ 287 | tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES 288 | ]) 289 | 290 | self.global_step = global_step 291 | 292 | def build(self): 293 | """Create all ops for model""" 294 | self.build_inputs() 295 | self.build_chr_embedding() 296 | self.build_lstm_model() 297 | self.build_sentence_score_loss() 298 | self.setup_global_step() 299 | -------------------------------------------------------------------------------- /ops/input_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 04Mar2017 5 | """Batching, padding and masking the input sequence and output sequence""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | import sys 13 | 14 | from hanziconv.hanziconv import HanziConv 15 | 16 | 17 | def parse_example_queue(example_queue, config): 18 | """ Read one example. 19 | This function read one example and return context sequence and tag sequence 20 | correspondingly. 21 | 22 | Args: 23 | filename_queue: A filename queue returned by string_input_producer 24 | context_feature_name: Context feature name in TFRecord. Set in ModelConfig 25 | tag_feature_name: Tag feature name in TFRecord. Set in ModelConfig 26 | 27 | Returns: 28 | input_seq: An int32 Tensor with different length. 29 | tag_seq: An int32 Tensor with different length. 30 | """ 31 | 32 | #Parse one example 33 | context, features = tf.parse_single_sequence_example( 34 | example_queue, 35 | context_features={ 36 | config.length_name: tf.FixedLenFeature([], dtype=tf.int64) 37 | }, 38 | sequence_features={ 39 | config.context_feature_name: 40 | tf.FixedLenSequenceFeature([], dtype=tf.int64), 41 | config.tag_feature_name: 42 | tf.FixedLenSequenceFeature([], dtype=tf.int64) 43 | }) 44 | 45 | return (features[config.context_feature_name], 46 | features[config.tag_feature_name], context[config.length_name]) 47 | 48 | 49 | def example_queue_shuffle(reader, 50 | filename_queue, 51 | is_training, 52 | example_queue_name='example_queue', 53 | capacity=50000, 54 | num_reader_threads=1): 55 | """ 56 | This function shuffle the examples within the filename queues. Since there's no 57 | padding option in shuffle_batch, we have to manually shuffle the example queue. 58 | 59 | The process is given as below. 60 | create filename queue >> read examples from filename queue >> enqueue example to example queue(RandomShuffleQueue) 61 | 62 | However, this is not totally random shuffle since the memory limiation. Therefore, 63 | we need to specify a capacity of the example queue. 64 | 65 | Args: 66 | reader: A TFRecord Reader 67 | filename_queue: A queue generated by string_input_producer 68 | is_traning: If not training then use FIFOqueue(No need to shuffle). 69 | example_queue_name: Name of the example queue 70 | capacity: Value queue capacity. Should be large enough for better mixing 71 | num_reader_threads: Number of thread to enqueue the value queue 72 | 73 | Returns: 74 | example_queue: An example queue that is shuffled. Ready for parsing and batching. 75 | """ 76 | 77 | #Init queue 78 | if is_training: 79 | example_queue = tf.RandomShuffleQueue( 80 | capacity=capacity, 81 | min_after_dequeue=capacity % 2, 82 | dtypes=[tf.string], 83 | name="random_" + example_queue_name) 84 | else: 85 | example_queue = tf.FIFOQueue( 86 | capacity=capacity, 87 | dtypes=[tf.string], 88 | name="fifo_" + example_queue_name) 89 | 90 | #Manually create ops to enqueue 91 | enqueue_example_ops = [] 92 | for _ in range(num_reader_threads): 93 | _, example = reader.read(filename_queue) 94 | enqueue_example_ops.append(example_queue.enqueue([example])) 95 | 96 | #Add queue runner 97 | tf.train.queue_runner.add_queue_runner( 98 | tf.train.queue_runner.QueueRunner(example_queue, enqueue_example_ops)) 99 | tf.summary.scalar( 100 | "queue/%s/fraction_of_%d_full" % (example_queue.name, capacity), 101 | tf.cast(example_queue.size(), tf.float32) * (1. / capacity)) 102 | 103 | return example_queue 104 | 105 | 106 | def process_line_msr_pku(l): 107 | decoded_line = l.decode('utf8').strip().split(' ') 108 | return [w.strip('\r\n') for w in decoded_line] 109 | 110 | 111 | def process_line_as_training(l): 112 | if sys.version_info >= (3, 0): 113 | decoded_line = HanziConv.toSimplified( 114 | l.decode('utf8')).strip().split('\u3000') 115 | else: 116 | decoded_line = HanziConv.toSimplified( 117 | l.decode('utf8')).strip().split(u'\u3000') 118 | return [w.strip('\r\n') for w in decoded_line] 119 | 120 | 121 | def process_line_cityu(l): 122 | decoded_line = HanziConv.toSimplified(l.decode('utf8')).strip().split(' ') 123 | return [w.strip('\r\n') for w in decoded_line] 124 | 125 | 126 | def get_process_fn(filename): 127 | 128 | if 'msr' in filename or 'pk' in filename: 129 | return process_line_msr_pku 130 | 131 | elif 'as' in filename: 132 | return process_line_as_training 133 | 134 | elif 'cityu' in filename: 135 | return process_line_cityu 136 | -------------------------------------------------------------------------------- /ops/vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 04Mar2017 5 | """Vocab class""" 6 | 7 | 8 | class Vocabulary(object): 9 | """Simple vocabulary wrapper.""" 10 | 11 | def __init__(self, vocab, id_vocab, unk_id, unk_word=''): 12 | """Initializes the vocabulary. 13 | 14 | Args: 15 | vocab: A dictionary of word to word_id. 16 | unk_id: Id of the special 'unknown' word. 17 | """ 18 | self._vocab = vocab 19 | self._id_vocab = id_vocab 20 | self._unk_id = unk_id 21 | self._vocab[unk_word] = len(self._vocab) 22 | self._id_vocab[len(self._vocab)] = unk_word 23 | 24 | def word_to_id(self, word): 25 | """Returns the integer id of a word string.""" 26 | if word in self._vocab: 27 | return self._vocab[word] 28 | else: 29 | return self._unk_id 30 | 31 | def id_to_word(self, word_id): 32 | """Returns the word string of an integer word id.""" 33 | if word_id >= len(self._vocab): 34 | return self._id_vocab[self.unk_id] 35 | else: 36 | return self._id_vocab[word_id] 37 | -------------------------------------------------------------------------------- /process_chr_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 05Mar2017 5 | """Download and process the Chinese character embedding table""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import urllib.request 12 | import os 13 | import pickle 14 | import numpy as np 15 | 16 | import tensorflow as tf 17 | import configuration 18 | 19 | FLAGS = tf.app.flags.FLAGS 20 | 21 | tf.flags.DEFINE_string("chr_embedding_dir", 'polyglot-zh_char.pkl', 22 | "Path to polyglot embedding file") 23 | tf.flags.DEFINE_string("vocab_dir", "data/vocab.pkl", 24 | "Path of vocabulary file.") 25 | 26 | 27 | class Vocabulary(object): 28 | """Simple vocabulary wrapper.""" 29 | 30 | def __init__(self, vocab, unk_id, unk_word=''): 31 | """Initializes the vocabulary. 32 | 33 | Args: 34 | vocab: A dictionary of word to word_id. 35 | unk_id: Id of the special 'unknown' word. 36 | """ 37 | self._vocab = vocab 38 | self._unk_id = unk_id 39 | self._vocab[unk_word] = 0 40 | 41 | def word_to_id(self, word): 42 | """Returns the integer id of a word string.""" 43 | if word in self._vocab: 44 | return self._vocab[word] 45 | else: 46 | return self._unk_id 47 | 48 | def id_to_word(self, word_id): 49 | """Returns the word string of an integer word id.""" 50 | if word_id >= len(self._vocab): 51 | return self._vocab[self.unk_id] 52 | else: 53 | return self._vocab[word_id] 54 | 55 | 56 | def download_embedding(): 57 | """ 58 | Download files from web 59 | Seems cannot download by pgm 60 | Download from: https://sites.google.com/site/rmyeid/projects/polyglot 61 | 62 | Returns: 63 | A tuple (word, embedding). Emebddings shape is (100004, 64). 64 | """ 65 | 66 | assert (tf.gfile.Exists(FLAGS.chr_embedding_dir)), ( 67 | "Embedding pkl don't found, please \ 68 | download the Chinese chr embedding from https://sites.google.com/site/rmyeid/projects/polyglot" 69 | ) 70 | 71 | with open(FLAGS.chr_embedding_dir, 'rb') as f: 72 | u = pickle._Unpickler(f) 73 | u.encoding = 'latin1' 74 | p = u.load() 75 | 76 | return p 77 | 78 | 79 | def process_embedding(vocab, original_embedding, config): 80 | """ 81 | This function will process the embedding. The embedding table will be organized with 82 | the same order as the word_count. Any unknown features will be abandomed. 83 | 84 | Args: 85 | vocab: Vocabulary obj generated by build input 86 | original_embedding: A tuple (word, embedding). Emebddings shape is (100004, 64). 87 | 88 | Returns: 89 | embedding_table: A numpy 2d array. Will be feed to embedding_placeholder when graph execution 90 | """ 91 | 92 | #Init 2d numpy array 93 | embedding_table = np.zeros((len(vocab._vocab), config.embedding_size)) 94 | 95 | word, embedding = original_embedding 96 | 97 | for i, w in enumerate(word): 98 | embedding_table[vocab.word_to_id(w), :] = embedding[i, :] 99 | 100 | #Manually set the last row of embedding(unknown chr) 101 | embedding_table[0, :] = embedding[0, :] 102 | 103 | return embedding_table 104 | 105 | 106 | def main(unused_argv): 107 | 108 | #Load configuration 109 | model_config = configuration.ModelConfig() 110 | 111 | #Load vocabulary object 112 | vocab = pickle.load(open(FLAGS.vocab_dir, 'rb')) 113 | 114 | original_embedding = download_embedding() 115 | 116 | chr_embedding = process_embedding(vocab, original_embedding, model_config) 117 | 118 | pickle.dump(chr_embedding, open('chr_embedding.pkl', 'wb')) 119 | 120 | 121 | if __name__ == '__main__': 122 | tf.app.run() 123 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #Author: Jay Yip 4 | #Date 04Mar2017 5 | """Train the model""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | 13 | import configuration 14 | from lstm_based_cws_model import LSTMCWS 15 | from ops.vocab import Vocabulary 16 | 17 | import pickle 18 | 19 | FLAGS = tf.app.flags.FLAGS 20 | 21 | tf.flags.DEFINE_string("input_file_dir", "data", 22 | "Path of TFRecord input files.") 23 | tf.flags.DEFINE_string("train_dir", "save_model", 24 | "Directory for saving and loading model checkpoints.") 25 | tf.flags.DEFINE_integer("log_every_n_steps", 5000, 26 | "Frequency at which loss and global step are logged.") 27 | tf.flags.DEFINE_string("log_dir", "log", "Path of summary") 28 | 29 | tf.logging.set_verbosity(tf.logging.INFO) 30 | 31 | 32 | def main(unused_argv): 33 | 34 | assert FLAGS.input_file_dir, "--input_file_dir is required" 35 | assert FLAGS.train_dir, "--train_dir is required" 36 | 37 | #Load configuration 38 | model_config = configuration.ModelConfig() 39 | train_config = configuration.TrainingConfig() 40 | model_config.train_dir = FLAGS.train_dir 41 | model_config.input_file_dir = FLAGS.input_file_dir 42 | 43 | #Create train dir 44 | train_dir = FLAGS.train_dir 45 | if not tf.gfile.IsDirectory(train_dir): 46 | tf.logging.info('Create Training dir as %s', train_dir) 47 | tf.gfile.MakeDirs(train_dir) 48 | 49 | #Load chr emdedding table 50 | if train_config.embedding_random: 51 | shape = [ 52 | len(pickle.load(open('data/vocab.pkl', 'rb'))._vocab), 53 | model_config.embedding_size 54 | ] 55 | else: 56 | chr_embedding = pickle.load(open('chr_embedding.pkl', 'rb')) 57 | shape = chr_embedding.shape 58 | 59 | #Build graph 60 | g = tf.Graph() 61 | with g.as_default(): 62 | #Set embedding table 63 | with tf.variable_scope('seq_embedding') as seq_embedding_scope: 64 | chr_embedding_var = tf.get_variable( 65 | name='chr_embedding', 66 | shape=(shape[0], shape[1]), 67 | trainable=True, 68 | initializer=tf.initializers.orthogonal(-0.1, 0.1)) 69 | if not train_config.embedding_random: 70 | embedding = tf.convert_to_tensor( 71 | chr_embedding, dtype=tf.float32) 72 | embedding_assign_op = chr_embedding_var.assign(chr_embedding) 73 | 74 | #Build model 75 | model = LSTMCWS(model_config, 'train') 76 | print('Building model...') 77 | model.build() 78 | 79 | # merged = tf.summary.merge_all() 80 | # train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', 81 | # g) 82 | 83 | #Set up learning rate and learning rate decay function 84 | learning_rate_decay_fn = None 85 | learning_rate = tf.constant(train_config.initial_learning_rate) 86 | if train_config.learning_rate_decay_factor > 0: 87 | num_batches_per_epoch = ( 88 | train_config.num_examples_per_epoch / model_config.batch_size) 89 | decay_steps = int( 90 | num_batches_per_epoch * train_config.num_epochs_per_decay) 91 | 92 | def _learning_rate_decay_fn(learning_rate, global_step): 93 | return tf.train.exponential_decay( 94 | learning_rate, 95 | global_step, 96 | decay_steps=decay_steps, 97 | decay_rate=train_config.learning_rate_decay_factor, 98 | staircase=True) 99 | 100 | learning_rate_decay_fn = _learning_rate_decay_fn 101 | 102 | print('Setting up training ops...') 103 | #Set up training op 104 | train_op = tf.contrib.layers.optimize_loss( 105 | loss=model.batch_loss, 106 | global_step=model.global_step, 107 | learning_rate=learning_rate, 108 | optimizer=train_config.optimizer, 109 | clip_gradients=train_config.clip_gradients, 110 | learning_rate_decay_fn=learning_rate_decay_fn, 111 | name='train_op') 112 | 113 | #Set up saver 114 | saver = tf.train.Saver(max_to_keep=train_config.max_checkpoints_to_keep) 115 | 116 | gpu_options = tf.GPUOptions( 117 | visible_device_list=",".join(map(str, [0])), 118 | per_process_gpu_memory_fraction=0.33) 119 | 120 | sess_config = tf.ConfigProto(gpu_options=gpu_options) 121 | 122 | print('Start Training...') 123 | # Run training. 124 | tf.contrib.slim.learning.train( 125 | train_op, 126 | train_dir, 127 | log_every_n_steps=FLAGS.log_every_n_steps, 128 | graph=g, 129 | global_step=model.global_step, 130 | number_of_steps=train_config.training_step, 131 | saver=saver, 132 | save_summaries_secs=30, 133 | session_config=sess_config) 134 | 135 | 136 | if __name__ == '__main__': 137 | tf.app.run() 138 | -------------------------------------------------------------------------------- /word_count: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JayYip/cws-tensorflow/dd6495cddc1fed99dae837b51daa056f1f281218/word_count --------------------------------------------------------------------------------