├── tools ├── __init__.py ├── total_size.py ├── word_count.py ├── make_dicts.py ├── convert_to_h5.py ├── predict.py ├── score_preprocess.py ├── score.py └── ner_data_preprocess.py ├── dl_segmenter ├── custom │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── layers.cpython-36.pyc │ │ └── callbacks.cpython-36.pyc │ └── callbacks.py ├── __pycache__ │ ├── core.cpython-36.pyc │ ├── utils.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── data_loader.cpython-36.pyc ├── __init__.py ├── utils.py ├── data_loader.py └── core.py ├── assets ├── loss.png ├── Bi-LSTM.png └── accuracy.png ├── setup.py ├── examples ├── decode_example.py └── train_example.py ├── .gitignore ├── train_example.py ├── README.md └── LICENSE.txt /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dl_segmenter/custom/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/assets/loss.png -------------------------------------------------------------------------------- /assets/Bi-LSTM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/assets/Bi-LSTM.png -------------------------------------------------------------------------------- /assets/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/assets/accuracy.png -------------------------------------------------------------------------------- /dl_segmenter/__pycache__/core.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/__pycache__/core.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/custom/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/custom/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/custom/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/custom/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/custom/__pycache__/callbacks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GlassyWing/bi-lstm-crf/HEAD/dl_segmenter/custom/__pycache__/callbacks.cpython-36.pyc -------------------------------------------------------------------------------- /dl_segmenter/__init__.py: -------------------------------------------------------------------------------- 1 | from dl_segmenter.core import DLSegmenter 2 | import json 3 | 4 | get_or_create = DLSegmenter.get_or_create 5 | 6 | 7 | def save_config(obj, config_path, encoding="utf-8"): 8 | with open(config_path, mode="w+", encoding=encoding) as file: 9 | json.dump(obj.get_config(), file) 10 | -------------------------------------------------------------------------------- /tools/total_size.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if __name__ == '__main__': 4 | count = 0 5 | for root, dirs, files in os.walk("../data/2014/valid"): 6 | for name in files: 7 | file = os.path.join(root, name) 8 | with open(file, encoding='utf-8') as f: 9 | count += len(f.readlines()) 10 | 11 | print(count) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='dl_segmenter', 4 | 5 | version='0.1-SNAPSHOT', 6 | 7 | url='https://github.com/GlassyWing/bi-lstm-crf', 8 | 9 | license='Apache License 2.0', 10 | 11 | author='Manlier', 12 | 13 | author_email='dengjiaxim@gmail.com', 14 | 15 | description='inset pest predict model', 16 | 17 | packages=find_packages(exclude=['tests', 'examples']), 18 | 19 | package_data={'dl_segmenter': ['*.*', 'checkpoints/*', 'config/*']}, 20 | 21 | long_description=open('README.md', encoding="utf-8").read(), 22 | 23 | zip_safe=False, 24 | 25 | install_requires=['keras', 'keras-contrib'], 26 | 27 | ) 28 | -------------------------------------------------------------------------------- /tools/word_count.py: -------------------------------------------------------------------------------- 1 | from rx import Observable 2 | 3 | 4 | def word_counts(lines): 5 | word_count = {} 6 | for line in lines: 7 | words = line.split() 8 | for word in words: 9 | if word_count.get(word) is None: 10 | word_count[word] = 1 11 | else: 12 | word_count[word] += 1 13 | return word_count 14 | 15 | 16 | def save_to_file(source_file, target_file): 17 | with open(source_file, "r", encoding="UTF-8") as f: 18 | lines = f.readlines() 19 | 20 | word_count = word_counts(lines) 21 | with open(target_file, "a", encoding="UTF-8") as f: 22 | for w, c in word_count.items(): 23 | f.write("{} {}\n".format(w, c)) 24 | f.flush() 25 | 26 | 27 | save_to_file("./score/gold_full.utf8", "./score/jieba.dict") 28 | -------------------------------------------------------------------------------- /tools/make_dicts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from dl_segmenter.utils import make_dictionaries 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description="生成字典。") 7 | parser.add_argument("file_path", type=str, help="用于生成字典的标注文件") 8 | parser.add_argument("-s", "--src_dict_path", type=str, help="源字典保存路径") 9 | parser.add_argument("-t", "--tgt_dict_path", type=str, help="目标字典保存路径") 10 | parser.add_argument("--min_freq", type=int, default=1, help="词频数阈值,小于该阈值的词将被忽略") 11 | 12 | args = parser.parse_args() 13 | 14 | make_dictionaries(args.file_path, 15 | src_dict_path=args.src_dict_path, 16 | tgt_dict_path=args.tgt_dict_path, 17 | filters="\t\n", 18 | oov_token="", 19 | min_freq=args.min_freq) 20 | -------------------------------------------------------------------------------- /tools/convert_to_h5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from dl_segmenter.data_loader import DataLoader 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description="转换为hdf5格式") 7 | parser.add_argument("txt_path", type=str, help="BIS标注文本文件路径") 8 | parser.add_argument("h5_path", type=str, help="转换后的hdf5文件保存路径") 9 | parser.add_argument("-s", "--src_dict_path", type=str, help="源字典保存路径", required=True) 10 | parser.add_argument("-t", "--tgt_dict_path", type=str, help="目标字典保存路径", required=True) 11 | parser.add_argument("--seq_len", help="语句长度", default=150, type=int) 12 | 13 | args = parser.parse_args() 14 | 15 | data_loader = DataLoader(args.src_dict_path, args.tgt_dict_path, 16 | batch_size=1, 17 | max_len=args.seq_len, 18 | sparse_target=False) 19 | 20 | data_loader.load_and_dump_to_h5(args.txt_path, args.h5_path, encoding='utf-8') 21 | -------------------------------------------------------------------------------- /tools/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from dl_segmenter import get_or_create 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description="执行命令行分词") 7 | parser.add_argument("-s", "--text", help="要进行分割的语句") 8 | parser.add_argument("-f", "--file", help="要进行分割的文件。", default="../data/restore.utf8") 9 | parser.add_argument("-o", "--out_file", help="分割完成后输出的文件。", default="../data/pred_text.utf8") 10 | 11 | args = parser.parse_args() 12 | 13 | tokenizer = get_or_create("../data/default-config.json", 14 | src_dict_path="../data/src_dict.json", 15 | tgt_dict_path="../data/tgt_dict.json", 16 | weights_path="../models/weights.32--0.18.h5") 17 | 18 | text = args.text 19 | file = args.file 20 | out_file = args.out_file 21 | 22 | texts = [] 23 | if text is not None: 24 | texts = text.split(' ') 25 | results = tokenizer.decode_texts(texts) 26 | print(results) 27 | 28 | elif file is not None: 29 | with open(file, encoding='utf-8') as f: 30 | texts = list(map(lambda x: x[0:-1], f.readlines())) 31 | 32 | if out_file is not None: 33 | with open(out_file, mode="w+", encoding="utf-8") as f: 34 | for text in texts: 35 | seq, tag = tokenizer.decode_texts([text])[0] 36 | f.write(' '.join(seq) + '\n') 37 | 38 | -------------------------------------------------------------------------------- /examples/decode_example.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from dl_segmenter import get_or_create, DLSegmenter 4 | 5 | if __name__ == '__main__': 6 | LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" 7 | logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) 8 | segmenter: DLSegmenter = get_or_create("../data/default-config.json", 9 | src_dict_path="../data/src_dict.json", 10 | tgt_dict_path="../data/tgt_dict.json", 11 | weights_path="../models/weights.32--0.18.h5") 12 | 13 | texts = [ 14 | "昨晚,英国首相特里萨•梅(TheresaMay)试图挽救其退欧协议的努力,在布鲁塞尔遭遇了严重麻烦。" 15 | "倍感失望的欧盟领导人们指责她没有拿出可行的提案来向充满敌意的英国议会兜售她的退欧计划。" 16 | , 17 | "物理仿真引擎的作用,是让虚拟世界中的物体运动符合真实世界的物理定律,经常用于游戏领域,以便让画面看起来更富有真实感。" 18 | "PhysX是由英伟达提出的物理仿真引擎,其物理模拟计算由专门加速芯片GPU来进行处理," 19 | "在节省CPU负担的同时还能将物理运算效能成倍提升,由此带来更加符合真实世界的物理效果。" 20 | , 21 | "好莱坞女演员奥黛丽·赫本(AudreyHepburn)被称为“坠入人间的天使”," 22 | "主演了《蒂凡尼的早餐》《龙凤配》《罗马假日》等经典影片,并以《罗马假日》获封奥斯卡影后。" 23 | "据外媒报道,奥黛丽·赫本的故事将被拍成一部剧集。" 24 | , 25 | "巴纳德星的名字起源于一百多年前一位名叫爱德华·爱默生·巴纳德的天文学家。" 26 | "他发现有一颗星在夜空中划过的速度很快,这引起了他极大的注意。" 27 | , 28 | "叶依姆的家位于仓山区池后弄6号,属于烟台山历史风貌区," 29 | "一家三代五口人挤在五六十平方米土木结构的公房里,屋顶逢雨必漏,居住环境不好。" 30 | "2013年11月,烟台山历史风貌区地块房屋征收工作启动,叶依姆的梦想正逐渐变为现实。" 31 | , 32 | "人民网北京1月2日电据中央纪委监察部网站消息,日前,经中共中央批准," 33 | "中共中央纪委对湖南省政协原副主席童名谦严重违纪违法问题进行了立案检查。" 34 | ] 35 | 36 | for _ in range(1): 37 | start_time = time.time() 38 | for sent, tag in segmenter.decode_texts(texts): 39 | print(sent) 40 | print(tag) 41 | # for s, t in zip(sent, tag): 42 | # print(s, t) 43 | print(f"cost {(time.time() - start_time) * 1000}ms") 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | models/ 4 | tests/ 5 | logs/ 6 | config/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | -------------------------------------------------------------------------------- /tools/score_preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | 6 | def process_file(file): 7 | with open(file, 'r', encoding='UTF-8') as f: 8 | text = f.readlines() 9 | bises = _parse_text(text) 10 | return bises 11 | 12 | 13 | def _parse_text(text: list): 14 | bises = [] 15 | for line in text: 16 | line, _ = re.subn('\n', '', line) 17 | if line == '' or line == '\n': 18 | continue 19 | bises.append(_tag(line)) 20 | return bises 21 | 22 | 23 | def _tag(line): 24 | """ 25 | 给指定的一行文本打上BIS标签 26 | :param line: 文本行 27 | :return: 28 | """ 29 | bis = [] 30 | words = re.split('\s+', line) 31 | pre_word = None 32 | pos_t = None 33 | for word in words: 34 | tokens = word.split('/') 35 | 36 | if len(tokens) == 2: 37 | word, pos = tokens 38 | elif len(tokens) == 3: 39 | word, pos_t, pos = tokens 40 | else: 41 | continue 42 | 43 | if len(word) == 0 or word.strip() == '': 44 | continue 45 | 46 | if word[0] == '[': 47 | pre_word = word 48 | continue 49 | if pre_word is not None: 50 | pre_word += word 51 | if pos_t is None: 52 | continue 53 | elif pos_t[-1] != ']': 54 | pos_t = None 55 | continue 56 | else: 57 | word = pre_word[1:] 58 | pre_word = None 59 | pos_t = None 60 | bis.append((word, pos)) 61 | 62 | return bis 63 | 64 | 65 | def remove_pos(source_dir, target_path): 66 | for root, dirs, files in os.walk(source_dir): 67 | for name in files: 68 | file = os.path.join(root, name) 69 | bises = process_file(file) 70 | 71 | with open(target_path, encoding="utf-8", mode="a") as f: 72 | for bis in bises: 73 | sent, tags = [], [] 74 | for char, tag in bis: 75 | sent.append(char) 76 | tags.append(tag) 77 | sent = ' '.join(sent) 78 | f.write(sent + "\n") 79 | 80 | 81 | def restore(source_dir, target_path): 82 | for root, dirs, files in os.walk(source_dir): 83 | for name in files: 84 | file = os.path.join(root, name) 85 | bises = process_file(file) 86 | with open(target_path, encoding="utf-8", mode="a") as f: 87 | for bis in bises: 88 | sent, tags = [], [] 89 | for char, tag in bis: 90 | sent.append(char) 91 | tags.append(tag) 92 | sent = ''.join(sent) 93 | f.write(sent + "\n") 94 | 95 | 96 | if __name__ == '__main__': 97 | parse = argparse.ArgumentParser(description="根据指定的语料生成黄金标准文件与其相应的无分词标记的原始文件") 98 | parse.add_argument("--corups_dir", help="语料文件夹", default="../data/2014/") 99 | parse.add_argument("--gold_file_path", help="生成的黄金标准文件路径", default="../data/gold.utf8") 100 | parse.add_argument("--restore_file_path", help="生成无标记的原始文件路径", default="../data/restore.utf8") 101 | 102 | args = parse.parse_args() 103 | corups_dir = args.corups_dir 104 | gold_file_path = args.gold_file_path 105 | restore_file_path = args.restore_file_path 106 | 107 | print("Processing...") 108 | remove_pos(corups_dir, gold_file_path) 109 | restore(corups_dir, restore_file_path) 110 | print("Process done.") 111 | -------------------------------------------------------------------------------- /examples/train_example.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import ModelCheckpoint, TensorBoard 2 | from keras.optimizers import Adam 3 | 4 | from dl_segmenter import get_or_create, save_config 5 | from dl_segmenter.custom.callbacks import LRFinder, SGDRScheduler, WatchScheduler 6 | from dl_segmenter.data_loader import DataLoader 7 | 8 | if __name__ == '__main__': 9 | h5_dataset_path = "../data/2014_processed.h5" # 转换为hdf5格式的数据集 10 | config_save_path = "../data/default-config.json" # 模型配置路径 11 | weights_save_path = "../models/weights.{epoch:02d}-{val_loss:.2f}.h5" # 模型权重保存路径 12 | init_weights_path = "../models/weights.23-0.02.sgdr.h5" # 预训练模型权重文件路径 13 | embedding_file_path = "G:\data\word-vectors\word.embdding.iter5" # 词向量文件路径,若不使用设为None 14 | embedding_file_path = None # 词向量文件路径,若不使用设为None 15 | 16 | src_dict_path = "../data/src_dict.json" # 源字典路径 17 | tgt_dict_path = "../data/tgt_dict.json" # 目标字典路径 18 | batch_size = 32 19 | epochs = 32 20 | 21 | import os 22 | 23 | # GPU 下用于选择训练的GPU 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 25 | 26 | data_loader = DataLoader(src_dict_path=src_dict_path, 27 | tgt_dict_path=tgt_dict_path, 28 | batch_size=batch_size) 29 | 30 | # steps_per_epoch = 415030 // data_loader.batch_size 31 | # validation_steps = 20379 // data_loader.batch_size 32 | 33 | steps_per_epoch = 2000 34 | validation_steps = 20 35 | 36 | config = { 37 | "vocab_size": data_loader.src_vocab_size, 38 | "chunk_size": data_loader.tgt_vocab_size, 39 | "embed_dim": 300, 40 | "bi_lstm_units": 256, 41 | "max_num_words": 20000, 42 | "dropout_rate": 0.1 43 | } 44 | 45 | tokenizer = get_or_create(config, 46 | optimizer=Adam(), 47 | embedding_file=embedding_file_path, 48 | src_dict_path=src_dict_path, 49 | weights_path=init_weights_path) 50 | 51 | save_config(tokenizer, config_save_path) 52 | 53 | # tokenizer.model.summary() 54 | 55 | ck = ModelCheckpoint(weights_save_path, 56 | save_best_only=True, 57 | save_weights_only=True, 58 | monitor='val_loss', 59 | verbose=0) 60 | log = TensorBoard(log_dir='../logs', 61 | histogram_freq=0, 62 | batch_size=data_loader.batch_size, 63 | write_graph=True, 64 | write_grads=False) 65 | 66 | # Use LRFinder to find effective learning rate 67 | lr_finder = LRFinder(1e-6, 1e-2, steps_per_epoch, epochs=1) # => (2e-4, 3e-4) 68 | lr_scheduler = WatchScheduler(lambda _, lr: lr / 2, min_lr=2e-4, max_lr=4e-4, watch="val_loss", watch_his_len=2) 69 | lr_scheduler = SGDRScheduler(min_lr=4e-5, max_lr=1e-3, steps_per_epoch=steps_per_epoch, 70 | cycle_length=15, 71 | lr_decay=0.9, 72 | mult_factor=1.2) 73 | 74 | X_train, Y_train, X_valid, Y_valid = DataLoader.load_data(h5_dataset_path, frac=0.8) 75 | 76 | tokenizer.model.fit_generator(data_loader.generator_from_data(X_train, Y_train), 77 | epochs=1, 78 | steps_per_epoch=steps_per_epoch, 79 | validation_data=data_loader.generator_from_data(X_valid, Y_valid), 80 | validation_steps=validation_steps, 81 | callbacks=[ck, log, lr_finder]) 82 | 83 | lr_finder.plot_loss() 84 | -------------------------------------------------------------------------------- /dl_segmenter/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from keras_preprocessing.text import Tokenizer 5 | from keras_preprocessing.text import tokenizer_from_json 6 | 7 | 8 | def _parse_data(fh, word_delimiter=' ', sent_delimiter='\t'): 9 | text = fh.readlines() 10 | sent, chunk = [], [], 11 | for line in text: 12 | line = line[0:-1] 13 | chars, tags = line.split(sent_delimiter) 14 | sent.append(chars.split(word_delimiter)) 15 | chunk.append(tags.split(word_delimiter)) 16 | return sent, chunk 17 | 18 | 19 | def _parse_data_from_dir(file_dir, word_delimiter=' ', sent_delimiter='\t'): 20 | all_sent, all_chunk = [], [] 21 | for root, dirs, files in os.walk(file_dir): 22 | for name in files: 23 | file = os.path.join(root, name) 24 | sent, chunk = _parse_data(open(file, encoding="utf-8"), word_delimiter, sent_delimiter) 25 | all_sent.extend(sent) 26 | all_chunk.extend(chunk) 27 | return all_sent, all_chunk 28 | 29 | 30 | def save_dictionary(tokenizer, dict_path, encoding="utf-8"): 31 | with open(dict_path, mode="w+", encoding=encoding) as file: 32 | json.dump(tokenizer.to_json(), file) 33 | 34 | 35 | def load_dictionary(dict_path, encoding="utf-8"): 36 | with open(dict_path, mode="r", encoding=encoding) as file: 37 | return tokenizer_from_json(json.load(file)) 38 | 39 | 40 | def load_dictionaries(src_dict_path, tgt_dict_path, encoding="utf-8"): 41 | return load_dictionary(src_dict_path, encoding), load_dictionary(tgt_dict_path, encoding) 42 | 43 | 44 | def make_dictionaries(file_path, 45 | src_dict_path=None, 46 | tgt_dict_path=None, 47 | encoding="utf-8", 48 | min_freq=5, 49 | **kwargs): 50 | if not os.path.isdir(file_path): 51 | 52 | sents, chunks = _parse_data(open(file_path, 'r', encoding=encoding)) 53 | else: 54 | sents, chunks = _parse_data_from_dir(file_path) 55 | 56 | src_tokenizer = Tokenizer(**kwargs) 57 | tgt_tokenizer = Tokenizer(**kwargs) 58 | 59 | src_tokenizer.fit_on_texts(sents) 60 | tgt_tokenizer.fit_on_texts(chunks) 61 | 62 | src_sub = sum(map(lambda x: x[1] < min_freq, src_tokenizer.word_counts.items())) 63 | tgt_sub = sum(map(lambda x: x[1] < min_freq, tgt_tokenizer.word_counts.items())) 64 | 65 | src_tokenizer.num_words = len(src_tokenizer.word_index) - src_sub 66 | tgt_tokenizer.num_words = len(tgt_tokenizer.word_index) - tgt_sub 67 | 68 | if src_dict_path is not None: 69 | save_dictionary(src_tokenizer, src_dict_path, encoding=encoding) 70 | if tgt_dict_path is not None: 71 | save_dictionary(tgt_tokenizer, tgt_dict_path, encoding=encoding) 72 | 73 | return src_tokenizer, tgt_tokenizer 74 | 75 | 76 | def get_embedding_index(embedding_file): 77 | embedding_index = {} 78 | with open(os.path.join(embedding_file), encoding='UTF-8') as f: 79 | for line in f: 80 | values = line.split() 81 | word = values[0] 82 | coefs = np.asarray(values[1:], dtype=np.float32) 83 | embedding_index[word] = coefs 84 | return embedding_index 85 | 86 | 87 | def create_embedding_matrix(embeddings_index, word_index, vocab_size, embed_dim): 88 | embedding_matrix = np.zeros((vocab_size, embed_dim)) 89 | for word, i in word_index.items(): 90 | if i >= vocab_size: 91 | continue 92 | embedding_vector = embeddings_index.get(word) 93 | if embedding_vector is not None: 94 | # words not found in embedding index will be all-zeros. 95 | embedding_matrix[i] = embedding_vector 96 | return embedding_matrix 97 | -------------------------------------------------------------------------------- /train_example.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import ModelCheckpoint, TensorBoard 2 | from keras.optimizers import Adam 3 | 4 | from dl_segmenter import get_or_create, save_config 5 | from dl_segmenter.custom.callbacks import LRFinder, SGDRScheduler, WatchScheduler, SingleModelCK, LRSchedulerPerStep 6 | from dl_segmenter.data_loader import DataLoader 7 | import matplotlib.pyplot as plt 8 | import os 9 | 10 | if __name__ == '__main__': 11 | h5_dataset_path = "data/2014_processed.h5" # 转换为hdf5格式的数据集 12 | config_save_path = "config/default-config.json" # 模型配置路径 13 | weights_save_path = "models/weights.{epoch:02d}-{val_loss:.2f}.h5" # 模型权重保存路径 14 | init_weights_path = "models/weights.23-0.02.sgdr.h5" # 预训练模型权重文件路径 15 | embedding_file_path = "G:\data\word-vectors\word.embdding.iter5" # 词向量文件路径,若不使用设为None 16 | embedding_file_path = None # 词向量文件路径,若不使用设为None 17 | 18 | src_dict_path = "config/src_dict.json" # 源字典路径 19 | tgt_dict_path = "config/tgt_dict.json" # 目标字典路径 20 | batch_size = 32 21 | epochs = 128 22 | num_gpu = 1 23 | max_seq_len = 150 24 | initial_epoch = 0 25 | 26 | # GPU 下用于选择训练的GPU 27 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 28 | 29 | steps_per_epoch = 2000 30 | validation_steps = 20 31 | 32 | data_loader = DataLoader(src_dict_path=src_dict_path, 33 | tgt_dict_path=tgt_dict_path, 34 | batch_size=batch_size, 35 | max_len=max_seq_len, 36 | shuffle_batch=steps_per_epoch, 37 | sparse_target=False) 38 | 39 | config = { 40 | "vocab_size": data_loader.src_vocab_size, 41 | "chunk_size": data_loader.tgt_vocab_size, 42 | "sparse_target": data_loader.sparse_target, 43 | "embed_dim": 300, 44 | "bi_lstm_units": 256, 45 | } 46 | 47 | os.makedirs(os.path.dirname(weights_save_path), exist_ok=True) 48 | 49 | segmenter = get_or_create(config, 50 | optimizer=Adam(), 51 | embedding_file=embedding_file_path, 52 | src_dict_path=src_dict_path, 53 | weights_path=init_weights_path) 54 | 55 | save_config(segmenter, config_save_path) 56 | 57 | segmenter.model.summary() 58 | 59 | ck = SingleModelCK(weights_save_path, 60 | segmenter.model, 61 | save_best_only=True, 62 | save_weights_only=True, 63 | monitor='val_loss', 64 | verbose=0) 65 | log = TensorBoard(log_dir='logs', 66 | histogram_freq=0, 67 | batch_size=data_loader.batch_size, 68 | write_graph=True, 69 | write_grads=False) 70 | 71 | # Use LRFinder to find effective learning rate 72 | lr_finder = LRFinder(1e-6, 1e-2, steps_per_epoch, epochs=1) # => (1e-4, 1e-3) 73 | lr_scheduler = SGDRScheduler(min_lr=1e-4, max_lr=1e-3, 74 | initial_epoch=initial_epoch, 75 | steps_per_epoch=steps_per_epoch, 76 | cycle_length=10, 77 | lr_decay=0.9, 78 | mult_factor=1.2) 79 | 80 | X_train, Y_train, X_valid, Y_valid = DataLoader.load_data(h5_dataset_path, frac=0.9) 81 | 82 | segmenter.parallel_model.fit_generator(data_loader.generator_from_data(X_train, Y_train), 83 | epochs=epochs, 84 | steps_per_epoch=steps_per_epoch, 85 | validation_data=data_loader.generator_from_data(X_valid, Y_valid), 86 | validation_steps=validation_steps, 87 | callbacks=[ck, log, lr_scheduler], 88 | initial_epoch=initial_epoch) 89 | 90 | # lr_finder.plot_loss() 91 | # plt.savefig("loss.png") 92 | -------------------------------------------------------------------------------- /tools/score.py: -------------------------------------------------------------------------------- 1 | # coding:utf_8 2 | import sys 3 | import re 4 | 5 | # """ 6 | # 通过与黄金标准文件对比分析中文分词效果. 7 | 8 | # 使用方法: 9 | # python crf_tag_score.py test_gold.utf8 your_tagger_output.utf8 10 | 11 | # 分析结果示例如下: 12 | # 标准词数:104372 个,正确词数:96211 个,错误词数:6037 个 13 | # 标准行数:1944,正确行数:589 ,错误行数:1355 14 | # Recall: 92.1808531024% 15 | # Precision: 94.0957280338% 16 | # F MEASURE: 93.1284483593% 17 | 18 | 19 | # 参考:中文分词器分词效果的评测方法 20 | # http://ju.outofmemory.cn/entry/46140 21 | 22 | # """ 23 | ''' 24 | 通过与黄金标准文件对比分析中文分词效果. 25 | 分析结果如下: 26 | 一次迭代: 27 | result: 28 | 标准词数:26940个,正确词数:25341个,错误词数:1739个 29 | 标准行数:929,正确行数:407,错误行数:522 30 | Recall: 0.940646 31 | Precision: 0.935783 32 | F MEASURE: 0.938208 33 | ERR RATE: 0.064551 34 | 35 | 三十次迭代: 36 | result: 37 | 标准词数:26940个,正确词数:25719个,错误词数:1299个 38 | 标准行数:929,正确行数:493,错误行数:436 39 | Recall: 0.954677 40 | Precision: 0.951921 41 | F MEASURE: 0.953297 42 | ERR RATE: 0.048218 43 | 44 | 45 | CRF: 46 | result: 47 | 标准词数:26940个,正确词数:25544个,错误词数:1354个 48 | 标准行数:929,正确行数:480,错误行数:449 49 | Recall: 0.948181 50 | Precision: 0.949662 51 | F MEASURE: 0.948921 52 | ERR RATE: 0.050260 53 | 54 | ''' 55 | 56 | 57 | def read_line(f): 58 | ''' 59 | 读取一行,并清洗空格和换行 60 | ''' 61 | line = f.readline() 62 | return line.strip() 63 | 64 | 65 | def prf_score(real_text_file, pred_text_file, prf_file, epoch): 66 | file_gold = open(real_text_file, 'r', encoding='utf8') 67 | # file_gold = codecs.open(r'../corpus/msr_test_gold.utf8', 'r', 'utf8') 68 | # file_tag = codecs.open(r'pred_standard.txt', 'r', 'utf8') 69 | file_tag = open(pred_text_file, 'r', encoding='utf8') 70 | 71 | line1 = read_line(file_gold) 72 | N_count = 0 # 将正类分为正或者将正类分为负 73 | e_count = 0 # 将负类分为正 74 | c_count = 0 # 正类分为正 75 | e_line_count = 0 76 | c_line_count = 0 77 | 78 | while line1: 79 | line2 = read_line(file_tag) 80 | 81 | list1 = line1.split(' ') 82 | list2 = line2.split(' ') 83 | 84 | count1 = len(list1) # 标准分词数 85 | N_count += count1 86 | if line1 == line2: 87 | c_line_count += 1 # 分对的行数 88 | c_count += count1 # 分对的词数 89 | else: 90 | e_line_count += 1 91 | count2 = len(list2) 92 | 93 | arr1 = [] 94 | arr2 = [] 95 | 96 | pos = 0 97 | for w in list1: 98 | arr1.append(tuple([pos, pos + len(w)])) # list1中各个单词的起始位置 99 | pos += len(w) 100 | 101 | pos = 0 102 | for w in list2: 103 | arr2.append(tuple([pos, pos + len(w)])) # list2中各个单词的起始位置 104 | pos += len(w) 105 | 106 | for tp in arr2: 107 | if tp in arr1: 108 | c_count += 1 109 | else: 110 | e_count += 1 111 | 112 | line1 = read_line(file_gold) 113 | 114 | R = float(c_count) / N_count 115 | P = float(c_count) / (c_count + e_count) 116 | F = 2. * P * R / (P + R) 117 | ER = 1. * e_count / N_count 118 | 119 | print("result:") 120 | print('标准词数:%d个,词数正确率:%f个,词数错误率:%f \n' % (N_count, c_count / N_count, e_count / N_count)) 121 | print('标准行数:%d,行数正确率:%f,行数错误率:%f \n' % (c_line_count + e_line_count, c_line_count / (c_line_count + e_line_count), 122 | e_line_count / (c_line_count + e_line_count))) 123 | print('Recall: %f' % (R)) 124 | print('Precision: %f' % (P)) 125 | print('F MEASURE: %f' % (F)) 126 | print('ERR RATE: %f' % (ER)) 127 | 128 | # print P,R,F 129 | 130 | f = open(prf_file, 'a', encoding='utf-8') 131 | f.write('result-(epoch:%s):\n' % epoch) 132 | f.write('标准词数:%d,词数正确率:%f,词数错误率:%f \n' % (N_count, c_count / N_count, e_count / N_count)) 133 | f.write('标准行数:%d,行数正确率:%f,行数错误率:%f \n' % (c_line_count + e_line_count, c_line_count / (c_line_count + e_line_count), 134 | e_line_count / (c_line_count + e_line_count))) 135 | f.write('Recall: %f\n' % (R)) 136 | f.write('Precision: %f\n' % (P)) 137 | f.write('F MEASURE: %f\n' % (F)) 138 | f.write('ERR RATE: %f\n' % (ER)) 139 | f.write('====================================\n') 140 | 141 | return F 142 | 143 | 144 | if __name__ == '__main__': 145 | prf_score('../data/gold.utf8', '../data/pred_text.utf8', '../data/prf_tmp.txt', 15) 146 | -------------------------------------------------------------------------------- /tools/ner_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | 5 | def print_process(process): 6 | num_processed = int(30 * process) 7 | num_unprocessed = 30 - num_processed 8 | print( 9 | f"{''.join(['['] + ['='] * num_processed + ['>'] + [' '] * num_unprocessed + [']'])}, {(process * 100):.2f} %") 10 | 11 | 12 | def convert_to_bis(source_dir, target_path, log=False, combine=False, single_line=True): 13 | print("Converting...") 14 | for root, dirs, files in os.walk(source_dir): 15 | total = len(files) 16 | tgt_dir = target_path + root[len(source_dir):] 17 | 18 | print(tgt_dir) 19 | for index, name in enumerate(files): 20 | file = os.path.join(root, name) 21 | bises = process_file(file) 22 | if combine: 23 | _save_bises(bises, target_path, write_mode='a', single_line=single_line) 24 | else: 25 | os.makedirs(tgt_dir, exist_ok=True) 26 | _save_bises(bises, os.path.join(tgt_dir, name), single_line=single_line) 27 | if log: 28 | print_process((index + 1) / total) 29 | print("All converted") 30 | 31 | 32 | def _save_bises(bises, path, write_mode='w+', single_line=True): 33 | with open(path, mode=write_mode, encoding='UTF-8') as f: 34 | if single_line: 35 | for bis in bises: 36 | sent, tags = [], [] 37 | for char, tag in bis: 38 | sent.append(char) 39 | tags.append(tag) 40 | sent = ' '.join(sent) 41 | tags = ' '.join(tags) 42 | f.write(sent + "\t" + tags) 43 | f.write('\n') 44 | else: 45 | for bis in bises: 46 | for char, tag in bis: 47 | f.write(char + "\t" + tag + "\n") 48 | f.write("\n") 49 | 50 | 51 | def process_file(file): 52 | with open(file, 'r', encoding='UTF-8') as f: 53 | text = f.readlines() 54 | bises = _parse_text(text) 55 | return bises 56 | 57 | 58 | def _parse_text(text: list): 59 | bises = [] 60 | for line in text: 61 | # remove POS tag 62 | line, _ = re.subn('\\n', '', line) 63 | if line == '' or line == '\n': 64 | continue 65 | words = re.split('\s+', line) 66 | 67 | if len(words) > MAX_LEN_SIZE: 68 | texts = re.split('[。?!,.?!,]/w', line) 69 | if len(min(texts, key=len)) > MAX_LEN_SIZE: 70 | continue 71 | bises.extend(_parse_text(texts)) 72 | else: 73 | bises.append(_tag(words)) 74 | return bises 75 | 76 | 77 | def _tag(words): 78 | """ 79 | 给指定的一行文本打上BIS标签 80 | :param line: 文本行 81 | :return: 82 | """ 83 | bis = [] 84 | # words = list(map(list, words)) 85 | pre_word = None 86 | for word in words: 87 | pos_t = None 88 | tokens = word.split('/') 89 | if len(tokens) == 2: 90 | word, pos = tokens 91 | elif len(tokens) == 3: 92 | word, pos_t, pos = tokens 93 | else: 94 | continue 95 | 96 | word = list(word) 97 | pos = pos.upper() 98 | 99 | if len(word) == 0: 100 | continue 101 | if word[0] == '[': 102 | pre_word = word 103 | continue 104 | if pre_word is not None: 105 | pre_word += word 106 | if pos_t is None: 107 | continue 108 | elif pos_t[-1] != ']': 109 | continue 110 | else: 111 | word = pre_word[1:] 112 | pre_word = None 113 | 114 | if len(word) == 1: 115 | bis.append((word[0], 'S-' + pos)) 116 | else: 117 | for i, char in enumerate(word): 118 | if i == 0: 119 | bis.append((char, 'B-' + pos)) 120 | else: 121 | bis.append((char, 'I-' + pos)) 122 | # bis.append(('\n', 'O')) 123 | return bis 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser(description="将使用词性标注的文件转换为用BIS分块标记的文件。") 128 | parser.add_argument("corups_dir", type=str, help="指定存放语料库的文件夹,程序将会递归查找目录下的文件。") 129 | parser.add_argument("output_path", type=str, default='.', help="指定标记好的文件的输出路径。") 130 | parser.add_argument("-c", "--combine", help="是否组装为一个文件", default=False, type=bool) 131 | parser.add_argument("-s", "--single_line", help="是否为单行模式", default=False, type=bool) 132 | parser.add_argument("--log", help="是否打印进度条", default=False, type=bool) 133 | parser.add_argument("--max_len", help="处理后的最大语句长度(将原句子按标点符号断句,若断句后的长度仍比最大长度长,将忽略", 134 | default=150, type=int) 135 | args = parser.parse_args() 136 | MAX_LEN_SIZE = args.max_len 137 | 138 | convert_to_bis(args.corups_dir, args.output_path, args.log, args.combine, args.single_line) 139 | -------------------------------------------------------------------------------- /dl_segmenter/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | from keras.preprocessing.sequence import pad_sequences 6 | from keras.utils import to_categorical 7 | 8 | from dl_segmenter.utils import load_dictionary 9 | 10 | 11 | class DataLoader: 12 | 13 | def __init__(self, 14 | src_dict_path, 15 | tgt_dict_path, 16 | batch_size=64, 17 | max_len=999, 18 | fix_len=True, 19 | word_delimiter=' ', 20 | sent_delimiter='\t', 21 | shuffle_batch=10, 22 | encoding="utf-8", 23 | sparse_target=False): 24 | self.src_tokenizer = load_dictionary(src_dict_path, encoding) 25 | self.tgt_tokenizer = load_dictionary(tgt_dict_path, encoding) 26 | self.batch_size = batch_size 27 | self.max_len = max_len 28 | self.fix_len = fix_len 29 | self.word_delimiter = word_delimiter 30 | self.sent_delimiter = sent_delimiter 31 | self.src_vocab_size = self.src_tokenizer.num_words 32 | self.tgt_vocab_size = self.tgt_tokenizer.num_words 33 | self.shuffle_batch = shuffle_batch 34 | self.sparse_target = sparse_target 35 | 36 | def generator(self, file_path, encoding="utf-8"): 37 | if os.path.isdir(file_path): 38 | while True: 39 | for sent, chunk in self.load_sents_from_dir(file_path): 40 | yield sent, chunk 41 | while True: 42 | for sent, chunk in self.load_sents_from_file(file_path, encoding): 43 | yield sent, chunk 44 | 45 | def load_sents_from_dir(self, source_dir, encoding="utf-8"): 46 | for root, dirs, files in os.walk(source_dir): 47 | for name in files: 48 | file = os.path.join(root, name) 49 | for sent, chunk in self.load_sents_from_file(file, encoding=encoding): 50 | yield sent, chunk 51 | 52 | def load_sents_from_file(self, file_path, encoding): 53 | with open(file_path, encoding=encoding) as f: 54 | sent, chunk = [], [] 55 | for line in f: 56 | line = line[:-1] 57 | chars, tags = line.split(self.sent_delimiter) 58 | sent.append(chars.split(self.word_delimiter)) 59 | chunk.append(tags.split(self.word_delimiter)) 60 | if len(sent) >= self.batch_size: 61 | sent = self.src_tokenizer.texts_to_sequences(sent) 62 | chunk = self.tgt_tokenizer.texts_to_sequences(chunk) 63 | sent, chunk = self._pad_seq(sent, chunk) 64 | if not self.sparse_target: 65 | chunk = to_categorical(chunk, num_classes=self.tgt_vocab_size + 1) 66 | yield sent, chunk 67 | sent, chunk = [], [] 68 | 69 | @staticmethod 70 | def load_data(h5_file_path, frac=None): 71 | with h5py.File(h5_file_path, 'r') as dfile: 72 | X, Y = dfile['X'][:], dfile['Y'][:] 73 | 74 | if frac is not None: 75 | assert 0 < frac < 1 76 | split_point = int(X.shape[0] * frac) 77 | X_train = X[:split_point] 78 | Y_train = Y[:split_point] 79 | X_valid = X[split_point:] 80 | Y_valid = Y[split_point:] 81 | return X_train, Y_train, X_valid, Y_valid 82 | return X, Y 83 | 84 | def generator_from_data(self, X, Y): 85 | steps = 0 86 | total_size = X.shape[0] 87 | while True: 88 | if steps >= self.shuffle_batch: 89 | indicates = list(range(total_size)) 90 | np.random.shuffle(indicates) 91 | X = X[indicates] 92 | Y = Y[indicates] 93 | steps = 0 94 | sample_index = np.random.randint(0, total_size - self.batch_size) 95 | ret_x = X[sample_index:sample_index + self.batch_size] 96 | ret_y = Y[sample_index:sample_index + self.batch_size] 97 | 98 | if not self.sparse_target: 99 | ret_y = to_categorical(ret_y, num_classes=self.tgt_vocab_size + 1) 100 | else: 101 | ret_y = np.expand_dims(ret_y, 2) 102 | yield ret_x, ret_y 103 | steps += 1 104 | 105 | def load_and_dump_to_h5(self, file_path, output_path, encoding): 106 | with open(file_path, encoding=encoding) as f: 107 | sent, chunk = [], [] 108 | for line in f: 109 | line = line[:-1] 110 | chars, tags = line.split(self.sent_delimiter) 111 | sent.append(chars.split(self.word_delimiter)) 112 | chunk.append(tags.split(self.word_delimiter)) 113 | 114 | sent = self.src_tokenizer.texts_to_sequences(sent) 115 | chunk = self.tgt_tokenizer.texts_to_sequences(chunk) 116 | sent, chunk = self._pad_seq(sent, chunk) 117 | 118 | indicates = list(range(sent.shape[0])) 119 | np.random.shuffle(indicates) 120 | sent = sent[indicates] 121 | chunk = chunk[indicates] 122 | 123 | with h5py.File(output_path, 'w') as dfile: 124 | dfile.create_dataset('X', data=sent) 125 | dfile.create_dataset('Y', data=chunk) 126 | 127 | def _pad_seq(self, sent, chunk): 128 | if not self.fix_len: 129 | len_sent = min(len(max(sent, key=len)), self.max_len) 130 | len_chunk = min(len(max(chunk, key=len)), self.max_len) 131 | sent = pad_sequences(sent, maxlen=len_sent, padding='post') 132 | chunk = pad_sequences(chunk, maxlen=len_chunk, padding='post') 133 | else: 134 | sent = pad_sequences(sent, maxlen=self.max_len, padding='post') 135 | chunk = pad_sequences(chunk, maxlen=self.max_len, padding='post') 136 | return sent, chunk 137 | -------------------------------------------------------------------------------- /dl_segmenter/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | import traceback 5 | from concurrent.futures import ThreadPoolExecutor 6 | from multiprocessing import Lock 7 | 8 | import numpy as np 9 | from keras import Input, Model 10 | from keras.layers import Embedding, Bidirectional, Dense, Dropout, GRU, LSTM 11 | from keras.optimizers import Adam 12 | from keras.preprocessing.sequence import pad_sequences 13 | from keras.utils import multi_gpu_model 14 | from keras_contrib.layers import CRF 15 | from keras_contrib.losses import crf_loss 16 | from keras_contrib.metrics import crf_accuracy 17 | from keras_preprocessing.text import Tokenizer 18 | 19 | from dl_segmenter.utils import load_dictionary, create_embedding_matrix, get_embedding_index 20 | 21 | 22 | class DLSegmenter: 23 | 24 | def __init__(self, 25 | vocab_size, 26 | chunk_size, 27 | embed_dim=300, 28 | bi_lstm_units=256, 29 | dropout_rate=0.1, 30 | num_gpu=0, 31 | optimizer=Adam(), 32 | sparse_target=False, 33 | emb_matrix=None, 34 | weights_path=None, 35 | rule_fn=None, 36 | src_tokenizer: Tokenizer = None, 37 | tgt_tokenizer: Tokenizer = None): 38 | self.vocab_size = vocab_size 39 | self.chunk_size = chunk_size 40 | self.embed_dim = embed_dim 41 | self.bi_lstm_units = bi_lstm_units 42 | self.dropout_rate = dropout_rate 43 | self.sparse_target = sparse_target 44 | 45 | self.rule_fn = rule_fn 46 | self.num_gpu = num_gpu 47 | self.optimizer = optimizer 48 | self.src_tokenizer = src_tokenizer 49 | self.tgt_tokenizer = tgt_tokenizer 50 | self.model, self.parallel_model = self.__build_model(emb_matrix) 51 | if weights_path is not None: 52 | try: 53 | self.model.load_weights(weights_path) 54 | logging.info("weights loaded!") 55 | except: 56 | logging.error("No weights found, create a new model.") 57 | 58 | def __build_model(self, emb_matrix=None): 59 | word_input = Input(shape=(None,), dtype='int32', name="word_input") 60 | 61 | word_emb = Embedding(self.vocab_size + 1, self.embed_dim, 62 | weights=[emb_matrix] if emb_matrix is not None else None, 63 | trainable=True if emb_matrix is None else False, 64 | name='word_emb')(word_input) 65 | 66 | bilstm_output = Bidirectional(LSTM(self.bi_lstm_units // 2, 67 | return_sequences=True))(word_emb) 68 | 69 | bilstm_output = Dropout(self.dropout_rate)(bilstm_output) 70 | 71 | output = Dense(self.chunk_size + 1, kernel_initializer="he_normal")(bilstm_output) 72 | output = CRF(self.chunk_size + 1, sparse_target=self.sparse_target)(output) 73 | 74 | model = Model([word_input], [output]) 75 | parallel_model = model 76 | if self.num_gpu > 1: 77 | parallel_model = multi_gpu_model(model, gpus=self.num_gpu) 78 | 79 | parallel_model.compile(optimizer=self.optimizer, loss=crf_loss, metrics=[crf_accuracy]) 80 | return model, parallel_model 81 | 82 | def decode_sequences(self, sequences): 83 | sequences = self._seq_to_matrix(sequences) 84 | output = self.model.predict_on_batch(sequences) # [N, -1, chunk_size + 1] 85 | output = np.argmax(output, axis=2) 86 | return self.tgt_tokenizer.sequences_to_texts(output) 87 | 88 | def _single_decode(self, args): 89 | sent, tag = args 90 | cur_sent, cur_tag = [], [] 91 | tag = tag.split(' ') 92 | t1, pre_pos = [], None 93 | for i in range(len(sent)): 94 | tokens = tag[i].split('-') 95 | if len(tokens) == 2: 96 | c, pos = tokens 97 | else: 98 | c = 'i' 99 | pos = "" 100 | 101 | word = sent[i] 102 | if c in 'sb': 103 | if len(t1) != 0: 104 | cur_sent.append(''.join(t1)) 105 | cur_tag.append(pre_pos) 106 | t1 = [word] 107 | pre_pos = pos 108 | elif c in 'ie': 109 | t1.append(word) 110 | pre_pos = pos 111 | 112 | if len(t1) != 0: 113 | cur_sent.append(''.join(t1)) 114 | cur_tag.append(pre_pos) 115 | 116 | if self.rule_fn is not None: 117 | return self.rule_fn(cur_sent, cur_tag) 118 | return cur_sent, cur_tag 119 | 120 | def decode_texts(self, texts): 121 | sents = [] 122 | with ThreadPoolExecutor() as executor: 123 | for text in executor.map(lambda x: list(re.subn("\s+", "", x)[0]), texts): 124 | sents.append(text) 125 | sequences = self.src_tokenizer.texts_to_sequences(sents) 126 | tags = self.decode_sequences(sequences) 127 | 128 | ret = [] 129 | with ThreadPoolExecutor() as executor: 130 | for cur_sent, cur_tag in executor.map(self._single_decode, zip(sents, tags)): 131 | ret.append((cur_sent, cur_tag)) 132 | 133 | return ret 134 | 135 | def _seq_to_matrix(self, sequences): 136 | max_len = len(max(sequences, key=len)) 137 | return pad_sequences(sequences, maxlen=max_len, padding="post") 138 | 139 | def get_config(self): 140 | return { 141 | "vocab_size": self.vocab_size, 142 | "chunk_size": self.chunk_size, 143 | "embed_dim": self.embed_dim, 144 | "sparse_target": self.sparse_target, 145 | "bi_lstm_units": self.bi_lstm_units, 146 | "dropout_rate": self.dropout_rate, 147 | } 148 | 149 | __singleton = None 150 | __lock = Lock() 151 | 152 | @staticmethod 153 | def get_or_create(config, src_dict_path=None, 154 | tgt_dict_path=None, 155 | weights_path=None, 156 | embedding_file=None, 157 | optimizer=Adam(), 158 | rule_fn=None, 159 | encoding="utf-8"): 160 | DLSegmenter.__lock.acquire() 161 | try: 162 | if DLSegmenter.__singleton is None: 163 | if type(config) == str: 164 | with open(config, encoding=encoding) as file: 165 | config = dict(json.load(file)) 166 | elif type(config) == dict: 167 | config = config 168 | else: 169 | raise ValueError("Unexpect config type!") 170 | 171 | if src_dict_path is not None: 172 | src_tokenizer = load_dictionary(src_dict_path, encoding) 173 | config['src_tokenizer'] = src_tokenizer 174 | if embedding_file is not None: 175 | emb_matrix = create_embedding_matrix(get_embedding_index(embedding_file), 176 | src_tokenizer.word_index, 177 | min(config['vocab_size'] + 1, config['max_num_words']), 178 | config['embed_dim']) 179 | config['emb_matrix'] = emb_matrix 180 | if tgt_dict_path is not None: 181 | config['tgt_tokenizer'] = load_dictionary(tgt_dict_path, encoding) 182 | 183 | config['rule_fn'] = rule_fn 184 | config['weights_path'] = weights_path 185 | config['optimizer'] = optimizer 186 | DLSegmenter.__singleton = DLSegmenter(**config) 187 | except Exception: 188 | traceback.print_exc() 189 | finally: 190 | DLSegmenter.__lock.release() 191 | return DLSegmenter.__singleton 192 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bi-lstm-crf 2 | 3 | 基于Universal Transformer的分词模型见:[https://github.com/GlassyWing/transformer-word-segmenter](https://github.com/GlassyWing/transformer-word-segmenter): 4 | 5 | ## 简介 6 | 7 | 不同于英文自然语言处理,中文自然语言处理,例如语义分析、文本分类、词语蕴含等任务都需要预先进行分词。要将中文进行分割,直观的方式是通过为语句中的每一个字进行标记,以确定这个字是位于一个词的开头还是之中: 8 | 9 | 例如“**成功入侵民主党的电脑系统**”这句话,我们为其标注为: 10 | 11 | ```js 12 | "成功 入侵 民主党 的 电脑系统" 13 | B I B I B I I S B I I I 14 | ``` 15 | 16 | 其中`B`表示一个词语的开头,`I`表示非一个词语的开头,`S`表示单字成词。这样我们就能达到分词的效果。 17 | 18 | 对于句子这样的序列而言,要为其进行标注,常用的是使用Bi-LSTM卷积网络进行序列标注,如下图: 19 | 20 |
21 | 22 | 23 |
24 | 25 | 通过Bi-LSTM获得每个词所对应的所有标签的概率,取最大概率的标注即可获得整个标注序列,如上图序列`W0W1W2`的标注为`BIS`。但这样有可能会取得不合逻辑的标注序列,如`BS`、`SI`等。我们需要为其设定一些约束,如: 26 | 27 | * B后只能是I 28 | * S之后只能是B、S 29 | * ... 30 | 31 | 而要做到这一点,我们可以在原有的模型基础之上,加上一个CRF层,该层的作用即是学习符号之间的约束(如上所述)。模型架构变为Embedding + Bi-LSTM + CRF,原理参考论文:https://arxiv.org/abs/1508.01991。 32 | 33 | ## 语料预处理 34 | 35 | 要训练模型,首先需要准备好语料,这里选用人民日报2014年的80万语料作为训练语料。语料格式如下: 36 | 37 | ```js 38 | "人民网/nz 1月1日/t 讯/ng 据/p 《/w [纽约/nsf 时报/n]/nz 》/w 报道/v ,/w 美国/nsf 华尔街/nsf 股市/n 在/p 2013年/t 的/ude1 最后/f 一天/mq 继续/v 上涨/vn ,/w 和/cc [全球/n 股市/n]/nz 一样/uyy ,/w 都/d 以/p [最高/a 纪录/n]/nz 或/c 接近/v [最高/a 纪录/n]/nz 结束/v 本/rz 年/qt 的/ude1 交易/vn 。/w " 39 | ``` 40 | 41 | 原格式中每一个词语使用空格分开后面使用POS标记词性,而本模型所需要的语料格式如下: 42 | 43 | ```js 44 | 嫌 疑 人 赵 国 军 。 B-N I-N I-N B-NR I-NR I-NR S-W 45 | ``` 46 | 47 | 使用命令: 48 | 49 | ```sh 50 | python tools/data_preprocess.py people-2014/train 2014_processed -c True -s True 51 | ``` 52 | 53 | 可将原文件转换为用BIS标签(B:表示语句块的开始,I:表示非语句块的开始,S:表示单独成词)标注的文件。 54 | 55 | 如上将会使用`people-2014/train`下的文件生成文本文件`2014_processed` 56 | 57 | ## 生成字典 58 | 59 | 使用命令: 60 | 61 | ```python 62 | python tools/make_dicts.py 2014_processed -s src_dict.json -t tgt_dict.json 63 | ``` 64 | 65 | 这会使用文件`2014_processed`,生成两个字典文件,`src_dict.json`, `tgt_dict.json` 66 | 67 | 使用方式见:`python tools/make_dicts.py -h` 68 | 69 | ## 转换为hdf5格式 70 | 71 | 使用命令: 72 | 73 | ```python 74 | python tools/convert_to_h5.py 2014_processed 2014_processed.h5 -s src_dict.json -t tgt_dict.json 75 | ``` 76 | 77 | 可将文本文件`2014_processed`转换为hdf5格式,提升训练速度, 78 | 79 | 使用方式见:`python tools/convert_to_h5.py -h` 80 | 81 | ## 训练 82 | 83 | 训练示例见: 84 | 85 | ```python 86 | train_example.py 87 | ``` 88 | 89 | 训练时,默认会生成模型配置文件`data/default-config.json`, 权重文件将会生成在`models`文件夹下。 90 | 91 | ### 使用字(词)向量 92 | 93 | 在训练时可以使用已训练的字(词)向量作为每一个字的表征,字(词)向量的格式如下: 94 | 95 | ```js 96 | 而 -0.037438 0.143471 0.391358 ... 97 | 个 -0.045985 -0.065485 0.251576 ... 98 | 以 -0.085605 0.081578 0.227135 ... 99 | 可以 0.012544 0.069829 0.117207 ... 100 | 第 -0.321195 0.065808 0.089396 ... 101 | 上 -0.186070 0.189417 0.265060 ... 102 | 之 0.037873 0.075681 0.239715 ... 103 | 于 -0.197969 0.018578 0.233496 ... 104 | 对 -0.115746 -0.025029 -0 ... 105 | ``` 106 | 107 | 每一行,为一个字(词)和它所对应的特征向量。 108 | 109 | 汉字字(词)向量来源 110 | 可从[https://github.com/Embedding/Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors)获得字(词)向量。字(词)向量文件中每一行格式为一个字(词)与其对应的300维向量。 111 | 112 | ### 训练效果 113 | 114 | 训练时模型配置如下: 115 | 116 | ```json 117 | config = { 118 | "vocab_size": 6864, 119 | "chunk_size": 259, 120 | "embed_dim": 300, 121 | "bi_lstm_units": 256, 122 | "max_num_words": 20000, 123 | "dropout_rate": 0.1 124 | } 125 | ``` 126 | 127 | 其它参数: 128 | 129 | | 参数 | 值 | 130 | | ---------------- | ---- | 131 | | batch size | 32 | 132 | | epochs | 32 | 133 | | steps_per_epoch | 2000 | 134 | | validation_steps | 20 | 135 | 136 | **注**: 训练未使用词向量 137 | 138 | 最终效果: 139 | 140 | 在迭代32次后,验证集精度达到98% 141 | 142 | 143 | 144 | 145 | ## 分词/解码 146 | 147 | 1. 编码方式: 148 | 149 | ```python 150 | import time 151 | 152 | from dl_segmenter import get_or_create, DLSegmenter 153 | 154 | if __name__ == '__main__': 155 | segmenter: DLSegmenter = get_or_create("../data/default-config.json", 156 | src_dict_path="../data/src_dict.json", 157 | tgt_dict_path="../data/tgt_dict.json", 158 | weights_path="../models/weights.32--0.18.h5") 159 | 160 | for _ in range(1): 161 | start_time = time.time() 162 | for sent, tag in segmenter.decode_texts([ 163 | "美国司法部副部长罗森·施泰因(Rod Rosenstein)指," 164 | "这些俄罗斯情报人员涉嫌利用电脑病毒或“钓鱼电邮”," 165 | "成功入侵民主党的电脑系统,偷取民主党高层成员之间的电邮," 166 | "另外也从美国一个州的电脑系统偷取了50万名美国选民的资料。"]): 167 | print(sent) 168 | print(tag) 169 | print(f"cost {(time.time() - start_time) * 1000}ms") 170 | 171 | ``` 172 | 173 | `get_or_create`: 174 | 175 | - 参数: 176 | - config_path: 模型配置路径 177 | - src_dict_path:源字典文件路径 178 | - tgt_dict_path:目标字典文件路径 179 | - weights_path:权重文件路径 180 | - 返回: 181 | 分词器对象 182 | 183 | `decode_texts`: 184 | - 参数: 185 | - 字符串序列(即可同时处理多段文本) 186 | - 返回: 187 | - 一个序列,序列中每一个元素为对应语句的分词结果和每个词的词性标签。 188 | 189 | 2. 命令方式: 190 | 191 | ```python 192 | python examples/predict.py -s <语句> 193 | ``` 194 | 195 | 命令方式所使用的模型配置文件、字典文件等如编程方式中所示。进行分词时,多句话可用空格分隔,具体使用方式可使用`predict.py -h`查看。 196 | 197 | ### 分词效果展示 198 | 199 | 1. 科技类 200 | 201 | > _物理仿真引擎的作用,是让虚拟世界中的物体运动符合真实世界的物理定律,经常用于游戏领域,以便让画面看起来更富有真实感。PhysX是由英伟达提出的物理仿真引擎,其物理模拟计算由专门加速芯片GPU来进行处理,在节省CPU负担的同时还能将物理运算效能成倍提升,由此带来更加符合真实世界的物理效果。_ 202 | 203 | ```python 204 | ['物理', '仿真引擎', '的', '作用', ',', '是', '让', '虚拟世界', '中', '的', '物体运动', '符合', '真实世界', '的', '物理定律', ',', '经常', '用于', '游戏', '领域', ',', '以便', '让', '画面', '看起来', '更', '富有', '真实感', '。', 'PhysX', '是', '由', '英伟达', '提出', '的', '物理', '仿真引擎', ',', '其', '物理模拟计算', '由', '专门', '加速', '芯片', 'GPU', '来', '进行', '处理', ',', '在', '节省', 'CPU', '负担', '的', '同时', '还', '能', '将', '物理运算', '效能', '成', '倍', '提升', ',', '由此', '带来', '更加', '符合', '真实世界', '的', '物理', '效果', '。'] 205 | ['n', 'n', 'ude1', 'n', 'w', 'vshi', 'v', 'gi', 'f', 'ude1', 'nz', 'v', 'nz', 'ude1', 'nz', 'w', 'd', 'v', 'n', 'n', 'w', 'd', 'v', 'n', 'v', 'd', 'v', 'n', 'w', 'x', 'vshi', 'p', 'nz', 'v', 'ude1', 'n', 'n', 'w', 'rz', 'nz', 'p', 'd', 'vi', 'n', 'x', 'vf', 'vn', 'vn', 'w', 'p', 'v', 'x', 'n', 'ude1', 'c', 'd', 'v', 'd', 'nz', 'n', 'v', 'q', 'v', 'w', 'd', 'v', 'd', 'v', 'nz', 'ude1', 'n', 'n', 'w'] 206 | ``` 207 | 208 | 2. 政治类 209 | 210 | > _昨晚,英国首相特里萨•梅(Theresa May)试图挽救其退欧协议的努力,在布鲁塞尔遭遇了严重麻烦。倍感失望的欧盟领导人们指责她没有拿出可行的提案来向充满敌意的英国议会兜售她的退欧计划。_ 211 | 212 | ```python 213 | ['昨晚', ',', '英国', '首相', '特里萨•梅', '(', 'TheresaMay', ')', '试图', '挽救', '其', '退', '欧', '协议', '的', '努力', ',', '在', '布鲁塞尔', '遭遇', '了', '严重', '麻烦', '。', '倍感', '失望', '的', '欧盟', '领导', '人们', '指责', '她', '没有', '拿出', '可行', '的', '提案', '来', '向', '充满', '敌意', '的', '英国议会', '兜售', '她', '的', '退欧', '计划', '。'] 214 | ['t', 'w', 'ns', 'nnt', 'nrf', 'w', 'x', 'w', 'v', 'vn', 'rz', 'v', 'b', 'n', 'ude1', 'ad', 'w', 'p', 'nsf', 'v', 'ule', 'a', 'an', 'w', 'v', 'a', 'ude1', 'n', 'n', 'n', 'v', 'rr', 'v', 'v', 'a', 'ude1', 'n', 'vf', 'p', 'v', 'n', 'ude1', 'nt', 'v', 'rr', 'ude1', 'nz', 'n', 'w'] 215 | ``` 216 | 217 | 3. 新闻类 218 | 219 | > _印度尼西亚国家抗灾署此前发布消息证实,印尼巽他海峡附近的万丹省当地时间22号晚遭海啸袭击。_ 220 | 221 | ```python 222 | ['印度尼西亚', '国家', '抗灾署', '此前', '发布', '消息', '证实', ',', '印尼', '巽他海峡', '附近', '的', '万丹省', '当地时间', '22号', '晚', '遭', '海啸', '袭击', '。'] 223 | ['nsf', 'n', 'nz', 't', 'v', 'n', 'v', 'w', 'ns', 'nz', 'f', 'ude1', 'ns', 'nz', 'mq', 'tg', 'v', 'n', 'vn', 'w'] 224 | ``` 225 | 226 | ## 分词评估结果 227 | 228 | 使用开发集进行评估: 229 | 230 | ```py 231 | result-(epoch:32): 232 | 标准词数:20744,词数正确率:0.939404,词数错误率:0.049653 233 | 标准行数:317,行数正确率:0.337539,行数错误率:0.662461 234 | Recall: 0.939404 235 | Precision: 0.949798 236 | F MEASURE: 0.944572 237 | ERR RATE: 0.049653 238 | ``` 239 | 240 | ## 其它 241 | 242 | ### 如何评估 243 | 244 | 使用与黄金标准文件进行对比的方式,进行评估。 245 | 246 | 1. 数据预处理 247 | 248 | 为了生成黄金标准文件和无分词标记的原始文件,可用下列命令: 249 | 250 | ```python 251 | python examples/score_preprocess.py --corups_dir <评估用语料文件夹> \ 252 | --gold_file_path <生成的黄金标准文件路径> \ 253 | --restore_file_path <生成无标记的原始文件路径> 254 | ``` 255 | 256 | 2. 读取无标记的原始文件,并进行分词,输出到文件: 257 | 258 | ```python 259 | python examples/predict.py -f <要分割的文本文件的路径> -o <保存分词结果的文件路径> 260 | ``` 261 | 262 | 3. 生成评估结果: 263 | 264 | 执行`score.py`可生成评估文件,默认使用黄金分割文件`../data/gold.utf8`,使用模型分词后的文件`../data/gold.utf8`,评估结果保存到`../data/prf_tmp.txt`中。 265 | 266 | ```py 267 | def main(): 268 | F = prf_score('../data/gold.utf8', '../data/gold.utf8', '../data/prf_tmp.txt', 15) 269 | ``` 270 | 271 | ## 附录 272 | 273 | 1. 分词语料库: https://pan.baidu.com/s/1EtXdhPR0lGF8c7tT8epn6Q 密码: yj9j 274 | 2. 已训练模型权重、配置及字典: https://pan.baidu.com/s/1_IK-e8CDrgaCn-jZqozKJA 提取码: grng -------------------------------------------------------------------------------- /dl_segmenter/custom/callbacks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import keras.backend as K 4 | import numpy as np 5 | from keras.callbacks import Callback, ModelCheckpoint 6 | 7 | 8 | class HistoryCache: 9 | 10 | def __init__(self, his_len=10): 11 | self.history = [0] * his_len 12 | self.history_len = his_len 13 | self.cursor = 0 14 | self.len = 0 15 | 16 | def put(self, value): 17 | self.history[self.cursor] = value 18 | self.cursor += 1 19 | if self.cursor >= self.history_len: 20 | self.cursor = 0 21 | if self.len + 1 <= self.history_len: 22 | self.len += 1 23 | 24 | def mean(self): 25 | return np.array(self.history[0: self.len]).mean() 26 | 27 | 28 | class WatchScheduler(Callback): 29 | 30 | def __init__(self, schedule, min_lr, max_lr, watch="loss", watch_his_len=10): 31 | super().__init__() 32 | self.schedule = schedule 33 | self.watch = watch 34 | self.min_lr = min_lr 35 | self.max_lr = max_lr 36 | self.history_cache = HistoryCache(watch_his_len) 37 | 38 | def on_train_begin(self, logs=None): 39 | logs = logs or {} 40 | K.set_value(self.model.optimizer.lr, self.max_lr) 41 | 42 | def on_epoch_begin(self, epoch, logs=None): 43 | logs = logs or {} 44 | logs['lr'] = K.get_value(self.model.optimizer.lr) 45 | 46 | def on_epoch_end(self, epoch, logs=None): 47 | lr = float(K.get_value(self.model.optimizer.lr)) 48 | watch_value = logs.get(self.watch) 49 | if watch_value is None: 50 | raise ValueError(f"Watched value '{self.watch}' don't exist") 51 | 52 | if lr <= self.min_lr: 53 | return 54 | 55 | self.history_cache.put(watch_value) 56 | 57 | if watch_value > self.history_cache.mean(): 58 | lr = self.schedule(epoch, lr) 59 | print(f"Update learning rate: {lr}") 60 | K.set_value(self.model.optimizer.lr, lr) 61 | 62 | 63 | from keras.callbacks import Callback 64 | import matplotlib.pyplot as plt 65 | 66 | 67 | class LRFinder(Callback): 68 | ''' 69 | A simple callback for finding the optimal learning rate range for your model + dataset. 70 | 71 | # Usage 72 | ```python 73 | lr_finder = LRFinder(min_lr=1e-5, 74 | max_lr=1e-2, 75 | steps_per_epoch=np.ceil(epoch_size/batch_size), 76 | epochs=3) 77 | model.fit(X_train, Y_train, callbacks=[lr_finder]) 78 | 79 | lr_finder.plot_loss() 80 | ``` 81 | 82 | # Arguments 83 | min_lr: The lower bound of the learning rate range for the experiment. 84 | max_lr: The upper bound of the learning rate range for the experiment. 85 | steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. 86 | epochs: Number of epochs to run experiment. Usually between 2 and 4 epochs is sufficient. 87 | 88 | # References 89 | Blog post: jeremyjordan.me/nn-learning-rate 90 | Original paper: https://arxiv.org/abs/1506.01186 91 | ''' 92 | 93 | def __init__(self, min_lr=1e-5, max_lr=1e-2, steps_per_epoch=None, epochs=None): 94 | super().__init__() 95 | 96 | self.min_lr = min_lr 97 | self.max_lr = max_lr 98 | self.total_iterations = steps_per_epoch * epochs 99 | self.iteration = 0 100 | self.history = {} 101 | 102 | def clr(self): 103 | '''Calculate the learning rate.''' 104 | x = self.iteration / self.total_iterations 105 | return self.min_lr + (self.max_lr - self.min_lr) * x 106 | 107 | def on_train_begin(self, logs=None): 108 | '''Initialize the learning rate to the minimum value at the start of training.''' 109 | logs = logs or {} 110 | K.set_value(self.model.optimizer.lr, self.min_lr) 111 | 112 | def on_batch_end(self, epoch, logs=None): 113 | '''Record previous batch statistics and update the learning rate.''' 114 | logs = logs or {} 115 | self.iteration += 1 116 | 117 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 118 | self.history.setdefault('iterations', []).append(self.iteration) 119 | 120 | for k, v in logs.items(): 121 | self.history.setdefault(k, []).append(v) 122 | 123 | K.set_value(self.model.optimizer.lr, self.clr()) 124 | 125 | def plot_lr(self): 126 | '''Helper function to quickly inspect the learning rate schedule.''' 127 | plt.plot(self.history['iterations'], self.history['lr']) 128 | plt.yscale('log') 129 | plt.xlabel('Iteration') 130 | plt.ylabel('Learning rate') 131 | plt.show() 132 | 133 | def plot_loss(self): 134 | '''Helper function to quickly observe the learning rate experiment results.''' 135 | plt.plot(self.history['lr'], self.history['loss']) 136 | plt.xscale('log') 137 | plt.xlabel('Learning rate') 138 | plt.ylabel('Loss') 139 | plt.show() 140 | 141 | 142 | class SGDRScheduler(Callback): 143 | '''Cosine annealing learning rate scheduler with periodic restarts. 144 | # Usage 145 | ```python 146 | schedule = SGDRScheduler(min_lr=1e-5, 147 | max_lr=1e-2, 148 | steps_per_epoch=np.ceil(epoch_size/batch_size), 149 | lr_decay=0.9, 150 | cycle_length=5, 151 | mult_factor=1.5) 152 | model.fit(X_train, Y_train, epochs=100, callbacks=[schedule]) 153 | ``` 154 | # Arguments 155 | min_lr: The lower bound of the learning rate range for the experiment. 156 | max_lr: The upper bound of the learning rate range for the experiment. 157 | steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. 158 | lr_decay: Reduce the max_lr after the completion of each cycle. 159 | Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8. 160 | cycle_length: Initial number of epochs in a cycle. 161 | mult_factor: Scale epochs_to_restart after each full cycle completion. 162 | initial_epoch: Used to resume training, **note**: Other args must be same as last training. 163 | # References 164 | Blog post: jeremyjordan.me/nn-learning-rate 165 | Original paper: http://arxiv.org/abs/1608.03983 166 | ''' 167 | 168 | def __init__(self, 169 | min_lr, 170 | max_lr, 171 | steps_per_epoch, 172 | lr_decay=1, 173 | cycle_length=10, 174 | mult_factor=2, 175 | initial_epoch=0): 176 | 177 | self.min_lr = min_lr 178 | self.max_lr = max_lr 179 | self.lr_decay = lr_decay 180 | 181 | self.batch_since_restart = 0 182 | self.next_restart = cycle_length 183 | 184 | self.steps_per_epoch = steps_per_epoch 185 | 186 | self.cycle_length = cycle_length 187 | self.mult_factor = mult_factor 188 | 189 | self.history = {} 190 | 191 | self.recovery_status(initial_epoch) 192 | 193 | def recovery_status(self, initial_epoch): 194 | # Return to the last state when it was stopped. 195 | if initial_epoch < self.cycle_length: 196 | num_cycles = 0 197 | else: 198 | ratio = initial_epoch / self.cycle_length 199 | 200 | num_cycles = 0 201 | while ratio > 0: 202 | ratio -= self.mult_factor ** num_cycles 203 | num_cycles += 1 204 | 205 | # If haven't done 206 | if ratio < 0: 207 | num_cycles -= 1 208 | 209 | done_epochs = 0 210 | for _ in range(num_cycles): 211 | self.max_lr *= self.lr_decay 212 | done_epochs += self.cycle_length 213 | self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) 214 | 215 | self.batch_since_restart = (initial_epoch - done_epochs) * self.steps_per_epoch 216 | 217 | def clr(self): 218 | '''Calculate the learning rate.''' 219 | fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length) 220 | lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi)) 221 | return lr 222 | 223 | def on_train_begin(self, logs=None): 224 | '''Initialize the learning rate to the minimum value at the start of training.''' 225 | logs = logs or {} 226 | K.set_value(self.model.optimizer.lr, self.max_lr) 227 | 228 | def on_batch_end(self, batch, logs=None): 229 | '''Record previous batch statistics and update the learning rate.''' 230 | logs = logs or {} 231 | self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) 232 | for k, v in logs.items(): 233 | self.history.setdefault(k, []).append(v) 234 | 235 | self.batch_since_restart += 1 236 | K.set_value(self.model.optimizer.lr, self.clr()) 237 | 238 | def on_epoch_end(self, epoch, logs=None): 239 | '''Check for end of current cycle, apply restarts when necessary.''' 240 | if epoch + 1 == self.next_restart: 241 | self.batch_since_restart = 0 242 | self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) 243 | self.next_restart += self.cycle_length 244 | self.max_lr *= self.lr_decay 245 | self.best_weights = self.model.get_weights() 246 | 247 | def on_train_end(self, logs=None): 248 | '''Set weights to the values from the end of the most recent cycle for best performance.''' 249 | self.model.set_weights(self.best_weights) 250 | 251 | 252 | class LRSchedulerPerStep(Callback): 253 | def __init__(self, d_model, warmup=4000, initial_epoch=0, steps_per_epoch=None): 254 | """ 255 | learning rate decay strategy in "Attention is all you need" https://arxiv.org/abs/1706.03762 256 | :param d_model: model dimension 257 | :param warmup: warm up steps 258 | :param initial_epoch: Used to resume training, 259 | **note**: Other args must be same as last training. 260 | """ 261 | self.basic = d_model ** -0.5 262 | self.warm = warmup ** -1.5 263 | self.step_num = 0 264 | if initial_epoch > 0: 265 | assert steps_per_epoch is not None 266 | self.step_num = initial_epoch * steps_per_epoch 267 | 268 | def on_batch_begin(self, batch, logs=None): 269 | self.step_num += 1 270 | lr = self.basic * min(self.step_num ** -0.5, self.step_num * self.warm) 271 | K.set_value(self.model.optimizer.lr, lr) 272 | 273 | 274 | class SingleModelCK(ModelCheckpoint): 275 | """ 276 | 用于解决在多gpu下训练保存的权重无法应用于单gpu的情况 277 | """ 278 | 279 | def __init__(self, filepath, model, monitor='val_loss', verbose=0, 280 | save_best_only=False, save_weights_only=False, 281 | mode='auto', period=1): 282 | super().__init__(filepath=filepath, monitor=monitor, verbose=verbose, 283 | save_weights_only=save_weights_only, 284 | save_best_only=save_best_only, 285 | mode=mode, period=period) 286 | self.model = model 287 | 288 | def set_model(self, model): 289 | pass 290 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2018] [manlier] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. --------------------------------------------------------------------------------