├── pkuseg ├── dicts │ ├── __init__.py │ └── default.pkl ├── models │ └── tagIndex.txt ├── gradient.py ├── postag │ ├── __init__.py │ ├── model.py │ └── feature_extractor.pyx ├── scorer.py ├── model.py ├── download.py ├── res_summarize.py ├── optimizer.py ├── config.py ├── inference.pyx ├── data.py ├── trainer.py ├── __init__.py └── feature_extractor.pyx ├── requirements.txt ├── readme ├── history.md ├── comparison.md ├── environment.md ├── multiprocess.md ├── interface.md └── readme_english.md ├── .gitignore ├── tags.txt ├── LICENSE ├── example.txt ├── setup.py └── README.md /pkuseg/dicts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | cython 3 | -------------------------------------------------------------------------------- /pkuseg/models/tagIndex.txt: -------------------------------------------------------------------------------- 1 | B 0 2 | B_single 1 3 | I 2 4 | I_end 3 5 | I_first 4 6 | -------------------------------------------------------------------------------- /readme/history.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/pkuseg-python/HEAD/readme/history.md -------------------------------------------------------------------------------- /readme/comparison.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/pkuseg-python/HEAD/readme/comparison.md -------------------------------------------------------------------------------- /pkuseg/dicts/default.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/pkuseg-python/HEAD/pkuseg/dicts/default.pkl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | *.cpp 3 | *.pyd 4 | *.html 5 | **/__pycache__/ 6 | *.egg-info 7 | *.txt* 8 | build 9 | tmp/ 10 | data/ 11 | models/ 12 | *stats -------------------------------------------------------------------------------- /tags.txt: -------------------------------------------------------------------------------- 1 | n 名词 2 | t 时间词 3 | s 处所词 4 | f 方位词 5 | m 数词 6 | q 量词 7 | b 区别词 8 | r 代词 9 | v 动词 10 | a 形容词 11 | z 状态词 12 | d 副词 13 | p 介词 14 | c 连词 15 | u 助词 16 | y 语气词 17 | e 叹词 18 | o 拟声词 19 | i 成语 20 | l 习惯用语 21 | j 简称 22 | h 前接成分 23 | k 后接成分 24 | g 语素 25 | x 非语素字 26 | w 标点符号 27 | nr 人名 28 | ns 地名 29 | nt 机构名称 30 | nx 外文字符 31 | nz 其它专名 32 | vd 副动词 33 | vn 名动词 34 | vx 形式动词 35 | ad 副形词 36 | an 名形词 37 | -------------------------------------------------------------------------------- /readme/environment.md: -------------------------------------------------------------------------------- 1 | # 实验环境 2 | 3 | 考虑到jieba分词和THULAC工具包等并没有提供细领域的预训练模型,为了便于比较,我们重新使用它们提供的训练接口在细领域的数据集上进行训练,用训练得到的模型进行中文分词。 4 | 5 | 我们选择Linux作为测试环境,在新闻数据(MSRA)、混合型文本(CTB8)、网络文本(WEIBO)数据上对不同工具包进行了准确率测试。我们使用了第二届国际汉语分词评测比赛提供的分词评价脚本。其中MSRA与WEIBO使用标准训练集测试集划分,CTB8采用随机划分。对于不同的分词工具包,训练测试数据的划分都是一致的;**即所有的分词工具包都在相同的训练集上训练,在相同的测试集上测试**。对于所有数据集,pkuseg使用了不使用词典的训练和测试接口。以下是pkuseg训练和测试代码示例: 6 | 7 | ``` 8 | pkuseg.train('msr_training.utf8', 'msr_test_gold.utf8', './models') 9 | pkuseg.test('msr_test.raw', 'output.txt', user_dict=None) 10 | ``` 11 | 12 | 13 | -------------------------------------------------------------------------------- /readme/multiprocess.md: -------------------------------------------------------------------------------- 1 | 2 | # 多进程分词 3 | 4 | 当将以上代码示例置于文件中运行时,如涉及多进程功能,请务必使用`if __name__ == '__main__'`保护全局语句,如: 5 | mp.py文件 6 | ```python3 7 | import pkuseg 8 | 9 | if __name__ == '__main__': 10 | pkuseg.test('input.txt', 'output.txt', nthread=20) 11 | pkuseg.train('msr_training.utf8', 'msr_test_gold.utf8', './models', nthread=20) 12 | ``` 13 | 运行 14 | ``` 15 | python3 mp.py 16 | ``` 17 | 详见[无法使用多进程分词和训练功能,提示RuntimeError和BrokenPipeError](https://github.com/lancopku/pkuseg-python/wiki#3-无法使用多进程分词和训练功能提示runtimeerror和brokenpipeerror)。 18 | 19 | **在Windows平台上,请当文件足够大时再使用多进程分词功能**,详见[关于多进程速度问题](https://github.com/lancopku/pkuseg-python/wiki#9-关于多进程速度问题)。 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018-2019 pkuseg authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /example.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 以下为使用细领域模型的切分结果。 4 | 在使用中,如果用户明确待分词的领域,可加载对应的模型进行分词。如果用户无法确定具体领域,推荐使用在混合领域上训练的通用模型。 5 | 6 | 7 | 8 | 医药领域分词示例: 9 | 10 | 医生 工具 通常 包括 病历 管理 、 药品 信息 查询 、 临床 指南 、 前沿 的 医学 资讯 。 11 | 医联 平台 : 包括 挂号 预约 查看 院内 信息 化验单 等 , 目前 出现 与 微信 、 支付宝 结合的 趋势 。 12 | 甲状腺功能减退症 简称 甲减 , 是 甲状腺 制造 的 甲状腺激素 过少 而 引发 的疾病 。 13 | 14 | 15 | 16 | 旅游领域分词示例: 17 | 18 | 在 这里 可以 俯瞰 维多利亚港 的 香港岛 , 九龙 半岛 两岸 , 美景 无敌 。 19 | 以往 去 香港 都 是 去 旺角 尖沙咀 中环 等等 闹市 地区 。 20 | 初 至 重庆 , 我 就 来到 了 洪崖洞 , 在 这里 , 旧时 城墙 、 吊脚楼 仿 若 镶嵌 在 现代 钢筋 水泥 城市 间 的 一 枚 朴玉 。 21 | 首都 机场 提供 了 手机 值机 、 自助 值机 、 自助 行李 托运 、 自助 通关 等 多种 便捷 举措 。 22 | 23 | 24 | 25 | 网络领域分词示例: 26 | 27 | 视频 中 , 胡可 负责 录制 , 沙溢 则 带 着 安吉 和 小鱼儿 坐在 沙发 上 唱 着 《 学猫 叫 》 , 小鱼儿 还 争 着 要 坐在 C位 , 一家人 其乐融融 28 | 【 这是 我 的 世界 , 你 还 未 见 过 】 欢迎 来 参加 我 的 演唱会 听点 音乐 29 | 被 全家 套路 的 小鱼儿 也 太 可怜 了 : 我 要求 C位 ! ! 我 不要 唱 “ 喵喵喵 ” 结果 七 秒 记忆 又 继续 唱 了 起来 哈哈 哈哈 哈哈 哈哈 30 | 31 | 32 | 33 | 新闻领域分词示例: 34 | 35 | 乌克兰 一直 想 加入 北约 , 并 不断 的 按照 西方 国家 的 要求 “ 改造 ” 自己 , 据 乌克兰 之 声 2月20日 报道 称 , 乌克兰 政府 正式 通过 最新 《 宪法 修正案 》 , 正式 确定 乌克兰 将 加入 北约 作为 重要 国家 方针 , 该 法 强调 , " 这项 法律 将 于 发布 次日 起 生效 " 。 36 | 美国广播公司 网站 2月20日 报道 称 , 特朗普 19日 在 推特 上 写 道 : “ 正如 我 预测 的 那样 , 主要 由 开放 边界 的 民主党 人和 激进 左派 主导 的 16个 州 已经 在 第九巡回法院 提起 诉讼 。 ” 他 不 忘 讽刺 加州 : “ 加州 已 在 失控 的 高铁 项目 上 浪费 了 数十亿美元 , 完全 没有 完成 的 希望 。 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /pkuseg/gradient.py: -------------------------------------------------------------------------------- 1 | import pkuseg.model 2 | from typing import List 3 | 4 | import pkuseg.inference as _inf 5 | import pkuseg.data 6 | 7 | 8 | def get_grad_SGD_minibatch( 9 | grad: List[float], model: pkuseg.model.Model, X: List[pkuseg.data.Example] 10 | ): 11 | # if idset is not None: 12 | # idset.clear() 13 | all_id_set = set() 14 | errors = 0 15 | for x in X: 16 | error, id_set = get_grad_CRF(grad, model, x) 17 | errors += error 18 | all_id_set.update(id_set) 19 | 20 | return errors, all_id_set 21 | 22 | 23 | def get_grad_CRF( 24 | grad: List[float], model: pkuseg.model.Model, x: pkuseg.data.Example 25 | ): 26 | 27 | id_set = set() 28 | 29 | n_tag = model.n_tag 30 | bel = _inf.belief(len(x), n_tag) 31 | belMasked = _inf.belief(len(x), n_tag) 32 | 33 | Ylist, YYlist, maskYlist, maskYYlist = _inf.getYYandY(model, x) 34 | Z, sum_edge = _inf.get_beliefs(bel, model, x, Ylist, YYlist) 35 | ZGold, sum_edge_masked = _inf.get_beliefs(belMasked, model, x, maskYlist, maskYYlist) 36 | 37 | for i, node_feature_list in enumerate(x.features): 38 | for feature_id in node_feature_list: 39 | trans_id = model._get_node_tag_feature_id(feature_id, 0) 40 | id_set.update(range(trans_id, trans_id + n_tag)) 41 | grad[trans_id:trans_id+n_tag] += bel.belState[i] - belMasked.belState[i] 42 | 43 | backoff = model.n_feature * n_tag 44 | grad[backoff:] += sum_edge - sum_edge_masked 45 | id_set.update(range(backoff, backoff + n_tag * n_tag)) 46 | 47 | return Z - ZGold, id_set 48 | -------------------------------------------------------------------------------- /pkuseg/postag/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | 4 | import os 5 | import time 6 | 7 | from ..inference import decodeViterbi_fast 8 | 9 | 10 | from .feature_extractor import FeatureExtractor 11 | from .model import Model 12 | 13 | class Postag: 14 | def __init__(self, model_name): 15 | modelDir = model_name 16 | self.feature_extractor = FeatureExtractor.load(modelDir) 17 | self.model = Model.load(modelDir) 18 | 19 | self.idx_to_tag = { 20 | idx: tag for tag, idx in self.feature_extractor.tag_to_idx.items() 21 | } 22 | 23 | self.n_feature = len(self.feature_extractor.feature_to_idx) 24 | self.n_tag = len(self.feature_extractor.tag_to_idx) 25 | 26 | # print("finish") 27 | 28 | def _cut(self, text): 29 | examples = list(self.feature_extractor.normalize_text(text)) 30 | length = len(examples) 31 | 32 | all_feature = [] # type: List[List[int]] 33 | for idx in range(length): 34 | node_feature_idx = self.feature_extractor.get_node_features_idx( 35 | idx, examples 36 | ) 37 | # node_feature = self.feature_extractor.get_node_features( 38 | # idx, examples 39 | # ) 40 | 41 | # node_feature_idx = [] 42 | # for feature in node_feature: 43 | # feature_idx = self.feature_extractor.feature_to_idx.get(feature) 44 | # if feature_idx is not None: 45 | # node_feature_idx.append(feature_idx) 46 | # if not node_feature_idx: 47 | # node_feature_idx.append(0) 48 | 49 | all_feature.append(node_feature_idx) 50 | 51 | _, tags = decodeViterbi_fast(all_feature, self.model) 52 | tags = list(map(lambda x:self.idx_to_tag[x], tags)) 53 | return tags 54 | 55 | def tag(self, sen): 56 | """txt: list[str], tags: list[str]""" 57 | tags = self._cut(sen) 58 | return tags 59 | 60 | -------------------------------------------------------------------------------- /readme/interface.md: -------------------------------------------------------------------------------- 1 | # 代码示例 2 | 3 | 以下代码示例适用于python交互式环境。 4 | 5 | 代码示例1:使用默认配置进行分词(**如果用户无法确定分词领域,推荐使用默认模型分词**) 6 | ```python3 7 | import pkuseg 8 | 9 | seg = pkuseg.pkuseg() # 以默认配置加载模型 10 | text = seg.cut('我爱北京天安门') # 进行分词 11 | print(text) 12 | ``` 13 | 14 | 代码示例2:细领域分词(**如果用户明确分词领域,推荐使用细领域模型分词**) 15 | ```python3 16 | import pkuseg 17 | 18 | seg = pkuseg.pkuseg(model_name='medicine') # 程序会自动下载所对应的细领域模型 19 | text = seg.cut('我爱北京天安门') # 进行分词 20 | print(text) 21 | ``` 22 | 23 | 代码示例3:分词同时进行词性标注,各词性标签的详细含义可参考 [tags.txt](https://github.com/lancopku/pkuseg-python/blob/master/tags.txt) 24 | ```python3 25 | import pkuseg 26 | 27 | seg = pkuseg.pkuseg(postag=True) # 开启词性标注功能 28 | text = seg.cut('我爱北京天安门') # 进行分词和词性标注 29 | print(text) 30 | ``` 31 | 32 | 33 | 代码示例4:对文件分词 34 | ```python3 35 | import pkuseg 36 | 37 | # 对input.txt的文件分词输出到output.txt中 38 | # 开20个进程 39 | pkuseg.test('input.txt', 'output.txt', nthread=20) 40 | ``` 41 | 42 | 43 | 代码示例5:额外使用用户自定义词典 44 | ```python3 45 | import pkuseg 46 | 47 | seg = pkuseg.pkuseg(user_dict='my_dict.txt') # 给定用户词典为当前目录下的"my_dict.txt" 48 | text = seg.cut('我爱北京天安门') # 进行分词 49 | print(text) 50 | ``` 51 | 52 | 53 | 代码示例6:使用自训练模型分词(以CTB8模型为例) 54 | ```python3 55 | import pkuseg 56 | 57 | seg = pkuseg.pkuseg(model_name='./ctb8') # 假设用户已经下载好了ctb8的模型并放在了'./ctb8'目录下,通过设置model_name加载该模型 58 | text = seg.cut('我爱北京天安门') # 进行分词 59 | print(text) 60 | ``` 61 | 62 | 63 | 64 | 代码示例7:训练新模型 (模型随机初始化) 65 | ```python3 66 | import pkuseg 67 | 68 | # 训练文件为'msr_training.utf8' 69 | # 测试文件为'msr_test_gold.utf8' 70 | # 训练好的模型存到'./models'目录下 71 | # 训练模式下会保存最后一轮模型作为最终模型 72 | # 目前仅支持utf-8编码,训练集和测试集要求所有单词以单个或多个空格分开 73 | pkuseg.train('msr_training.utf8', 'msr_test_gold.utf8', './models') 74 | ``` 75 | 76 | 77 | 代码示例8:fine-tune训练(从预加载的模型继续训练) 78 | ```python3 79 | import pkuseg 80 | 81 | # 训练文件为'train.txt' 82 | # 测试文件为'test.txt' 83 | # 加载'./pretrained'目录下的模型,训练好的模型保存在'./models',训练10轮 84 | pkuseg.train('train.txt', 'test.txt', './models', train_iter=10, init_model='./pretrained') 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | from distutils.extension import Extension 4 | 5 | import numpy as np 6 | 7 | def is_source_release(path): 8 | return os.path.exists(os.path.join(path, "PKG-INFO")) 9 | 10 | def setup_package(): 11 | root = os.path.abspath(os.path.dirname(__file__)) 12 | 13 | long_description = "pkuseg-python" 14 | 15 | extensions = [ 16 | Extension( 17 | "pkuseg.inference", 18 | ["pkuseg/inference.pyx"], 19 | include_dirs=[np.get_include()], 20 | language="c++" 21 | ), 22 | Extension( 23 | "pkuseg.feature_extractor", 24 | ["pkuseg/feature_extractor.pyx"], 25 | include_dirs=[np.get_include()], 26 | ), 27 | Extension( 28 | "pkuseg.postag.feature_extractor", 29 | ["pkuseg/postag/feature_extractor.pyx"], 30 | include_dirs=[np.get_include()], 31 | ), 32 | ] 33 | 34 | if not is_source_release(root): 35 | from Cython.Build import cythonize 36 | extensions = cythonize(extensions, annotate=True) 37 | 38 | 39 | setuptools.setup( 40 | name="pkuseg", 41 | version="0.0.25", 42 | author="Lanco", 43 | author_email="luoruixuan97@pku.edu.cn", 44 | description="A small package for Chinese word segmentation", 45 | long_description=long_description, 46 | long_description_content_type="text/markdown", 47 | url="https://github.com/lancopku/pkuseg-python", 48 | packages=setuptools.find_packages(), 49 | package_data={"": ["*.txt*", "*.pkl", "*.npz", "*.pyx", "*.pxd"]}, 50 | classifiers=[ 51 | "Programming Language :: Python :: 3", 52 | "License :: Other/Proprietary License", 53 | "Operating System :: OS Independent", 54 | ], 55 | install_requires=["cython", "numpy>=1.16.0"], 56 | setup_requires=["cython", "numpy>=1.16.0"], 57 | ext_modules=extensions, 58 | zip_safe=False, 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | setup_package() 64 | -------------------------------------------------------------------------------- /pkuseg/postag/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | class Model: 6 | def __init__(self, n_feature, n_tag): 7 | 8 | self.n_tag = n_tag 9 | self.n_feature = n_feature 10 | self.n_transition_feature = n_tag * (n_feature + n_tag) 11 | self.w = np.zeros(self.n_transition_feature) 12 | 13 | def _get_node_tag_feature_id(self, feature_id, tag_id): 14 | return feature_id * self.n_tag + tag_id 15 | 16 | def _get_tag_tag_feature_id(self, pre_tag_id, tag_id): 17 | return self.n_feature * self.n_tag + tag_id * self.n_tag + pre_tag_id 18 | 19 | @classmethod 20 | def load(cls, model_dir): 21 | model_path = os.path.join(model_dir, "weights.npz") 22 | if os.path.exists(model_path): 23 | npz = np.load(model_path) 24 | sizes = npz["sizes"] 25 | w = npz["w"] 26 | model = cls.__new__(cls) 27 | model.n_tag = int(sizes[0]) 28 | model.n_feature = int(sizes[1]) 29 | model.n_transition_feature = model.n_tag * ( 30 | model.n_feature + model.n_tag 31 | ) 32 | model.w = w 33 | assert model.w.shape[0] == model.n_transition_feature 34 | return model 35 | 36 | print( 37 | "WARNING: weights.npz does not exist, try loading using old format", 38 | file=sys.stderr, 39 | ) 40 | 41 | model_path = os.path.join(model_dir, "model.txt") 42 | with open(model_path, encoding="utf-8") as f: 43 | ary = f.readlines() 44 | 45 | model = cls.__new__(cls) 46 | model.n_tag = int(ary[0].strip()) 47 | wsize = int(ary[1].strip()) 48 | w = np.zeros(wsize) 49 | for i in range(2, wsize): 50 | w[i - 2] = float(ary[i].strip()) 51 | model.w = w 52 | model.n_feature = wsize // model.n_tag - model.n_tag 53 | model.n_transition_feature = wsize 54 | 55 | model.save(model_dir) 56 | return model 57 | 58 | @classmethod 59 | def new(cls, model, copy_weight=True): 60 | 61 | new_model = cls.__new__(cls) 62 | new_model.n_tag = model.n_tag 63 | if copy_weight: 64 | new_model.w = model.w.copy() 65 | else: 66 | new_model.w = np.zeros_like(model.w) 67 | new_model.n_feature = ( 68 | new_model.w.shape[0] // new_model.n_tag - new_model.n_tag 69 | ) 70 | new_model.n_transition_feature = new_model.w.shape[0] 71 | return new_model 72 | 73 | def save(self, model_dir): 74 | sizes = np.array([self.n_tag, self.n_feature]) 75 | np.savez( 76 | os.path.join(model_dir, "weights.npz"), sizes=sizes, w=self.w 77 | ) 78 | # np.save 79 | # with open(file, "w", encoding="utf-8") as f: 80 | # f.write("{}\n{}\n".format(self.n_tag, self.w.shape[0])) 81 | # for value in self.w: 82 | # f.write("{:.4f}\n".format(value)) 83 | -------------------------------------------------------------------------------- /pkuseg/scorer.py: -------------------------------------------------------------------------------- 1 | from pkuseg.config import Config 2 | 3 | 4 | def getFscore(goldTagList, resTagList, idx_to_chunk_tag): 5 | scoreList = [] 6 | assert len(resTagList) == len(goldTagList) 7 | getNewTagList(idx_to_chunk_tag, goldTagList) 8 | getNewTagList(idx_to_chunk_tag, resTagList) 9 | goldChunkList = getChunks(goldTagList) 10 | resChunkList = getChunks(resTagList) 11 | gold_chunk = 0 12 | res_chunk = 0 13 | correct_chunk = 0 14 | for i in range(len(goldChunkList)): 15 | res = resChunkList[i] 16 | gold = goldChunkList[i] 17 | resChunkAry = res.split(Config.comma) 18 | tmp = [] 19 | for t in resChunkAry: 20 | if len(t) > 0: 21 | tmp.append(t) 22 | resChunkAry = tmp 23 | goldChunkAry = gold.split(Config.comma) 24 | tmp = [] 25 | for t in goldChunkAry: 26 | if len(t) > 0: 27 | tmp.append(t) 28 | goldChunkAry = tmp 29 | gold_chunk += len(goldChunkAry) 30 | res_chunk += len(resChunkAry) 31 | goldChunkSet = set() 32 | for im in goldChunkAry: 33 | goldChunkSet.add(im) 34 | for im in resChunkAry: 35 | if im in goldChunkSet: 36 | correct_chunk += 1 37 | pre = correct_chunk / res_chunk * 100 38 | rec = correct_chunk / gold_chunk * 100 39 | f1 = 0 if correct_chunk == 0 else 2 * pre * rec / (pre + rec) 40 | scoreList.append(f1) 41 | scoreList.append(pre) 42 | scoreList.append(rec) 43 | infoList = [] 44 | infoList.append(gold_chunk) 45 | infoList.append(res_chunk) 46 | infoList.append(correct_chunk) 47 | return scoreList, infoList 48 | 49 | 50 | def getNewTagList(tagMap, tagList): 51 | tmpList = [] 52 | for im in tagList: 53 | tagAry = im.split(Config.comma) 54 | for i in range(len(tagAry)): 55 | if tagAry[i] == "": 56 | continue 57 | index = int(tagAry[i]) 58 | if not index in tagMap: 59 | raise Exception("Error") 60 | tagAry[i] = tagMap[index] 61 | newTags = ",".join(tagAry) 62 | tmpList.append(newTags) 63 | tagList.clear() 64 | for im in tmpList: 65 | tagList.append(im) 66 | 67 | 68 | def getChunks(tagList): 69 | tmpList = [] 70 | for im in tagList: 71 | tagAry = im.split(Config.comma) 72 | tmp = [] 73 | for t in tagAry: 74 | if t != "": 75 | tmp.append(t) 76 | tagAry = tmp 77 | chunks = "" 78 | for i in range(len(tagAry)): 79 | if tagAry[i].startswith("B"): 80 | pos = i 81 | length = 1 82 | ty = tagAry[i] 83 | for j in range(i + 1, len(tagAry)): 84 | if tagAry[j] == "I": 85 | length += 1 86 | else: 87 | break 88 | chunk = ty + "*" + str(length) + "*" + str(pos) 89 | chunks = chunks + chunk + "," 90 | tmpList.append(chunks) 91 | return tmpList 92 | -------------------------------------------------------------------------------- /pkuseg/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | from .config import config 6 | 7 | 8 | class Model: 9 | def __init__(self, n_feature, n_tag): 10 | 11 | self.n_tag = n_tag 12 | self.n_feature = n_feature 13 | self.n_transition_feature = n_tag * (n_feature + n_tag) 14 | if config.random: 15 | self.w = np.random.random(size=(self.n_transition_feature,)) * 2 - 1 16 | else: 17 | self.w = np.zeros(self.n_transition_feature) 18 | 19 | def expand(self, n_feature, n_tag): 20 | new_transition_feature = n_tag * (n_feature + n_tag) 21 | if config.random: 22 | new_w = np.random.random(size=(new_transition_feature,)) * 2 - 1 23 | else: 24 | new_w = np.zeros(new_transition_feature) 25 | n_node = self.n_tag * self.n_feature 26 | n_edge = self.n_tag * self.n_tag 27 | new_w[:n_node] = self.w[:n_node] 28 | new_w[-n_edge:] = self.w[-n_edge:] 29 | self.n_tag = n_tag 30 | self.n_feature = n_feature 31 | self.n_transition_feature = new_transition_feature 32 | self.w = new_w 33 | 34 | def _get_node_tag_feature_id(self, feature_id, tag_id): 35 | return feature_id * self.n_tag + tag_id 36 | 37 | def _get_tag_tag_feature_id(self, pre_tag_id, tag_id): 38 | return self.n_feature * self.n_tag + tag_id * self.n_tag + pre_tag_id 39 | 40 | @classmethod 41 | def load(cls, model_dir=None): 42 | if model_dir is None: 43 | model_dir = config.modelDir 44 | model_path = os.path.join(model_dir, "weights.npz") 45 | if os.path.exists(model_path): 46 | npz = np.load(model_path) 47 | sizes = npz["sizes"] 48 | w = npz["w"] 49 | model = cls.__new__(cls) 50 | model.n_tag = int(sizes[0]) 51 | model.n_feature = int(sizes[1]) 52 | model.n_transition_feature = model.n_tag * ( 53 | model.n_feature + model.n_tag 54 | ) 55 | model.w = w 56 | assert model.w.shape[0] == model.n_transition_feature 57 | return model 58 | 59 | print( 60 | "WARNING: weights.npz does not exist, try loading using old format", 61 | file=sys.stderr, 62 | ) 63 | 64 | model_path = os.path.join(model_dir, "model.txt") 65 | with open(model_path, encoding="utf-8") as f: 66 | ary = f.readlines() 67 | 68 | model = cls.__new__(cls) 69 | model.n_tag = int(ary[0].strip()) 70 | wsize = int(ary[1].strip()) 71 | w = np.zeros(wsize) 72 | for i in range(2, wsize): 73 | w[i - 2] = float(ary[i].strip()) 74 | model.w = w 75 | model.n_feature = wsize // model.n_tag - model.n_tag 76 | model.n_transition_feature = wsize 77 | 78 | model.save(model_dir) 79 | return model 80 | 81 | @classmethod 82 | def new(cls, model, copy_weight=True): 83 | 84 | new_model = cls.__new__(cls) 85 | new_model.n_tag = model.n_tag 86 | if copy_weight: 87 | new_model.w = model.w.copy() 88 | else: 89 | new_model.w = np.zeros_like(model.w) 90 | new_model.n_feature = ( 91 | new_model.w.shape[0] // new_model.n_tag - new_model.n_tag 92 | ) 93 | new_model.n_transition_feature = new_model.w.shape[0] 94 | return new_model 95 | 96 | def save(self, model_dir=None): 97 | if model_dir is None: 98 | model_dir = config.modelDir 99 | sizes = np.array([self.n_tag, self.n_feature]) 100 | np.savez( 101 | os.path.join(model_dir, "weights.npz"), sizes=sizes, w=self.w 102 | ) 103 | # np.save 104 | # with open(file, "w", encoding="utf-8") as f: 105 | # f.write("{}\n{}\n".format(self.n_tag, self.w.shape[0])) 106 | # for value in self.w: 107 | # f.write("{:.4f}\n".format(value)) 108 | -------------------------------------------------------------------------------- /pkuseg/download.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import hashlib 4 | import os 5 | import re 6 | import shutil 7 | import sys 8 | import tempfile 9 | import zipfile 10 | 11 | try: 12 | from requests.utils import urlparse 13 | from requests import get as urlopen 14 | requests_available = True 15 | except ImportError: 16 | requests_available = False 17 | if sys.version_info[0] == 2: 18 | from urlparse import urlparse # noqa f811 19 | from urllib2 import urlopen # noqa f811 20 | else: 21 | from urllib.request import urlopen 22 | from urllib.parse import urlparse 23 | try: 24 | from tqdm import tqdm 25 | except ImportError: 26 | tqdm = None # defined below 27 | 28 | HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') 29 | 30 | def download_model(url, model_dir, hash_prefix, progress=True): 31 | if not os.path.exists(model_dir): 32 | os.makedirs(model_dir) 33 | parts = urlparse(url) 34 | filename = os.path.basename(parts.path) 35 | cached_file = os.path.join(model_dir, filename) 36 | if not os.path.exists(cached_file): 37 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 38 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress) 39 | unzip_file(cached_file, os.path.join(model_dir, filename.split('.')[0])) 40 | 41 | 42 | def _download_url_to_file(url, dst, hash_prefix, progress): 43 | if requests_available: 44 | u = urlopen(url, stream=True, timeout=5) 45 | file_size = int(u.headers["Content-Length"]) 46 | u = u.raw 47 | else: 48 | u = urlopen(url, timeout=5) 49 | meta = u.info() 50 | if hasattr(meta, 'getheaders'): 51 | file_size = int(meta.getheaders("Content-Length")[0]) 52 | else: 53 | file_size = int(meta.get_all("Content-Length")[0]) 54 | 55 | f = tempfile.NamedTemporaryFile(delete=False) 56 | try: 57 | if hash_prefix is not None: 58 | sha256 = hashlib.sha256() 59 | with tqdm(total=file_size, disable=not progress) as pbar: 60 | while True: 61 | buffer = u.read(8192) 62 | if len(buffer) == 0: 63 | break 64 | f.write(buffer) 65 | if hash_prefix is not None: 66 | sha256.update(buffer) 67 | pbar.update(len(buffer)) 68 | 69 | f.close() 70 | if hash_prefix is not None: 71 | digest = sha256.hexdigest() 72 | if digest[:len(hash_prefix)] != hash_prefix: 73 | raise RuntimeError('invalid hash value (expected "{}", got "{}")' 74 | .format(hash_prefix, digest)) 75 | shutil.move(f.name, dst) 76 | finally: 77 | f.close() 78 | if os.path.exists(f.name): 79 | os.remove(f.name) 80 | 81 | 82 | if tqdm is None: 83 | # fake tqdm if it's not installed 84 | class tqdm(object): 85 | 86 | def __init__(self, total, disable=False): 87 | self.total = total 88 | self.disable = disable 89 | self.n = 0 90 | 91 | def update(self, n): 92 | if self.disable: 93 | return 94 | 95 | self.n += n 96 | sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) 97 | sys.stderr.flush() 98 | 99 | def __enter__(self): 100 | return self 101 | 102 | def __exit__(self, exc_type, exc_val, exc_tb): 103 | if self.disable: 104 | return 105 | 106 | sys.stderr.write('\n') 107 | 108 | def unzip_file(zip_name, target_dir): 109 | if not os.path.exists(target_dir): 110 | os.makedirs(target_dir) 111 | file_zip = zipfile.ZipFile(zip_name, 'r') 112 | for file in file_zip.namelist(): 113 | file_zip.extract(file, target_dir) 114 | file_zip.close() 115 | 116 | if __name__ == '__main__': 117 | url = 'https://github.com/lancopku/pkuseg-python/releases/download/v0.0.14/mixed.zip' 118 | download_model(url, '.') 119 | 120 | -------------------------------------------------------------------------------- /pkuseg/res_summarize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .config import Config 3 | import os 4 | 5 | 6 | def tomatrix(s): 7 | lines = s.split(Config.lineEnd) 8 | lst = [] 9 | for line in lines: 10 | if line == "": 11 | continue 12 | if not line.startswith("%"): 13 | tmp = [] 14 | for i in line.split(Config.comma): 15 | tmp.append(float(i)) 16 | lst.append(tmp) 17 | return np.array(lst) 18 | 19 | 20 | def summarize(config): 21 | with open( 22 | os.path.join(config.outDir, config.fResRaw), encoding="utf-8" 23 | ) as sr: 24 | txt = sr.read() 25 | txt = txt.replace("\r", "") 26 | regions = txt.split(config.triLineEnd) 27 | 28 | with open( 29 | os.path.join(config.outDir, config.fResSum), "w", encoding="utf-8" 30 | ) as sw: 31 | for region in regions: 32 | if region == "": 33 | continue 34 | 35 | blocks = region.split(config.biLineEnd) 36 | mList = [] 37 | for im in blocks: 38 | mList.append(tomatrix(im)) 39 | 40 | avgM = np.zeros_like(mList[0]) 41 | for m in mList: 42 | avgM = avgM + m 43 | avgM = avgM / len(mList) 44 | 45 | sqravgM = np.zeros_like(mList[0]) 46 | for m in mList: 47 | sqravgM += m * m 48 | sqravgM = sqravgM / len(mList) 49 | 50 | deviM = (sqravgM - avgM * avgM) ** 0.5 51 | 52 | sw.write("%averaged values:\n") 53 | for i in range(avgM.shape[0]): 54 | for j in range(avgM.shape[1]): 55 | sw.write("{:.2f},".format(avgM[i, j])) 56 | sw.write("\n") 57 | 58 | sw.write("\n%deviations:\n") 59 | for i in range(deviM.shape[0]): 60 | for j in range(deviM.shape[1]): 61 | sw.write("{:.2f},".format(deviM[i, j])) 62 | # sw.write(("%.2f" % deviM[i, j]) + ",") 63 | sw.write("\n") 64 | 65 | sw.write("\n%avg & devi:\n") 66 | for i in range(avgM.shape[0]): 67 | for j in range(avgM.shape[1]): 68 | sw.write("{:.2f}+-{:,2f},".format(avgM[i, j], deviM[i, j])) 69 | sw.write("\n") 70 | 71 | sw.write("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n\n\n") 72 | 73 | 74 | def write(config, timeList, errList, diffList, scoreListList): 75 | def log(message): 76 | if config.rawResWrite: 77 | config.swResRaw.write(message) 78 | 79 | log("% training results:" + config.metric + "\n") 80 | for i in range(config.ttlIter): 81 | it = i 82 | log("% iter#={} ".format(it)) 83 | lst = scoreListList[i] 84 | if config.evalMetric == "f1": 85 | log( 86 | "% f-score={:.2f}% precision={:.2f}% recall={:.2f}% ".format( 87 | lst[0], lst[1], lst[2] 88 | ) 89 | ) 90 | else: 91 | log("% {}={:.2f}% ".format(config.metric, lst[0])) 92 | time = 0 93 | for k in range(i + 1): 94 | time += timeList[k] 95 | log( 96 | "cumulative-time(sec)={:.2f} objective={:.2f} diff={:.2f}\n".format( 97 | time, errList[i], diffList[i] 98 | ) 99 | ) 100 | 101 | # #ttlScore = 0 102 | # for i in range(config.ttlIter): 103 | # it = i + 1 104 | # log("% iter#={} ".format(it)) 105 | # lst = scoreListList[i] 106 | # # ttlScore += lst[0] 107 | # if config.evalMetric == "f1": 108 | # log( 109 | # "% f-score={:.2f}% precision={:.2f}% recall={:.2f}% ".format( 110 | # lst[0], lst[1], lst[2] 111 | # ) 112 | # ) 113 | # else: 114 | # log("% {}={:.2f}% ".format(config.metric, lst[0])) 115 | # time = 0 116 | # for k in range(i + 1): 117 | # time += timeList[k] 118 | # log( 119 | # "cumulative-time(sec)={:.2f} objective={:.2f} diff={:.2f}\n".format( 120 | # time, errList[i], diffList[i] 121 | # ) 122 | # ) 123 | -------------------------------------------------------------------------------- /pkuseg/optimizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import pkuseg.gradient as _grad 5 | 6 | # from pkuseg.config import config 7 | 8 | 9 | class Optimizer: 10 | def __init__(self): 11 | self._preVals = [] 12 | 13 | def converge_test(self, err): 14 | val = 1e100 15 | if len(self._preVals) > 1: 16 | prevVal = self._preVals[0] 17 | if len(self._preVals) == 10: 18 | self._preVals.pop(0) 19 | avgImprovement = (prevVal - err) / len(self._preVals) 20 | relAvg = avgImprovement / abs(err) 21 | val = relAvg 22 | self._preVals.append(err) 23 | return val 24 | 25 | def optimize(self): 26 | raise NotImplementedError() 27 | 28 | 29 | class ADF(Optimizer): 30 | def __init__(self, config, dataset, model): 31 | 32 | super().__init__() 33 | 34 | self.config = config 35 | 36 | self._model = model 37 | self._X = dataset 38 | self.decayList = np.ones_like(self._model.w) * config.rate0 39 | 40 | def optimize(self): 41 | config = self.config 42 | sample_size = 0 43 | w = self._model.w 44 | fsize = w.shape[0] 45 | xsize = len(self._X) 46 | grad = np.zeros(fsize) 47 | error = 0 48 | 49 | feature_count_list = np.zeros(fsize) 50 | # feature_count_list = [0] * fsize 51 | ri = list(range(xsize)) 52 | random.shuffle(ri) 53 | 54 | update_interval = xsize // config.nUpdate 55 | 56 | # config.interval = xsize // config.nUpdate 57 | n_sample = 0 58 | for t in range(0, xsize, config.miniBatch): 59 | XX = [] 60 | end = False 61 | for k in range(t, t + config.miniBatch): 62 | i = ri[k] 63 | x = self._X[i] 64 | XX.append(x) 65 | if k == xsize - 1: 66 | end = True 67 | break 68 | mb_size = len(XX) 69 | n_sample += mb_size 70 | 71 | # fSet = set() 72 | 73 | err, feature_set = _grad.get_grad_SGD_minibatch( 74 | grad, self._model, XX 75 | ) 76 | error += err 77 | 78 | feature_set = list(feature_set) 79 | 80 | feature_count_list[feature_set] += 1 81 | 82 | # for i in feature_set: 83 | # feature_count_list[i] += 1 84 | check = False 85 | 86 | for k in range(t, t + config.miniBatch): 87 | if t != 0 and k % update_interval == 0: 88 | check = True 89 | 90 | # update decay rates 91 | if check or end: 92 | 93 | self.decayList *= ( 94 | config.upper 95 | - (config.upper - config.lower) 96 | * feature_count_list 97 | / n_sample 98 | ) 99 | feature_count_list.fill(0) 100 | 101 | # for i in range(fsize): 102 | # v = feature_count_list[i] 103 | # u = v / n_sample 104 | # eta = config.upper - (config.upper - config.lower) * u 105 | # self.decayList[i] *= eta 106 | # feature_count_list 107 | # for i in range(len(feature_count_list)): 108 | # feature_count_list[i] = 0 109 | # update weights 110 | 111 | w[feature_set] -= self.decayList[feature_set] * grad[feature_set] 112 | grad[feature_set] = 0 113 | # for i in feature_set: 114 | # w[i] -= self.decayList[i] * grad[i] 115 | # grad[i] = 0 116 | # reg 117 | if check or end: 118 | if config.reg != 0: 119 | w -= self.decayList * ( 120 | w / (config.reg * config.reg) * n_sample / xsize 121 | ) 122 | 123 | # for i in range(fsize): 124 | # grad_i = ( 125 | # w[i] / (config.reg * config.reg) * (n_sample / xsize) 126 | # ) 127 | # w[i] -= self.decayList[i] * grad_i 128 | n_sample = 0 129 | sample_size += mb_size 130 | if config.reg != 0: 131 | s = (w * w).sum() 132 | error += s / (2.0 * config.reg * config.reg) 133 | diff = self.converge_test(error) 134 | return error, sample_size, diff 135 | -------------------------------------------------------------------------------- /pkuseg/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | 5 | class Config: 6 | lineEnd = "\n" 7 | biLineEnd = "\n\n" 8 | triLineEnd = "\n\n\n" 9 | undrln = "_" 10 | blank = " " 11 | tab = "\t" 12 | star = "*" 13 | slash = "/" 14 | comma = "," 15 | delimInFeature = "." 16 | B = "B" 17 | num = "0123456789.几二三四五六七八九十千万亿兆零1234567890%" 18 | letter = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz/・-" 19 | mark = "*" 20 | model_urls = { 21 | "postag": "https://github.com/lancopku/pkuseg-python/releases/download/v0.0.16/postag.zip", 22 | "medicine": "https://github.com/lancopku/pkuseg-python/releases/download/v0.0.16/medicine.zip", 23 | "tourism": "https://github.com/lancopku/pkuseg-python/releases/download/v0.0.16/tourism.zip", 24 | "news": "https://github.com/lancopku/pkuseg-python/releases/download/v0.0.16/news.zip", 25 | "web": "https://github.com/lancopku/pkuseg-python/releases/download/v0.0.16/web.zip", 26 | } 27 | model_hash = { 28 | "postag": "afdf15f4e39bc47a39be4c37e3761b0c8f6ad1783f3cd3aff52984aebc0a1da9", 29 | "medicine": "773d655713acd27dd1ea9f97d91349cc1b6aa2fc5b158cd742dc924e6f239dfc", 30 | "tourism": "1c84a0366fe6fda73eda93e2f31fd399923b2f5df2818603f426a200b05cbce9", 31 | "news": "18188b68e76b06fc437ec91edf8883a537fe25fa606641534f6f004d2f9a2e42", 32 | "web": "4867f5817f187246889f4db259298c3fcee07c0b03a2d09444155b28c366579e", 33 | } 34 | available_models = ["default", "medicine", "tourism", "web", "news"] 35 | models_with_dict = ["medicine", "tourism"] 36 | 37 | 38 | def __init__(self): 39 | # main setting 40 | self.pkuseg_home = os.path.expanduser(os.getenv('PKUSEG_HOME', '~/.pkuseg')) 41 | self.trainFile = os.path.join("data", "small_training.utf8") 42 | self.testFile = os.path.join("data", "small_test.utf8") 43 | self._tmp_dir = tempfile.TemporaryDirectory() 44 | self.homepath = self._tmp_dir.name 45 | self.tempFile = os.path.join(self.homepath, ".pkuseg", "temp") 46 | self.readFile = os.path.join("data", "small_test.utf8") 47 | self.outputFile = os.path.join("data", "small_test_output.utf8") 48 | 49 | self.modelOptimizer = "crf.adf" 50 | self.rate0 = 0.05 # init value of decay rate in SGD and ADF training 51 | # self.reg = 1 52 | # self.regs = [1] 53 | # self.regList = self.regs.copy() 54 | self.random = ( 55 | 0 56 | ) # 0 for 0-initialization of model weights, 1 for random init of model weights 57 | self.evalMetric = ( 58 | "f1" 59 | ) # tok.acc (token accuracy), str.acc (string accuracy), f1 (F1-score) 60 | self.trainSizeScale = 1 # for scaling the size of training data 61 | self.ttlIter = 20 # of training iterations 62 | self.nUpdate = 10 # for ADF training 63 | self.outFolder = os.path.join(self.tempFile, "output") 64 | self.save = 1 # save model file 65 | self.rawResWrite = True 66 | self.miniBatch = 1 # mini-batch in stochastic training 67 | self.nThread = 10 # number of processes 68 | # ADF training 69 | self.upper = 0.995 # was tuned for nUpdate = 10 70 | self.lower = 0.6 # was tuned for nUpdate = 10 71 | 72 | # global variables 73 | self.metric = None 74 | self.reg = 1 75 | self.outDir = self.outFolder 76 | self.testrawDir = "rawinputs/" 77 | self.testinputDir = "inputs/" 78 | self.tempDir = os.path.join(self.homepath, ".pkuseg", "temp") 79 | self.testoutputDir = "entityoutputs/" 80 | 81 | # self.GL_init = True 82 | self.weightRegMode = "L2" # choosing weight regularizer: L2, L1) 83 | 84 | self.c_train = os.path.join(self.tempFile, "train.conll.txt") 85 | self.f_train = os.path.join(self.tempFile, "train.feat.txt") 86 | 87 | self.c_test = os.path.join(self.tempFile, "test.conll.txt") 88 | self.f_test = os.path.join(self.tempFile, "test.feat.txt") 89 | 90 | self.fTune = "tune.txt" 91 | self.fLog = "trainLog.txt" 92 | self.fResSum = "summarizeResult.txt" 93 | self.fResRaw = "rawResult.txt" 94 | self.fOutput = "outputTag-{}.txt" 95 | 96 | self.fFeatureTrain = os.path.join(self.tempFile, "ftrain.txt") 97 | self.fGoldTrain = os.path.join(self.tempFile, "gtrain.txt") 98 | self.fFeatureTest = os.path.join(self.tempFile, "ftest.txt") 99 | self.fGoldTest = os.path.join(self.tempFile, "gtest.txt") 100 | 101 | self.modelDir = os.path.join( 102 | os.path.dirname(os.path.realpath(__file__)), "models", "ctb8" 103 | ) 104 | self.fModel = os.path.join(self.modelDir, "model.txt") 105 | 106 | # feature 107 | self.numLetterNorm = True 108 | self.featureTrim = 0 109 | self.wordFeature = True 110 | self.wordMax = 6 111 | self.wordMin = 2 112 | self.nLabel = 5 113 | self.order = 1 114 | 115 | def globalCheck(self): 116 | if self.evalMetric == "f1": 117 | self.metric = "f-score" 118 | elif self.evalMetric == "tok.acc": 119 | self.metric = "token-accuracy" 120 | elif self.evalMetric == "str.acc": 121 | self.metric = "string-accuracy" 122 | else: 123 | raise Exception("invalid eval metric") 124 | assert self.rate0 > 0 125 | assert self.trainSizeScale > 0 126 | assert self.ttlIter > 0 127 | assert self.nUpdate > 0 128 | assert self.miniBatch > 0 129 | assert self.reg > 0 130 | 131 | 132 | config = Config() 133 | -------------------------------------------------------------------------------- /pkuseg/inference.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # cython: infer_types=True 3 | # cython: language_level=3 4 | cimport cython 5 | import numpy as np 6 | cimport numpy as np 7 | 8 | 9 | from libcpp.vector cimport vector 10 | from libc.math cimport exp, log 11 | 12 | 13 | np.import_array() 14 | 15 | 16 | class belief: 17 | def __init__(self, nNodes, nStates): 18 | self.belState = np.zeros((nNodes, nStates)) 19 | self.belEdge = np.zeros((nNodes, nStates * nStates)) 20 | self.Z = 0 21 | 22 | 23 | @cython.boundscheck(False) 24 | @cython.wraparound(False) 25 | cpdef get_beliefs(object bel, object m, object x, np.ndarray[double, ndim=2] Y, np.ndarray[double, ndim=2] YY): 26 | cdef: 27 | np.ndarray[double, ndim=2] belState = bel.belState 28 | np.ndarray[double, ndim=2] belEdge = bel.belEdge 29 | int nNodes = len(x) 30 | int nTag = m.n_tag 31 | double Z = 0 32 | np.ndarray[double, ndim=1] alpha_Y = np.zeros(nTag) 33 | np.ndarray[double, ndim=1] newAlpha_Y = np.zeros(nTag) 34 | np.ndarray[double, ndim=1] tmp_Y = np.zeros(nTag) 35 | np.ndarray[double, ndim=2] YY_trans = YY.transpose() 36 | np.ndarray[double, ndim=1] YY_t_r = YY_trans.reshape(-1) 37 | np.ndarray[double, ndim=1] sum_edge = np.zeros(nTag * nTag) 38 | 39 | for i in range(nNodes - 1, 0, -1): 40 | tmp_Y = belState[i] + Y[i] 41 | belState[i-1] = logMultiply(YY, tmp_Y) 42 | 43 | for i in range(nNodes): 44 | if i > 0: 45 | tmp_Y = alpha_Y.copy() 46 | newAlpha_Y = logMultiply(YY_trans, tmp_Y) + Y[i] 47 | else: 48 | newAlpha_Y = Y[i].copy() 49 | if i > 0: 50 | tmp_Y = Y[i] + belState[i] 51 | belEdge[i] = YY_t_r 52 | for yPre in range(nTag): 53 | for y in range(nTag): 54 | belEdge[i, y * nTag + yPre] += tmp_Y[y] + alpha_Y[yPre] 55 | belState[i] = belState[i] + newAlpha_Y 56 | alpha_Y = newAlpha_Y 57 | Z = logSum(alpha_Y) 58 | for i in range(nNodes): 59 | belState[i] = np.exp(belState[i] - Z) 60 | for i in range(1, nNodes): 61 | sum_edge += np.exp(belEdge[i] - Z) 62 | return Z, sum_edge 63 | 64 | 65 | 66 | @cython.boundscheck(False) 67 | @cython.wraparound(False) 68 | cpdef run_viterbi(np.ndarray[double, ndim=2] node_score, np.ndarray[double, ndim=2] edge_score): 69 | cdef int i, y, y_pre, i_pre, tag, w=node_score.shape[0], h=node_score.shape[1] 70 | cdef double ma, sc 71 | cdef np.ndarray[double, ndim=2] max_score = np.zeros((w, h), dtype=np.float64) 72 | cdef np.ndarray[int, ndim=2] pre_tag = np.zeros((w, h), dtype=np.int32) 73 | cdef np.ndarray[unsigned char, ndim=2] init_check = np.zeros((w, h), dtype=np.uint8) 74 | cdef np.ndarray[int, ndim=1] states = np.zeros(w, dtype=np.int32) 75 | for y in range(h): 76 | max_score[w-1, y] = node_score[w-1, y] 77 | for i in range(w - 2, -1, -1): 78 | for y in range(h): 79 | for y_pre in range(h): 80 | i_pre = i + 1 81 | sc = max_score[i_pre, y_pre] + node_score[i, y] + edge_score[y, y_pre] 82 | if not init_check[i, y]: 83 | init_check[i, y] = 1 84 | max_score[i, y] = sc 85 | pre_tag[i, y] = y_pre 86 | elif sc >= max_score[i, y]: 87 | max_score[i, y] = sc 88 | pre_tag[i, y] = y_pre 89 | ma = max_score[0, 0] 90 | tag = 0 91 | for y in range(1, h): 92 | sc = max_score[0, y] 93 | if ma < sc: 94 | ma = sc 95 | tag = y 96 | states[0] = tag 97 | for i in range(1, w): 98 | tag = pre_tag[i-1, tag] 99 | states[i] = tag 100 | if ma > 300: 101 | ma = 300 102 | return exp(ma), states 103 | 104 | @cython.boundscheck(False) 105 | @cython.wraparound(False) 106 | cpdef getLogYY(vector[vector[int]] feature_temp, int num_tag, int backoff, np.ndarray[double, ndim=1] w, double scalar): 107 | cdef: 108 | int num_node = feature_temp.size() 109 | np.ndarray[double, ndim=2] node_score = np.zeros((num_node, num_tag), dtype=np.float64) 110 | np.ndarray[double, ndim=2] edge_score = np.ones((num_tag, num_tag), dtype=np.float64) 111 | int s, s_pre, i 112 | double maskValue, tmp 113 | vector[int] f_list 114 | int f, ft 115 | for i in range(num_node): 116 | f_list = feature_temp[i] 117 | for ft in f_list: 118 | for s in range(num_tag): 119 | f = ft * num_tag + s 120 | node_score[i, s] += w[f] * scalar 121 | for s in range(num_tag): 122 | for s_pre in range(num_tag): 123 | f = backoff + s * num_tag + s_pre 124 | edge_score[s_pre, s] += w[f] * scalar 125 | return node_score, edge_score 126 | 127 | @cython.boundscheck(False) 128 | @cython.wraparound(False) 129 | cpdef maskY(object tags, int nNodes, int nTag, np.ndarray[double, ndim=2] Y): 130 | cdef np.ndarray[double, ndim=2] mask_Yi = Y.copy() 131 | cdef double maskValue = -1e100 132 | cdef list tagList = tags 133 | cdef int i 134 | for i in range(nNodes): 135 | for s in range(nTag): 136 | if tagList[i] != s: 137 | mask_Yi[i, s] = maskValue 138 | return mask_Yi 139 | 140 | @cython.boundscheck(False) 141 | @cython.wraparound(False) 142 | cdef logMultiply(np.ndarray[double, ndim=2] A, np.ndarray[double, ndim=1] B): 143 | cdef int r, c 144 | cdef np.ndarray[double, ndim=2] toSumLists = np.zeros_like(A) 145 | cdef np.ndarray[double, ndim=1] ret = np.zeros(A.shape[0]) 146 | for r in range(A.shape[0]): 147 | for c in range(A.shape[1]): 148 | toSumLists[r, c] = A[r, c] + B[c] 149 | for r in range(A.shape[0]): 150 | ret[r] = logSum(toSumLists[r]) 151 | return ret 152 | 153 | @cython.boundscheck(False) 154 | @cython.wraparound(False) 155 | cdef logSum(double[:] a): 156 | cdef int n = a.shape[0] 157 | cdef double s = a[0] 158 | cdef double m1 159 | cdef double m2 160 | for i in range(1, n): 161 | if s >= a[i]: 162 | m1, m2 = s, a[i] 163 | else: 164 | m1, m2 = a[i], s 165 | s = m1 + log(1 + exp(m2 - m1)) 166 | return s 167 | 168 | 169 | def decodeViterbi_fast(feature_temp, model): 170 | Y, YY = getLogYY(feature_temp, model.n_tag, model.n_feature*model.n_tag, model.w, 1.0) 171 | numer, tags = run_viterbi(Y, YY) 172 | tags = list(tags) 173 | return numer, tags 174 | 175 | 176 | def getYYandY(model, example): 177 | Y, YY = getLogYY(example.features, model.n_tag, model.n_feature*model.n_tag, model.w, 1.0) 178 | mask_Y = maskY(example.tags, len(example), model.n_tag, Y) 179 | mask_YY = YY 180 | return Y, YY, mask_Y, mask_YY 181 | 182 | -------------------------------------------------------------------------------- /readme/readme_english.md: -------------------------------------------------------------------------------- 1 | 2 | # Pkuseg 3 | 4 | A multi-domain Chinese word segmentation toolkit. 5 | 6 | ## Highlights 7 | 8 | The pkuseg-python toolkit has the following features: 9 | 10 | 1. Supporting multi-domain Chinese word segmentation. Pkuseg-python supports multi-domain segmentation, including domains like news, web, medicine, and tourism. Users are free to choose different pre-trained models according to the domain features of the text to be segmented. If not sure the domain of the text, users are recommended to use the default model trained on mixed-domain data. 11 | 12 | 2. Higher word segmentation results. Compared with existing word segmentation toolkits, pkuseg-python can achieve higher F1 scores on the same dataset. 13 | 14 | 3. Supporting model training. Pkuseg-python also supports users to train a new segmentation model with their own data. 15 | 16 | 4. Supporting POS tagging. We also provide users POS tagging interfaces for further lexical analysis. 17 | 18 | 19 | 20 | ## Installation 21 | 22 | - Requirements: python3 23 | 24 | 1. Install pkuseg-python by using PyPI: (with the default model trained on mixed-doimain data) 25 | ``` 26 | pip3 install pkuseg 27 | ``` 28 | or update to the latest version (**suggested**): 29 | ``` 30 | pip3 install -U pkuseg 31 | ``` 32 | 2. Install pkuseg-python by using image source for fast speed: 33 | ``` 34 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple pkuseg 35 | ``` 36 | or update to the latest version (**suggested**): 37 | ``` 38 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple -U pkuseg 39 | ``` 40 | Note: The previous two installing commands only support python3.5, python3.6, python3.7 on linux, mac, and **windows 64 bit**. 41 | 3. If the code is downloaded from GitHub, please run the following command to install pkuseg-python: 42 | ``` 43 | python setup.py build_ext -i 44 | ``` 45 | 46 | Note: the github code does not contain the pre-trained models, users need to download the pre-trained models from [release](https://github.com/lancopku/pkuseg-python/releases), and set parameter 'model_name' as the model path. 47 | 48 | 49 | 50 | 51 | ## Usage 52 | 53 | #### Examples 54 | 55 | 56 | Example 1: Segmentation under the default configuration. **If users are not sure the domain of the text to be segmented, the default configuration is recommended.** 57 | ```python3 58 | import pkuseg 59 | 60 | seg = pkuseg.pkuseg() #load the default model 61 | text = seg.cut('我爱北京天安门') 62 | print(text) 63 | ``` 64 | 65 | Example 2: Domain-specific segmentation. **If users know the text domain, they can select a pre-trained domain model according to the domain features.** 66 | 67 | ```python3 68 | import pkuseg 69 | seg = pkuseg.pkuseg(model_name='medicine') 70 | #Automatically download the domain-specific model. 71 | text = seg.cut('我爱北京天安门') 72 | print(text) 73 | ``` 74 | 75 | Example 3:Segmentation and POS tagging. For the detailed meaning of each POS tag, please refer to [tags.txt](https://github.com/lancopku/pkuseg-python/blob/master/tags.txt). 76 | ```python3 77 | import pkuseg 78 | 79 | seg = pkuseg.pkuseg(postag=True) 80 | text = seg.cut('我爱北京天安门') 81 | print(text) 82 | ``` 83 | 84 | 85 | Example 4:Segmentation with a text file as input. 86 | ```python3 87 | import pkuseg 88 | 89 | #Take file 'input.txt' as input. 90 | #The segmented result is stored in file 'output.txt'. 91 | pkuseg.test('input.txt', 'output.txt', nthread=20) 92 | ``` 93 | 94 | 95 | Example 5: Segmentation with a user-defined dictionary. 96 | ```python3 97 | import pkuseg 98 | 99 | seg = pkuseg.pkuseg(user_dict='my_dict.txt') 100 | text = seg.cut('我爱北京天安门') 101 | print(text) 102 | ``` 103 | 104 | 105 | Example 6: Segmentation with a user-trained model. Take CTB8 as an example. 106 | ```python3 107 | import pkuseg 108 | 109 | seg = pkuseg.pkuseg(model_name='./ctb8') 110 | text = seg.cut('我爱北京天安门') 111 | print(text) 112 | ``` 113 | 114 | 115 | 116 | Example 7: Training a new model (randomly initialized). 117 | 118 | ```python3 119 | import pkuseg 120 | 121 | # Training file: 'msr_training.utf8'. 122 | # Test file: 'msr_test_gold.utf8'. 123 | # Save the trained model to './models'. 124 | # The training and test files are in utf-8 encoding. 125 | pkuseg.train('msr_training.utf8', 'msr_test_gold.utf8', './models') 126 | ``` 127 | 128 | Example 8: Fine-tuning. Take a pre-trained model as input. 129 | ```python3 130 | import pkuseg 131 | 132 | # Training file: 'train.txt'. 133 | # Testing file'test.txt'. 134 | # The path of the pre-trained model: './pretrained'. 135 | # Save the trained model to './models'. 136 | # The training and test files are in utf-8 encoding. 137 | pkuseg.train('train.txt', 'test.txt', './models', train_iter=10, init_model='./pretrained') 138 | ``` 139 | 140 | 141 | 142 | #### Parameter Settings 143 | 144 | Segmentation for sentences. 145 | ``` 146 | pkuseg.pkuseg(model_name = "default", user_dict = "default", postag = False) 147 | model_name The path of the used model. 148 | "default". The default mixed-domain model. 149 | "news". The model trained on news domain data. 150 | "web". The model trained on web domain data. 151 | "medicine". The model trained on medicine domain data. 152 | "tourism". The model trained on tourism domain data. 153 | model_path. Load a model from the user-specified path. 154 | user_dict Set up the user dictionary. 155 | "default". Use the default dictionary. 156 | None. No dictionary is used. 157 | dict_path. The path of the user-defined dictionary. Each line only contains one word. 158 | postag POS tagging or not. 159 | False. The default setting. Segmentation without POS tagging. 160 | True. Segmentation with POS tagging. 161 | ``` 162 | 163 | Segmentation for documents. 164 | 165 | ``` 166 | pkuseg.test(readFile, outputFile, model_name = "default", user_dict = "default", postag = False, nthread = 10) 167 | readFile The path of the input file. 168 | outputFile The path of the output file. 169 | model_name The path of the used model. Refer to pkuseg.pkuseg. 170 | user_dict The path of the user dictionary. Refer to pkuseg.pkuseg. 171 | postag POS tagging or not. Refer to pkuseg.pkuseg. 172 | nthread The number of threads. 173 | ``` 174 | 175 | Model training. 176 | ``` 177 | pkuseg.train(trainFile, testFile, savedir, train_iter = 20, init_model = None) 178 | trainFile The path of the training file. 179 | testFile The path of the test file. 180 | savedir The saved path of the trained model. 181 | train_iter The maximum number of training epochs. 182 | init_model By default, None means random initialization. Users can also load a pre-trained model as initialization, like init_model='./models/'. 183 | ``` 184 | 185 | 186 | ## Publication 187 | 188 | The toolkit is mainly based on the following publication. If you use the toolkit, please cite the paper: 189 | * Ruixuan Luo, Jingjing Xu, Yi Zhang, Zhiyuan Zhang, Xuancheng Ren, Xu Sun. [PKUSEG: A Toolkit for Multi-Domain Chinese Word Segmentation](https://arxiv.org/abs/1906.11455). Arxiv. 2019. 190 | 191 | ``` 192 | 193 | @article{pkuseg, 194 | author = {Luo, Ruixuan and Xu, Jingjing and Zhang, Yi and Zhang, Zhiyuan and Ren, Xuancheng and Sun, Xu}, 195 | journal = {CoRR}, 196 | title = {PKUSEG: A Toolkit for Multi-Domain Chinese Word Segmentation.}, 197 | url = {https://arxiv.org/abs/1906.11455}, 198 | volume = {abs/1906.11455}, 199 | year = 2019 200 | } 201 | ``` 202 | 203 | ## Related Work 204 | 205 | * Xu Sun, Houfeng Wang, Wenjie Li. Fast Online Training with Frequency-Adaptive Learning Rates for Chinese Word Segmentation and New Word Detection. ACL. 2012. 206 | * Jingjing Xu and Xu Sun. Dependency-based gated recursive neural network for chinese word segmentation. ACL. 2016. 207 | * Jingjing Xu and Xu Sun. Transfer learning for low-resource chinese word segmentation with a novel neural network. NLPCC. 2017. 208 | 209 | 210 | ## Authors 211 | 212 | Ruixuan Luo, Jingjing Xu, Xuancheng Ren, Yi Zhang, Zhiyuan Zhang, Bingzhen Wei, Xu Sun 213 | 214 | [Language Computing and Machine Learning Group](http://lanco.pku.edu.cn/), Peking University 215 | 216 | 217 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pkuseg:一个多领域中文分词工具包 [**(English Version)**](readme/readme_english.md) 2 | 3 | pkuseg 是基于论文[[Luo et. al, 2019](#论文引用)]的工具包。其简单易用,支持细分领域分词,有效提升了分词准确度。 4 | 5 | 6 | 7 | ## 目录 8 | 9 | * [主要亮点](#主要亮点) 10 | * [编译和安装](#编译和安装) 11 | * [各类分词工具包的性能对比](#各类分词工具包的性能对比) 12 | * [使用方式](#使用方式) 13 | * [论文引用](#论文引用) 14 | * [作者](#作者) 15 | * [常见问题及解答](#常见问题及解答) 16 | 17 | 18 | 19 | ## 主要亮点 20 | 21 | pkuseg具有如下几个特点: 22 | 23 | 1. 多领域分词。不同于以往的通用中文分词工具,此工具包同时致力于为不同领域的数据提供个性化的预训练模型。根据待分词文本的领域特点,用户可以自由地选择不同的模型。 我们目前支持了新闻领域,网络领域,医药领域,旅游领域,以及混合领域的分词预训练模型。在使用中,如果用户明确待分词的领域,可加载对应的模型进行分词。如果用户无法确定具体领域,推荐使用在混合领域上训练的通用模型。各领域分词样例可参考 [**example.txt**](https://github.com/lancopku/pkuseg-python/blob/master/example.txt)。 24 | 2. 更高的分词准确率。相比于其他的分词工具包,当使用相同的训练数据和测试数据,pkuseg可以取得更高的分词准确率。 25 | 3. 支持用户自训练模型。支持用户使用全新的标注数据进行训练。 26 | 4. 支持词性标注。 27 | 28 | 29 | ## 编译和安装 30 | 31 | - 目前**仅支持python3** 32 | - **为了获得好的效果和速度,强烈建议大家通过pip install更新到目前的最新版本** 33 | 34 | 1. 通过PyPI安装(自带模型文件): 35 | ``` 36 | pip3 install pkuseg 37 | 之后通过import pkuseg来引用 38 | ``` 39 | **建议更新到最新版本**以获得更好的开箱体验: 40 | ``` 41 | pip3 install -U pkuseg 42 | ``` 43 | 2. 如果PyPI官方源下载速度不理想,建议使用镜像源,比如: 44 | 初次安装: 45 | ``` 46 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple pkuseg 47 | ``` 48 | 更新: 49 | ``` 50 | pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple -U pkuseg 51 | ``` 52 | 53 | 3. 如果不使用pip安装方式,选择从GitHub下载,可运行以下命令安装: 54 | ``` 55 | python setup.py build_ext -i 56 | ``` 57 | 58 | GitHub的代码并不包括预训练模型,因此需要用户自行下载或训练模型,预训练模型可详见[release](https://github.com/lancopku/pkuseg-python/releases)。使用时需设定"model_name"为模型文件。 59 | 60 | 注意:**安装方式1和2目前仅支持linux(ubuntu)、mac、windows 64 位的python3版本**。如果非以上系统,请使用安装方式3进行本地编译安装。 61 | 62 | 63 | ## 各类分词工具包的性能对比 64 | 65 | 我们选择jieba、THULAC等国内代表分词工具包与pkuseg做性能比较,详细设置可参考[实验环境](readme/environment.md)。 66 | 67 | 68 | 69 | #### 细领域训练及测试结果 70 | 71 | 以下是在不同数据集上的对比结果: 72 | 73 | | MSRA | Precision | Recall | F-score | 74 | | :----- | --------: | -----: | --------: | 75 | | jieba | 87.01 | 89.88 | 88.42 | 76 | | THULAC | 95.60 | 95.91 | 95.71 | 77 | | pkuseg | 96.94 | 96.81 | **96.88** | 78 | 79 | 80 | | WEIBO | Precision | Recall | F-score | 81 | | :----- | --------: | -----: | --------: | 82 | | jieba | 87.79 | 87.54 | 87.66 | 83 | | THULAC | 93.40 | 92.40 | 92.87 | 84 | | pkuseg | 93.78 | 94.65 | **94.21** | 85 | 86 | 87 | 88 | 89 | #### 默认模型在不同领域的测试效果 90 | 91 | 考虑到很多用户在尝试分词工具的时候,大多数时候会使用工具包自带模型测试。为了直接对比“初始”性能,我们也比较了各个工具包的默认模型在不同领域的测试效果。请注意,这样的比较只是为了说明默认情况下的效果,并不一定是公平的。 92 | 93 | | Default | MSRA | CTB8 | PKU | WEIBO | All Average | 94 | | ------- | :---: | :---: | :---: | :---: | :---------: | 95 | | jieba | 81.45 | 79.58 | 81.83 | 83.56 | 81.61 | 96 | | THULAC | 85.55 | 87.84 | 92.29 | 86.65 | 88.08 | 97 | | pkuseg | 87.29 | 91.77 | 92.68 | 93.43 | **91.29** | 98 | 99 | 其中,`All Average`显示的是在所有测试集上F-score的平均。 100 | 101 | 更多详细比较可参见[和现有工具包的比较](readme/comparison.md)。 102 | 103 | ## 使用方式 104 | 105 | #### 代码示例 106 | 107 | 以下代码示例适用于python交互式环境。 108 | 109 | 代码示例1:使用默认配置进行分词(**如果用户无法确定分词领域,推荐使用默认模型分词**) 110 | ```python3 111 | import pkuseg 112 | 113 | seg = pkuseg.pkuseg() # 以默认配置加载模型 114 | text = seg.cut('我爱北京天安门') # 进行分词 115 | print(text) 116 | ``` 117 | 118 | 代码示例2:细领域分词(**如果用户明确分词领域,推荐使用细领域模型分词**) 119 | ```python3 120 | import pkuseg 121 | 122 | seg = pkuseg.pkuseg(model_name='medicine') # 程序会自动下载所对应的细领域模型 123 | text = seg.cut('我爱北京天安门') # 进行分词 124 | print(text) 125 | ``` 126 | 127 | 代码示例3:分词同时进行词性标注,各词性标签的详细含义可参考 [tags.txt](https://github.com/lancopku/pkuseg-python/blob/master/tags.txt) 128 | ```python3 129 | import pkuseg 130 | 131 | seg = pkuseg.pkuseg(postag=True) # 开启词性标注功能 132 | text = seg.cut('我爱北京天安门') # 进行分词和词性标注 133 | print(text) 134 | ``` 135 | 136 | 137 | 代码示例4:对文件分词 138 | ```python3 139 | import pkuseg 140 | 141 | # 对input.txt的文件分词输出到output.txt中 142 | # 开20个进程 143 | pkuseg.test('input.txt', 'output.txt', nthread=20) 144 | ``` 145 | 146 | 其他使用示例可参见[详细代码示例](readme/interface.md)。 147 | 148 | 149 | 150 | #### 参数说明 151 | 152 | 模型配置 153 | ``` 154 | pkuseg.pkuseg(model_name = "default", user_dict = "default", postag = False) 155 | model_name 模型路径。 156 | "default",默认参数,表示使用我们预训练好的混合领域模型(仅对pip下载的用户)。 157 | "news", 使用新闻领域模型。 158 | "web", 使用网络领域模型。 159 | "medicine", 使用医药领域模型。 160 | "tourism", 使用旅游领域模型。 161 | model_path, 从用户指定路径加载模型。 162 | user_dict 设置用户词典。 163 | "default", 默认参数,使用我们提供的词典。 164 | None, 不使用词典。 165 | dict_path, 在使用默认词典的同时会额外使用用户自定义词典,可以填自己的用户词典的路径,词典格式为一行一个词(如果选择进行词性标注并且已知该词的词性,则在该行写下词和词性,中间用tab字符隔开)。 166 | postag 是否进行词性分析。 167 | False, 默认参数,只进行分词,不进行词性标注。 168 | True, 会在分词的同时进行词性标注。 169 | ``` 170 | 171 | 对文件进行分词 172 | ``` 173 | pkuseg.test(readFile, outputFile, model_name = "default", user_dict = "default", postag = False, nthread = 10) 174 | readFile 输入文件路径。 175 | outputFile 输出文件路径。 176 | model_name 模型路径。同pkuseg.pkuseg 177 | user_dict 设置用户词典。同pkuseg.pkuseg 178 | postag 设置是否开启词性分析功能。同pkuseg.pkuseg 179 | nthread 测试时开的进程数。 180 | ``` 181 | 182 | 模型训练 183 | ``` 184 | pkuseg.train(trainFile, testFile, savedir, train_iter = 20, init_model = None) 185 | trainFile 训练文件路径。 186 | testFile 测试文件路径。 187 | savedir 训练模型的保存路径。 188 | train_iter 训练轮数。 189 | init_model 初始化模型,默认为None表示使用默认初始化,用户可以填自己想要初始化的模型的路径如init_model='./models/'。 190 | ``` 191 | 192 | 193 | 194 | #### 多进程分词 195 | 196 | 当将以上代码示例置于文件中运行时,如涉及多进程功能,请务必使用`if __name__ == '__main__'`保护全局语句,详见[多进程分词](readme/multiprocess.md)。 197 | 198 | 199 | 200 | ## 预训练模型 201 | 202 | 从pip安装的用户在使用细领域分词功能时,只需要设置model_name字段为对应的领域即可,会自动下载对应的细领域模型。 203 | 204 | 从github下载的用户则需要自己下载对应的预训练模型,并设置model_name字段为预训练模型路径。预训练模型可以在[release](https://github.com/lancopku/pkuseg-python/releases)部分下载。以下是对预训练模型的说明: 205 | 206 | - **news**: 在MSRA(新闻语料)上训练的模型。 207 | 208 | - **web**: 在微博(网络文本语料)上训练的模型。 209 | 210 | - **medicine**: 在医药领域上训练的模型。 211 | 212 | - **tourism**: 在旅游领域上训练的模型。 213 | 214 | - **mixed**: 混合数据集训练的通用模型。随pip包附带的是此模型。 215 | 216 | 我们还通过领域自适应的方法,利用维基百科的未标注数据实现了几个细领域预训练模型的自动构建以及通用模型的优化,这些模型目前仅可以在release中下载: 217 | 218 | - **art**: 在艺术与文化领域上训练的模型。 219 | 220 | - **entertainment**: 在娱乐与体育领域上训练的模型。 221 | 222 | - **science**: 在科学领域上训练的模型。 223 | 224 | - **default_v2**: 使用领域自适应方法得到的优化后的通用模型,相较于默认模型规模更大,但泛化性能更好。 225 | 226 | 227 | 228 | 欢迎更多用户可以分享自己训练好的细分领域模型。 229 | 230 | 231 | 232 | ## 版本历史 233 | 234 | 详见[版本历史](readme/history.md)。 235 | 236 | 237 | ## 开源协议 238 | 1. 本代码采用MIT许可证。 239 | 2. 欢迎对该工具包提出任何宝贵意见和建议,请发邮件至jingjingxu@pku.edu.cn。 240 | 241 | 242 | 243 | ## 论文引用 244 | 245 | 该代码包主要基于以下科研论文,如使用了本工具,请引用以下论文: 246 | * Ruixuan Luo, Jingjing Xu, Yi Zhang, Zhiyuan Zhang, Xuancheng Ren, Xu Sun. [PKUSEG: A Toolkit for Multi-Domain Chinese Word Segmentation](https://arxiv.org/abs/1906.11455). Arxiv. 2019. 247 | 248 | ``` 249 | 250 | @article{pkuseg, 251 | author = {Luo, Ruixuan and Xu, Jingjing and Zhang, Yi and Zhang, Zhiyuan and Ren, Xuancheng and Sun, Xu}, 252 | journal = {CoRR}, 253 | title = {PKUSEG: A Toolkit for Multi-Domain Chinese Word Segmentation.}, 254 | url = {https://arxiv.org/abs/1906.11455}, 255 | volume = {abs/1906.11455}, 256 | year = 2019 257 | } 258 | ``` 259 | 260 | ## 其他相关论文 261 | 262 | * Xu Sun, Houfeng Wang, Wenjie Li. Fast Online Training with Frequency-Adaptive Learning Rates for Chinese Word Segmentation and New Word Detection. ACL. 2012. 263 | * Jingjing Xu and Xu Sun. Dependency-based gated recursive neural network for chinese word segmentation. ACL. 2016. 264 | * Jingjing Xu and Xu Sun. Transfer learning for low-resource chinese word segmentation with a novel neural network. NLPCC. 2017. 265 | 266 | ## 常见问题及解答 267 | 268 | 269 | 1. [为什么要发布pkuseg?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#1-为什么要发布pkuseg) 270 | 2. [pkuseg使用了哪些技术?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#2-pkuseg使用了哪些技术) 271 | 3. [无法使用多进程分词和训练功能,提示RuntimeError和BrokenPipeError。](https://github.com/lancopku/pkuseg-python/wiki/FAQ#3-无法使用多进程分词和训练功能提示runtimeerror和brokenpipeerror) 272 | 4. [是如何跟其它工具包在细领域数据上进行比较的?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#4-是如何跟其它工具包在细领域数据上进行比较的) 273 | 5. [在黑盒测试集上进行比较的话,效果如何?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#5-在黑盒测试集上进行比较的话效果如何) 274 | 6. [如果我不了解待分词语料的所属领域呢?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#6-如果我不了解待分词语料的所属领域呢) 275 | 7. [如何看待在一些特定样例上的分词结果?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#7-如何看待在一些特定样例上的分词结果) 276 | 8. [关于运行速度问题?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#8-关于运行速度问题) 277 | 9. [关于多进程速度问题?](https://github.com/lancopku/pkuseg-python/wiki/FAQ#9-关于多进程速度问题) 278 | 279 | 280 | ## 致谢 281 | 282 | 感谢俞士汶教授(北京大学计算语言所)与邱立坤博士提供的训练数据集! 283 | 284 | ## 作者 285 | 286 | Ruixuan Luo (罗睿轩), Jingjing Xu(许晶晶), Xuancheng Ren(任宣丞), Yi Zhang(张艺), Zhiyuan Zhang(张之远), Bingzhen Wei(位冰镇), Xu Sun (孙栩) 287 | 288 | 北京大学 [语言计算与机器学习研究组](http://lanco.pku.edu.cn/) 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | -------------------------------------------------------------------------------- /pkuseg/postag/feature_extractor.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # cython: infer_types=True 3 | # cython: language_level=3 4 | import json 5 | import os 6 | import sys 7 | import pickle 8 | from collections import Counter 9 | from itertools import product 10 | 11 | import cython 12 | 13 | @cython.boundscheck(False) 14 | @cython.wraparound(False) 15 | cpdef get_slice_str(iterable, int start, int length, int all_len): 16 | if start < 0 or start >= all_len: 17 | return "" 18 | if start + length >= all_len + 1: 19 | return "" 20 | return "".join(iterable[start : start + length]) 21 | 22 | 23 | 24 | @cython.boundscheck(False) 25 | @cython.wraparound(False) 26 | @cython.nonecheck(False) 27 | def __get_node_features_idx(int idx, list nodes not None, dict feature_to_idx not None): 28 | 29 | cdef: 30 | list flist = [] 31 | Py_ssize_t i = idx 32 | int length = len(nodes) 33 | int j 34 | 35 | 36 | w = nodes[i] 37 | 38 | # $$ starts feature 39 | flist.append(0) 40 | 41 | # unigram/bgiram feature 42 | feat = "w." + w 43 | if feat in feature_to_idx: 44 | feature = feature_to_idx[feat] 45 | flist.append(feature) 46 | 47 | for j in range(1, 4): 48 | if len(w)>=j: 49 | feat = "tr1.pre.%d.%s"%(j, w[:j]) 50 | if feat in feature_to_idx: 51 | flist.append(feature_to_idx[feat]) 52 | feat = "tr1.post.%d.%s"%(j, w[-j:]) 53 | if feat in feature_to_idx: 54 | flist.append(feature_to_idx[feat]) 55 | 56 | if i > 0: 57 | feat = "tr1.w-1." + nodes[i - 1] 58 | else: 59 | feat = "tr1.w-1.BOS" 60 | if feat in feature_to_idx: 61 | flist.append(feature_to_idx[feat]) 62 | if i < length - 1: 63 | feat = "tr1.w1." + nodes[i + 1] 64 | else: 65 | feat = "tr1.w1.EOS" 66 | if feat in feature_to_idx: 67 | flist.append(feature_to_idx[feat]) 68 | if i > 1: 69 | feat = "tr1.w-2." + nodes[i - 2] 70 | else: 71 | feat = "tr1.w-2.BOS" 72 | if feat in feature_to_idx: 73 | flist.append(feature_to_idx[feat]) 74 | if i < length - 2: 75 | feat = "tr1.w2." + nodes[i + 2] 76 | else: 77 | feat = "tr1.w2.EOS" 78 | if feat in feature_to_idx: 79 | flist.append(feature_to_idx[feat]) 80 | if i > 0: 81 | feat = "tr1.w_-1_0." + nodes[i - 1] + "." + w 82 | else: 83 | feat = "tr1.w_-1_0.BOS" 84 | if feat in feature_to_idx: 85 | flist.append(feature_to_idx[feat]) 86 | if i < length - 1: 87 | feat = "tr1.w_0_1." + w + "." + nodes[i + 1] 88 | else: 89 | feat = "tr1.w_0_1.EOS" 90 | if feat in feature_to_idx: 91 | flist.append(feature_to_idx[feat]) 92 | 93 | return flist 94 | 95 | 96 | class FeatureExtractor: 97 | 98 | keywords = "-._,|/*:" 99 | 100 | num = set("0123456789." "几二三四五六七八九十千万亿兆零" "1234567890%") 101 | letter = set( 102 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghigklmnopqrstuvwxyz" "/・-" 103 | ) 104 | 105 | keywords_translate_table = str.maketrans("-._,|/*:", "&&&&&&&&") 106 | 107 | @classmethod 108 | def keyword_rename(cls, text): 109 | return text.translate(cls.keywords_translate_table) 110 | 111 | @classmethod 112 | def _num_letter_normalize(cls, word): 113 | if not list(filter(lambda x:x not in cls.num, word)): 114 | return "**Num" 115 | return word 116 | 117 | @classmethod 118 | def normalize_text(cls, text): 119 | for i in range(len(text)): 120 | text[i] = cls.keyword_rename(text[i]) 121 | for character in text: 122 | yield cls._num_letter_normalize(character) 123 | 124 | 125 | def __init__(self): 126 | 127 | # self.unigram = set() # type: Set[str] 128 | # self.bigram = set() # type: Set[str] 129 | self.feature_to_idx = {} # type: Dict[str, int] 130 | self.tag_to_idx = {} # type: Dict[str, int] 131 | 132 | def get_node_features_idx(self, idx, nodes): 133 | return __get_node_features_idx(idx, nodes, self.feature_to_idx) 134 | 135 | def get_node_features(self, idx, wordary): 136 | cdef int length = len(wordary) 137 | w = wordary[idx] 138 | flist = [] 139 | 140 | # 1 start feature 141 | flist.append("$$") 142 | 143 | # 8 unigram/bgiram feature 144 | flist.append("w." + w) 145 | 146 | # prefix/suffix 147 | for i in range(1, 4): 148 | if len(w)>=i: 149 | flist.append("tr1.pre.%d.%s"%(i, w[:i])) 150 | flist.append("tr1.post.%d.%s"%(i, w[-i:])) 151 | else: 152 | flist.append("/") 153 | flist.append("/") 154 | 155 | if idx > 0: 156 | flist.append("tr1.w-1." + wordary[idx - 1]) 157 | else: 158 | flist.append("tr1.w-1.BOS") 159 | if idx < len(wordary) - 1: 160 | flist.append("tr1.w1." + wordary[idx + 1]) 161 | else: 162 | flist.append("tr1.w1.EOS") 163 | if idx > 1: 164 | flist.append("tr1.w-2." + wordary[idx - 2]) 165 | else: 166 | flist.append("tr1.w-2.BOS") 167 | if idx < len(wordary) - 2: 168 | flist.append("tr1.w2." + wordary[idx + 2]) 169 | else: 170 | flist.append("tr1.w2.EOS") 171 | if idx > 0: 172 | flist.append("tr1.w_-1_0." + wordary[idx - 1] + "." + w) 173 | else: 174 | flist.append("tr1.w_-1_0.BOS") 175 | if idx < len(wordary) - 1: 176 | flist.append("tr1.w_0_1." + w + "." + wordary[idx + 1]) 177 | else: 178 | flist.append("tr1.w_0_1.EOS") 179 | 180 | return flist 181 | 182 | def convert_feature_file_to_idx_file( 183 | self, feature_file, feature_idx_file, tag_idx_file 184 | ): 185 | 186 | with open(feature_file, "r", encoding="utf8") as reader: 187 | lines = reader.readlines() 188 | 189 | with open(feature_idx_file, "w", encoding="utf8") as f_writer, open( 190 | tag_idx_file, "w", encoding="utf8" 191 | ) as t_writer: 192 | 193 | f_writer.write("{}\n\n".format(len(self.feature_to_idx))) 194 | t_writer.write("{}\n\n".format(len(self.tag_to_idx))) 195 | 196 | tags_idx = [] # type: List[str] 197 | features_idx = [] # type: List[List[str]] 198 | for line in lines: 199 | line = line.strip() 200 | if not line: 201 | # sentence finish 202 | for feature_idx in features_idx: 203 | if not feature_idx: 204 | f_writer.write("0\n") 205 | else: 206 | f_writer.write(",".join(map(str, feature_idx))) 207 | f_writer.write("\n") 208 | f_writer.write("\n") 209 | 210 | t_writer.write(",".join(map(str, tags_idx))) 211 | t_writer.write("\n\n") 212 | 213 | tags_idx = [] 214 | features_idx = [] 215 | continue 216 | 217 | splits = line.split(" ") 218 | feature_idx = [ 219 | self.feature_to_idx[feat] 220 | for feat in splits[:-1] 221 | if feat in self.feature_to_idx 222 | ] 223 | features_idx.append(feature_idx) 224 | if not splits[-1] in self.tag_to_idx: 225 | tags_idx.append(-1) 226 | else: 227 | tags_idx.append(self.tag_to_idx[splits[-1]]) 228 | 229 | def convert_text_file_to_feature_file( 230 | self, text_file, conll_file=None, feature_file=None 231 | ): 232 | 233 | if conll_file is None: 234 | conll_file = "{}.conll{}".format(*os.path.split(text_file)) 235 | if feature_file is None: 236 | feature_file = "{}.feat{}".format(*os.path.split(text_file)) 237 | 238 | conll_line_format = "{} {}\n" 239 | 240 | with open(text_file, "r", encoding="utf8") as reader, open( 241 | conll_file, "w", encoding="utf8" 242 | ) as c_writer, open(feature_file, "w", encoding="utf8") as f_writer: 243 | for line in reader.read().strip().replace("\r", "").split("\n\n"): 244 | line = line.strip() 245 | if not line: 246 | continue 247 | line = self.keyword_rename(line).split("\n") 248 | words = [] 249 | tags = [] 250 | for word_tag in line: 251 | word, tag = word_tag.split() 252 | words.append(word) 253 | tags.append(tag) 254 | example = [ 255 | self._num_letter_normalize(word) 256 | for word in words 257 | ] 258 | for word, tag in zip(example, tags): 259 | c_writer.write(conll_line_format.format(word, tag)) 260 | c_writer.write("\n") 261 | 262 | for idx, tag in enumerate(tags): 263 | features = self.get_node_features(idx, example) 264 | features = [ 265 | (feature if feature in self.feature_to_idx else "/") 266 | for feature in features 267 | ] 268 | features.append(tag) 269 | f_writer.write(" ".join(features)) 270 | f_writer.write("\n") 271 | f_writer.write("\n") 272 | 273 | def save(self, model_dir): 274 | data = {} 275 | data["feature_to_idx"] = self.feature_to_idx 276 | data["tag_to_idx"] = self.tag_to_idx 277 | 278 | with open(os.path.join(model_dir, 'features.pkl'), 'wb') as writer: 279 | pickle.dump(data, writer, protocol=pickle.HIGHEST_PROTOCOL) 280 | 281 | 282 | @classmethod 283 | def load(cls, model_dir): 284 | extractor = cls.__new__(cls) 285 | 286 | feature_path = os.path.join(model_dir, "features.pkl") 287 | if os.path.exists(feature_path): 288 | with open(feature_path, "rb") as reader: 289 | data = pickle.load(reader) 290 | extractor.feature_to_idx = data["feature_to_idx"] 291 | extractor.tag_to_idx = data["tag_to_idx"] 292 | 293 | return extractor 294 | 295 | 296 | print( 297 | "WARNING: features.pkl does not exist, try loading features.json", 298 | file=sys.stderr, 299 | ) 300 | 301 | 302 | feature_path = os.path.join(model_dir, "features.json") 303 | if os.path.exists(feature_path): 304 | with open(feature_path, "r", encoding="utf8") as reader: 305 | data = json.load(reader) 306 | extractor.feature_to_idx = data["feature_to_idx"] 307 | extractor.tag_to_idx = data["tag_to_idx"] 308 | extractor.save(model_dir) 309 | return extractor 310 | print( 311 | "WARNING: features.json does not exist, try loading using old format", 312 | file=sys.stderr, 313 | ) 314 | 315 | extractor.feature_to_idx = {} 316 | feature_base_name = os.path.join(model_dir, "featureIndex.txt") 317 | for i in range(10): 318 | with open( 319 | "{}_{}".format(feature_base_name, i), "r", encoding="utf8" 320 | ) as reader: 321 | for line in reader: 322 | feature, index = line.split(" ") 323 | feature = ".".join(feature.split(".")[1:]) 324 | extractor.feature_to_idx[feature] = int(index) 325 | 326 | extractor.tag_to_idx = {} 327 | with open( 328 | os.path.join(model_dir, "tagIndex.txt"), "r", encoding="utf8" 329 | ) as reader: 330 | for line in reader: 331 | tag, index = line.split(" ") 332 | extractor.tag_to_idx[tag] = int(index) 333 | 334 | print( 335 | "INFO: features.json is saved", 336 | file=sys.stderr, 337 | ) 338 | extractor.save(model_dir) 339 | 340 | return extractor 341 | -------------------------------------------------------------------------------- /pkuseg/data.py: -------------------------------------------------------------------------------- 1 | # from .config import Config 2 | # from pkuseg.feature_generator import 3 | # import os 4 | import copy 5 | import random 6 | 7 | 8 | # class dataFormat: 9 | # def __init__(self, config): 10 | # self.featureIndexMap = {} 11 | # self.tagIndexMap = {} 12 | # self.config = config 13 | 14 | # def convert(self): 15 | # config = self.config 16 | # if config.runMode.find("train") >= 0: 17 | # self.getMaps(config.fTrain) 18 | # self.saveFeature(config.modelDir + "/featureIndex.txt") 19 | # self.convertFile(config.fTrain) 20 | # else: 21 | # self.readFeature(config.modelDir + "/featureIndex.txt") 22 | # self.readTag(config.modelDir + "/tagIndex.txt") 23 | # self.convertFile(config.fTest) 24 | # if config.dev: 25 | # self.convertFile(config.fDev) 26 | 27 | # def saveFeature(self, file): 28 | # featureList = list(self.featureIndexMap.keys()) 29 | # num = len(featureList) // 10 30 | # for i in range(10): 31 | # l = i * num 32 | # r = (i + 1) * num if i < 9 else len(featureList) 33 | # with open(file + "_" + str(i), "w", encoding="utf-8") as sw: 34 | # for w in range(l, r): 35 | # word = featureList[w] 36 | # sw.write(word + " " + str(self.featureIndexMap[word]) + "\n") 37 | 38 | # def readFeature(self, file): 39 | # featureList = [] 40 | # for i in range(10): 41 | # featureList.append([]) 42 | # with open(file + "_" + str(i), encoding="utf-8") as f: 43 | # lines = f.readlines() 44 | # for line in lines: 45 | # featureList[i].append(line.strip()) 46 | # feature = [] 47 | # for i in range(10): 48 | # for line in featureList[i]: 49 | # word, index = line.split(" ") 50 | # self.featureIndexMap[word] = int(index) 51 | 52 | # def readFeatureNormal(self, path): 53 | # with open(path, encoding="utf-8") as f: 54 | # lines = f.readlines() 55 | # for line in lines: 56 | # u, v = line.split(" ") 57 | # self.featureIndexMap[u] = int(v) 58 | 59 | # def readTag(self, path): 60 | # with open(path, encoding="utf-8") as f: 61 | # lines = f.readlines() 62 | # for line in lines: 63 | # u, v = line.split(" ") 64 | # self.tagIndexMap[u] = int(v) 65 | 66 | # def getMaps(self, file): 67 | # config = self.config 68 | # if not os.path.exists(file): 69 | # print("file {} not exist!".format(file)) 70 | # print("file {} converting...".format(file)) 71 | # featureFreqMap = {} 72 | # tagSet = set() 73 | # with open(file, encoding="utf-8") as f: 74 | # lines = f.readlines() 75 | # for line in lines: 76 | # line = line.replace("\t", " ") 77 | # line = line.replace("\r", "").strip() 78 | # if line == "": 79 | # continue 80 | # ary = line.split(config.blank) 81 | # for i in range(1, len(ary) - 1): 82 | # if ary[i] == "" or ary[i] == "/": 83 | # continue 84 | # if config.weightRegMode == "GL": 85 | # if not config.GL_init and config.groupTrim[i - 1]: 86 | # continue 87 | 88 | # ary2 = ary[i].split(config.slash) 89 | # feature = str(i) + "." + ary2[0] 90 | # if not feature in featureFreqMap: 91 | # featureFreqMap[feature] = 0 92 | # featureFreqMap[feature] += 1 93 | # tag = ary[-1] 94 | # tagSet.add(tag) 95 | # sortList = [] 96 | # for k in featureFreqMap: 97 | # sortList.append(k + " " + str(featureFreqMap[k])) 98 | # if config.weightRegMode == "GL": 99 | # sortList.sort(key=lambda x: (int(x.split(config.blank)[1].strip()), x)) 100 | # with open("featureTemp_sorted.txt", "w", encoding="utf-8") as f: 101 | # for x in sortList: 102 | # f.write(x + "\n") 103 | # config.groupStart = [0] 104 | # config.groupEnd = [] 105 | # for k in range(1, len(sortList)): 106 | # thisAry = sortList[k].split(config.dot) 107 | # preAry = sortList[k - 1].split(config.dot) 108 | # s = thisAry[0] 109 | # preAry = preAry[0] 110 | # if s != preAry: 111 | # config.groupStart.append(k) 112 | # config.groupEnd.append(k) 113 | # config.groupEnd.append(len(sortList)) 114 | # else: 115 | # sortList.sort( 116 | # key=lambda x: (int(x.split(config.blank)[1].strip()), x), reverse=True 117 | # ) 118 | 119 | # if config.weightRegMode == "GL" and config.GL_init: 120 | # if nFeatTemp != len(config.groupStart): 121 | # raise Exception( 122 | # "inconsistent # of features per line, check the feature file for consistency!" 123 | # ) 124 | # with open( 125 | # os.path.join(config.modelDir, "featureIndex.txt"), "w", encoding="utf-8" 126 | # ) as swFeat: 127 | # for i, l in enumerate(sortList): 128 | # ary = l.split(config.blank) 129 | # self.featureIndexMap[ary[0]] = i 130 | # swFeat.write("{} {}\n".format(ary[0].strip(), i)) 131 | # with open(os.path.join(config.modelDir, "tagIndex.txt"), "w", encoding="utf-8") as swTag: 132 | # tagSortList = [] 133 | # for tag in tagSet: 134 | # tagSortList.append(tag) 135 | # tagSortList.sort() 136 | # for i, l in enumerate(tagSortList): 137 | # self.tagIndexMap[l] = i 138 | # swTag.write("{} {}\n".format(l, i)) 139 | 140 | # def convertFile(self, file): 141 | # config = self.config 142 | # if not os.path.exists(file): 143 | # print("file {} not exist!".format(file)) 144 | # print("file converting...") 145 | # if file == config.fTrain: 146 | # swFeature = open(config.fFeatureTrain, "w", encoding="utf-8") 147 | # swGold = open(config.fGoldTrain, "w", encoding="utf-8") 148 | # else: 149 | # swFeature = open(config.fFeatureTest, "w", encoding="utf-8") 150 | # swGold = open(config.fGoldTest, "w", encoding="utf-8") 151 | # swFeature.write(str(len(self.featureIndexMap)) + "\n\n") 152 | # swGold.write(str(len(self.tagIndexMap)) + "\n\n") 153 | # with open(file, encoding="utf-8") as sr: 154 | # readLines = sr.readlines() 155 | # featureList = [] 156 | # goldList = [] 157 | # for k in range(len(readLines)): 158 | # line = readLines[k] 159 | # line = line.replace("\t", "").strip() 160 | # featureLine = "" 161 | # goldLine = "" 162 | # if line == "": 163 | # featureLine = featureLine + "\n" 164 | # goldLine = goldLine + "\n\n" 165 | # featureList.append(featureLine) 166 | # goldList.append(goldLine) 167 | # continue 168 | # flag = 0 169 | # ary = line.split(config.blank) 170 | # tmp = [] 171 | # for i in ary: 172 | # if i != "": 173 | # tmp.append(i) 174 | # ary = tmp 175 | # for i in range(1, len(ary) - 1): 176 | # if ary[i] == "/": 177 | # continue 178 | # ary2 = ary[i].split(config.slash) 179 | # tmp = [] 180 | # for j in ary2: 181 | # if j != "": 182 | # tmp.append(j) 183 | # ary2 = tmp 184 | # feature = str(i) + "." + ary2[0] 185 | # value = "" 186 | # real = False 187 | # if len(ary2) > 1: 188 | # value = ary2[1] 189 | # real = True 190 | # if not feature in self.featureIndexMap: 191 | # continue 192 | # flag = 1 193 | # fIndex = self.featureIndexMap[feature] 194 | # if not real: 195 | # featureLine = featureLine + str(fIndex) + "," 196 | # else: 197 | # featureLine = featureLine + str(fIndex) + "/" + value + "," 198 | # if flag == 0: 199 | # featureLine = featureLine + "0" 200 | # featureLine = featureLine + "\n" 201 | # tag = ary[-1] 202 | # tIndex = self.tagIndexMap[tag] 203 | # goldLine = goldLine + str(tIndex) + "," 204 | # featureList.append(featureLine) 205 | # goldList.append(goldLine) 206 | # for i in range(len(featureList)): 207 | # swFeature.write(featureList[i]) 208 | # swGold.write(goldList[i]) 209 | # swFeature.close() 210 | # swGold.close() 211 | 212 | 213 | class DataSet: 214 | def __init__(self, n_tag=0, n_feature=0): 215 | self.lst = [] # type: List[Example] 216 | self.n_tag = n_tag 217 | self.n_feature = n_feature 218 | # if len(args) == 2: 219 | # if type(args[0]) == int: 220 | # self.nTag, self.nFeature = args 221 | # else: 222 | # self.load(args[0], args[1]) 223 | 224 | def __len__(self): 225 | return len(self.lst) 226 | 227 | def __iter__(self): 228 | return self.iterator() 229 | 230 | def __getitem__(self, x): 231 | return self.lst[x] 232 | 233 | def iterator(self): 234 | for i in self.lst: 235 | yield i 236 | 237 | def append(self, x): 238 | self.lst.append(x) 239 | 240 | def clear(self): 241 | self.lst = [] 242 | 243 | def randomShuffle(self): 244 | cp = copy.deepcopy(self) 245 | random.shuffle(cp.lst) 246 | return cp 247 | 248 | # def setDataInfo(self, X): 249 | # self.nTag = X.nTag 250 | # self.nFeature = X.nFeature 251 | 252 | def resize(self, scale): 253 | dataset = DataSet(self.n_tag, self.n_feature) 254 | new_size = int(len(self) * scale) 255 | old_size = len(self) 256 | for i in range(new_size): 257 | if i >= old_size: 258 | i %= old_size 259 | dataset.append(self[i]) 260 | return dataset 261 | 262 | @classmethod 263 | def load(cls, feature_idx_file, tag_idx_file): 264 | dataset = cls.__new__(cls) 265 | 266 | # def load(self, fileFeature, fileTag): 267 | with open(feature_idx_file, encoding="utf-8") as f_reader, open( 268 | tag_idx_file, encoding="utf-8" 269 | ) as t_reader: 270 | 271 | example_strs = f_reader.read().split("\n\n")[:-1] 272 | tags_strs = t_reader.read().split("\n\n")[:-1] 273 | 274 | assert len(example_strs) == len( 275 | tags_strs 276 | ), "lengths do not match:\t{}\n{}\n".format(example_strs, tags_strs) 277 | 278 | n_feature = int(example_strs[0]) 279 | n_tag = int(tags_strs[0]) 280 | 281 | dataset.n_feature = n_feature 282 | dataset.n_tag = n_tag 283 | dataset.lst = [] 284 | 285 | for example_str, tags_str in zip(example_strs[1:], tags_strs[1:]): 286 | features = [ 287 | list(map(int, feature_line.split(","))) 288 | for feature_line in example_str.split("\n") 289 | ] 290 | tags = tags_str.split(",") 291 | example = Example(features, tags) 292 | dataset.lst.append(example) 293 | 294 | return dataset 295 | # txt = srfileFeature.read() 296 | # txt.replace("\r", "") 297 | # fAry = txt.split(Config.biLineEnd) 298 | # tmp = [] 299 | # for i in fAry: 300 | # if i != "": 301 | # tmp.append(i) 302 | # fAry = tmp 303 | # txt = srfileTag.read() 304 | # txt.replace("\r", "") 305 | # tAry = txt.split(Config.biLineEnd) 306 | # tmp = [] 307 | # for i in tAry: 308 | # if i != "": 309 | # tmp.append(i) 310 | # tAry = tmp 311 | 312 | # assert len(fAry) == len(tAry) 313 | # self.nFeature = int(fAry[0]) 314 | # self.nTag = int(tAry[0]) 315 | # for i in range(1, len(fAry)): 316 | # features = fAry[i] 317 | # tags = tAry[i] 318 | # seq = dataSeq() 319 | # seq.read(features, tags) 320 | # self.append(seq) 321 | 322 | # @property 323 | # def NTag(self): 324 | # return self.nTag 325 | 326 | 327 | class Example: 328 | def __init__(self, features, tags): 329 | self.features = features # type: List[List[int]] 330 | self.tags = list(map(int, tags)) # type: List[int] 331 | self.predicted_tags = None 332 | 333 | def __len__(self): 334 | return len(self.features) 335 | 336 | 337 | # class dataSeq: 338 | # def __init__(self, *args): 339 | # self.featureTemps = [] 340 | # self.yGold = [] 341 | # if len(args) == 2: 342 | # self.featureTemps = copy.deepcopy(args[0]) 343 | # self.yGold = copy.deepcopy(args[1]) 344 | # elif len(args) == 3: 345 | # x, n, length = args 346 | # end = min(n + length, len(x)) 347 | # for i in range(n, end): 348 | # self.featureTemps.append(x.featureTemps[i]) 349 | # yGold.append(x.yGold[i]) 350 | 351 | # def __len__(self): 352 | # return len(self.featureTemps) 353 | 354 | # def read(self, a, b): 355 | # lineAry = a.split(Config.lineEnd) 356 | # for im in lineAry: 357 | # if im == "": 358 | # continue 359 | # nodeList = [] 360 | # imAry = im.split(Config.comma) 361 | # for imm in imAry: 362 | # if imm == "": 363 | # continue 364 | # if imm.find("/") >= 0: 365 | # biAry = imm.split(Config.slash) 366 | # ft = featureTemp(int(biAry[0], float(biAry[1]))) 367 | # nodeList.append(ft) 368 | # else: 369 | # ft = featureTemp(int(imm), 1) 370 | # nodeList.append(ft) 371 | # self.featureTemps.append(nodeList) 372 | # lineAry = b.split(Config.comma) 373 | # for im in lineAry: 374 | # if im == "": 375 | # continue 376 | # self.yGold.append(int(im)) 377 | 378 | # # def load(self, feature): 379 | # # for imAry in feature: 380 | # # nodeList = [] 381 | # # for imm in imAry: 382 | # # if imm == "": 383 | # # continue 384 | # # if imm.find("/") >= 0: 385 | # # biAry = imm.split(Config.slash) 386 | # # ft = featureTemp(int(biAry[0], float(biAry[1]))) 387 | # # nodeList.append(ft) 388 | # # else: 389 | # # ft = featureTemp(int(imm), 1) 390 | # # nodeList.append(ft) 391 | # # self.featureTemps.append(nodeList) 392 | # # self.yGold.append(0) 393 | 394 | # def getFeatureTemp(self, *args): 395 | # return ( 396 | # self.featureTemps if len(args) == 0 else self.featureTemps[args[0]] 397 | # ) 398 | 399 | # def getTags(self, *args): 400 | # return self.yGold if len(args) == 0 else self.yGold[args[0]] 401 | 402 | # def setTags(self, lst): 403 | # assert len(lst) == len(self.yGold) 404 | # for i in range(len(lst)): 405 | # self.yGold[i] = lst[i] 406 | 407 | 408 | # class dataSeqTest: 409 | # def __init__(self, x, yOutput): 410 | # self._x = x 411 | # self._yOutput = yOutput 412 | -------------------------------------------------------------------------------- /pkuseg/trainer.py: -------------------------------------------------------------------------------- 1 | # from .config import config 2 | # from .feature import * 3 | # from .data_format import * 4 | # from .toolbox import * 5 | import os 6 | import time 7 | from multiprocessing import Process, Queue 8 | 9 | from pkuseg import res_summarize 10 | 11 | # from .inference import * 12 | # from .config import Config 13 | from pkuseg.config import Config, config 14 | from pkuseg.data import DataSet 15 | from pkuseg.feature_extractor import FeatureExtractor 16 | 17 | # from .feature_generator import * 18 | from pkuseg.model import Model 19 | import pkuseg.inference as _inf 20 | 21 | # from .inference import * 22 | # from .gradient import * 23 | from pkuseg.optimizer import ADF 24 | from pkuseg.scorer import getFscore 25 | 26 | # from typing import TextIO 27 | 28 | # from .res_summarize import summarize 29 | # from .res_summarize import write as reswrite 30 | 31 | # from pkuseg.trainer import Trainer 32 | 33 | 34 | def train(config=None): 35 | if config is None: 36 | config = Config() 37 | 38 | if config.init_model is None: 39 | feature_extractor = FeatureExtractor() 40 | else: 41 | feature_extractor = FeatureExtractor.load(config.init_model) 42 | feature_extractor.build(config.trainFile) 43 | feature_extractor.save() 44 | 45 | feature_extractor.convert_text_file_to_feature_file( 46 | config.trainFile, config.c_train, config.f_train 47 | ) 48 | feature_extractor.convert_text_file_to_feature_file( 49 | config.testFile, config.c_test, config.f_test 50 | ) 51 | 52 | feature_extractor.convert_feature_file_to_idx_file( 53 | config.f_train, config.fFeatureTrain, config.fGoldTrain 54 | ) 55 | feature_extractor.convert_feature_file_to_idx_file( 56 | config.f_test, config.fFeatureTest, config.fGoldTest 57 | ) 58 | 59 | config.globalCheck() 60 | 61 | config.swLog = open(os.path.join(config.outDir, config.fLog), "w") 62 | config.swResRaw = open(os.path.join(config.outDir, config.fResRaw), "w") 63 | config.swTune = open(os.path.join(config.outDir, config.fTune), "w") 64 | 65 | print("\nstart training...") 66 | config.swLog.write("\nstart training...\n") 67 | 68 | print("\nreading training & test data...") 69 | config.swLog.write("\nreading training & test data...\n") 70 | 71 | trainset = DataSet.load(config.fFeatureTrain, config.fGoldTrain) 72 | testset = DataSet.load(config.fFeatureTest, config.fGoldTest) 73 | 74 | trainset = trainset.resize(config.trainSizeScale) 75 | 76 | print( 77 | "done! train/test data sizes: {}/{}".format(len(trainset), len(testset)) 78 | ) 79 | config.swLog.write( 80 | "done! train/test data sizes: {}/{}\n".format( 81 | len(trainset), len(testset) 82 | ) 83 | ) 84 | 85 | config.swLog.write("\nr: {}\n".format(config.reg)) 86 | print("\nr: {}".format(config.reg)) 87 | if config.rawResWrite: 88 | config.swResRaw.write("\n%r: {}\n".format(config.reg)) 89 | 90 | trainer = Trainer(config, trainset, feature_extractor) 91 | 92 | time_list = [] 93 | err_list = [] 94 | diff_list = [] 95 | score_list_list = [] 96 | 97 | for i in range(config.ttlIter): 98 | # config.glbIter += 1 99 | time_s = time.time() 100 | err, sample_size, diff = trainer.train_epoch() 101 | time_t = time.time() - time_s 102 | time_list.append(time_t) 103 | err_list.append(err) 104 | diff_list.append(diff) 105 | 106 | score_list = trainer.test(testset, i) 107 | score_list_list.append(score_list) 108 | score = score_list[0] 109 | 110 | logstr = "iter{} diff={:.2e} train-time(sec)={:.2f} {}={:.2f}%".format( 111 | i, diff, time_t, config.metric, score 112 | ) 113 | config.swLog.write(logstr + "\n") 114 | config.swLog.write("------------------------------------------------\n") 115 | config.swLog.flush() 116 | print(logstr) 117 | 118 | res_summarize.write(config, time_list, err_list, diff_list, score_list_list) 119 | if config.save == 1: 120 | trainer.model.save() 121 | 122 | config.swLog.close() 123 | config.swResRaw.close() 124 | config.swTune.close() 125 | 126 | res_summarize.summarize(config) 127 | 128 | print("finished.") 129 | 130 | 131 | class Trainer: 132 | def __init__(self, config, dataset, feature_extractor): 133 | self.config = config 134 | self.X = dataset 135 | self.n_feature = dataset.n_feature 136 | self.n_tag = dataset.n_tag 137 | 138 | if config.init_model is None: 139 | self.model = Model(self.n_feature, self.n_tag) 140 | else: 141 | self.model = Model.load(config.init_model) 142 | self.model.expand(self.n_feature, self.n_tag) 143 | 144 | self.optim = self._get_optimizer(dataset, self.model) 145 | 146 | self.feature_extractor = feature_extractor 147 | self.idx_to_chunk_tag = {} 148 | for tag, idx in feature_extractor.tag_to_idx.items(): 149 | if tag.startswith("I"): 150 | tag = "I" 151 | if tag.startswith("O"): 152 | tag = "O" 153 | self.idx_to_chunk_tag[idx] = tag 154 | 155 | def _get_optimizer(self, dataset, model): 156 | config = self.config 157 | if "adf" in config.modelOptimizer: 158 | return ADF(config, dataset, model) 159 | 160 | raise ValueError("Invalid Optimizer") 161 | 162 | def train_epoch(self): 163 | return self.optim.optimize() 164 | 165 | def test(self, testset, iteration): 166 | 167 | outfile = os.path.join(config.outDir, config.fOutput.format(iteration)) 168 | 169 | func_mapping = { 170 | "tok.acc": self._decode_tokAcc, 171 | "str.acc": self._decode_strAcc, 172 | "f1": self._decode_fscore, 173 | } 174 | 175 | with open(outfile, "w", encoding="utf8") as writer: 176 | score_list = func_mapping[config.evalMetric]( 177 | testset, self.model, writer 178 | ) 179 | 180 | for example in testset: 181 | example.predicted_tags = None 182 | 183 | return score_list 184 | 185 | def _decode(self, testset: DataSet, model: Model): 186 | if config.nThread == 1: 187 | self._decode_single(testset, model) 188 | else: 189 | self._decode_multi_proc(testset, model) 190 | 191 | def _decode_single(self, testset: DataSet, model: Model): 192 | # n_tag = model.n_tag 193 | for example in testset: 194 | _, tags = _inf.decodeViterbi_fast(example.features, model) 195 | example.predicted_tags = tags 196 | 197 | @staticmethod 198 | def _decode_proc(model, in_queue, out_queue): 199 | while True: 200 | item = in_queue.get() 201 | if item is None: 202 | return 203 | idx, features = item 204 | _, tags = _inf.decodeViterbi_fast(features, model) 205 | out_queue.put((idx, tags)) 206 | 207 | def _decode_multi_proc(self, testset: DataSet, model: Model): 208 | in_queue = Queue() 209 | out_queue = Queue() 210 | procs = [] 211 | nthread = self.config.nThread 212 | for i in range(nthread): 213 | p = Process( 214 | target=self._decode_proc, args=(model, in_queue, out_queue) 215 | ) 216 | procs.append(p) 217 | 218 | for idx, example in enumerate(testset): 219 | in_queue.put((idx, example.features)) 220 | 221 | for proc in procs: 222 | in_queue.put(None) 223 | proc.start() 224 | 225 | for _ in range(len(testset)): 226 | idx, tags = out_queue.get() 227 | testset[idx].predicted_tags = tags 228 | 229 | for p in procs: 230 | p.join() 231 | 232 | # token accuracy 233 | def _decode_tokAcc(self, dataset, model, writer): 234 | config = self.config 235 | 236 | self._decode(dataset, model) 237 | n_tag = model.n_tag 238 | all_correct = [0] * n_tag 239 | all_pred = [0] * n_tag 240 | all_gold = [0] * n_tag 241 | 242 | for example in dataset: 243 | pred = example.predicted_tags 244 | gold = example.tags 245 | 246 | if writer is not None: 247 | writer.write(",".join(map(str, pred))) 248 | writer.write("\n") 249 | 250 | for pred_tag, gold_tag in zip(pred, gold): 251 | all_pred[pred_tag] += 1 252 | all_gold[gold_tag] += 1 253 | if pred_tag == gold_tag: 254 | all_correct[gold_tag] += 1 255 | 256 | config.swLog.write( 257 | "% tag-type #gold #output #correct-output token-precision token-recall token-f-score\n" 258 | ) 259 | sumGold = 0 260 | sumOutput = 0 261 | sumCorrOutput = 0 262 | 263 | for i, (correct, gold, pred) in enumerate( 264 | zip(all_correct, all_gold, all_pred) 265 | ): 266 | sumGold += gold 267 | sumOutput += pred 268 | sumCorrOutput += correct 269 | 270 | if gold == 0: 271 | rec = 0 272 | else: 273 | rec = correct * 100.0 / gold 274 | 275 | if pred == 0: 276 | prec = 0 277 | else: 278 | prec = correct * 100.0 / pred 279 | 280 | config.swLog.write( 281 | "% {}: {} {} {} {:.2f} {:.2f} {:.2f}\n".format( 282 | i, 283 | gold, 284 | pred, 285 | correct, 286 | prec, 287 | rec, 288 | (2 * prec * rec / (prec + rec)), 289 | ) 290 | ) 291 | 292 | if sumGold == 0: 293 | rec = 0 294 | else: 295 | rec = sumCorrOutput * 100.0 / sumGold 296 | if sumOutput == 0: 297 | prec = 0 298 | else: 299 | prec = sumCorrOutput * 100.0 / sumOutput 300 | 301 | if prec == 0 and rec == 0: 302 | fscore = 0 303 | else: 304 | fscore = 2 * prec * rec / (prec + rec) 305 | 306 | config.swLog.write( 307 | "% overall-tags: {} {} {} {:.2f} {:.2f} {:.2f}\n".format( 308 | sumGold, sumOutput, sumCorrOutput, prec, rec, fscore 309 | ) 310 | ) 311 | config.swLog.flush() 312 | return [fscore] 313 | 314 | def _decode_strAcc(self, dataset, model, writer): 315 | 316 | config = self.config 317 | 318 | self._decode(dataset, model) 319 | 320 | correct = 0 321 | total = len(dataset) 322 | 323 | for example in dataset: 324 | pred = example.predicted_tags 325 | gold = example.tags 326 | 327 | if writer is not None: 328 | writer.write(",".join(map(str, pred))) 329 | writer.write("\n") 330 | 331 | for pred_tag, gold_tag in zip(pred, gold): 332 | if pred_tag != gold_tag: 333 | break 334 | else: 335 | correct += 1 336 | 337 | acc = correct / total * 100.0 338 | config.swLog.write( 339 | "total-tag-strings={} correct-tag-strings={} string-accuracy={}%".format( 340 | total, correct, acc 341 | ) 342 | ) 343 | return [acc] 344 | 345 | def _decode_fscore(self, dataset, model, writer): 346 | config = self.config 347 | 348 | self._decode(dataset, model) 349 | 350 | gold_tags = [] 351 | pred_tags = [] 352 | 353 | for example in dataset: 354 | pred = example.predicted_tags 355 | gold = example.tags 356 | 357 | pred_str = ",".join(map(str, pred)) 358 | pred_tags.append(pred_str) 359 | if writer is not None: 360 | writer.write(pred_str) 361 | writer.write("\n") 362 | gold_tags.append(",".join(map(str, gold))) 363 | 364 | scoreList, infoList = getFscore( 365 | gold_tags, pred_tags, self.idx_to_chunk_tag 366 | ) 367 | config.swLog.write( 368 | "#gold-chunk={} #output-chunk={} #correct-output-chunk={} precision={:.2f} recall={:.2f} f-score={:.2f}\n".format( 369 | infoList[0], 370 | infoList[1], 371 | infoList[2], 372 | scoreList[1], 373 | scoreList[2], 374 | scoreList[0], 375 | ) 376 | ) 377 | return scoreList 378 | 379 | # acc = correct / total * 100.0 380 | # config.swLog.write( 381 | # "total-tag-strings={} correct-tag-strings={} string-accuracy={}%".format( 382 | # total, correct, acc 383 | # ) 384 | # ) 385 | 386 | # goldTagList = [] 387 | # resTagList = [] 388 | # for x in X2: 389 | # res = "" 390 | # for im in x._yOutput: 391 | # res += str(im) + "," 392 | # resTagList.append(res) 393 | # # if not dynamic: 394 | # if writer is not None: 395 | # for i in range(len(x._yOutput)): 396 | # writer.write(str(x._yOutput[i]) + ",") 397 | # writer.write("\n") 398 | # goldTags = x._x.getTags() 399 | # gold = "" 400 | # for im in goldTags: 401 | # gold += str(im) + "," 402 | # goldTagList.append(gold) 403 | # # if dynamic: 404 | # # return resTagList 405 | # scoreList = [] 406 | 407 | # if config.runMode == "train": 408 | # infoList = [] 409 | # scoreList = getFscore( 410 | # goldTagList, resTagList, infoList, self.idx_to_chunk_tag 411 | # ) 412 | # config.swLog.write( 413 | # "#gold-chunk={} #output-chunk={} #correct-output-chunk={} precision={:.2f} recall={:.2f} f-score={:.2f}\n".format( 414 | # infoList[0], 415 | # infoList[1], 416 | # infoList[2], 417 | # "%.2f" % scoreList[1], 418 | # "%.2f" % scoreList[2], 419 | # "%.2f" % scoreList[0], 420 | # ) 421 | # ) 422 | # return scoreList 423 | 424 | # # def multiThreading(self, X, X2): 425 | # config = self.config 426 | # # if dynamic: 427 | # # for i in range(len(X)): 428 | # # X2.append(dataSeqTest(X[i], [])) 429 | # # for k, x in enumerate(X2): 430 | # # tags = [] 431 | # # prob = self.Inf.decodeViterbi_fast(self.Model, x._x, tags) 432 | # # X2[k]._yOutput.clear() 433 | # # X2[k]._yOutput.extend(tags) 434 | # # return 435 | 436 | # for i in range(len(X)): 437 | # X2.append(dataSeqTest(X[i], [])) 438 | # if len(X) < config.nThread: 439 | # config.nThread = len(X) 440 | # interval = (len(X2) + config.nThread - 1) // config.nThread 441 | # procs = [] 442 | # Q = Queue(5000) 443 | # for i in range(config.nThread): 444 | # start = i * interval 445 | # end = min(start + interval, len(X2)) 446 | # proc = Process( 447 | # target=Trainer.taskRunner_test, 448 | # args=(self.Inf, self.Model, X2, start, end, Q), 449 | # ) 450 | # proc.start() 451 | # procs.append(proc) 452 | # for i in range(len(X2)): 453 | # t = Q.get() 454 | # k, tags = t 455 | # X2[k]._yOutput.clear() 456 | # X2[k]._yOutput.extend(tags) 457 | # for proc in procs: 458 | # proc.join() 459 | 460 | # @staticmethod 461 | # def taskRunner_test(Inf, Model, X2, start, end, Q): 462 | # for k in range(start, end): 463 | # x = X2[k] 464 | # tags = [] 465 | # prob = Inf.decodeViterbi_fast(Model, x._x, tags) 466 | # Q.put((k, tags)) 467 | -------------------------------------------------------------------------------- /pkuseg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | 4 | if sys.version_info[0] < 3: 5 | print("pkuseg does not support python2", file=sys.stderr) 6 | sys.exit(1) 7 | 8 | import os 9 | import time 10 | import pickle as pkl 11 | import multiprocessing 12 | 13 | from multiprocessing import Process, Queue 14 | 15 | import pkuseg.trainer as trainer 16 | import pkuseg.inference as _inf 17 | 18 | from pkuseg.config import config 19 | from pkuseg.feature_extractor import FeatureExtractor 20 | from pkuseg.model import Model 21 | from pkuseg.download import download_model 22 | from pkuseg.postag import Postag 23 | 24 | class TrieNode: 25 | """建立词典的Trie树节点""" 26 | 27 | def __init__(self, isword): 28 | self.isword = isword 29 | self.usertag = '' 30 | self.children = {} 31 | 32 | 33 | class Preprocesser: 34 | """预处理器,在用户词典中的词强制分割""" 35 | 36 | def __init__(self, dict_file): 37 | """初始化建立Trie树""" 38 | if dict_file is None: 39 | dict_file = [] 40 | self.dict_data = dict_file 41 | if isinstance(dict_file, str): 42 | with open(dict_file, encoding="utf-8") as f: 43 | lines = f.readlines() 44 | self.trie = TrieNode(False) 45 | for line in lines: 46 | fields = line.strip().split('\t') 47 | word = fields[0].strip() 48 | usertag = fields[1].strip() if len(fields) > 1 else '' 49 | self.insert(word, usertag) 50 | else: 51 | self.trie = TrieNode(False) 52 | for w_t in dict_file: 53 | if isinstance(w_t, str): 54 | w = w_t.strip() 55 | t = '' 56 | else: 57 | assert isinstance(w_t, tuple) 58 | assert len(w_t)==2 59 | w, t = map(lambda x:x.strip(), w_t) 60 | self.insert(w, t) 61 | 62 | def insert(self, word, usertag): 63 | """Trie树中插入单词""" 64 | l = len(word) 65 | now = self.trie 66 | for i in range(l): 67 | c = word[i] 68 | if not c in now.children: 69 | now.children[c] = TrieNode(False) 70 | now = now.children[c] 71 | now.isword = True 72 | now.usertag = usertag 73 | 74 | def solve(self, txt): 75 | """对文本进行预处理""" 76 | outlst = [] 77 | iswlst = [] 78 | taglst = [] 79 | l = len(txt) 80 | last = 0 81 | i = 0 82 | while i < l: 83 | now = self.trie 84 | j = i 85 | found = False 86 | usertag = '' 87 | last_word_idx = -1 # 表示从当前位置i往后匹配,最长匹配词词尾的idx 88 | while True: 89 | c = txt[j] 90 | if not c in now.children and last_word_idx != -1: 91 | found = True 92 | break 93 | if not c in now.children and last_word_idx == -1: 94 | break 95 | now = now.children[c] 96 | if now.isword: 97 | last_word_idx = j 98 | usertag = now.usertag 99 | j += 1 100 | if j == l and last_word_idx == -1: 101 | break 102 | if j == l and last_word_idx != -1 : 103 | j = last_word_idx + 1 104 | found = True 105 | break 106 | if found: 107 | if last != i: 108 | outlst.append(txt[last:i]) 109 | iswlst.append(False) 110 | taglst.append('') 111 | outlst.append(txt[i:j]) 112 | iswlst.append(True) 113 | taglst.append(usertag) 114 | last = j 115 | i = j 116 | else: 117 | i += 1 118 | if last < l: 119 | outlst.append(txt[last:l]) 120 | iswlst.append(False) 121 | taglst.append('') 122 | return outlst, iswlst, taglst 123 | 124 | class Postprocesser: 125 | """对分词结果后处理""" 126 | def __init__(self, common_name, other_names): 127 | if common_name is None and other_names is None: 128 | self.do_process = False 129 | return 130 | self.do_process = True 131 | if common_name is None: 132 | self.common_words = set() 133 | else: 134 | # with open(common_name, encoding='utf-8') as f: 135 | # lines = f.readlines() 136 | # self.common_words = set(map(lambda x:x.strip(), lines)) 137 | with open(common_name, "rb") as f: 138 | all_words = pkl.load(f).strip().split("\n") 139 | self.common_words = set(all_words) 140 | if other_names is None: 141 | self.other_words = set() 142 | else: 143 | self.other_words = set() 144 | for other_name in other_names: 145 | # with open(other_name, encoding='utf-8') as f: 146 | # lines = f.readlines() 147 | # self.other_words.update(set(map(lambda x:x.strip(), lines))) 148 | with open(other_name, "rb") as f: 149 | all_words = pkl.load(f).strip().split("\n") 150 | self.other_words.update(set(all_words)) 151 | 152 | def post_process(self, sent, check_seperated): 153 | for m in reversed(range(2, 8)): 154 | end = len(sent)-m 155 | if end < 0: 156 | continue 157 | i = 0 158 | while (i < end + 1): 159 | merged_words = ''.join(sent[i:i+m]) 160 | if merged_words in self.common_words: 161 | do_seg = True 162 | elif merged_words in self.other_words: 163 | if check_seperated: 164 | seperated = all(((w in self.common_words) 165 | or (w in self.other_words)) for w in sent[i:i+m]) 166 | else: 167 | seperated = False 168 | if seperated: 169 | do_seg = False 170 | else: 171 | do_seg = True 172 | else: 173 | do_seg = False 174 | if do_seg: 175 | for k in range(m): 176 | del sent[i] 177 | sent.insert(i, merged_words) 178 | i += 1 179 | end = len(sent) - m 180 | else: 181 | i += 1 182 | return sent 183 | 184 | def __call__(self, sent): 185 | if not self.do_process: 186 | return sent 187 | return self.post_process(sent, check_seperated=True) 188 | 189 | class pkuseg: 190 | def __init__(self, model_name="default", user_dict="default", postag=False): 191 | """初始化函数,加载模型及用户词典""" 192 | # print("loading model") 193 | # config = Config() 194 | # self.config = config 195 | self.postag = postag 196 | if model_name in ["default"]: 197 | config.modelDir = os.path.join( 198 | os.path.dirname(os.path.realpath(__file__)), 199 | "models", 200 | model_name, 201 | ) 202 | elif model_name in config.available_models: 203 | config.modelDir = os.path.join( 204 | config.pkuseg_home, 205 | model_name, 206 | ) 207 | download_model(config.model_urls[model_name], config.pkuseg_home, config.model_hash[model_name]) 208 | else: 209 | config.modelDir = model_name 210 | # config.fModel = os.path.join(config.modelDir, "model.txt") 211 | if user_dict is None: 212 | file_name = None 213 | other_names = None 214 | else: 215 | if user_dict not in config.available_models: 216 | file_name = user_dict 217 | else: 218 | file_name = None 219 | if model_name in config.models_with_dict: 220 | other_name = os.path.join( 221 | config.pkuseg_home, 222 | model_name, 223 | model_name+"_dict.pkl", 224 | ) 225 | default_name = os.path.join( 226 | os.path.dirname(os.path.realpath(__file__)), 227 | "dicts", "default.pkl", 228 | ) 229 | other_names = [other_name, default_name] 230 | else: 231 | default_name = os.path.join( 232 | os.path.dirname(os.path.realpath(__file__)), 233 | "dicts", "default.pkl", 234 | ) 235 | other_names = [default_name] 236 | 237 | self.preprocesser = Preprocesser(file_name) 238 | # self.preprocesser = Preprocesser([]) 239 | self.postprocesser = Postprocesser(None, other_names) 240 | 241 | self.feature_extractor = FeatureExtractor.load() 242 | self.model = Model.load() 243 | 244 | self.idx_to_tag = { 245 | idx: tag for tag, idx in self.feature_extractor.tag_to_idx.items() 246 | } 247 | 248 | self.n_feature = len(self.feature_extractor.feature_to_idx) 249 | self.n_tag = len(self.feature_extractor.tag_to_idx) 250 | 251 | if postag: 252 | download_model(config.model_urls["postag"], config.pkuseg_home, config.model_hash["postag"]) 253 | postag_dir = os.path.join( 254 | config.pkuseg_home, 255 | "postag", 256 | ) 257 | self.tagger = Postag(postag_dir) 258 | 259 | # print("finish") 260 | 261 | def _cut(self, text): 262 | """ 263 | 直接对文本分词 264 | """ 265 | 266 | examples = list(self.feature_extractor.normalize_text(text)) 267 | length = len(examples) 268 | 269 | all_feature = [] # type: List[List[int]] 270 | for idx in range(length): 271 | node_feature_idx = self.feature_extractor.get_node_features_idx( 272 | idx, examples 273 | ) 274 | # node_feature = self.feature_extractor.get_node_features( 275 | # idx, examples 276 | # ) 277 | 278 | # node_feature_idx = [] 279 | # for feature in node_feature: 280 | # feature_idx = self.feature_extractor.feature_to_idx.get(feature) 281 | # if feature_idx is not None: 282 | # node_feature_idx.append(feature_idx) 283 | # if not node_feature_idx: 284 | # node_feature_idx.append(0) 285 | 286 | all_feature.append(node_feature_idx) 287 | 288 | _, tags = _inf.decodeViterbi_fast(all_feature, self.model) 289 | 290 | words = [] 291 | current_word = None 292 | is_start = True 293 | for tag, char in zip(tags, text): 294 | if is_start: 295 | current_word = char 296 | is_start = False 297 | elif "B" in self.idx_to_tag[tag]: 298 | words.append(current_word) 299 | current_word = char 300 | else: 301 | current_word += char 302 | if current_word: 303 | words.append(current_word) 304 | 305 | return words 306 | 307 | def cut(self, txt): 308 | """分词,结果返回一个list""" 309 | 310 | txt = txt.strip() 311 | 312 | ret = [] 313 | usertags = [] 314 | 315 | if not txt: 316 | return ret 317 | 318 | imary = txt.split() # 根据空格分为多个片段 319 | 320 | # 对每个片段分词 321 | for w0 in imary: 322 | if not w0: 323 | continue 324 | 325 | # 根据用户词典拆成更多片段 326 | lst, isword, taglst = self.preprocesser.solve(w0) 327 | 328 | for w, isw, usertag in zip(lst, isword, taglst): 329 | if isw: 330 | ret.append(w) 331 | usertags.append(usertag) 332 | continue 333 | 334 | output = self._cut(w) 335 | post_output = self.postprocesser(output) 336 | ret.extend(post_output) 337 | usertags.extend(['']*len(post_output)) 338 | 339 | if self.postag: 340 | tags = self.tagger.tag(ret.copy()) 341 | for i, usertag in enumerate(usertags): 342 | if usertag: 343 | tags[i] = usertag 344 | ret = list(zip(ret, tags)) 345 | return ret 346 | 347 | 348 | def train(trainFile, testFile, savedir, train_iter=20, init_model=None): 349 | """用于训练模型""" 350 | # config = Config() 351 | starttime = time.time() 352 | if not os.path.exists(trainFile): 353 | raise Exception("trainfile does not exist.") 354 | if not os.path.exists(testFile): 355 | raise Exception("testfile does not exist.") 356 | if not os.path.exists(config.tempFile): 357 | os.makedirs(config.tempFile) 358 | if not os.path.exists(config.tempFile + "/output"): 359 | os.mkdir(config.tempFile + "/output") 360 | # config.runMode = "train" 361 | config.trainFile = trainFile 362 | config.testFile = testFile 363 | config.modelDir = savedir 364 | # config.fModel = os.path.join(config.modelDir, "model.txt") 365 | config.nThread = 1 366 | config.ttlIter = train_iter 367 | config.init_model = init_model 368 | 369 | os.makedirs(config.modelDir, exist_ok=True) 370 | 371 | trainer.train(config) 372 | 373 | # pkuseg.main.run(config) 374 | # clearDir(config.tempFile) 375 | print("Total time: " + str(time.time() - starttime)) 376 | 377 | 378 | def _test_single_proc( 379 | input_file, output_file, model_name="default", user_dict="default", postag=False, verbose=False 380 | ): 381 | 382 | times = [] 383 | times.append(time.time()) 384 | seg = pkuseg(model_name, user_dict, postag=postag) 385 | 386 | times.append(time.time()) 387 | if not os.path.exists(input_file): 388 | raise Exception("input_file {} does not exist.".format(input_file)) 389 | with open(input_file, "r", encoding="utf-8") as f: 390 | lines = f.readlines() 391 | 392 | times.append(time.time()) 393 | results = [] 394 | for line in lines: 395 | if not postag: 396 | results.append(" ".join(seg.cut(line))) 397 | else: 398 | results.append(" ".join(map(lambda x:"/".join(x), seg.cut(line)))) 399 | 400 | times.append(time.time()) 401 | with open(output_file, "w", encoding="utf-8") as f: 402 | f.write("\n".join(results)) 403 | times.append(time.time()) 404 | 405 | print("total_time:\t{:.3f}".format(times[-1] - times[0])) 406 | 407 | if verbose: 408 | time_strs = ["load_model", "read_file", "word_seg", "write_file"] 409 | for key, value in zip( 410 | time_strs, 411 | [end - start for start, end in zip(times[:-1], times[1:])], 412 | ): 413 | print("{}:\t{:.3f}".format(key, value)) 414 | 415 | 416 | def _proc_deprecated(seg, lines, start, end, q): 417 | for i in range(start, end): 418 | l = lines[i].strip() 419 | ret = seg.cut(l) 420 | q.put((i, " ".join(ret))) 421 | 422 | 423 | def _proc(seg, in_queue, out_queue): 424 | # TODO: load seg (json or pickle serialization) in sub_process 425 | # to avoid pickle seg online when using start method other 426 | # than fork 427 | while True: 428 | item = in_queue.get() 429 | if item is None: 430 | return 431 | idx, line = item 432 | if not seg.postag: 433 | output_str = " ".join(seg.cut(line)) 434 | else: 435 | output_str = " ".join(map(lambda x:"/".join(x), seg.cut(line))) 436 | out_queue.put((idx, output_str)) 437 | 438 | 439 | def _proc_alt(model_name, user_dict, postag, in_queue, out_queue): 440 | seg = pkuseg(model_name, user_dict, postag=postag) 441 | while True: 442 | item = in_queue.get() 443 | if item is None: 444 | return 445 | idx, line = item 446 | if not postag: 447 | output_str = " ".join(seg.cut(line)) 448 | else: 449 | output_str = " ".join(map(lambda x:"/".join(x), seg.cut(line))) 450 | out_queue.put((idx, output_str)) 451 | 452 | 453 | def _test_multi_proc( 454 | input_file, 455 | output_file, 456 | nthread, 457 | model_name="default", 458 | user_dict="default", 459 | postag=False, 460 | verbose=False, 461 | ): 462 | 463 | alt = multiprocessing.get_start_method() == "spawn" 464 | 465 | times = [] 466 | times.append(time.time()) 467 | 468 | if alt: 469 | seg = None 470 | else: 471 | seg = pkuseg(model_name, user_dict, postag) 472 | 473 | times.append(time.time()) 474 | if not os.path.exists(input_file): 475 | raise Exception("input_file {} does not exist.".format(input_file)) 476 | with open(input_file, "r", encoding="utf-8") as f: 477 | lines = f.readlines() 478 | 479 | times.append(time.time()) 480 | in_queue = Queue() 481 | out_queue = Queue() 482 | procs = [] 483 | for _ in range(nthread): 484 | if alt: 485 | p = Process( 486 | target=_proc_alt, 487 | args=(model_name, user_dict, postag, in_queue, out_queue), 488 | ) 489 | else: 490 | p = Process(target=_proc, args=(seg, in_queue, out_queue)) 491 | procs.append(p) 492 | 493 | for idx, line in enumerate(lines): 494 | in_queue.put((idx, line)) 495 | 496 | for proc in procs: 497 | in_queue.put(None) 498 | proc.start() 499 | 500 | times.append(time.time()) 501 | result = [None] * len(lines) 502 | for _ in result: 503 | idx, line = out_queue.get() 504 | result[idx] = line 505 | 506 | times.append(time.time()) 507 | for p in procs: 508 | p.join() 509 | 510 | times.append(time.time()) 511 | with open(output_file, "w", encoding="utf-8") as f: 512 | f.write("\n".join(result)) 513 | times.append(time.time()) 514 | 515 | print("total_time:\t{:.3f}".format(times[-1] - times[0])) 516 | 517 | if verbose: 518 | time_strs = [ 519 | "load_model", 520 | "read_file", 521 | "start_proc", 522 | "word_seg", 523 | "join_proc", 524 | "write_file", 525 | ] 526 | 527 | if alt: 528 | times = times[1:] 529 | time_strs = time_strs[1:] 530 | time_strs[2] = "load_modal & word_seg" 531 | 532 | for key, value in zip( 533 | time_strs, 534 | [end - start for start, end in zip(times[:-1], times[1:])], 535 | ): 536 | print("{}:\t{:.3f}".format(key, value)) 537 | 538 | 539 | def test( 540 | input_file, 541 | output_file, 542 | model_name="default", 543 | user_dict="default", 544 | nthread=10, 545 | postag=False, 546 | verbose=False, 547 | ): 548 | 549 | if nthread > 1: 550 | _test_multi_proc( 551 | input_file, output_file, nthread, model_name, user_dict, postag, verbose 552 | ) 553 | else: 554 | _test_single_proc( 555 | input_file, output_file, model_name, user_dict, postag, verbose 556 | ) 557 | 558 | -------------------------------------------------------------------------------- /pkuseg/feature_extractor.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # cython: infer_types=True 3 | # cython: language_level=3 4 | import json 5 | import os 6 | import sys 7 | import pickle 8 | from collections import Counter 9 | from itertools import product 10 | 11 | import cython 12 | from pkuseg.config import config 13 | 14 | 15 | @cython.boundscheck(False) 16 | @cython.wraparound(False) 17 | cpdef get_slice_str(iterable, int start, int length, int all_len): 18 | if start < 0 or start >= all_len: 19 | return "" 20 | if start + length >= all_len + 1: 21 | return "" 22 | return "".join(iterable[start : start + length]) 23 | 24 | 25 | 26 | @cython.boundscheck(False) 27 | @cython.wraparound(False) 28 | @cython.nonecheck(False) 29 | def __get_node_features_idx(object config not None, int idx, list nodes not None, dict feature_to_idx not None, set unigram not None): 30 | 31 | cdef: 32 | list flist = [] 33 | Py_ssize_t i = idx 34 | int length = len(nodes) 35 | int word_max = config.wordMax 36 | int word_min = config.wordMin 37 | int word_range = word_max - word_min + 1 38 | 39 | 40 | c = nodes[i] 41 | 42 | # $$ starts feature 43 | flist.append(0) 44 | 45 | # 8 unigram/bgiram feature 46 | feat = 'c.' + c 47 | if feat in feature_to_idx: 48 | feature = feature_to_idx[feat] 49 | flist.append(feature) 50 | 51 | 52 | if i > 0: 53 | prev_c = nodes[i-1] 54 | feat = 'c-1.' + prev_c 55 | if feat in feature_to_idx: 56 | feature = feature_to_idx[feat] 57 | flist.append(feature) 58 | 59 | feat = 'c-1c.' + prev_c + '.' + c 60 | if feat in feature_to_idx: 61 | feature = feature_to_idx[feat] 62 | flist.append(feature) 63 | 64 | if i + 1 < length: 65 | next_c = nodes[i+1] 66 | 67 | feat = 'c1.' + next_c 68 | if feat in feature_to_idx: 69 | feature = feature_to_idx[feat] 70 | flist.append(feature) 71 | 72 | feat = 'cc1.' + c + '.' + next_c 73 | if feat in feature_to_idx: 74 | feature = feature_to_idx[feat] 75 | flist.append(feature) 76 | 77 | 78 | if i > 1: 79 | prepre_char = nodes[i-2] 80 | feat = 'c-2.' + prepre_char 81 | if feat in feature_to_idx: 82 | feature = feature_to_idx[feat] 83 | flist.append(feature) 84 | 85 | feat = 'c-2c-1.' + prepre_char + '.' + nodes[i-1] 86 | if feat in feature_to_idx: 87 | feature = feature_to_idx[feat] 88 | flist.append(feature) 89 | 90 | 91 | 92 | if i + 2 < length: 93 | feat = 'c2.' + nodes[i+2] 94 | if feat in feature_to_idx: 95 | feature = feature_to_idx[feat] 96 | flist.append(feature) 97 | 98 | 99 | # no num/letter based features 100 | if not config.wordFeature: 101 | return flist 102 | 103 | 104 | # 2 * (wordMax-wordMin+1) word features (default: 2*(6-2+1)=10 ) 105 | # the character starts or ends a word 106 | 107 | prelst_in = [] 108 | for l in range(word_max, word_min - 1, -1): 109 | # length 6 ... 2 (default) 110 | # "prefix including current c" wordary[n-l+1, n] 111 | # current character ends word 112 | tmp = get_slice_str(nodes, i - l + 1, l, length) 113 | if tmp in unigram: 114 | feat = 'w-1.' + tmp 115 | if feat in feature_to_idx: 116 | feature = feature_to_idx[feat] 117 | flist.append(feature) 118 | 119 | prelst_in.append(tmp) 120 | else: 121 | prelst_in.append("**noWord") 122 | 123 | 124 | 125 | postlst_in = [] 126 | for l in range(word_max, word_min - 1, -1): 127 | # "suffix" wordary[n, n+l-1] 128 | # current character starts word 129 | tmp = get_slice_str(nodes, i, l, length) 130 | if tmp in unigram: 131 | feat = 'w1.' + tmp 132 | if feat in feature_to_idx: 133 | feature = feature_to_idx[feat] 134 | flist.append(feature) 135 | 136 | postlst_in.append(tmp) 137 | else: 138 | postlst_in.append("**noWord") 139 | 140 | 141 | # these are not in feature list 142 | prelst_ex = [] 143 | for l in range(word_max, word_min - 1, -1): 144 | # "prefix excluding current c" wordary[n-l, n-1] 145 | tmp = get_slice_str(nodes, i - l, l, length) 146 | if tmp in unigram: 147 | prelst_ex.append(tmp) 148 | else: 149 | prelst_ex.append("**noWord") 150 | 151 | 152 | postlst_ex = [] 153 | for l in range(word_max, word_min - 1, -1): 154 | # "suffix excluding current c" wordary[n+1, n+l] 155 | tmp = get_slice_str(nodes, i + 1, l, length) 156 | if tmp in unigram: 157 | postlst_ex.append(tmp) 158 | else: 159 | postlst_ex.append("**noWord") 160 | 161 | 162 | # this character is in the middle of a word 163 | # 2*(wordMax-wordMin+1)^2 (default: 2*(6-2+1)^2=50) 164 | 165 | for pre in prelst_ex: 166 | for post in postlst_in: 167 | feat = 'ww.l.' + pre + '*' + post 168 | if feat in feature_to_idx: 169 | feature = feature_to_idx[feat] 170 | flist.append(feature) 171 | 172 | 173 | for pre in prelst_in: 174 | for post in postlst_ex: 175 | feat = 'ww.r.' + pre + '*' + post 176 | if feat in feature_to_idx: 177 | feature = feature_to_idx[feat] 178 | flist.append(feature) 179 | 180 | 181 | return flist 182 | 183 | 184 | class FeatureExtractor: 185 | 186 | keywords = "-._,|/*:" 187 | 188 | num = set("0123456789." "几二三四五六七八九十千万亿兆零" "1234567890%") 189 | letter = set( 190 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghigklmnopqrstuvwxyz" "/・-" 191 | ) 192 | 193 | keywords_translate_table = str.maketrans("-._,|/*:", "&&&&&&&&") 194 | 195 | @classmethod 196 | def keyword_rename(cls, text): 197 | return text.translate(cls.keywords_translate_table) 198 | 199 | @classmethod 200 | def _num_letter_normalize_char(cls, character): 201 | if character in cls.num: 202 | return "**Num" 203 | if character in cls.letter: 204 | return "**Letter" 205 | return character 206 | 207 | @classmethod 208 | def normalize_text(cls, text): 209 | text = cls.keyword_rename(text) 210 | for character in text: 211 | if config.numLetterNorm: 212 | yield cls._num_letter_normalize_char(character) 213 | else: 214 | yield character 215 | 216 | 217 | def __init__(self): 218 | 219 | self.unigram = set() # type: Set[str] 220 | self.bigram = set() # type: Set[str] 221 | self.feature_to_idx = {} # type: Dict[str, int] 222 | self.tag_to_idx = {} # type: Dict[str, int] 223 | 224 | def build(self, train_file): 225 | with open(train_file, "r", encoding="utf8") as reader: 226 | lines = reader.readlines() 227 | 228 | examples = [] # type: List[List[List[str]]] 229 | 230 | # first pass to collect unigram and bigram and tag info 231 | word_length_info = Counter() 232 | specials = set() 233 | for line in lines: 234 | line = line.strip("\n\r") # .replace("\t", " ") 235 | if not line: 236 | continue 237 | 238 | line = self.keyword_rename(line) 239 | 240 | # str.split() without sep sees consecutive whiltespaces as one separator 241 | # e.g., '\ra \t b \r\n'.split() = ['a', 'b'] 242 | words = [word for word in line.split()] 243 | 244 | word_length_info.update(map(len, words)) 245 | specials.update(word for word in words if len(word)>=10) 246 | self.unigram.update(words) 247 | 248 | for pre, suf in zip(words[:-1], words[1:]): 249 | self.bigram.add("{}*{}".format(pre, suf)) 250 | 251 | example = [ 252 | self._num_letter_normalize_char(character) 253 | for word in words 254 | for character in word 255 | ] 256 | examples.append(example) 257 | 258 | max_word_length = max(word_length_info.keys()) 259 | for length in range(1, max_word_length + 1): 260 | print("length = {} : {}".format(length, word_length_info[length])) 261 | # print('special words: {}'.format(', '.join(specials))) 262 | # second pass to get features 263 | 264 | feature_freq = Counter() 265 | 266 | for example in examples: 267 | for i, _ in enumerate(example): 268 | node_features = self.get_node_features(i, example) 269 | feature_freq.update( 270 | feature for feature in node_features if feature != "/" 271 | ) 272 | 273 | feature_set = ( 274 | feature 275 | for feature, freq in feature_freq.most_common() 276 | if freq > config.featureTrim 277 | ) 278 | 279 | tot = len(self.feature_to_idx) 280 | for feature in feature_set: 281 | if not feature in self.feature_to_idx: 282 | self.feature_to_idx[feature] = tot 283 | tot += 1 284 | # self.feature_to_idx = { 285 | # feature: idx for idx, feature in enumerate(feature_set) 286 | # } 287 | 288 | if config.nLabel == 2: 289 | B = B_single = "B" 290 | I_first = I = I_end = "I" 291 | elif config.nLabel == 3: 292 | B = B_single = "B" 293 | I_first = I = "I" 294 | I_end = "I_end" 295 | elif config.nLabel == 4: 296 | B = "B" 297 | B_single = "B_single" 298 | I_first = I = "I" 299 | I_end = "I_end" 300 | elif config.nLabel == 5: 301 | B = "B" 302 | B_single = "B_single" 303 | I_first = "I_first" 304 | I = "I" 305 | I_end = "I_end" 306 | 307 | tag_set = {B, B_single, I_first, I, I_end} 308 | self.tag_to_idx = {tag: idx for idx, tag in enumerate(sorted(tag_set))} 309 | 310 | 311 | 312 | def get_node_features_idx(self, idx, nodes): 313 | return __get_node_features_idx(config, idx, nodes, self.feature_to_idx, self.unigram) 314 | 315 | 316 | def get_node_features(self, idx, wordary): 317 | cdef int length = len(wordary) 318 | w = wordary[idx] 319 | flist = [] 320 | 321 | # 1 start feature 322 | flist.append("$$") 323 | 324 | # 8 unigram/bgiram feature 325 | flist.append("c." + w) 326 | if idx > 0: 327 | flist.append("c-1." + wordary[idx - 1]) 328 | else: 329 | flist.append("/") 330 | if idx < len(wordary) - 1: 331 | flist.append("c1." + wordary[idx + 1]) 332 | else: 333 | flist.append("/") 334 | if idx > 1: 335 | flist.append("c-2." + wordary[idx - 2]) 336 | else: 337 | flist.append("/") 338 | if idx < len(wordary) - 2: 339 | flist.append("c2." + wordary[idx + 2]) 340 | else: 341 | flist.append("/") 342 | if idx > 0: 343 | flist.append("c-1c." + wordary[idx - 1] + config.delimInFeature + w) 344 | else: 345 | flist.append("/") 346 | if idx < len(wordary) - 1: 347 | flist.append("cc1." + w + config.delimInFeature + wordary[idx + 1]) 348 | else: 349 | flist.append("/") 350 | if idx > 1: 351 | flist.append( 352 | "c-2c-1." 353 | + wordary[idx - 2] 354 | + config.delimInFeature 355 | + wordary[idx - 1] 356 | ) 357 | else: 358 | flist.append("/") 359 | 360 | # no num/letter based features 361 | if not config.wordFeature: 362 | return flist 363 | 364 | # 2 * (wordMax-wordMin+1) word features (default: 2*(6-2+1)=10 ) 365 | # the character starts or ends a word 366 | tmplst = [] 367 | for l in range(config.wordMax, config.wordMin - 1, -1): 368 | # length 6 ... 2 (default) 369 | # "prefix including current c" wordary[n-l+1, n] 370 | # current character ends word 371 | tmp = get_slice_str(wordary, idx - l + 1, l, length) 372 | if tmp != "": 373 | if tmp in self.unigram: 374 | flist.append("w-1." + tmp) 375 | tmplst.append(tmp) 376 | else: 377 | flist.append("/") 378 | tmplst.append("**noWord") 379 | else: 380 | flist.append("/") 381 | tmplst.append("**noWord") 382 | prelst_in = tmplst 383 | 384 | tmplst = [] 385 | for l in range(config.wordMax, config.wordMin - 1, -1): 386 | # "suffix" wordary[n, n+l-1] 387 | # current character starts word 388 | tmp = get_slice_str(wordary, idx, l, length) 389 | if tmp != "": 390 | if tmp in self.unigram: 391 | flist.append("w1." + tmp) 392 | tmplst.append(tmp) 393 | else: 394 | flist.append("/") 395 | tmplst.append("**noWord") 396 | else: 397 | flist.append("/") 398 | tmplst.append("**noWord") 399 | postlst_in = tmplst 400 | 401 | # these are not in feature list 402 | tmplst = [] 403 | for l in range(config.wordMax, config.wordMin - 1, -1): 404 | # "prefix excluding current c" wordary[n-l, n-1] 405 | tmp = get_slice_str(wordary, idx - l, l, length) 406 | if tmp != "": 407 | if tmp in self.unigram: 408 | tmplst.append(tmp) 409 | else: 410 | tmplst.append("**noWord") 411 | else: 412 | tmplst.append("**noWord") 413 | prelst_ex = tmplst 414 | 415 | tmplst = [] 416 | for l in range(config.wordMax, config.wordMin - 1, -1): 417 | # "suffix excluding current c" wordary[n+1, n+l] 418 | tmp = get_slice_str(wordary, idx + 1, l, length) 419 | if tmp != "": 420 | if tmp in self.unigram: 421 | tmplst.append(tmp) 422 | else: 423 | tmplst.append("**noWord") 424 | else: 425 | tmplst.append("**noWord") 426 | postlst_ex = tmplst 427 | 428 | # this character is in the middle of a word 429 | # 2*(wordMax-wordMin+1)^2 (default: 2*(6-2+1)^2=50) 430 | 431 | for pre in prelst_ex: 432 | for post in postlst_in: 433 | bigram = pre + "*" + post 434 | if bigram in self.bigram: 435 | flist.append("ww.l." + bigram) 436 | else: 437 | flist.append("/") 438 | 439 | for pre in prelst_in: 440 | for post in postlst_ex: 441 | bigram = pre + "*" + post 442 | if bigram in self.bigram: 443 | flist.append("ww.r." + bigram) 444 | else: 445 | flist.append("/") 446 | 447 | return flist 448 | 449 | def convert_feature_file_to_idx_file( 450 | self, feature_file, feature_idx_file, tag_idx_file 451 | ): 452 | 453 | with open(feature_file, "r", encoding="utf8") as reader: 454 | lines = reader.readlines() 455 | 456 | with open(feature_idx_file, "w", encoding="utf8") as f_writer, open( 457 | tag_idx_file, "w", encoding="utf8" 458 | ) as t_writer: 459 | 460 | f_writer.write("{}\n\n".format(len(self.feature_to_idx))) 461 | t_writer.write("{}\n\n".format(len(self.tag_to_idx))) 462 | 463 | tags_idx = [] # type: List[str] 464 | features_idx = [] # type: List[List[str]] 465 | for line in lines: 466 | line = line.strip() 467 | if not line: 468 | # sentence finish 469 | for feature_idx in features_idx: 470 | if not feature_idx: 471 | f_writer.write("0\n") 472 | else: 473 | f_writer.write(",".join(map(str, feature_idx))) 474 | f_writer.write("\n") 475 | f_writer.write("\n") 476 | 477 | t_writer.write(",".join(map(str, tags_idx))) 478 | t_writer.write("\n\n") 479 | 480 | tags_idx = [] 481 | features_idx = [] 482 | continue 483 | 484 | splits = line.split(" ") 485 | feature_idx = [ 486 | self.feature_to_idx[feat] 487 | for feat in splits[:-1] 488 | if feat in self.feature_to_idx 489 | ] 490 | features_idx.append(feature_idx) 491 | tags_idx.append(self.tag_to_idx[splits[-1]]) 492 | 493 | def convert_text_file_to_feature_file( 494 | self, text_file, conll_file=None, feature_file=None 495 | ): 496 | 497 | if conll_file is None: 498 | conll_file = "{}.conll{}".format(*os.path.split(text_file)) 499 | if feature_file is None: 500 | feature_file = "{}.feat{}".format(*os.path.split(text_file)) 501 | 502 | if config.nLabel == 2: 503 | B = B_single = "B" 504 | I_first = I = I_end = "I" 505 | elif config.nLabel == 3: 506 | B = B_single = "B" 507 | I_first = I = "I" 508 | I_end = "I_end" 509 | elif config.nLabel == 4: 510 | B = "B" 511 | B_single = "B_single" 512 | I_first = I = "I" 513 | I_end = "I_end" 514 | elif config.nLabel == 5: 515 | B = "B" 516 | B_single = "B_single" 517 | I_first = "I_first" 518 | I = "I" 519 | I_end = "I_end" 520 | 521 | conll_line_format = "{} {}\n" 522 | 523 | with open(text_file, "r", encoding="utf8") as reader, open( 524 | conll_file, "w", encoding="utf8" 525 | ) as c_writer, open(feature_file, "w", encoding="utf8") as f_writer: 526 | for line in reader: 527 | line = line.strip() 528 | if not line: 529 | continue 530 | words = self.keyword_rename(line).split() 531 | example = [] 532 | tags = [] 533 | for word in words: 534 | word_length = len(word) 535 | for idx, character in enumerate(word): 536 | if word_length == 1: 537 | tag = B_single 538 | elif idx == 0: 539 | tag = B 540 | elif idx == word_length - 1: 541 | tag = I_end 542 | elif idx == 1: 543 | tag = I_first 544 | else: 545 | tag = I 546 | c_writer.write(conll_line_format.format(character, tag)) 547 | 548 | if config.numLetterNorm: 549 | example.append( 550 | self._num_letter_normalize_char(character) 551 | ) 552 | else: 553 | example.append(character) 554 | tags.append(tag) 555 | c_writer.write("\n") 556 | 557 | for idx, tag in enumerate(tags): 558 | features = self.get_node_features(idx, example) 559 | features = [ 560 | (feature if feature in self.feature_to_idx else "/") 561 | for feature in features 562 | ] 563 | features.append(tag) 564 | f_writer.write(" ".join(features)) 565 | f_writer.write("\n") 566 | f_writer.write("\n") 567 | 568 | def save(self, model_dir=None): 569 | if model_dir is None: 570 | model_dir = config.modelDir 571 | data = {} 572 | data["unigram"] = sorted(list(self.unigram)) 573 | data["bigram"] = sorted(list(self.bigram)) 574 | data["feature_to_idx"] = self.feature_to_idx 575 | data["tag_to_idx"] = self.tag_to_idx 576 | 577 | with open(os.path.join(model_dir, 'features.pkl'), 'wb') as writer: 578 | pickle.dump(data, writer, protocol=pickle.HIGHEST_PROTOCOL) 579 | 580 | 581 | # with open( 582 | # os.path.join(config.modelDir, "features.json"), "w", encoding="utf8" 583 | # ) as writer: 584 | # json.dump(data, writer, ensure_ascii=False) 585 | 586 | @classmethod 587 | def load(cls, model_dir=None): 588 | if model_dir is None: 589 | model_dir = config.modelDir 590 | extractor = cls.__new__(cls) 591 | 592 | feature_path = os.path.join(model_dir, "features.pkl") 593 | if os.path.exists(feature_path): 594 | with open(feature_path, "rb") as reader: 595 | data = pickle.load(reader) 596 | extractor.unigram = set(data["unigram"]) 597 | extractor.bigram = set(data["bigram"]) 598 | extractor.feature_to_idx = data["feature_to_idx"] 599 | extractor.tag_to_idx = data["tag_to_idx"] 600 | 601 | return extractor 602 | 603 | 604 | print( 605 | "WARNING: features.pkl does not exist, try loading features.json", 606 | file=sys.stderr, 607 | ) 608 | 609 | 610 | feature_path = os.path.join(model_dir, "features.json") 611 | if os.path.exists(feature_path): 612 | with open(feature_path, "r", encoding="utf8") as reader: 613 | data = json.load(reader) 614 | extractor.unigram = set(data["unigram"]) 615 | extractor.bigram = set(data["bigram"]) 616 | extractor.feature_to_idx = data["feature_to_idx"] 617 | extractor.tag_to_idx = data["tag_to_idx"] 618 | extractor.save(model_dir) 619 | return extractor 620 | print( 621 | "WARNING: features.json does not exist, try loading using old format", 622 | file=sys.stderr, 623 | ) 624 | 625 | with open( 626 | os.path.join(model_dir, "unigram_word.txt"), 627 | "r", 628 | encoding="utf8", 629 | ) as reader: 630 | extractor.unigram = set([line.strip() for line in reader]) 631 | 632 | with open( 633 | os.path.join(model_dir, "bigram_word.txt"), 634 | "r", 635 | encoding="utf8", 636 | ) as reader: 637 | extractor.bigram = set(line.strip() for line in reader) 638 | 639 | extractor.feature_to_idx = {} 640 | feature_base_name = os.path.join(model_dir, "featureIndex.txt") 641 | for i in range(10): 642 | with open( 643 | "{}_{}".format(feature_base_name, i), "r", encoding="utf8" 644 | ) as reader: 645 | for line in reader: 646 | feature, index = line.split(" ") 647 | feature = ".".join(feature.split(".")[1:]) 648 | extractor.feature_to_idx[feature] = int(index) 649 | 650 | extractor.tag_to_idx = {} 651 | with open( 652 | os.path.join(model_dir, "tagIndex.txt"), "r", encoding="utf8" 653 | ) as reader: 654 | for line in reader: 655 | tag, index = line.split(" ") 656 | extractor.tag_to_idx[tag] = int(index) 657 | 658 | print( 659 | "INFO: features.json is saved", 660 | file=sys.stderr, 661 | ) 662 | extractor.save(model_dir) 663 | 664 | return extractor 665 | --------------------------------------------------------------------------------