├── .gitignore ├── Config ├── NER │ ├── README.md │ └── ner_data.conf └── SIM │ └── README.md ├── Data ├── DB_Data │ ├── README.md │ └── clean_triple.csv ├── NER_Data │ ├── README.md │ ├── dev.txt │ ├── q_t_a_df_testing.csv │ ├── q_t_a_df_training.csv │ ├── testing.txt │ └── train.txt ├── NLPCC2016KBQA │ ├── nlpcc-iccpol-2016.kbqa.kb │ ├── nlpcc-iccpol-2016.kbqa.testing-data │ └── nlpcc-iccpol-2016.kbqa.training-data ├── Sim_Data │ ├── README.md │ ├── dev.txt │ ├── train.txt │ └── val.txt ├── construct_dataset.py ├── construct_dataset_attribute.py ├── data_process.py ├── load_dbdata.py └── triple_clean.py ├── LICENSE ├── README.md ├── adcf.py ├── args.py ├── bert ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── create_pretraining_data.py ├── extract_features.py ├── modeling.py ├── modeling_test.py ├── multilingual.md ├── optimization.py ├── optimization_test.py ├── requirements.txt ├── run_classifier.py ├── run_pretraining.py ├── run_squad.py ├── sample_text.txt ├── tokenization.py └── tokenization_test.py ├── conlleval.pl ├── conlleval.py ├── fujc.py ├── global_config.py ├── his.py ├── image ├── KB.png └── NER.jpg ├── kbqa_test.py ├── kl.py ├── lstm_crf_layer.py ├── neo4j_qa.py ├── prt.py ├── qa_my.py ├── qa_my.sh ├── recommend_articles.log.2019-08-06 ├── recommend_articles.log.2019-08-21 ├── recommend_articles.log.2019-08-23 ├── run_ner.py ├── run_ner.sh ├── run_similarity.py ├── subpcs.py ├── terminal_ner.sh ├── terminal_predict.py └── tf_metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Config/NER/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Config/NER/ner_data.conf: -------------------------------------------------------------------------------- 1 | {"num_train_steps": 17046, "num_warmup_steps": 1704, "num_train_size": 13637, "eval.tf_record_path": "./Output/NER\\eval.tf_record", "num_eval_size": 9016} -------------------------------------------------------------------------------- /Config/SIM/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/DB_Data/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/NER_Data/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/NLPCC2016KBQA/nlpcc-iccpol-2016.kbqa.kb: -------------------------------------------------------------------------------- 1 | 空气干燥 ||| 别名 ||| 空气干燥 2 | 空气干燥 ||| 中文名 ||| 空气干燥 3 | 空气干燥 ||| 外文名 ||| air drying 4 | 空气干燥 ||| 形式 ||| 两个 5 | 空气干燥 ||| 作用 ||| 将空气中的水份去除 6 | 罗育德 ||| 别名 ||| 罗育德 7 | 罗育德 ||| 中文名 ||| 罗育德 8 | 罗育德 ||| 民族 ||| 汉族 9 | 罗育德 ||| 出生地 ||| 河南郑州 10 | 罗育德 ||| 出生日期 ||| 1971年6月 11 | 罗育德 ||| 职业 ||| 中共深圳市南山区委常委、区政府常务副区长 12 | 罗育德 ||| 毕业院校 ||| 深圳大学 13 | 罗育德 ||| 民 族 ||| 汉族 14 | 罗育德 ||| 职 业 ||| 中共深圳市南山区委常委、区政府常务副区长 15 | 鳞 ||| 别名 ||| 鳞 16 | 鳞 ||| 中文名 ||| 鳞 17 | 鳞 ||| 拼音 ||| lín 18 | 鳞 ||| 英文名 ||| Squama 19 | 鳞 ||| 汉字笔划 ||| 20 20 | 鳞 ||| 名称 ||| 鳞 21 | 鳞 ||| 繁体 ||| 鱗 22 | 鳞 ||| 笔画 ||| 20 23 | 鳞 ||| 部首 ||| 鱼 24 | 鳞 ||| 释义 ||| 鳞 (形声。从鱼,粦声。本义:鱼甲) 同本义 鳞,鱼甲也。――《说文》 鳞罗布烈。――扬雄《羽猎赋》 25 | 浙江村 ||| 别名 ||| 浙江村 26 | 浙江村 ||| 中文名 ||| 浙江村 27 | 浙江村 ||| 主要构成 ||| 温州人 28 | 浙江村 ||| 位置 ||| 北京丰台区大红门、木樨园地区 29 | 浙江村 ||| 辐射范围 ||| 周围5公里区域 30 | 浙江村 ||| 中心 ||| 木樨园桥环岛 31 | 浙江村 ||| 从事 ||| 服装生产批发、五金电器等 32 | 于明诠 ||| 别名 ||| 于明诠 33 | 于明诠 ||| 中文名 ||| 于明诠 34 | 于明诠 ||| 国籍 ||| 中国 35 | 于明诠 ||| 出生地 ||| 山东乐陵 36 | 于明诠 ||| 出生日期 ||| 1963年生 37 | 于明诠 ||| 职业 ||| 书法家 38 | 于明诠 ||| 主要成就 ||| 获"全国百名优秀青年文艺家\ 39 | 于明诠 ||| 代表作品 ||| 《书法的传统与传统的书法》 40 | 金刚部 ||| 别名 ||| 金刚部 41 | 金刚部 ||| 中文名 ||| 金刚部 42 | 金刚部 ||| 部主 ||| 阿閦如来 43 | 金刚部 ||| 含义 ||| 能摧破诸怨敌 44 | 金刚部 ||| 部主金刚界 ||| 阿閦佛 45 | 钟飙 ||| 别名 ||| 钟飙 46 | 钟飙 ||| 中文名 ||| 钟飙 47 | 钟飙 ||| 出生地 ||| 中国重庆 48 | 钟飙 ||| 出生日期 ||| 1968年 49 | 钟飙 ||| 主要成就 ||| 以独特的构图方式和切入当代艺术的视角著称 50 | 钟飙 ||| 性别 ||| 男 51 | 钟飙 ||| 国籍 ||| 中国 52 | 钟飙 ||| 出生年月 ||| 1968年 53 | 钟飙 ||| 职业 ||| 艺术家 54 | 钟飙 ||| 毕业院校 ||| 中国美术学院 55 | 葵涌中学 ||| 别名 ||| 葵涌中学 56 | 葵涌中学 ||| 中文名 ||| 葵涌中学 57 | 葵涌中学 ||| 创建时间 ||| 1969年 58 | 葵涌中学 ||| 占地面积 ||| 43547.26㎡ 59 | 葵涌中学 ||| 绿化占地 ||| 15865 ㎡ 60 | 汪尧田 ||| 别名 ||| 汪尧田 61 | 汪尧田 ||| 中文名 ||| 汪尧田 62 | 汪尧田 ||| 国籍 ||| 中国 63 | 汪尧田 ||| 民族 ||| 汉族 64 | 汪尧田 ||| 出生地 ||| 不详 65 | 愈裂霜 ||| 别名 ||| 愈裂霜 66 | 愈裂霜 ||| 中文名 ||| 愈裂霜 67 | 愈裂霜 ||| 药剂类型 ||| 软膏制剂 68 | 愈裂霜 ||| 主治 ||| 真菌引起的手气、脚气等 69 | 盖盖虫 ||| 别名 ||| 盖盖虫 70 | 盖盖虫 ||| 中文名 ||| 盖盖虫 71 | 盖盖虫 ||| 日文名 ||| カブルモ 72 | 盖盖虫 ||| 英文名 ||| Karrablast 73 | 盖盖虫 ||| 全国编号 ||| 588 74 | 盖盖虫 ||| 属性 ||| 虫 75 | 盖盖虫 ||| 种类 ||| 咬住神奇宝贝 76 | 盖盖虫 ||| 身高 ||| 0.5m 77 | 盖盖虫 ||| 体重 ||| 5.9kg 78 | 李冀川(成都电视台《道听途说》节目主持人) ||| 别名 ||| 李冀川 79 | 李冀川(成都电视台《道听途说》节目主持人) ||| 中文名 ||| 李冀川 80 | 李冀川(成都电视台《道听途说》节目主持人) ||| 出生日期 ||| 1981.12 81 | 李冀川(成都电视台《道听途说》节目主持人) ||| 籍贯 ||| 成都市温江区 82 | 李冀川(成都电视台《道听途说》节目主持人) ||| 学历 ||| 本科 83 | 唐本 ||| 别名 ||| 唐本 84 | 唐本 ||| 中文名 ||| 唐本 85 | 唐本 ||| 国籍 ||| 中国 86 | 唐本 ||| 民族 ||| 汉族 87 | 唐本 ||| 出生地 ||| 浙江省兰溪 88 | 诺水河镇 ||| 别名 ||| 诺水河镇 89 | 诺水河镇 ||| 中文名称 ||| 诺水河镇 90 | 诺水河镇 ||| 地理位置 ||| 东经107°27″,北纬32°16″——32°33″ 91 | 诺水河镇 ||| 面积 ||| 308.38平方公里 92 | 诺水河镇 ||| 人口 ||| 21114人 93 | 诺水河镇 ||| 中文名 ||| 诺水河镇 94 | 诺水河镇 ||| 所属国 ||| 中国 95 | 诺水河镇 ||| 所属市 ||| 巴山市 96 | 诺水河镇 ||| 诺水河镇 ||| 中华人民共和国 97 | 诺水河镇 ||| 上级行政区 ||| 通江县 98 | 诺水河镇 ||| 行政区类型 ||| 镇 99 | 诺水河镇 ||| 行政区划代码 ||| 511921113 100 | 诺水河镇 ||| 村级区划单位数 ||| 25 101 | 诺水河镇 ||| - 社区数 ||| 1 102 | 诺水河镇 ||| - 行政村数 ||| 24 103 | 诺水河镇 ||| 地理、人口、经济 ||| 32°23′19″N 107°10′39″E / 32.38857°N 107.17753°E坐标:32°23′19″N 107°10′39″E / 32.38857°N 107.17753°E 104 | 诺水河镇 ||| 电话区号 ||| +86 105 | 水冷 ||| 别名 ||| 水冷 106 | 水冷 ||| 中文名 ||| 水冷 107 | 水冷 ||| 产品 ||| 水冷散热器 108 | 水冷 ||| 分类 ||| 主动式水冷和被动式水冷 109 | 水冷 ||| 类别 ||| 散热器 110 | 水冷 ||| 材料 ||| 聚氯乙烯、聚乙烯和硅树脂三种 111 | 水冷 ||| 类似 ||| 风冷散热 112 | 水冷 ||| 分 类 ||| 主动式水冷和被动式水冷 113 | 水冷 ||| 材 料 ||| 聚氯乙烯、聚乙烯和硅树脂三种 114 | 水冷 ||| 产 品 ||| 水冷散热器 115 | 水冷 ||| 类 别 ||| 散热器 116 | 水冷 ||| 类 似 ||| 风冷散热 117 | 水冷机箱 ||| 别名 ||| 水冷机箱 118 | 水冷机箱 ||| 中文名 ||| 水冷机箱 119 | 水冷机箱 ||| 缺点 ||| 普遍体积过大,操作不够简单 120 | 水冷机箱 ||| 类型 ||| 电脑内部发热部件散热的一种装置 121 | 水冷机箱 ||| 功能 ||| 它包括水冷散热系统和防尘机箱 122 | 水冷机箱 ||| 英文名 ||| Water-cooled chassis 123 | 美少女战士R ||| 别名 ||| 美少女战士R 124 | 美少女战士R ||| 制作 ||| テレビ朝日 125 | 美少女战士R ||| 其他名称 ||| 美少女戦士 セーラームーンR 126 | 美少女战士R ||| 香港TVB首播 ||| 1995年8月26日~1996年1月21日 127 | 美少女战士R ||| 日本播放 ||| 1993年3月6日~1994年3月5日 128 | 美少女战士R ||| 香港TVB重播 ||| 2011年6月3日 129 | 美少女战士R ||| 导演 ||| 佐藤顺一、几原邦彦 130 | 美少女战士R ||| 中文名 ||| 美少女战士R 131 | 美少女战士R ||| 美少女戦士 セーラームーンR ||| 美少女戦士 セーラームーンR 132 | 美少女战士R ||| 日语假名 ||| びしょうじょせんし セーラームーンR 133 | 美少女战士R ||| 罗马字 ||| Bishōjo Senshi Sērā Mūn Aru 134 | 美少女战士R ||| 类型 ||| 魔法少女、爱情 135 | 美少女战士R ||| 作品原名 ||| 美少女戦士セーラームーンR 136 | 美少女战士R ||| 正式译名 ||| 美少女战士R 美少女战士第二部 137 | 美少女战士R ||| 原作 ||| 武内直子 138 | 美少女战士R ||| 企划 ||| 东伊里弥、太田贤司 139 | 美少女战士R ||| 剧本统筹 ||| 富田祐弘 140 | 美少女战士R ||| 编剧 ||| 富田祐弘等10人 141 | 美少女战士R ||| 人物设定 ||| 只野和子 142 | 美少女战士R ||| 动画制作 ||| 东映动画 143 | 美少女战士R ||| 播放电视台 ||| 朝日电视台 无线电视翡翠台 中华电视公司 无线电视J2台 144 | 美少女战士R ||| 播放期间 ||| 1993年3月6日-1994年3月5日 145 | 美少女战士R ||| 话数 ||| 全43话 146 | 美少女战士R ||| 版权信息 ||| ©东映动画 147 | 马丁·泰勒 ||| 别名 ||| 马丁·泰勒 148 | 马丁·泰勒 ||| 外文名 ||| MartinTaylor 149 | 马丁·泰勒 ||| 国籍 ||| 英格兰 150 | 马丁·泰勒 ||| 出生日期 ||| 1979年11月9日 151 | 马丁·泰勒 ||| 身高 ||| 193cm 152 | 马丁·泰勒 ||| 体重 ||| 89Kg 153 | 马丁·泰勒 ||| 运动项目 ||| 足球 154 | 马丁·泰勒 ||| 所属运动队 ||| 沃特福德足球俱乐部 155 | 马丁·泰勒 ||| 中文名 ||| 马丁。泰勒 156 | 马丁·泰勒 ||| 专业特点 ||| 后卫 157 | 马丁·泰勒 ||| 英文名 ||| Martin Taylor 158 | 马丁·泰勒 ||| 籍贯 ||| 英格兰阿兴顿 159 | 马丁·泰勒 ||| 性别 ||| 男 160 | 马丁·泰勒 ||| 出生年月 ||| 1979年11月9日 161 | 马丁·泰勒 ||| 马丁·泰勒(2008年摄于老特拉福德球场) ||| 马丁·泰勒(2008年摄于老特拉福德球场) 162 | 马丁·泰勒 ||| 出生 ||| 1945年09月14日(69岁) 切斯特, 柴郡, 英格兰 163 | 马丁·泰勒 ||| 教育程度 ||| 皇家文理学校(吉尔福德)(英语:Royal Grammar School, Guildford) 东英吉利亚大学 164 | 马丁·泰勒 ||| 职业 ||| 足球评论员 165 | 马丁·泰勒 ||| 雇主 ||| 天空电视台, ESPN, TSN(英语:TSN) 166 | 马丁·泰勒 ||| 个人资料 ||| Martin Taylor[1] 167 | 马丁·泰勒 ||| 出生地点 ||| 英格兰阿兴顿 168 | 马丁·泰勒 ||| 位置 ||| 后卫 169 | 马丁·泰勒 ||| 俱乐部资料 ||| 谢周三 170 | 马丁·泰勒 ||| 球衣号码 ||| 5 171 | 马丁·泰勒 ||| 青年队 ||| 克拉姆灵顿少年队 172 | 马丁·泰勒 ||| 职业俱乐部* ||| 出场 173 | 马丁·泰勒 ||| 1997–2004 ||| 88 174 | 马丁·泰勒 ||| 2000 ||| 4 175 | 马丁·泰勒 ||| 2004–2010 ||| 99 176 | 马丁·泰勒 ||| 2007 ||| 8 177 | 马丁·泰勒 ||| 2010–2012 ||| 90 178 | 马丁·泰勒 ||| 2012– ||| 10 179 | 马丁·泰勒 ||| 国家队 ||| 1 180 | 入射线 ||| 别名 ||| 入射线 181 | 入射线 ||| 中文名 ||| 入射线 182 | 入射线 ||| 学科 ||| 物理光学 183 | 入射线 ||| 类别 ||| 射线 184 | 入射线 ||| 相对 ||| 反射线 185 | 万家灯火(林兆华李六乙导演话剧) ||| 别名 ||| 万家灯火 186 | 万家灯火(林兆华李六乙导演话剧) ||| 中文名 ||| 万家灯火 187 | 万家灯火(林兆华李六乙导演话剧) ||| 导演 ||| 林兆华、李六乙 188 | 万家灯火(林兆华李六乙导演话剧) ||| 编剧 ||| 李云龙 189 | 万家灯火(林兆华李六乙导演话剧) ||| 制作 ||| 北京人民艺术剧院 190 | 万家灯火(林兆华李六乙导演话剧) ||| 上映时间 ||| 2002年 191 | 万家灯火(林兆华李六乙导演话剧) ||| 主演 ||| 宋丹丹、濮存昕、米铁增、何冰 192 | 紫屋魔恋 ||| 别名 ||| 紫屋魔恋 193 | 紫屋魔恋 ||| 中文名 ||| 紫屋魔恋 194 | 紫屋魔恋 ||| 制片人 ||| 彼得.古伯 195 | 紫屋魔恋 ||| 主演 ||| 杰克.尼科尔森 196 | 紫屋魔恋 ||| 片长 ||| 121分钟 197 | 美丽的日子(王心凌演唱专辑) ||| 别名 ||| 美丽的日子 198 | 美丽的日子(王心凌演唱专辑) ||| 中文名 ||| 美丽的日子 199 | 美丽的日子(王心凌演唱专辑) ||| 发行时间 ||| 2009年11月13日 200 | 美丽的日子(王心凌演唱专辑) ||| 地区 ||| 台湾 201 | 美丽的日子(王心凌演唱专辑) ||| 语言 ||| 普通话 202 | 美丽的日子(王心凌演唱专辑) ||| 歌手 ||| 王心凌 203 | 美丽的日子(王心凌演唱专辑) ||| 音乐风格 ||| 流行 204 | 埃尔文·约翰逊 ||| 别名 ||| 埃尔文·约翰逊 205 | 埃尔文·约翰逊 ||| 外文名 ||| EarvinJohnson 206 | 埃尔文·约翰逊 ||| 国籍 ||| 美国 207 | 埃尔文·约翰逊 ||| 出生日期 ||| 1959年08月14日 208 | 埃尔文·约翰逊 ||| 身高 ||| 2.06米/6尺9寸 209 | 埃尔文·约翰逊 ||| 体重 ||| 98公斤/215磅 210 | 埃尔文·约翰逊 ||| 运动项目 ||| 篮球 211 | 埃尔文·约翰逊 ||| 所属运动队 ||| 湖人队 212 | 埃尔文·约翰逊 ||| 中文名 ||| 埃尔文·约翰逊 213 | 埃尔文·约翰逊 ||| 出生地 ||| 密歇根 214 | 埃尔文·约翰逊 ||| 主要奖项 ||| 五届NBA总冠军 三届NBA最有价值球员奖 三届NBA总决赛最有价值球员奖 二次NBA全明星赛最有价值球员奖 215 | 埃尔文·约翰逊 ||| 粤语地区译名 ||| 埃尔文·莊逊 216 | 埃尔文·约翰逊 ||| 英文名 ||| ErvinJohnson 217 | 埃尔文·约翰逊 ||| 性别 ||| 男 218 | 埃尔文·约翰逊 ||| 出生年月 ||| 1959年8月14日 219 | 埃尔文·约翰逊 ||| 职业 ||| 篮球 220 | 埃尔文·约翰逊 ||| 体 重 ||| 98公斤/215磅 221 | 埃尔文·约翰逊 ||| 别 名 ||| Magic Johnson/魔术师-约翰逊 222 | 埃尔文·约翰逊 ||| 国 籍 ||| 美国 223 | 河北外国语职业学院 ||| 别名 ||| 河北外国语职业学院 224 | 河北外国语职业学院 ||| 中文名 ||| 河北外国语职业学院 225 | 河北外国语职业学院 ||| 英文名 ||| 英文:Hebei Vocational College of Foreign Languages韩语:하북외국어전문대학은 226 | 河北外国语职业学院 ||| 简称 ||| 河北外国语职院 227 | 河北外国语职业学院 ||| 创办时间 ||| 1948年 228 | 河北外国语职业学院 ||| 类别 ||| 公立大学 229 | 河北外国语职业学院 ||| 学校类型 ||| 语言 230 | 河北外国语职业学院 ||| 所属地区 ||| 中国秦皇岛 231 | 河北外国语职业学院 ||| 现任校长 ||| 丁国声 232 | 河北外国语职业学院 ||| 主管部门 ||| 河北省教育厅 233 | 河北外国语职业学院 ||| 校训 ||| 德以立校 学以弘业 234 | 河北外国语职业学院 ||| 主要院系 ||| 英语系、教育系、国际商务系、酒店航空系、西语系等 235 | 河北外国语职业学院 ||| 外文名 ||| 英文:Hebei Vocational College of Foreign Languages韩语:하북외국어전문대학은 236 | 河北外国语职业学院 ||| 中文名称 ||| 河北外国语职业学院 237 | 河北外国语职业学院 ||| 校址 ||| 学院地址:河北省秦皇岛市南戴河前进路6号 238 | 河北外国语职业学院 ||| 邮编 ||| 066311 239 | 河北外国语职业学院 ||| 外文名称 ||| Hebei Vocational College of Foreign Languages 240 | 河北外国语职业学院 ||| 学校主页 ||| http://www.hbvcfl.com.cn/ 241 | 徐峥 ||| 别名 ||| 徐峥 242 | 徐峥 ||| 中文名 ||| 徐峥[13] 243 | 徐峥 ||| 外文名 ||| Xú Zhēng 244 | 徐峥 ||| 国籍 ||| 中国 245 | 徐峥 ||| 民族 ||| 汉族 246 | 徐峥 ||| 星座 ||| 白羊座 247 | 徐峥 ||| 血型 ||| A型 248 | 徐峥 ||| 身高 ||| 178cm 249 | 徐峥 ||| 体重 ||| 72kg 250 | 徐峥 ||| 出生地 ||| 上海市闸北区 251 | 徐峥 ||| 出生日期 ||| 1972年4月18日 252 | 徐峥 ||| 职业 ||| 演员、导演 253 | 徐峥 ||| 毕业院校 ||| 上海戏剧学院 254 | 徐峥 ||| 代表作品 ||| 人再囧途之泰囧、无人区、人在囧途、爱情呼叫转移、春光灿烂猪八戒、李卫当官 255 | 徐峥 ||| 主要成就 ||| 中国话剧金狮奖演员奖 北京大学生电影节最受欢迎导演 电影频道传媒大奖最受关注男演员 上海国际电影节最卖座电影 安徽卫视国剧盛典年度最佳角色 展开 256 | 徐峥 ||| 妻子 ||| 陶虹 257 | 徐峥 ||| 女儿 ||| 徐小宝 258 | 徐峥 ||| 籍贯 ||| 上海 259 | 徐峥 ||| 性别 ||| 男 260 | 徐峥 ||| 出生年月 ||| 1972年4月18日 261 | 徐峥 ||| 语言 ||| 普通话、吴语、英语 262 | 徐峥 ||| 祖籍 ||| 浙江省嘉兴市平湖市曹桥街道百寿村沈家门 263 | 徐峥 ||| 导演代表作 ||| 人再囧途之泰囧 264 | 徐峥 ||| 本名 ||| 徐峥 265 | 徐峥 ||| 出生 ||| 1972年4月18日(42岁)  中国上海市 266 | 徐峥 ||| 教育程度 ||| 上海戏剧学院90级表演系本科毕业 267 | 徐峥 ||| 配偶 ||| 陶虹(2003年-) 268 | 徐峥 ||| 儿女 ||| 一女 269 | 徐峥 ||| 活跃年代 ||| 1997年- 270 | 徐峥 ||| 经纪公司 ||| 影响力明星经纪公司 271 | 徐峥 ||| 奖项 ||| 详见下文 272 | 徐峥 ||| 电影 ||| 《爱情呼叫转移》 《疯狂的石头》 《疯狂的赛车》 《人在囧途》 《人再囧途之泰囧》 273 | 徐峥 ||| 电视剧 ||| 《春光灿烂猪八戒》 《李卫当官》 《大男当婚》 274 | 计算机应用基础 ||| 别名 ||| 计算机应用基础 275 | 计算机应用基础 ||| 中文名 ||| 计算机应用基础 276 | 计算机应用基础 ||| 作者 ||| 刘晓斌、魏智荣、刘庆生 277 | 计算机应用基础 ||| 类别 ||| 计算机/网络 > 计算机理论 278 | 计算机应用基础 ||| 字数 ||| 445000 279 | 计算机应用基础 ||| ISBN ||| 9787122177513 280 | 计算机应用基础 ||| 出版社 ||| 化学工业出版社 281 | 计算机应用基础 ||| 页数 ||| 255 282 | 计算机应用基础 ||| 开本 ||| 16开 283 | 计算机应用基础 ||| 出版时间 ||| 2013年 284 | 计算机应用基础 ||| 装帧 ||| 平装 285 | 计算机应用基础 ||| 原作者 ||| 刘升贵,黄敏,庄强兵 286 | 计算机应用基础 ||| 定价 ||| 29.00元 287 | 登山临水 ||| 别名 ||| 登山临水 288 | 登山临水 ||| 名称 ||| 登山临水 289 | 登山临水 ||| 拼音 ||| dēng shān lín shuǐ 290 | 登山临水 ||| 出处 ||| 战国·楚·宋玉《九辩》:“登山临水兮送将归。” 291 | 登山临水 ||| 释义 ||| 形容旅途遥远。也指游山玩水。 292 | 登山临水 ||| 用法 ||| 褒义 谓语 293 | 登山临水 ||| 结构 ||| 联合式 294 | 白驹过隙 ||| 别名 ||| 白驹过隙 295 | 白驹过隙 ||| 名称 ||| 白驹过隙 296 | 石川数正 ||| 别名 ||| 石川数正 297 | 石川数正 ||| 中文名 ||| 石川数正 298 | 石川数正 ||| 外文名 ||| いしかわ かずまさ 299 | 石川数正 ||| 国籍 ||| 日本 300 | 石川数正 ||| 民族 ||| 大和民族 301 | 石川数正 ||| 出生地 ||| 大阪府 302 | 石川数正 ||| 出生日期 ||| 1533年 303 | 石川数正 ||| 逝世日期 ||| 1592年 304 | 石川数正 ||| 职业 ||| 家臣 305 | 石川数正 ||| 信仰 ||| 神道教 306 | 石川数正 ||| 主要成就 ||| 德川二重臣之一 307 | 石川数正 ||| 信 仰 ||| 神道教 308 | 石川数正 ||| 职 业 ||| 家臣 309 | 石川数正 ||| 国 籍 ||| 日本 310 | 石川数正 ||| 别 名 ||| 石川伯耆守 311 | 石川数正 ||| 民 族 ||| 大和民族 312 | 石川数正 ||| 日语写法 ||| 石川 数正 313 | 石川数正 ||| 假名 ||| いしかわ かずまさ 314 | 石川数正 ||| 平文式罗马字 ||| Ishikawa Kazumasa 315 | 净莲妖圣 ||| 别名 ||| 净莲妖圣 316 | 净莲妖圣 ||| 中文名 ||| 净莲妖圣 317 | 净莲妖圣 ||| 职业 ||| 斗圣巅峰 318 | 净莲妖圣 ||| 代表作品 ||| 斗破苍穹 319 | 净莲妖圣 ||| 异火 ||| 净莲妖火 320 | 对句作起法 ||| 别名 ||| 对句作起法 321 | 对句作起法 ||| 中文名 ||| 对句作起法 322 | 对句作起法 ||| 类属 ||| 诗歌技巧之一 323 | 对句作起法 ||| 人物 ||| 刘铁冷 324 | 对句作起法 ||| 著作 ||| 《作诗百法》 325 | 花右京女仆队 ||| 别名 ||| 花右京女仆队 326 | 花右京女仆队 ||| 中文名 ||| 花右京女仆队 327 | 花右京女仆队 ||| 原版名称 ||| 花右京メイド队 328 | 花右京女仆队 ||| 其他名称 ||| 花右京女佣队 329 | 花右京女仆队 ||| 作者 ||| 森繁(もりしげ) 330 | 花右京女仆队 ||| 类型 ||| 恋爱 331 | 花右京女仆队 ||| 地区 ||| 日本 332 | 花右京女仆队 ||| 连载杂志 ||| 少年champion 333 | 花右京女仆队 ||| 连载期间 ||| 1999年4月-2006年9月 334 | 花右京女仆队 ||| 出版社 ||| 秋田书店 335 | 花右京女仆队 ||| 单行本册数 ||| 全14卷 336 | 花右京女仆队 ||| 其他外文名 ||| 花右京メイド队 337 | 花右京女仆队 ||| 其他译名 ||| 花右京女佣队 338 | 花右京女仆队 ||| 出品时间 ||| 1999年4月-2006年9月 339 | 花右京女仆队 ||| 动画话数 ||| 全15话(3话电视没有播放)+12话 340 | 花右京女仆队 ||| 地 区 ||| 日本 341 | 花右京女仆队 ||| 外文名 ||| 花右京メイド队 342 | 花右京女仆队 ||| 播出时间 ||| 2001年(第1季)2004年(第2季) 343 | 花右京女仆队 ||| 作 者 ||| 森繁(もりしげ) 344 | 直销 ||| 别名 ||| 直销 345 | 直销 ||| 中文名 ||| 直销 346 | 直销 ||| 外文名 ||| Direct Selling 347 | 直销 ||| 归类 ||| 销售方式 348 | 直销 ||| 途径 ||| 直接向最终消费者推销产品 349 | 直销 ||| 别称 ||| 厂家直接销售 350 | 直销 ||| 形式 ||| 生产商文化 销售商文化 351 | 无尘衣 ||| 别名 ||| 无尘衣 352 | 无尘衣 ||| 中文名 ||| 无尘衣 353 | 无尘衣 ||| 类别 ||| 服装 354 | 无尘衣 ||| 国家 ||| 中国 355 | 无尘衣 ||| 适用 ||| 男女通用 356 | 异度传说 ||| 别名 ||| 异度传说 357 | 异度传说 ||| 中文名 ||| 异度传说 358 | 异度传说 ||| 游戏类别 ||| 角色扮演 359 | 异度传说 ||| 开发商 ||| MONOLITHSOFT 360 | 异度传说 ||| 发行时间 ||| 2002年4月3日(第一部) 361 | 异度传说 ||| 外文名 ||| XenoSaga 362 | 异度传说 ||| 游戏平台 ||| PS2 363 | 异度传说 ||| 发行商 ||| NAMCO 364 | 异度传说 ||| 主要角色 ||| 卯月紫苑、KOS-MOS、基奇、M.O.M.O.、凯欧斯、卯月仁等 365 | 秋和 ||| 别名 ||| 秋和 366 | 秋和 ||| 中文名 ||| 秋和 367 | 秋和 ||| 国籍 ||| 中国 368 | 秋和 ||| 出生日期 ||| 11月11日 369 | 秋和 ||| 职业 ||| 作家 370 | 秋和 ||| 代表作品 ||| 《尘埃眠于光年》 371 | 哪里 ||| 别名 ||| 哪里 372 | 哪里 ||| 中文名 ||| 哪里 373 | 哪里 ||| 外文名 ||| Where 374 | 哪里 ||| 拼音 ||| nǎ lǐ 375 | 哪里 ||| 注音 ||| ㄣㄚˇ ㄌㄧˇ 376 | 哪里 ||| 同义词 ||| 那边 那里 那处 那儿 何处 377 | 哪里 ||| 英文名 ||| where 378 | 机械能守恒定律 ||| 别名 ||| 机械能守恒定律 379 | 机械能守恒定律 ||| 中文名 ||| 机械能守恒定律 380 | 机械能守恒定律 ||| 外文名 ||| law of conservation of mechanical energy 381 | 机械能守恒定律 ||| 所属领域 ||| 动力学 382 | 机械能守恒定律 ||| 基本公式 ||| △E机=E(末)-E(初)=0 383 | 机械能守恒定律 ||| 条件 ||| 无外界能量损失 384 | 机械能守恒定律 ||| 贡献者 ||| 焦耳、迈尔和亥姆霍兹 385 | 机械能守恒定律 ||| 应用学科 ||| 物理学 386 | 机械能守恒定律 ||| 条 件 ||| 无外界能量损失 387 | 折扣优惠 ||| 别名 ||| 折扣优惠 388 | 折扣优惠 ||| 中文名 ||| 折扣优惠 389 | 折扣优惠 ||| 注音 ||| zhē kòu yōu huì 390 | 折扣优惠 ||| 相关词 ||| 降价、促销 391 | 折扣优惠 ||| 体现形式 ||| 代金券、促销活动 392 | 周易科学观 ||| 别名 ||| 周易科学观 393 | 周易科学观 ||| 书名 ||| 周易科学观 394 | 周易科学观 ||| 作者 ||| 徐道一 395 | 周易科学观 ||| ISBN ||| 9787502804596 396 | 周易科学观 ||| 页数 ||| 32开 397 | 周易科学观 ||| 出版社 ||| 地震出版社北京发行部 398 | 周易科学观 ||| 出版时间 ||| 1992-5-1 399 | 周易科学观 ||| 开本 ||| 32开 400 | 东陵少主 ||| 别名 ||| 东陵少主 401 | 东陵少主 ||| 中文名 ||| 东陵少主 402 | 东陵少主 ||| 配音 ||| 黄文择 403 | 东陵少主 ||| 登场作品 ||| 霹雳布袋戏 404 | 东陵少主 ||| 生日 ||| 5.21-6.22(双子座,会刊第149期) 405 | 东陵少主 ||| 性别 ||| 男 406 | 东陵少主 ||| 身高 ||| 184cm 407 | 东陵少主 ||| 身份 ||| 五方主星之东青龙 408 | 东陵少主 ||| 初登场 ||| 霹雳英雄榜之争王记第1集 409 | 东陵少主 ||| 退场 ||| 霹雳图腾第16集 410 | 东陵少主 ||| 根据地 ||| 怀拥天地七步阶 411 | 东陵少主 ||| 人物创作者 ||| 罗陵 412 | 金匮要略 ||| 别名 ||| 金匮要略 413 | 金匮要略 ||| 书名 ||| 活解金匮要略 414 | 金匮要略 ||| 作者 ||| (汉)张仲景著,(宋)林亿校正,杨鹏举,侯仙明,杨延巍注释 415 | 金匮要略 ||| 类别 ||| 图书 医学 中医学 经典古籍 416 | 金匮要略 ||| 页数 ||| 148页 417 | 金匮要略 ||| 定价 ||| 14.00元 418 | 金匮要略 ||| 出版社 ||| 学苑出版社 419 | 金匮要略 ||| 出版时间 ||| 2008年1月 420 | 金匮要略 ||| 书 名 ||| 活解金匮要略 421 | 金匮要略 ||| 类 别 ||| 图书 医学 中医学 经典古籍 422 | 金匮要略 ||| 定 价 ||| 14.00元 423 | 金匮要略 ||| 作 者 ||| (汉)张仲景著,(宋)林亿校正,杨鹏举,侯仙明,杨延巍注释 424 | 金匮要略 ||| 页 数 ||| 148页 425 | 祖国网 ||| 别名 ||| 祖国网 426 | 祖国网 ||| 中文名 ||| 祖国网 427 | 祖国网 ||| 成立时间 ||| 2009年2月 428 | 祖国网 ||| 口号 ||| 爱国者的精神家园 429 | 祖国网 ||| 因故关闭 ||| 2011年 430 | 祖国网 ||| 覆盖方面 ||| 政、经、军、商、史 431 | 祖国网 ||| 编辑总部 ||| 西柏坡 432 | 村官(大学生村官) ||| 别名 ||| 村官 433 | 村官(大学生村官) ||| 中文名 ||| 村官 434 | 村官(大学生村官) ||| 职务 ||| 村党支部书记助理、村主任助理 435 | 村官(大学生村官) ||| 面向 ||| 应届高校本科及以上学历毕业生 436 | 村官(大学生村官) ||| 开始于 ||| 1995年 437 | 村官(大学生村官) ||| 名称 ||| 村官 438 | 村官(大学生村官) ||| 拼音 ||| cunguan 439 | 村官(大学生村官) ||| 英文 ||| village official 440 | 村官(大学生村官) ||| 笔画 ||| 15 441 | 舜玉路街道 ||| 别名 ||| 舜玉路街道 442 | 舜玉路街道 ||| 中文名称 ||| 舜玉路街道 443 | 舜玉路街道 ||| 行政区类别 ||| 街道 444 | 舜玉路街道 ||| 所属地区 ||| 济南市 445 | 舜玉路街道 ||| 电话区号 ||| 0531 446 | 舜玉路街道 ||| 面积 ||| 4.6平方公里 447 | 舜玉路街道 ||| 人口 ||| 6.1万人 448 | 舜玉路街道 ||| 舜玉路街道 ||| 中华人民共和国 449 | 舜玉路街道 ||| 上级行政区 ||| 市中区 450 | 舜玉路街道 ||| 行政区类型 ||| 街道 451 | 舜玉路街道 ||| 行政区划代码 ||| 370103012 452 | 舜玉路街道 ||| 村级区划单位数 ||| 24 453 | 舜玉路街道 ||| - 社区数 ||| 24 454 | 舜玉路街道 ||| - 行政村数 ||| 0 455 | 舜玉路街道 ||| 地理、人口、经济 ||| 36°37′33″N 117°00′27″E / 36.62588°N 117.00746°E坐标:36°37′33″N 117°00′27″E / 36.62588°N 117.00746°E 456 | 谍影重重之上海 ||| 别名 ||| 谍影重重之上海 457 | 谍影重重之上海 ||| 中文名 ||| 《谍影重重之上海》 458 | 谍影重重之上海 ||| 出品公司 ||| 新力量影视 459 | 谍影重重之上海 ||| 导演 ||| 黄文利 460 | 谍影重重之上海 ||| 主演 ||| 李光洁, 秦海璐, 廖凡 461 | 谍影重重之上海 ||| 类型 ||| 谍战 462 | 谍影重重之上海 ||| 出品时间 ||| 2007年 463 | 谍影重重之上海 ||| 制片地区 ||| 中国大陆 464 | 谍影重重之上海 ||| 编剧 ||| 霍昕 465 | 谍影重重之上海 ||| 集数 ||| 24集 466 | 谍影重重之上海 ||| 上映时间 ||| 2009年9月5日 467 | 游珍珠泉记 ||| 别名 ||| 游珍珠泉记 468 | 游珍珠泉记 ||| 作品名称 ||| 游珍珠泉记 469 | 游珍珠泉记 ||| 创作年代 ||| 清代 470 | 游珍珠泉记 ||| 文学体裁 ||| 散文作品原文 471 | 游珍珠泉记 ||| 作者 ||| 王昶 472 | 日本电视剧 ||| 别名 ||| 日本电视剧 473 | 日本电视剧 ||| 中文名 ||| 日本电视剧 474 | 日本电视剧 ||| 外文名 ||| テレビドラマ 475 | 日本电视剧 ||| 俗称 ||| 日剧、霓虹剧 476 | 日本电视剧 ||| 语言 ||| 日语 477 | 日本电视剧 ||| 奖项 ||| 日剧学院赏、银河赏 等 478 | 日本电视剧 ||| 剧长 ||| 多数10-13集 479 | 日本电视剧 ||| 片长 ||| 40分钟到60分钟不等 480 | 日本电视剧 ||| 发行语言 ||| 日语 481 | 日本电视剧 ||| 拍摄地点 ||| 日本 482 | 日本电视剧 ||| 颜色 ||| 彩色 483 | 日本电视剧 ||| 类型 ||| 冬季剧,春季剧,夏季剧,秋季剧 484 | 日本电视剧 ||| 其它译名 ||| 日剧 霓虹剧 485 | 日本电视剧 ||| 制片地区 ||| 日本 486 | 日本电视剧 ||| 时段 ||| 月九,木十,金十,土九 487 | 平安银行 ||| 别名 ||| 平安银行 488 | 平安银行 ||| 公司名称 ||| 平安银行股份有限公司 489 | 平安银行 ||| 外文名称 ||| Ping An Bank Co., Ltd. 490 | 平安银行 ||| 总部地点 ||| 中国深圳 491 | 平安银行 ||| 成立时间 ||| 1987年 492 | 平安银行 ||| 经营范围 ||| 提供一站式的综合金融产品与服务 493 | 平安银行 ||| 公司性质 ||| 上市公司 494 | 平安银行 ||| 公司口号 ||| 平安银行 真的不一样 495 | 平安银行 ||| 总资产 ||| 18270亿元人民币(2013年6月) 496 | 平安银行 ||| 员工数 ||| 32851人(2013年6月) 497 | 平安银行 ||| 战略思想 ||| 变革 创新 发展 498 | 平安银行 ||| 经营理念 ||| 对外以客户为中心 对内以人为本 499 | 平安银行 ||| 银行特色 ||| 专业化 集约化 综合金融 500 | 平安银行 ||| 董事长 ||| 孙建一 501 | 平安银行 ||| 行长 ||| 邵平 502 | 平安银行 ||| 中文名称 ||| 平安银行股份有限公司 503 | 平安银行 ||| 行 长 ||| 邵平 504 | 平安银行 ||| 公司类型 ||| 商业银行 505 | 平安银行 ||| 股票代号 ||| 深交所:000001 506 | 平安银行 ||| 成立 ||| 1995年6月22日 507 | 平安银行 ||| 代表人物 ||| 董事长:孙建一 副总经理: 行长:邵平 监事长:邱伟 508 | 平安银行 ||| 产业 ||| 银行 509 | 平安银行 ||| 产品 ||| 存款结算、理财、信用卡等 510 | 平安银行 ||| 净利润 ||| 1,636,029,604元人民币 511 | 平安银行 ||| 母公司 ||| 中国平安集团 49.57% 512 | 平安银行 ||| 网址 ||| 平安银行股份有限公司 513 | 墓地风水 ||| 别名 ||| 墓地风水 514 | 墓地风水 ||| 中文名 ||| 墓地风水 515 | 墓地风水 ||| 外文名 ||| The graveyard of feng shui 516 | 墓地风水 ||| 学术分类 ||| 堪舆地理学 517 | 墓地风水 ||| 代表作品 ||| 《葬书》 518 | 墓地风水 ||| 代表人物 ||| 晋代的郭璞 519 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 别名 ||| 揠苗助长 520 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 中文名 ||| 揠苗助长 521 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 制片地区 ||| 美国 522 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 制片人 ||| 米切尔-胡尔维茨 523 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 类型 ||| 喜剧 524 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 主演 ||| 杰森-贝特曼,迈克尔-布拉斯 525 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 语言 ||| 英语 526 | 胸神经前支 ||| 别名 ||| 胸神经前支 527 | 胸神经前支 ||| 中文名 ||| 胸神经前支 528 | 胸神经前支 ||| 外文名 ||| anterior branch of thoracic nerves 529 | 胸神经前支 ||| 数量 ||| 12对 530 | 胸神经前支 ||| 分类 ||| 肋间神经,肋下神经 531 | 胸神经前支 ||| 分布特性 ||| 节段性 532 | 胸神经前支 ||| 检查部位 ||| 胸骨角、肋骨、剑突 533 | 闭环 ||| 别名 ||| 闭环 534 | 闭环 ||| 中文名 ||| 闭环 535 | 闭环 ||| 又叫 ||| 反馈控制系统 536 | 闭环 ||| 全称 ||| 闭环结构 537 | 闭环 ||| 原理 ||| 输出值尽量接近于期望值 538 | 闭环 ||| 应用 ||| 系统元件参数存在无法预计的变化 539 | 闭环 ||| 举例 ||| 调节水龙头、油路不畅流量下降 540 | HTC myTouch 4G Slide ||| 别名 ||| HTC myTouch 4G Slide 541 | HTC myTouch 4G Slide ||| 外文名 ||| HTC myTouch 4G Slide 542 | HTC myTouch 4G Slide ||| 网络模式 ||| GSM,WCDMA 543 | HTC myTouch 4G Slide ||| ROM容量 ||| 4GB 544 | HTC myTouch 4G Slide ||| RAM容量 ||| 768MB 545 | 三级甲等医院 ||| 别名 ||| 三级甲等医院 546 | 三级甲等医院 ||| 中文名 ||| 三级甲等医院 547 | 三级甲等医院 ||| 简称 ||| 三甲医院 548 | 三级甲等医院 ||| 类别 ||| 医院等级之一 549 | 三级甲等医院 ||| 三甲标准 ||| 病床数在501张以上、面向多地区 550 | 三级甲等医院 ||| 评选标准 ||| 按分等评分标准获得超过900分 551 | 三级甲等医院 ||| 级别 ||| 三甲医院为中国医院的最高级别 552 | 三级甲等医院 ||| 甲等标准 ||| 按分等评分标准获得超过900分 553 | 三级甲等医院 ||| 简 称 ||| 三甲医院 554 | 三级甲等医院 ||| 三级医院标准 ||| 病床数在501张以上、面向多地区 555 | 三级甲等医院 ||| 类 别 ||| 医院等级之一 556 | 不列颠尼亚(罗马帝国行省) ||| 别名 ||| 不列颠尼亚 557 | 不列颠尼亚(罗马帝国行省) ||| 中文名 ||| 不列颠尼亚 558 | 不列颠尼亚(罗马帝国行省) ||| 外文名 ||| Britannia 559 | 不列颠尼亚(罗马帝国行省) ||| 定义 ||| 不列颠岛古称;罗马帝国行省 560 | 白领 ||| 别名 ||| 白领 561 | 白领 ||| 中文名 ||| 白领 562 | 白领 ||| 英文名 ||| White-collar worker 563 | 归去来兮辞 ||| 别名 ||| 归去来兮辞 564 | 归去来兮辞 ||| 作品名称 ||| 归去来兮辞 -------------------------------------------------------------------------------- /Data/Sim_Data/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/construct_dataset.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | import os 4 | import pandas as pd 5 | 6 | 7 | ''' 8 | 构造NER训练集,实体序列标注,训练BERT+BiLSTM+CRF 9 | ''' 10 | # [training, testing] 11 | data_type = "testing" 12 | file = "./NLPCC2016KBQA/nlpcc-iccpol-2016.kbqa."+data_type+"-data" 13 | question_str = "")[1].strip() 38 | q_str = q_str.split(">")[1].replace(" ","").strip() 39 | if entities in q_str: 40 | q_list = list(q_str) 41 | seq_q_list.extend(q_list) 42 | seq_q_list.extend([" "]) 43 | tag_list = ["O" for i in range(len(q_list))] 44 | tag_start_index = q_str.find(entities) 45 | for i in range(tag_start_index, tag_start_index+len(entities)): 46 | if tag_start_index == i: 47 | tag_list[i] = "B-LOC" 48 | else: 49 | tag_list[i] = "I-LOC" 50 | seq_tag_list.extend(tag_list) 51 | seq_tag_list.extend([" "]) 52 | else: 53 | pass 54 | q_t_a_list.append([q_str, t_str, a_str]) 55 | 56 | print('\t'.join(seq_tag_list[0:50])) 57 | print('\t'.join(seq_q_list[0:50])) 58 | seq_result = [str(q)+" "+tag for q, tag in zip(seq_q_list, seq_tag_list)] 59 | with open("./NER_Data/"+data_type+".txt", "w", encoding='utf-8') as f: 60 | f.write("\n".join(seq_result)) 61 | 62 | df = pd.DataFrame(q_t_a_list, columns=["q_str", "t_str", "a_str"]) 63 | df.to_csv("./NER_Data/q_t_a_df_"+data_type+".csv", encoding='utf-8', index=False) -------------------------------------------------------------------------------- /Data/construct_dataset_attribute.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | import os 4 | import random 5 | import pandas as pd 6 | 7 | 8 | ''' 9 | 构造属性关联训练集,分类问题,训练BERT分类模型 10 | 1 11 | ''' 12 | # [training, testing] 13 | data_type = "training" 14 | file = "nlpcc-iccpol-2016.kbqa."+data_type+"-data" 15 | target = "./NER_Data/q_t_a_df_"+data_type+".csv" 16 | 17 | attribute_classify_sample = [] 18 | 19 | # count the number of attribute 20 | testing_df = pd.read_csv(target, encoding='utf-8') 21 | testing_df['attribute'] = testing_df['t_str'].apply(lambda x: x.split('|||')[1].strip()) 22 | attribute_list = testing_df['attribute'].tolist() 23 | print(len(set(attribute_list))) 24 | print(testing_df.head()) 25 | 26 | 27 | # construct sample 28 | for row in testing_df.index: 29 | question, pos_att = testing_df.loc[row][['q_str', 'attribute']] 30 | question = question.strip() 31 | pos_att = pos_att.strip() #positive attribute 32 | # random.shuffle(attribute_list) the complex is big 33 | # neg_att_list = attribute_list[0:5] 34 | neg_att_list = random.sample(attribute_list, 5) 35 | attribute_classify_sample.append([question, pos_att, '1']) 36 | neg_att_sample = [[question, neg_att, '0'] for neg_att in neg_att_list if neg_att != pos_att] 37 | attribute_classify_sample.extend(neg_att_sample) 38 | 39 | seq_result = [str(lineno) + '\t' + '\t'.join(line) for (lineno, line) in enumerate(attribute_classify_sample)] 40 | 41 | if data_type == 'testing': 42 | with open("./Sim_Data/"+data_type+".txt", "w", encoding='utf-8') as f: 43 | f.write("\n".join(seq_result)) 44 | else: 45 | val_seq_result = seq_result[0:12000] 46 | with open("./Sim_Data/"+"val"+".txt", "w", encoding='utf-8') as f: 47 | f.write("\n".join(val_seq_result)) 48 | 49 | training_seq_result = seq_result[12000:] 50 | with open("./Sim_Data/"+data_type+".txt", "w", encoding='utf-8') as f: 51 | f.write("\n".join(training_seq_result)) 52 | -------------------------------------------------------------------------------- /Data/data_process.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | 用于语料库的处理 5 | 1. 全部处理成小于max_seq_length的序列,这样可以避免解码出现不合法的数据或者在最后算结果的时候出现out of range 的错误。 6 | 7 | @Author: Macan 8 | """ 9 | 10 | 11 | import os 12 | import codecs 13 | import argparse 14 | 15 | 16 | def load_file(file_path): 17 | if not os.path.exists(file_path): 18 | return None 19 | with codecs.open(file_path, 'r', encoding='utf-8') as fd: 20 | for line in fd: 21 | yield line 22 | 23 | 24 | def _cut(sentence): 25 | new_sentence = [] 26 | sen = [] 27 | for i in sentence: 28 | if i.split(' ')[0] in ['。', '!', '?'] and len(sen) != 0: 29 | sen.append(i) 30 | new_sentence.append(sen) 31 | sen = [] 32 | continue 33 | sen.append(i) 34 | if len(new_sentence) == 1: #娄底那种一句话超过max_seq_length的且没有句号的,用,分割,再长的不考虑了。。。 35 | new_sentence = [] 36 | sen = [] 37 | for i in sentence: 38 | if i.split(' ')[0] in [','] and len(sen) != 0: 39 | sen.append(i) 40 | new_sentence.append(sen) 41 | sen = [] 42 | continue 43 | sen.append(i) 44 | return new_sentence 45 | 46 | 47 | def cut_sentence(file, max_seq_length): 48 | """ 49 | 句子截断 50 | :param file: 51 | :param max_seq_length: 52 | :return: 53 | """ 54 | context = [] 55 | sentence = [] 56 | cnt = 0 57 | for line in load_file(file): 58 | line = line.strip() 59 | if line == '' and len(sentence) != 0: 60 | # 判断这一句是否超过最大长度 61 | if len(sentence) > max_seq_length: 62 | sentence = _cut(sentence) 63 | context.extend(sentence) 64 | else: 65 | context.append(sentence) 66 | sentence = [] 67 | continue 68 | cnt += 1 69 | sentence.append(line) 70 | print('token cnt:{}'.format(cnt)) 71 | return context 72 | 73 | def write_to_file(file, context): 74 | # 首先将源文件改名为新文件名,避免覆盖 75 | os.rename(file, '{}.bak'.format(file)) 76 | with codecs.open(file, 'w', encoding='utf-8') as fd: 77 | for sen in context: 78 | for token in sen: 79 | fd.write(token + '\n') 80 | fd.write('\n') 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser(description='data pre process') 85 | parser.add_argument('--train_data', type=str, default='./NERdata/train.txt') 86 | parser.add_argument('--dev_data', type=str, default='./NERdata/dev.txt') 87 | parser.add_argument('--test_data', type=str, default='./NERdata/test.txt') 88 | parser.add_argument('--max_seq_length', type=int, default=126) 89 | args = parser.parse_args() 90 | 91 | print('cut train data to max sequence length:{}'.format(args.max_seq_length)) 92 | context = cut_sentence(args.train_data, args.max_seq_length) 93 | write_to_file(args.train_data, context) 94 | 95 | print('cut dev data to max sequence length:{}'.format(args.max_seq_length)) 96 | context = cut_sentence(args.dev_data, args.max_seq_length) 97 | write_to_file(args.dev_data, context) 98 | 99 | print('cut test data to max sequence length:{}'.format(args.max_seq_length)) 100 | context = cut_sentence(args.test_data, args.max_seq_length) 101 | write_to_file(args.test_data, context) -------------------------------------------------------------------------------- /Data/load_dbdata.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/4/18 20:47 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : load_dbdata.py 6 | # @Software: PyCharm 7 | 8 | 9 | import pymysql 10 | import pandas as pd 11 | from sqlalchemy import create_engine 12 | 13 | 14 | def create_db(): 15 | connect = pymysql.connect( # 连接数据库服务器 16 | user="root", 17 | password="Nic180319", 18 | host="localhost", 19 | port=3306, 20 | db="MySQL", 21 | charset="utf8" 22 | ) 23 | conn = connect.cursor() # 创建操作游标 24 | # 你需要一个游标 来实现对数据库的操作相当于一条线索 25 | 26 | # 创建表 27 | conn.execute("drop database if exists KB_QA") # 如果new_database数据库存在则删除 28 | conn.execute("create database KB_QA") # 新创建一个数据库 29 | conn.execute("use KB_QA") # 选择new_database这个数据库 30 | 31 | # sql 中的内容为创建一个名为new_table的表 32 | sql = """create table nlpccQA(entity VARCHAR(20) character set utf8 collate utf8_unicode_ci, 33 | attribute VARCHAR(20) character set utf8 collate utf8_unicode_ci, answer VARCHAR(20) character set utf8 34 | collate utf8_unicode_ci)""" # ()中的参数可以自行设置 35 | conn.execute("drop table if exists nlpccQA") # 如果表存在则删除 36 | conn.execute(sql) # 创建表 37 | 38 | # 删除 39 | # conn.execute("drop table new_table") 40 | 41 | conn.close() # 关闭游标连接 42 | connect.close() # 关闭数据库服务器连接 释放内存 43 | 44 | 45 | def loaddata(): 46 | # 初始化数据库连接,使用pymysql模块 47 | db_info = {'user': 'root', 48 | 'password': 'Nic180319', 49 | 'host': 'localhost', 50 | 'port': 3306, 51 | 'database': 'MySQL' 52 | } 53 | 54 | engine = create_engine( 55 | 'mysql+pymysql://%(user)s:%(password)s@%(host)s:%(port)d/%(database)s?charset=utf8' % db_info, encoding='utf-8') 56 | # 直接使用下一种形式也可以 57 | # engine = create_engine('mysql+pymysql://root:123456@localhost:3306/test') 58 | 59 | # 读取本地CSV文件 60 | df = pd.read_csv("./DB_Data/clean_triple.csv", sep=',', encoding='utf-8') 61 | print(df) 62 | # 将新建的DataFrame储存为MySQL中的数据表,不储存index列(index=False) 63 | # if_exists: 64 | # 1.fail:如果表存在,啥也不做 65 | # 2.replace:如果表存在,删了表,再建立一个新表,把数据插入 66 | # 3.append:如果表存在,把数据插入,如果表不存在创建一个表!! 67 | pd.io.sql.to_sql(df, 'nlpccQA', con=engine, index=False, if_exists='append', chunksize=10000) 68 | # df.to_sql('example', con=engine, if_exists='replace')这种形式也可以 69 | print("Write to MySQL successfully!") 70 | 71 | 72 | def upload_data(sql): 73 | connect = pymysql.connect( # 连接数据库服务器 74 | user="root", 75 | password="Nic180319", 76 | host="localhost", 77 | port=3306, 78 | db="MySQL", 79 | charset="utf8" 80 | ) 81 | cursor = connect.cursor() # 创建操作游标 82 | try: 83 | # 执行SQL语句 84 | cursor.execute(sql) 85 | # 获取所有记录列表 86 | results = cursor.fetchall() 87 | except Exception as e: 88 | print("Error: unable to fecth data: %s ,%s" % (repr(e), sql)) 89 | finally: 90 | # 关闭数据库连接 91 | cursor.close() 92 | connect.close() 93 | return results 94 | 95 | 96 | if __name__ == '__main__': 97 | create_db() 98 | loaddata() -------------------------------------------------------------------------------- /Data/triple_clean.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/4/18 20:16 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : triple_clean.py 6 | # @Software: PyCharm 7 | 8 | 9 | import pandas as pd 10 | 11 | 12 | ''' 13 | 构造NER训练集,实体序列标注,训练BERT+BiLSTM+CRF 14 | ''' 15 | 16 | question_str = "")[1].strip() 38 | q_str = q_str.split(">")[1].replace(" ","").strip() 39 | if ''.join(entities.split(' ')) in q_str: 40 | clean_triple = t_str.split(">")[1].replace('\t','').replace(" ","").strip().split("|||") 41 | triple_list.append(clean_triple) 42 | else: 43 | print(entities) 44 | print(q_str) 45 | print('------------------------') 46 | 47 | df = pd.DataFrame(triple_list, columns=["entity", "attribute", "answer"]) 48 | print(df) 49 | print(df.info()) 50 | df.to_csv("./DB_Data/clean_triple.csv", encoding='utf-8', index=False) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的QA系统,BERT模型 3 | 需要下载BERT预训练模型(中文)chinese_L-12_H-768_A-12 https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip 解压缩后放在./ModelParams文件夹里面(注意整个chinese_L-12_H-768_A-12文件夹放进去) 4 | 5 | 另外需要在根目录建立Output文件夹存放训练的模型参数文件 6 | 分为Output/NER(命名实体识别)文件夹和Output/SIM(相似度)文件夹 7 | 8 | 1.run_ner.sh训练(命名实体识别) 9 | 10 | 2.terminal_ner.sh(命名实体识别测试) 11 | 12 | 3.args.py 13 | 14 | train = true 预训练模式 15 | 16 | test = true 相似度测试 17 | 18 | 4.run_similarity 相似度的训练或测试(根据第3步的设置决定) 19 | 20 | 5.qa_my.sh(连接了本地的neo4j知识库) 21 | 22 | 问答 23 | 24 | 参考:https://github.com/WenRichard/KBQA-BERT 25 | -------------------------------------------------------------------------------- /adcf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from neo4j import GraphDatabase 3 | driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "Nic180319")) 4 | class Klg(object): 5 | #向neo4j添加实体 6 | def add_en_0(self,tx,name): 7 | tx.run("MERGE (a:Node {name: $name})",name = name) 8 | def add_en(self,name): 9 | with driver.session() as session: 10 | session.write_transaction(self.add_en_0,name) 11 | #向neo4j添加关系 12 | def add_rel_0(self,tx,en1,rel,en2): 13 | tx.run("MATCH (n1:Node{name:$en1}),(n2:Node{name:$en2})" 14 | "MERGE (n1)-[r:"+rel+"]->(n2)",en1=en1,en2=en2) 15 | def add_rel(self,en1,rel,en2): 16 | with driver.session() as session: 17 | session.write_transaction(self.add_rel_0,en1,rel,en2) 18 | #添加属性 19 | def add_att_0(self,tx,name,att_dict): 20 | for key in att_dict.keys(): 21 | value = att_dict[key] 22 | tx.run("MATCH (n:Node{name:$name})" 23 | "SET n."+key+" = $value",name=name,value=value) 24 | def add_att(self,name,att_dict): 25 | with driver.session() as session: 26 | session.write_transaction(self.add_att_0,name,att_dict) 27 | #删除实体 28 | def delete_0(self,tx,name): 29 | tx.run("MATCH (n:Node{name:$name})" 30 | "DETACH DELETE n",name = name) #DETACH会无视该节点的关系 31 | def delete(self,name): 32 | with driver.session() as session: 33 | session.write_transaction(self.delete_0,name) 34 | 35 | #删除关系 36 | def delete_rel_0(self,tx,en1,rel,en2): 37 | tx.run('MATCH (n1:Node{name:$en1})-[r:'+rel+']->(n2:Node{name:$en2})' 38 | 'DELETE r',en1=en1,en2=en2) 39 | def delete_rel(self,en1,rel,en2): 40 | with driver.session() as session: 41 | session.write_transaction(self.delete_rel_0,en1,rel,en2) 42 | 43 | #根据实体和关系查找实体,例如查找蛇纹玉的产地,find(driver,'蛇纹玉','产地') 44 | def find(self,driver,en1,relation): 45 | with driver.session() as session: 46 | relationship = relation 47 | entity1 = en1 48 | n = len(entity1) 49 | cypher_statement = 'MATCH (name1:Node)-[r:' +relationship+ ']->(name2:Node)\ 50 | WHERE SUBSTRING(name1.name,0,{0}) = "{1}" RETURN [name1.name,"{2}",\ 51 | name2.name]'.format(n,entity1,relationship) 52 | result = session.run(cypher_statement).value() 53 | return result 54 | #查找一个节点所关联的所有关系 55 | def old_allrel(self,driver,name): 56 | with driver.session() as session: 57 | cypher_statement = 'MATCH (n1:Node)-[r]->(n2:Node) WHERE n1.name = "{0}" RETURN r'.format(name) 58 | result = [] 59 | for i in range(len(session.run(cypher_statement).value())): 60 | result.append(session.run(cypher_statement).value()[i].type) 61 | return result 62 | 63 | 64 | 65 | if __name__ == "__main__": 66 | a = Klg() 67 | 68 | #a.add_en('中国') 69 | #a.add_en('同济大学') 70 | a.add_rel('周星驰','国籍','中国') 71 | result = a.find(driver,'周星驰','代表作品') 72 | print(result) 73 | #a.delete('同济大学') 74 | #a.delete_rel('李小龙','代表作品','《猛龙过江 》') 75 | #a.add_att('黄晓明',{'职业':'演员'}) 76 | #print(a.old_allrel(driver,name='高克')) -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | tf.logging.set_verbosity(tf.logging.INFO) 5 | 6 | file_path = os.path.dirname(__file__) 7 | 8 | model_dir = os.path.join(file_path, 'ModelParams/chinese_L-12_H-768_A-12/') 9 | config_name = os.path.join(model_dir, 'bert_config.json') 10 | ckpt_name = os.path.join(model_dir, 'bert_model.ckpt') 11 | output_dir = os.path.join(file_path, 'Output/SIM/result/') 12 | vocab_file = os.path.join(model_dir, 'vocab.txt') 13 | data_dir = os.path.join(file_path, 'Data/Sim_Data/') 14 | 15 | num_train_epochs = 10 16 | batch_size = 128 17 | learning_rate = 0.00005 18 | 19 | # gpu使用率 20 | gpu_memory_fraction = 0.8 21 | 22 | # 默认取倒数第二层的输出值作为句向量 23 | layer_indexes = [-2] 24 | 25 | # 序列的最大程度,单文本建议把该值调小 26 | max_seq_len = 32 27 | 28 | # 预训练模型 29 | train = False 30 | 31 | # 测试模型 32 | test = True 33 | -------------------------------------------------------------------------------- /bert/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /bert/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-YULU/KBQA-BERT/7150d7ef00f2ec0a28ffb11fd7d6b30c84a28eb9/bert/__init__.py -------------------------------------------------------------------------------- /bert/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | 24 | import tokenization 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("input_file", None, 32 | "Input raw text file (or comma-separated list of files).") 33 | 34 | flags.DEFINE_string( 35 | "output_file", None, 36 | "Output TF example file (or comma-separated list of files).") 37 | 38 | flags.DEFINE_string("vocab_file", None, 39 | "The vocabulary file that the BERT model was trained on.") 40 | 41 | flags.DEFINE_bool( 42 | "do_lower_case", True, 43 | "Whether to lower case the input text. Should be True for uncased " 44 | "models and False for cased models.") 45 | 46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 47 | 48 | flags.DEFINE_integer("max_predictions_per_seq", 20, 49 | "Maximum number of masked LM predictions per sequence.") 50 | 51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 52 | 53 | flags.DEFINE_integer( 54 | "dupe_factor", 10, 55 | "Number of times to duplicate the input data (with different masks).") 56 | 57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 58 | 59 | flags.DEFINE_float( 60 | "short_seq_prob", 0.1, 61 | "Probability of creating sequences which are shorter than the " 62 | "maximum length.") 63 | 64 | 65 | class TrainingInstance(object): 66 | """A single training instance (sentence pair).""" 67 | 68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 69 | is_random_next): 70 | self.tokens = tokens 71 | self.segment_ids = segment_ids 72 | self.is_random_next = is_random_next 73 | self.masked_lm_positions = masked_lm_positions 74 | self.masked_lm_labels = masked_lm_labels 75 | 76 | def __str__(self): 77 | s = "" 78 | s += "tokens: %s\n" % (" ".join( 79 | [tokenization.printable_text(x) for x in self.tokens])) 80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 81 | s += "is_random_next: %s\n" % self.is_random_next 82 | s += "masked_lm_positions: %s\n" % (" ".join( 83 | [str(x) for x in self.masked_lm_positions])) 84 | s += "masked_lm_labels: %s\n" % (" ".join( 85 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 86 | s += "\n" 87 | return s 88 | 89 | def __repr__(self): 90 | return self.__str__() 91 | 92 | 93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 94 | max_predictions_per_seq, output_files): 95 | """Create TF example files from `TrainingInstance`s.""" 96 | writers = [] 97 | for output_file in output_files: 98 | writers.append(tf.python_io.TFRecordWriter(output_file)) 99 | 100 | writer_index = 0 101 | 102 | total_written = 0 103 | for (inst_index, instance) in enumerate(instances): 104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 105 | input_mask = [1] * len(input_ids) 106 | segment_ids = list(instance.segment_ids) 107 | assert len(input_ids) <= max_seq_length 108 | 109 | while len(input_ids) < max_seq_length: 110 | input_ids.append(0) 111 | input_mask.append(0) 112 | segment_ids.append(0) 113 | 114 | assert len(input_ids) == max_seq_length 115 | assert len(input_mask) == max_seq_length 116 | assert len(segment_ids) == max_seq_length 117 | 118 | masked_lm_positions = list(instance.masked_lm_positions) 119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 120 | masked_lm_weights = [1.0] * len(masked_lm_ids) 121 | 122 | while len(masked_lm_positions) < max_predictions_per_seq: 123 | masked_lm_positions.append(0) 124 | masked_lm_ids.append(0) 125 | masked_lm_weights.append(0.0) 126 | 127 | next_sentence_label = 1 if instance.is_random_next else 0 128 | 129 | features = collections.OrderedDict() 130 | features["input_ids"] = create_int_feature(input_ids) 131 | features["input_mask"] = create_int_feature(input_mask) 132 | features["segment_ids"] = create_int_feature(segment_ids) 133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 137 | 138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 139 | 140 | writers[writer_index].write(tf_example.SerializeToString()) 141 | writer_index = (writer_index + 1) % len(writers) 142 | 143 | total_written += 1 144 | 145 | if inst_index < 20: 146 | tf.logging.info("*** Example ***") 147 | tf.logging.info("tokens: %s" % " ".join( 148 | [tokenization.printable_text(x) for x in instance.tokens])) 149 | 150 | for feature_name in features.keys(): 151 | feature = features[feature_name] 152 | values = [] 153 | if feature.int64_list.value: 154 | values = feature.int64_list.value 155 | elif feature.float_list.value: 156 | values = feature.float_list.value 157 | tf.logging.info( 158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 159 | 160 | for writer in writers: 161 | writer.close() 162 | 163 | tf.logging.info("Wrote %d total instances", total_written) 164 | 165 | 166 | def create_int_feature(values): 167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 168 | return feature 169 | 170 | 171 | def create_float_feature(values): 172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 173 | return feature 174 | 175 | 176 | def create_training_instances(input_files, tokenizer, max_seq_length, 177 | dupe_factor, short_seq_prob, masked_lm_prob, 178 | max_predictions_per_seq, rng): 179 | """Create `TrainingInstance`s from raw text.""" 180 | all_documents = [[]] 181 | 182 | # Input file format: 183 | # (1) One sentence per line. These should ideally be actual sentences, not 184 | # entire paragraphs or arbitrary spans of text. (Because we use the 185 | # sentence boundaries for the "next sentence prediction" task). 186 | # (2) Blank lines between documents. Document boundaries are needed so 187 | # that the "next sentence prediction" task doesn't span between documents. 188 | for input_file in input_files: 189 | with tf.gfile.GFile(input_file, "r") as reader: 190 | while True: 191 | line = tokenization.convert_to_unicode(reader.readline()) 192 | if not line: 193 | break 194 | line = line.strip() 195 | 196 | # Empty lines are used as document delimiters 197 | if not line: 198 | all_documents.append([]) 199 | tokens = tokenizer.tokenize(line) 200 | if tokens: 201 | all_documents[-1].append(tokens) 202 | 203 | # Remove empty documents 204 | all_documents = [x for x in all_documents if x] 205 | rng.shuffle(all_documents) 206 | 207 | vocab_words = list(tokenizer.vocab.keys()) 208 | instances = [] 209 | for _ in range(dupe_factor): 210 | for document_index in range(len(all_documents)): 211 | instances.extend( 212 | create_instances_from_document( 213 | all_documents, document_index, max_seq_length, short_seq_prob, 214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 215 | 216 | rng.shuffle(instances) 217 | return instances 218 | 219 | 220 | def create_instances_from_document( 221 | all_documents, document_index, max_seq_length, short_seq_prob, 222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 223 | """Creates `TrainingInstance`s for a single document.""" 224 | document = all_documents[document_index] 225 | 226 | # Account for [CLS], [SEP], [SEP] 227 | max_num_tokens = max_seq_length - 3 228 | 229 | # We *usually* want to fill up the entire sequence since we are padding 230 | # to `max_seq_length` anyways, so short sequences are generally wasted 231 | # computation. However, we *sometimes* 232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 233 | # sequences to minimize the mismatch between pre-training and fine-tuning. 234 | # The `target_seq_length` is just a rough target however, whereas 235 | # `max_seq_length` is a hard limit. 236 | target_seq_length = max_num_tokens 237 | if rng.random() < short_seq_prob: 238 | target_seq_length = rng.randint(2, max_num_tokens) 239 | 240 | # We DON'T just concatenate all of the tokens from a document into a long 241 | # sequence and choose an arbitrary split point because this would make the 242 | # next sentence prediction task too easy. Instead, we split the input into 243 | # segments "A" and "B" based on the actual "sentences" provided by the user 244 | # input. 245 | instances = [] 246 | current_chunk = [] 247 | current_length = 0 248 | i = 0 249 | while i < len(document): 250 | segment = document[i] 251 | current_chunk.append(segment) 252 | current_length += len(segment) 253 | if i == len(document) - 1 or current_length >= target_seq_length: 254 | if current_chunk: 255 | # `a_end` is how many segments from `current_chunk` go into the `A` 256 | # (first) sentence. 257 | a_end = 1 258 | if len(current_chunk) >= 2: 259 | a_end = rng.randint(1, len(current_chunk) - 1) 260 | 261 | tokens_a = [] 262 | for j in range(a_end): 263 | tokens_a.extend(current_chunk[j]) 264 | 265 | tokens_b = [] 266 | # Random next 267 | is_random_next = False 268 | if len(current_chunk) == 1 or rng.random() < 0.5: 269 | is_random_next = True 270 | target_b_length = target_seq_length - len(tokens_a) 271 | 272 | # This should rarely go for more than one iteration for large 273 | # corpora. However, just to be careful, we try to make sure that 274 | # the random document is not the same as the document 275 | # we're processing. 276 | for _ in range(10): 277 | random_document_index = rng.randint(0, len(all_documents) - 1) 278 | if random_document_index != document_index: 279 | break 280 | 281 | random_document = all_documents[random_document_index] 282 | random_start = rng.randint(0, len(random_document) - 1) 283 | for j in range(random_start, len(random_document)): 284 | tokens_b.extend(random_document[j]) 285 | if len(tokens_b) >= target_b_length: 286 | break 287 | # We didn't actually use these segments so we "put them back" so 288 | # they don't go to waste. 289 | num_unused_segments = len(current_chunk) - a_end 290 | i -= num_unused_segments 291 | # Actual next 292 | else: 293 | is_random_next = False 294 | for j in range(a_end, len(current_chunk)): 295 | tokens_b.extend(current_chunk[j]) 296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 297 | 298 | assert len(tokens_a) >= 1 299 | assert len(tokens_b) >= 1 300 | 301 | tokens = [] 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in tokens_a: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | 309 | tokens.append("[SEP]") 310 | segment_ids.append(0) 311 | 312 | for token in tokens_b: 313 | tokens.append(token) 314 | segment_ids.append(1) 315 | tokens.append("[SEP]") 316 | segment_ids.append(1) 317 | 318 | (tokens, masked_lm_positions, 319 | masked_lm_labels) = create_masked_lm_predictions( 320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 321 | instance = TrainingInstance( 322 | tokens=tokens, 323 | segment_ids=segment_ids, 324 | is_random_next=is_random_next, 325 | masked_lm_positions=masked_lm_positions, 326 | masked_lm_labels=masked_lm_labels) 327 | instances.append(instance) 328 | current_chunk = [] 329 | current_length = 0 330 | i += 1 331 | 332 | return instances 333 | 334 | 335 | def create_masked_lm_predictions(tokens, masked_lm_prob, 336 | max_predictions_per_seq, vocab_words, rng): 337 | """Creates the predictions for the masked LM objective.""" 338 | 339 | cand_indexes = [] 340 | for (i, token) in enumerate(tokens): 341 | if token == "[CLS]" or token == "[SEP]": 342 | continue 343 | cand_indexes.append(i) 344 | 345 | rng.shuffle(cand_indexes) 346 | 347 | output_tokens = list(tokens) 348 | 349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name 350 | 351 | num_to_predict = min(max_predictions_per_seq, 352 | max(1, int(round(len(tokens) * masked_lm_prob)))) 353 | 354 | masked_lms = [] 355 | covered_indexes = set() 356 | for index in cand_indexes: 357 | if len(masked_lms) >= num_to_predict: 358 | break 359 | if index in covered_indexes: 360 | continue 361 | covered_indexes.add(index) 362 | 363 | masked_token = None 364 | # 80% of the time, replace with [MASK] 365 | if rng.random() < 0.8: 366 | masked_token = "[MASK]" 367 | else: 368 | # 10% of the time, keep original 369 | if rng.random() < 0.5: 370 | masked_token = tokens[index] 371 | # 10% of the time, replace with random word 372 | else: 373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 374 | 375 | output_tokens[index] = masked_token 376 | 377 | masked_lms.append(masked_lm(index=index, label=tokens[index])) 378 | 379 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 380 | 381 | masked_lm_positions = [] 382 | masked_lm_labels = [] 383 | for p in masked_lms: 384 | masked_lm_positions.append(p.index) 385 | masked_lm_labels.append(p.label) 386 | 387 | return (output_tokens, masked_lm_positions, masked_lm_labels) 388 | 389 | 390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 391 | """Truncates a pair of sequences to a maximum sequence length.""" 392 | while True: 393 | total_length = len(tokens_a) + len(tokens_b) 394 | if total_length <= max_num_tokens: 395 | break 396 | 397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 398 | assert len(trunc_tokens) >= 1 399 | 400 | # We want to sometimes truncate from the front and sometimes from the 401 | # back to add more randomness and avoid biases. 402 | if rng.random() < 0.5: 403 | del trunc_tokens[0] 404 | else: 405 | trunc_tokens.pop() 406 | 407 | 408 | def main(_): 409 | tf.logging.set_verbosity(tf.logging.INFO) 410 | 411 | tokenizer = tokenization.FullTokenizer( 412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 413 | 414 | input_files = [] 415 | for input_pattern in FLAGS.input_file.split(","): 416 | input_files.extend(tf.gfile.Glob(input_pattern)) 417 | 418 | tf.logging.info("*** Reading from input files ***") 419 | for input_file in input_files: 420 | tf.logging.info(" %s", input_file) 421 | 422 | rng = random.Random(FLAGS.random_seed) 423 | instances = create_training_instances( 424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 426 | rng) 427 | 428 | output_files = FLAGS.output_file.split(",") 429 | tf.logging.info("*** Writing to output files ***") 430 | for output_file in output_files: 431 | tf.logging.info(" %s", output_file) 432 | 433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 434 | FLAGS.max_predictions_per_seq, output_files) 435 | 436 | 437 | if __name__ == "__main__": 438 | flags.mark_flag_as_required("input_file") 439 | flags.mark_flag_as_required("output_file") 440 | flags.mark_flag_as_required("vocab_file") 441 | tf.app.run() 442 | -------------------------------------------------------------------------------- /bert/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import modeling 27 | import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, 174 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 175 | tvars, init_checkpoint) 176 | if use_tpu: 177 | 178 | def tpu_scaffold(): 179 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 180 | return tf.train.Scaffold() 181 | 182 | scaffold_fn = tpu_scaffold 183 | else: 184 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 185 | 186 | tf.logging.info("**** Trainable Variables ****") 187 | for var in tvars: 188 | init_string = "" 189 | if var.name in initialized_variable_names: 190 | init_string = ", *INIT_FROM_CKPT*" 191 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 192 | init_string) 193 | 194 | all_layers = model.get_all_encoder_layers() 195 | 196 | predictions = { 197 | "unique_id": unique_ids, 198 | } 199 | 200 | for (i, layer_index) in enumerate(layer_indexes): 201 | predictions["layer_output_%d" % i] = all_layers[layer_index] 202 | 203 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 204 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 205 | return output_spec 206 | 207 | return model_fn 208 | 209 | 210 | def convert_examples_to_features(examples, seq_length, tokenizer): 211 | """Loads a data file into a list of `InputBatch`s.""" 212 | 213 | features = [] 214 | for (ex_index, example) in enumerate(examples): 215 | tokens_a = tokenizer.tokenize(example.text_a) 216 | 217 | tokens_b = None 218 | if example.text_b: 219 | tokens_b = tokenizer.tokenize(example.text_b) 220 | 221 | if tokens_b: 222 | # Modifies `tokens_a` and `tokens_b` in place so that the total 223 | # length is less than the specified length. 224 | # Account for [CLS], [SEP], [SEP] with "- 3" 225 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 226 | else: 227 | # Account for [CLS] and [SEP] with "- 2" 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | if tokens_b: 260 | for token in tokens_b: 261 | tokens.append(token) 262 | input_type_ids.append(1) 263 | tokens.append("[SEP]") 264 | input_type_ids.append(1) 265 | 266 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 267 | 268 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 269 | # tokens are attended to. 270 | input_mask = [1] * len(input_ids) 271 | 272 | # Zero-pad up to the sequence length. 273 | while len(input_ids) < seq_length: 274 | input_ids.append(0) 275 | input_mask.append(0) 276 | input_type_ids.append(0) 277 | 278 | assert len(input_ids) == seq_length 279 | assert len(input_mask) == seq_length 280 | assert len(input_type_ids) == seq_length 281 | 282 | if ex_index < 5: 283 | tf.logging.info("*** Example ***") 284 | tf.logging.info("unique_id: %s" % (example.unique_id)) 285 | tf.logging.info("tokens: %s" % " ".join( 286 | [tokenization.printable_text(x) for x in tokens])) 287 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 288 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 289 | tf.logging.info( 290 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 291 | 292 | features.append( 293 | InputFeatures( 294 | unique_id=example.unique_id, 295 | tokens=tokens, 296 | input_ids=input_ids, 297 | input_mask=input_mask, 298 | input_type_ids=input_type_ids)) 299 | return features 300 | 301 | 302 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 303 | """Truncates a sequence pair in place to the maximum length.""" 304 | 305 | # This is a simple heuristic which will always truncate the longer sequence 306 | # one token at a time. This makes more sense than truncating an equal percent 307 | # of tokens from each, since if one sequence is very short then each token 308 | # that's truncated likely contains more information than a longer sequence. 309 | while True: 310 | total_length = len(tokens_a) + len(tokens_b) 311 | if total_length <= max_length: 312 | break 313 | if len(tokens_a) > len(tokens_b): 314 | tokens_a.pop() 315 | else: 316 | tokens_b.pop() 317 | 318 | 319 | def read_examples(input_file): 320 | """Read a list of `InputExample`s from an input file.""" 321 | examples = [] 322 | unique_id = 0 323 | with tf.gfile.GFile(input_file, "r") as reader: 324 | while True: 325 | line = tokenization.convert_to_unicode(reader.readline()) 326 | if not line: 327 | break 328 | line = line.strip() 329 | text_a = None 330 | text_b = None 331 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 332 | if m is None: 333 | text_a = line 334 | else: 335 | text_a = m.group(1) 336 | text_b = m.group(2) 337 | examples.append( 338 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 339 | unique_id += 1 340 | return examples 341 | 342 | 343 | def main(_): 344 | tf.logging.set_verbosity(tf.logging.INFO) 345 | 346 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 347 | 348 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 349 | 350 | tokenizer = tokenization.FullTokenizer( 351 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 354 | run_config = tf.contrib.tpu.RunConfig( 355 | master=FLAGS.master, 356 | tpu_config=tf.contrib.tpu.TPUConfig( 357 | num_shards=FLAGS.num_tpu_cores, 358 | per_host_input_for_training=is_per_host)) 359 | 360 | examples = read_examples(FLAGS.input_file) 361 | 362 | features = convert_examples_to_features( 363 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 364 | 365 | unique_id_to_feature = {} 366 | for feature in features: 367 | unique_id_to_feature[feature.unique_id] = feature 368 | 369 | model_fn = model_fn_builder( 370 | bert_config=bert_config, 371 | init_checkpoint=FLAGS.init_checkpoint, 372 | layer_indexes=layer_indexes, 373 | use_tpu=FLAGS.use_tpu, 374 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 375 | 376 | # If TPU is not available, this will fall back to normal Estimator on CPU 377 | # or GPU. 378 | estimator = tf.contrib.tpu.TPUEstimator( 379 | use_tpu=FLAGS.use_tpu, 380 | model_fn=model_fn, 381 | config=run_config, 382 | predict_batch_size=FLAGS.batch_size) 383 | 384 | input_fn = input_fn_builder( 385 | features=features, seq_length=FLAGS.max_seq_length) 386 | 387 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 388 | "w")) as writer: 389 | for result in estimator.predict(input_fn, yield_single_examples=True): 390 | unique_id = int(result["unique_id"]) 391 | feature = unique_id_to_feature[unique_id] 392 | output_json = collections.OrderedDict() 393 | output_json["linex_index"] = unique_id 394 | all_features = [] 395 | for (i, token) in enumerate(feature.tokens): 396 | all_layers = [] 397 | for (j, layer_index) in enumerate(layer_indexes): 398 | layer_output = result["layer_output_%d" % j] 399 | layers = collections.OrderedDict() 400 | layers["index"] = layer_index 401 | layers["values"] = [ 402 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 403 | ] 404 | all_layers.append(layers) 405 | features = collections.OrderedDict() 406 | features["token"] = token 407 | features["layers"] = all_layers 408 | all_features.append(features) 409 | output_json["features"] = all_features 410 | writer.write(json.dumps(output_json) + "\n") 411 | 412 | 413 | if __name__ == "__main__": 414 | flags.mark_flag_as_required("input_file") 415 | flags.mark_flag_as_required("vocab_file") 416 | flags.mark_flag_as_required("bert_config_file") 417 | flags.mark_flag_as_required("init_checkpoint") 418 | flags.mark_flag_as_required("output_file") 419 | tf.app.run() 420 | -------------------------------------------------------------------------------- /bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /bert/multilingual.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | There are two multilingual models currently available. We do not plan to release 4 | more single-language models, but we may release `BERT-Large` versions of these 5 | two in the future: 6 | 7 | * **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**: 8 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 9 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: 10 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M 11 | parameters 12 | 13 | See the [list of languages](#list-of-languages) that the Multilingual model 14 | supports. The Multilingual model does include Chinese (and English), but if your 15 | fine-tuning data is Chinese-only, then the Chinese model will likely produce 16 | better results. 17 | 18 | ## Results 19 | 20 | To evaluate these systems, we use the 21 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a 22 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the 23 | dev and test sets have been translated (by humans) into 15 languages. Note that 24 | the training set was *machine* translated (we used the translations provided by 25 | XNLI, not Google NMT). For clarity, we only report on 6 languages below: 26 | 27 | 28 | 29 | | System | English | Chinese | Spanish | German | Arabic | Urdu | 30 | | ------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- | 31 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 | 32 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 | 33 | | BERT -Translate Train | **81.4** | **74.2** | **77.3** | **75.2** | **70.5** | 61.7 | 34 | | BERT - Translate Test | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** | 35 | | BERT - Zero Shot | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 | 36 | 37 | 38 | 39 | The first two rows are baselines from the XNLI paper and the last three rows are 40 | our results with BERT. 41 | 42 | **Translate Train** means that the MultiNLI training set was machine translated 43 | from English into the foreign language. So training and evaluation were both 44 | done in the foreign language. Unfortunately, training was done on 45 | machine-translated data, so it is impossible to quantify how much of the lower 46 | accuracy (compared to English) is due to the quality of the machine translation 47 | vs. the quality of the pre-trained model. 48 | 49 | **Translate Test** means that the XNLI test set was machine translated from the 50 | foreign language into English. So training and evaluation were both done on 51 | English. However, test evaluation was done on machine-translated English, so the 52 | accuracy depends on the quality of the machine translation system. 53 | 54 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English 55 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case, 56 | machine translation was not involved at all in either the pre-training or 57 | fine-tuning. 58 | 59 | Note that the English result is worse than the 84.2 MultiNLI baseline because 60 | this training used Multilingual BERT rather than English-only BERT. This implies 61 | that for high-resource languages, the Multilingual model is somewhat worse than 62 | a single-language model. However, it is not feasible for us to train and 63 | maintain dozens of single-language model. Therefore, if your goal is to maximize 64 | performance with a language other than English or Chinese, you might find it 65 | beneficial to run pre-training for additional steps starting from our 66 | Multilingual model on data from your language of interest. 67 | 68 | Here is a comparison of training Chinese models with the Multilingual 69 | `BERT-Base` and Chinese-only `BERT-Base`: 70 | 71 | System | Chinese 72 | ----------------------- | ------- 73 | XNLI Baseline | 67.0 74 | BERT Multilingual Model | 74.2 75 | BERT Chinese-only Model | 77.2 76 | 77 | Similar to English, the single-language model does 3% better than the 78 | Multilingual model. 79 | 80 | ## Fine-tuning Example 81 | 82 | The multilingual model does **not** require any special consideration or API 83 | changes. We did update the implementation of `BasicTokenizer` in 84 | `tokenization.py` to support Chinese character tokenization, so please update if 85 | you forked it. However, we did not change the tokenization API. 86 | 87 | To test the new models, we did modify `run_classifier.py` to add support for the 88 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language 89 | version of MultiNLI where the dev/test sets have been human-translated, and the 90 | training set has been machine-translated. 91 | 92 | To run the fine-tuning code, please download the 93 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the 94 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip) 95 | and then unpack both .zip files into some directory `$XNLI_DIR`. 96 | 97 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py` 98 | (Chinese by default), so please modify `XnliProcessor` if you want to run on 99 | another language. 100 | 101 | This is a large dataset, so this will training will take a few hours on a GPU 102 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for 103 | debugging, just set `num_train_epochs` to a small value like `0.1`. 104 | 105 | ```shell 106 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12 107 | export XNLI_DIR=/path/to/xnli 108 | 109 | python run_classifier.py \ 110 | --task_name=XNLI \ 111 | --do_train=true \ 112 | --do_eval=true \ 113 | --data_dir=$XNLI_DIR \ 114 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 115 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 116 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 117 | --max_seq_length=128 \ 118 | --train_batch_size=32 \ 119 | --learning_rate=5e-5 \ 120 | --num_train_epochs=2.0 \ 121 | --output_dir=/tmp/xnli_output/ 122 | ``` 123 | 124 | With the Chinese-only model, the results should look something like this: 125 | 126 | ``` 127 | ***** Eval results ***** 128 | eval_accuracy = 0.774116 129 | eval_loss = 0.83554 130 | global_step = 24543 131 | loss = 0.74603 132 | ``` 133 | 134 | ## Details 135 | 136 | ### Data Source and Sampling 137 | 138 | The languages chosen were the 139 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias). 140 | The entire Wikipedia dump for each language (excluding user and talk pages) was 141 | taken as the training data for each language 142 | 143 | However, the size of the Wikipedia for a given language varies greatly, and 144 | therefore low-resource languages may be "under-represented" in terms of the 145 | neural network model (under the assumption that languages are "competing" for 146 | limited model capacity to some extent). 147 | 148 | However, the size of a Wikipedia also correlates with the number of speakers of 149 | a language, and we also don't want to overfit the model by performing thousands 150 | of epochs over a tiny Wikipedia for a particular language. 151 | 152 | To balance these two factors, we performed exponentially smoothed weighting of 153 | the data during pre-training data creation (and WordPiece vocab creation). In 154 | other words, let's say that the probability of a language is *P(L)*, e.g., 155 | *P(English) = 0.21* means that after concatenating all of the Wikipedias 156 | together, 21% of our data is English. We exponentiate each probability by some 157 | factor *S* and then re-normalize, and sample from that distribution. In our case 158 | we use *S=0.7*. So, high-resource languages like English will be under-sampled, 159 | and low-resource languages like Icelandic will be over-sampled. E.g., in the 160 | original distribution English would be sampled 1000x more than Icelandic, but 161 | after smoothing it's only sampled 100x more. 162 | 163 | ### Tokenization 164 | 165 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are 166 | weighted the same way as the data, so low-resource languages are upweighted by 167 | some factor. We intentionally do *not* use any marker to denote the input 168 | language (so that zero-shot training can work). 169 | 170 | Because Chinese does not have whitespace characters, we add spaces around every 171 | character in the 172 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\)) 173 | before applying WordPiece. This means that Chinese is effectively 174 | character-tokenized. Note that the CJK Unicode block only includes 175 | Chinese-origin characters and does *not* include Hangul Korean or 176 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like 177 | all other languages. 178 | 179 | For all other languages, we apply the 180 | [same recipe as English](https://github.com/google-research/bert#tokenization): 181 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace 182 | tokenization. We understand that accent markers have substantial meaning in some 183 | languages, but felt that the benefits of reducing the effective vocabulary make 184 | up for this. Generally the strong contextual models of BERT should make up for 185 | any ambiguity introduced by stripping accent markers. 186 | 187 | ### List of Languages 188 | 189 | The multilingual model supports the following languages. These languages were 190 | chosen because they are the top 100 languages with the largest Wikipedias: 191 | 192 | * Afrikaans 193 | * Albanian 194 | * Arabic 195 | * Aragonese 196 | * Armenian 197 | * Asturian 198 | * Azerbaijani 199 | * Bashkir 200 | * Basque 201 | * Bavarian 202 | * Belarusian 203 | * Bengali 204 | * Bishnupriya Manipuri 205 | * Bosnian 206 | * Breton 207 | * Bulgarian 208 | * Burmese 209 | * Catalan 210 | * Cebuano 211 | * Chechen 212 | * Chinese (Simplified) 213 | * Chinese (Traditional) 214 | * Chuvash 215 | * Croatian 216 | * Czech 217 | * Danish 218 | * Dutch 219 | * English 220 | * Estonian 221 | * Finnish 222 | * French 223 | * Galician 224 | * Georgian 225 | * German 226 | * Greek 227 | * Gujarati 228 | * Haitian 229 | * Hebrew 230 | * Hindi 231 | * Hungarian 232 | * Icelandic 233 | * Ido 234 | * Indonesian 235 | * Irish 236 | * Italian 237 | * Japanese 238 | * Javanese 239 | * Kannada 240 | * Kazakh 241 | * Kirghiz 242 | * Korean 243 | * Latin 244 | * Latvian 245 | * Lithuanian 246 | * Lombard 247 | * Low Saxon 248 | * Luxembourgish 249 | * Macedonian 250 | * Malagasy 251 | * Malay 252 | * Malayalam 253 | * Marathi 254 | * Minangkabau 255 | * Nepali 256 | * Newar 257 | * Norwegian (Bokmal) 258 | * Norwegian (Nynorsk) 259 | * Occitan 260 | * Persian (Farsi) 261 | * Piedmontese 262 | * Polish 263 | * Portuguese 264 | * Punjabi 265 | * Romanian 266 | * Russian 267 | * Scots 268 | * Serbian 269 | * Serbo-Croatian 270 | * Sicilian 271 | * Slovak 272 | * Slovenian 273 | * South Azerbaijani 274 | * Spanish 275 | * Sundanese 276 | * Swahili 277 | * Swedish 278 | * Tagalog 279 | * Tajik 280 | * Tamil 281 | * Tatar 282 | * Telugu 283 | * Turkish 284 | * Ukrainian 285 | * Urdu 286 | * Uzbek 287 | * Vietnamese 288 | * Volapük 289 | * Waray-Waray 290 | * Welsh 291 | * West 292 | * Western Punjabi 293 | * Yoruba 294 | 295 | The only language which we had to unfortunately exclude was Thai, since it is 296 | the only language (other than Chinese) that does not use whitespace to delimit 297 | words, and it has too many characters-per-word to use character-based 298 | tokenization. Our WordPiece algorithm is quadratic with respect to the size of 299 | the input token so very long character strings do not work with it. 300 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /bert/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /bert/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | 140 | Args: 141 | do_lower_case: Whether to lower case the input. 142 | """ 143 | self.do_lower_case = do_lower_case 144 | 145 | def tokenize(self, text): 146 | """Tokenizes a piece of text.""" 147 | text = convert_to_unicode(text) 148 | text = self._clean_text(text) 149 | 150 | # This was added on November 1st, 2018 for the multilingual and Chinese 151 | # models. This is also applied to the English models now, but it doesn't 152 | # matter since the English models were not trained on any Chinese data 153 | # and generally don't have any Chinese data in them (there are Chinese 154 | # characters in the vocabulary because Wikipedia does have some Chinese 155 | # words in the English Wikipedia.). 156 | text = self._tokenize_chinese_chars(text) 157 | 158 | orig_tokens = whitespace_tokenize(text) 159 | split_tokens = [] 160 | for token in orig_tokens: 161 | if self.do_lower_case: 162 | token = token.lower() 163 | token = self._run_strip_accents(token) 164 | split_tokens.extend(self._run_split_on_punc(token)) 165 | 166 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 167 | return output_tokens 168 | 169 | def _run_strip_accents(self, text): 170 | """Strips accents from a piece of text.""" 171 | text = unicodedata.normalize("NFD", text) 172 | output = [] 173 | for char in text: 174 | cat = unicodedata.category(char) 175 | if cat == "Mn": 176 | continue 177 | output.append(char) 178 | return "".join(output) 179 | 180 | def _run_split_on_punc(self, text): 181 | """Splits punctuation on a piece of text.""" 182 | chars = list(text) 183 | i = 0 184 | start_new_word = True 185 | output = [] 186 | while i < len(chars): 187 | char = chars[i] 188 | if _is_punctuation(char): 189 | output.append([char]) 190 | start_new_word = True 191 | else: 192 | if start_new_word: 193 | output.append([]) 194 | start_new_word = False 195 | output[-1].append(char) 196 | i += 1 197 | 198 | return ["".join(x) for x in output] 199 | 200 | def _tokenize_chinese_chars(self, text): 201 | """Adds whitespace around any CJK character.""" 202 | output = [] 203 | for char in text: 204 | cp = ord(char) 205 | if self._is_chinese_char(cp): 206 | output.append(" ") 207 | output.append(char) 208 | output.append(" ") 209 | else: 210 | output.append(char) 211 | return "".join(output) 212 | 213 | def _is_chinese_char(self, cp): 214 | """Checks whether CP is the codepoint of a CJK character.""" 215 | # This defines a "chinese character" as anything in the CJK Unicode block: 216 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 217 | # 218 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 219 | # despite its name. The modern Korean Hangul alphabet is a different block, 220 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 221 | # space-separated words, so they are not treated specially and handled 222 | # like the all of the other languages. 223 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 224 | (cp >= 0x3400 and cp <= 0x4DBF) or # 225 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 226 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 227 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 228 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 229 | (cp >= 0xF900 and cp <= 0xFAFF) or # 230 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 231 | return True 232 | 233 | return False 234 | 235 | def _clean_text(self, text): 236 | """Performs invalid character removal and whitespace cleanup on text.""" 237 | output = [] 238 | for char in text: 239 | cp = ord(char) 240 | if cp == 0 or cp == 0xfffd or _is_control(char): 241 | continue 242 | if _is_whitespace(char): 243 | output.append(" ") 244 | else: 245 | output.append(char) 246 | return "".join(output) 247 | 248 | 249 | class WordpieceTokenizer(object): 250 | """Runs WordPiece tokenziation.""" 251 | 252 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 253 | self.vocab = vocab 254 | self.unk_token = unk_token 255 | self.max_input_chars_per_word = max_input_chars_per_word 256 | 257 | def tokenize(self, text): 258 | """Tokenizes a piece of text into its word pieces. 259 | 260 | This uses a greedy longest-match-first algorithm to perform tokenization 261 | using the given vocabulary. 262 | 263 | For example: 264 | input = "unaffable" 265 | output = ["un", "##aff", "##able"] 266 | 267 | Args: 268 | text: A single token or whitespace separated tokens. This should have 269 | already been passed through `BasicTokenizer. 270 | 271 | Returns: 272 | A list of wordpiece tokens. 273 | """ 274 | 275 | text = convert_to_unicode(text) 276 | 277 | output_tokens = [] 278 | for token in whitespace_tokenize(text): 279 | chars = list(token) 280 | if len(chars) > self.max_input_chars_per_word: 281 | output_tokens.append(self.unk_token) 282 | continue 283 | 284 | is_bad = False 285 | start = 0 286 | sub_tokens = [] 287 | while start < len(chars): 288 | end = len(chars) 289 | cur_substr = None 290 | while start < end: 291 | substr = "".join(chars[start:end]) 292 | if start > 0: 293 | substr = "##" + substr 294 | if substr in self.vocab: 295 | cur_substr = substr 296 | break 297 | end -= 1 298 | if cur_substr is None: 299 | is_bad = True 300 | break 301 | sub_tokens.append(cur_substr) 302 | start = end 303 | 304 | if is_bad: 305 | output_tokens.append(self.unk_token) 306 | else: 307 | output_tokens.extend(sub_tokens) 308 | return output_tokens 309 | 310 | 311 | def _is_whitespace(char): 312 | """Checks whether `chars` is a whitespace character.""" 313 | # \t, \n, and \r are technically contorl characters but we treat them 314 | # as whitespace since they are generally considered as such. 315 | if char == " " or char == "\t" or char == "\n" or char == "\r": 316 | return True 317 | cat = unicodedata.category(char) 318 | if cat == "Zs": 319 | return True 320 | return False 321 | 322 | 323 | def _is_control(char): 324 | """Checks whether `chars` is a control character.""" 325 | # These are technically control characters but we count them as whitespace 326 | # characters. 327 | if char == "\t" or char == "\n" or char == "\r": 328 | return False 329 | cat = unicodedata.category(char) 330 | if cat.startswith("C"): 331 | return True 332 | return False 333 | 334 | 335 | def _is_punctuation(char): 336 | """Checks whether `chars` is a punctuation character.""" 337 | cp = ord(char) 338 | # We treat all non-letter/number ASCII as punctuation. 339 | # Characters such as "^", "$", and "`" are not in the Unicode 340 | # Punctuation class but we treat them as punctuation anyways, for 341 | # consistency. 342 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 343 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 344 | return True 345 | cat = unicodedata.category(char) 346 | if cat.startswith("P"): 347 | return True 348 | return False 349 | -------------------------------------------------------------------------------- /bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | 22 | import tokenization 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | 36 | vocab_file = vocab_writer.name 37 | 38 | tokenizer = tokenization.FullTokenizer(vocab_file) 39 | os.unlink(vocab_file) 40 | 41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 42 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 43 | 44 | self.assertAllEqual( 45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 46 | 47 | def test_chinese(self): 48 | tokenizer = tokenization.BasicTokenizer() 49 | 50 | self.assertAllEqual( 51 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 52 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 53 | 54 | def test_basic_tokenizer_lower(self): 55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 56 | 57 | self.assertAllEqual( 58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 59 | ["hello", "!", "how", "are", "you", "?"]) 60 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 61 | 62 | def test_basic_tokenizer_no_lower(self): 63 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 64 | 65 | self.assertAllEqual( 66 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 67 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 68 | 69 | def test_wordpiece_tokenizer(self): 70 | vocab_tokens = [ 71 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 72 | "##ing" 73 | ] 74 | 75 | vocab = {} 76 | for (i, token) in enumerate(vocab_tokens): 77 | vocab[token] = i 78 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 79 | 80 | self.assertAllEqual(tokenizer.tokenize(""), []) 81 | 82 | self.assertAllEqual( 83 | tokenizer.tokenize("unwanted running"), 84 | ["un", "##want", "##ed", "runn", "##ing"]) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 88 | 89 | def test_convert_tokens_to_ids(self): 90 | vocab_tokens = [ 91 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 92 | "##ing" 93 | ] 94 | 95 | vocab = {} 96 | for (i, token) in enumerate(vocab_tokens): 97 | vocab[token] = i 98 | 99 | self.assertAllEqual( 100 | tokenization.convert_tokens_to_ids( 101 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 102 | 103 | def test_is_whitespace(self): 104 | self.assertTrue(tokenization._is_whitespace(u" ")) 105 | self.assertTrue(tokenization._is_whitespace(u"\t")) 106 | self.assertTrue(tokenization._is_whitespace(u"\r")) 107 | self.assertTrue(tokenization._is_whitespace(u"\n")) 108 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 109 | 110 | self.assertFalse(tokenization._is_whitespace(u"A")) 111 | self.assertFalse(tokenization._is_whitespace(u"-")) 112 | 113 | def test_is_control(self): 114 | self.assertTrue(tokenization._is_control(u"\u0005")) 115 | 116 | self.assertFalse(tokenization._is_control(u"A")) 117 | self.assertFalse(tokenization._is_control(u" ")) 118 | self.assertFalse(tokenization._is_control(u"\t")) 119 | self.assertFalse(tokenization._is_control(u"\r")) 120 | 121 | def test_is_punctuation(self): 122 | self.assertTrue(tokenization._is_punctuation(u"-")) 123 | self.assertTrue(tokenization._is_punctuation(u"$")) 124 | self.assertTrue(tokenization._is_punctuation(u"`")) 125 | self.assertTrue(tokenization._is_punctuation(u".")) 126 | 127 | self.assertFalse(tokenization._is_punctuation(u"A")) 128 | self.assertFalse(tokenization._is_punctuation(u" ")) 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.test.main() 133 | -------------------------------------------------------------------------------- /conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | # @features = split(/\t/,$line); 86 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 87 | elsif ($nbrOfFeatures != $#features and @features != 0) { 88 | printf STDERR "unexpected number of features: %d (%d)\n", 89 | $#features+1,$nbrOfFeatures+1; 90 | exit(1); 91 | } 92 | if (@features == 0 or 93 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 94 | if (@features < 2) { 95 | printf STDERR "feature length is %d. \n", @features; 96 | die "conlleval: unexpected number of features in line $line\n"; 97 | } 98 | if ($raw) { 99 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 100 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 101 | if ($features[$#features] ne "O") { 102 | $features[$#features] = "B-$features[$#features]"; 103 | } 104 | if ($features[$#features-1] ne "O") { 105 | $features[$#features-1] = "B-$features[$#features-1]"; 106 | } 107 | } 108 | # 20040126 ET code which allows hyphens in the types 109 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 110 | $guessed = $1; 111 | $guessedType = $2; 112 | } else { 113 | $guessed = $features[$#features]; 114 | $guessedType = ""; 115 | } 116 | pop(@features); 117 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 118 | $correct = $1; 119 | $correctType = $2; 120 | } else { 121 | $correct = $features[$#features]; 122 | $correctType = ""; 123 | } 124 | pop(@features); 125 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 126 | # ($correct,$correctType) = split(/-/,pop(@features)); 127 | $guessedType = $guessedType ? $guessedType : ""; 128 | $correctType = $correctType ? $correctType : ""; 129 | $firstItem = shift(@features); 130 | 131 | # 1999-06-26 sentence breaks should always be counted as out of chunk 132 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 133 | 134 | if ($inCorrect) { 135 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 136 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 137 | $lastGuessedType eq $lastCorrectType) { 138 | $inCorrect=$false; 139 | $correctChunk++; 140 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 141 | $correctChunk{$lastCorrectType}+1 : 1; 142 | } elsif ( 143 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 144 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 145 | $guessedType ne $correctType ) { 146 | $inCorrect=$false; 147 | } 148 | } 149 | 150 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 151 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 152 | $guessedType eq $correctType) { $inCorrect = $true; } 153 | 154 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 155 | $foundCorrect++; 156 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 157 | $foundCorrect{$correctType}+1 : 1; 158 | } 159 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 160 | $foundGuessed++; 161 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 162 | $foundGuessed{$guessedType}+1 : 1; 163 | } 164 | if ( $firstItem ne $boundary ) { 165 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 166 | $correctTags++; 167 | } 168 | $tokenCounter++; 169 | } 170 | 171 | $lastGuessed = $guessed; 172 | $lastCorrect = $correct; 173 | $lastGuessedType = $guessedType; 174 | $lastCorrectType = $correctType; 175 | } 176 | if ($inCorrect) { 177 | $correctChunk++; 178 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 179 | $correctChunk{$lastCorrectType}+1 : 1; 180 | } 181 | 182 | if (not $latex) { 183 | # compute overall precision, recall and FB1 (default values are 0.0) 184 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 185 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 186 | $FB1 = 2*$precision*$recall/($precision+$recall) 187 | if ($precision+$recall > 0); 188 | 189 | # print overall performance 190 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 191 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 192 | if ($tokenCounter>0) { 193 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 194 | printf "precision: %6.2f%%; ",$precision; 195 | printf "recall: %6.2f%%; ",$recall; 196 | printf "FB1: %6.2f\n",$FB1; 197 | } 198 | } 199 | 200 | # sort chunk type names 201 | undef($lastType); 202 | @sortedTypes = (); 203 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 204 | if (not($lastType) or $lastType ne $i) { 205 | push(@sortedTypes,($i)); 206 | } 207 | $lastType = $i; 208 | } 209 | # print performance per chunk type 210 | if (not $latex) { 211 | for $i (@sortedTypes) { 212 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 213 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 214 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 215 | if (not($foundCorrect{$i})) { $recall = 0.0; } 216 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 217 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 218 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 219 | printf "%17s: ",$i; 220 | printf "precision: %6.2f%%; ",$precision; 221 | printf "recall: %6.2f%%; ",$recall; 222 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 223 | } 224 | } else { 225 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 226 | for $i (@sortedTypes) { 227 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 228 | if (not($foundGuessed{$i})) { $precision = 0.0; } 229 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 230 | if (not($foundCorrect{$i})) { $recall = 0.0; } 231 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 232 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 233 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 234 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 235 | $i,$precision,$recall,$FB1; 236 | } 237 | print "\\hline\n"; 238 | $precision = 0.0; 239 | $recall = 0; 240 | $FB1 = 0.0; 241 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 242 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 243 | $FB1 = 2*$precision*$recall/($precision+$recall) 244 | if ($precision+$recall > 0); 245 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 246 | $precision,$recall,$FB1; 247 | } 248 | 249 | exit 0; 250 | 251 | # endOfChunk: checks if a chunk ended between the previous and current word 252 | # arguments: previous and current chunk tags, previous and current types 253 | # note: this code is capable of handling other chunk representations 254 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 255 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 256 | 257 | sub endOfChunk { 258 | my $prevTag = shift(@_); 259 | my $tag = shift(@_); 260 | my $prevType = shift(@_); 261 | my $type = shift(@_); 262 | my $chunkEnd = $false; 263 | 264 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 266 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 267 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 268 | 269 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 271 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 272 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 273 | 274 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 275 | $chunkEnd = $true; 276 | } 277 | 278 | # corrected 1998-12-22: these chunks are assumed to have length 1 279 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 280 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 281 | 282 | return($chunkEnd); 283 | } 284 | 285 | # startOfChunk: checks if a chunk started between the previous and current word 286 | # arguments: previous and current chunk tags, previous and current types 287 | # note: this code is capable of handling other chunk representations 288 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 289 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 290 | 291 | sub startOfChunk { 292 | my $prevTag = shift(@_); 293 | my $tag = shift(@_); 294 | my $prevType = shift(@_); 295 | my $type = shift(@_); 296 | my $chunkStart = $false; 297 | 298 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 300 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 301 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 302 | 303 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 305 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 306 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 307 | 308 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 309 | $chunkStart = $true; 310 | } 311 | 312 | # corrected 1998-12-22: these chunks are assumed to have length 1 313 | if ( $tag eq "[" ) { $chunkStart = $true; } 314 | if ( $tag eq "]" ) { $chunkStart = $true; } 315 | 316 | return($chunkStart); 317 | } -------------------------------------------------------------------------------- /conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | # add function :evaluate(predicted_label, ori_label): which will not read from file ////ner评估 13 | 14 | import sys 15 | import re 16 | import codecs 17 | from collections import defaultdict, namedtuple 18 | 19 | ANY_SPACE = '' 20 | 21 | 22 | class FormatError(Exception): 23 | pass 24 | 25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 26 | 27 | 28 | class EvalCounts(object): 29 | def __init__(self): 30 | self.correct_chunk = 0 # number of correctly identified chunks 31 | self.correct_tags = 0 # number of correct chunk tags 32 | self.found_correct = 0 # number of chunks in corpus 33 | self.found_guessed = 0 # number of identified chunks 34 | self.token_counter = 0 # token counter (ignores sentence breaks) 35 | 36 | # counts by type 37 | self.t_correct_chunk = defaultdict(int) 38 | self.t_found_correct = defaultdict(int) 39 | self.t_found_guessed = defaultdict(int) 40 | 41 | 42 | def parse_args(argv): 43 | import argparse 44 | parser = argparse.ArgumentParser( 45 | description='evaluate tagging results using CoNLL criteria', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 47 | ) 48 | arg = parser.add_argument 49 | arg('-b', '--boundary', metavar='STR', default='-X-', 50 | help='sentence boundary') 51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 52 | help='character delimiting items in input') 53 | arg('-o', '--otag', metavar='CHAR', default='O', 54 | help='alternative outside tag') 55 | arg('file', nargs='?', default=None) 56 | return parser.parse_args(argv) 57 | 58 | 59 | def parse_tag(t): 60 | m = re.match(r'^([^-]*)-(.*)$', t) 61 | return m.groups() if m else (t, '') 62 | 63 | 64 | def evaluate(iterable, options=None): 65 | if options is None: 66 | options = parse_args([]) # use defaults 67 | 68 | counts = EvalCounts() 69 | num_features = None # number of features per line 70 | in_correct = False # currently processed chunks is correct until now 71 | last_correct = 'O' # previous chunk tag in corpus 72 | last_correct_type = '' # type of previously identified chunk tag 73 | last_guessed = 'O' # previously identified chunk tag 74 | last_guessed_type = '' # type of previous chunk tag in corpus 75 | 76 | for line in iterable: 77 | line = line.rstrip('\r\n') 78 | 79 | if options.delimiter == ANY_SPACE: 80 | features = line.split() 81 | else: 82 | features = line.split(options.delimiter) 83 | 84 | if num_features is None: 85 | num_features = len(features) 86 | elif num_features != len(features) and len(features) != 0: 87 | raise FormatError('unexpected number of features: %d (%d)' % 88 | (len(features), num_features)) 89 | 90 | if len(features) == 0 or features[0] == options.boundary: 91 | features = [options.boundary, 'O', 'O'] 92 | if len(features) < 3: 93 | raise FormatError('unexpected number of features in line %s' % line) 94 | 95 | guessed, guessed_type = parse_tag(features.pop()) 96 | correct, correct_type = parse_tag(features.pop()) 97 | first_item = features.pop(0) 98 | 99 | if first_item == options.boundary: 100 | guessed = 'O' 101 | 102 | end_correct = end_of_chunk(last_correct, correct, 103 | last_correct_type, correct_type) 104 | end_guessed = end_of_chunk(last_guessed, guessed, 105 | last_guessed_type, guessed_type) 106 | start_correct = start_of_chunk(last_correct, correct, 107 | last_correct_type, correct_type) 108 | start_guessed = start_of_chunk(last_guessed, guessed, 109 | last_guessed_type, guessed_type) 110 | 111 | if in_correct: 112 | if (end_correct and end_guessed and 113 | last_guessed_type == last_correct_type): 114 | in_correct = False 115 | counts.correct_chunk += 1 116 | counts.t_correct_chunk[last_correct_type] += 1 117 | elif (end_correct != end_guessed or guessed_type != correct_type): 118 | in_correct = False 119 | 120 | if start_correct and start_guessed and guessed_type == correct_type: 121 | in_correct = True 122 | 123 | if start_correct: 124 | counts.found_correct += 1 125 | counts.t_found_correct[correct_type] += 1 126 | if start_guessed: 127 | counts.found_guessed += 1 128 | counts.t_found_guessed[guessed_type] += 1 129 | if first_item != options.boundary: 130 | if correct == guessed and guessed_type == correct_type: 131 | counts.correct_tags += 1 132 | counts.token_counter += 1 133 | 134 | last_guessed = guessed 135 | last_correct = correct 136 | last_guessed_type = guessed_type 137 | last_correct_type = correct_type 138 | 139 | if in_correct: 140 | counts.correct_chunk += 1 141 | counts.t_correct_chunk[last_correct_type] += 1 142 | 143 | return counts 144 | 145 | 146 | 147 | def uniq(iterable): 148 | seen = set() 149 | return [i for i in iterable if not (i in seen or seen.add(i))] 150 | 151 | 152 | def calculate_metrics(correct, guessed, total): 153 | tp, fp, fn = correct, guessed-correct, total-correct 154 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 155 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 156 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 157 | return Metrics(tp, fp, fn, p, r, f) 158 | 159 | 160 | def metrics(counts): 161 | c = counts 162 | overall = calculate_metrics( 163 | c.correct_chunk, c.found_guessed, c.found_correct 164 | ) 165 | by_type = {} 166 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 167 | by_type[t] = calculate_metrics( 168 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 169 | ) 170 | return overall, by_type 171 | 172 | 173 | def report(counts, out=None): 174 | if out is None: 175 | out = sys.stdout 176 | 177 | overall, by_type = metrics(counts) 178 | 179 | c = counts 180 | out.write('processed %d tokens with %d phrases; ' % 181 | (c.token_counter, c.found_correct)) 182 | out.write('found: %d phrases; correct: %d.\n' % 183 | (c.found_guessed, c.correct_chunk)) 184 | 185 | if c.token_counter > 0: 186 | out.write('accuracy: %6.2f%%; ' % 187 | (100.*c.correct_tags/c.token_counter)) 188 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 189 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 190 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 191 | 192 | for i, m in sorted(by_type.items()): 193 | out.write('%17s: ' % i) 194 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 195 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 196 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 197 | 198 | 199 | def report_notprint(counts, out=None): 200 | if out is None: 201 | out = sys.stdout 202 | 203 | overall, by_type = metrics(counts) 204 | 205 | c = counts 206 | final_report = [] 207 | line = [] 208 | line.append('processed %d tokens with %d phrases; ' % 209 | (c.token_counter, c.found_correct)) 210 | line.append('found: %d phrases; correct: %d.\n' % 211 | (c.found_guessed, c.correct_chunk)) 212 | final_report.append("".join(line)) 213 | 214 | if c.token_counter > 0: 215 | line = [] 216 | line.append('accuracy: %6.2f%%; ' % 217 | (100.*c.correct_tags/c.token_counter)) 218 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 219 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 220 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 221 | final_report.append("".join(line)) 222 | 223 | for i, m in sorted(by_type.items()): 224 | line = [] 225 | line.append('%17s: ' % i) 226 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 227 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 228 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 229 | final_report.append("".join(line)) 230 | return final_report 231 | 232 | 233 | def end_of_chunk(prev_tag, tag, prev_type, type_): 234 | # check if a chunk ended between the previous and current word 235 | # arguments: previous and current chunk tags, previous and current types 236 | chunk_end = False 237 | 238 | if prev_tag == 'E': chunk_end = True 239 | if prev_tag == 'S': chunk_end = True 240 | 241 | if prev_tag == 'B' and tag == 'B': chunk_end = True 242 | if prev_tag == 'B' and tag == 'S': chunk_end = True 243 | if prev_tag == 'B' and tag == 'O': chunk_end = True 244 | if prev_tag == 'I' and tag == 'B': chunk_end = True 245 | if prev_tag == 'I' and tag == 'S': chunk_end = True 246 | if prev_tag == 'I' and tag == 'O': chunk_end = True 247 | 248 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 249 | chunk_end = True 250 | 251 | # these chunks are assumed to have length 1 252 | if prev_tag == ']': chunk_end = True 253 | if prev_tag == '[': chunk_end = True 254 | 255 | return chunk_end 256 | 257 | 258 | def start_of_chunk(prev_tag, tag, prev_type, type_): 259 | # check if a chunk started between the previous and current word 260 | # arguments: previous and current chunk tags, previous and current types 261 | chunk_start = False 262 | 263 | if tag == 'B': chunk_start = True 264 | if tag == 'S': chunk_start = True 265 | 266 | if prev_tag == 'E' and tag == 'E': chunk_start = True 267 | if prev_tag == 'E' and tag == 'I': chunk_start = True 268 | if prev_tag == 'S' and tag == 'E': chunk_start = True 269 | if prev_tag == 'S' and tag == 'I': chunk_start = True 270 | if prev_tag == 'O' and tag == 'E': chunk_start = True 271 | if prev_tag == 'O' and tag == 'I': chunk_start = True 272 | 273 | if tag != 'O' and tag != '.' and prev_type != type_: 274 | chunk_start = True 275 | 276 | # these chunks are assumed to have length 1 277 | if tag == '[': chunk_start = True 278 | if tag == ']': chunk_start = True 279 | 280 | return chunk_start 281 | 282 | 283 | def return_report(input_file): 284 | with codecs.open(input_file, "r", "utf8") as f: 285 | counts = evaluate(f) 286 | return report_notprint(counts) 287 | 288 | 289 | def main(argv): 290 | args = parse_args(argv[1:]) 291 | 292 | if args.file is None: 293 | counts = evaluate(sys.stdin, args) 294 | else: 295 | with open(args.file) as f: 296 | counts = evaluate(f, args) 297 | report(counts) 298 | 299 | if __name__ == '__main__': 300 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /fujc.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | from subprocess import PIPE, Popen 4 | 5 | # 这几行是获取 子程序 可执行文件的位置,没有啥意义 6 | cur_file = os.path.realpath(__file__) 7 | cur_path = os.path.dirname(cur_file) 8 | exec_file = cur_path + '/' + 'subpcs.py' 9 | # 假设目的程序是 B.py 10 | # 注意:代码在 Unix 类操作系统上正常,对于 Windows 不知道,可能结果错误 11 | 12 | # 下面这行是创建子进程,在终端里就是 python B.py。 输入输出是管道。 13 | p = Popen("python " + exec_file, shell=True, stdout=PIPE, stdin=PIPE) 14 | print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") 15 | # 假设下面是程序里用户交互部分: 16 | while True: 17 | some_str = "可能是用户输的字符串,作为标准输入。\n" 18 | # 注意结尾一定要有换行符,因为标准输入是通过换行符结束的 19 | # 因此弊端就是 如果输入中就有换行符,需要自己处理,比如 20 | # 替换为 ||换行||,再在子程序里处理 21 | 22 | # 转 utf8 23 | p.stdin.write(bytes(some_str, encoding='utf8')) 24 | p.stdin.flush() 25 | # 至此,some_str 已经被传输到了 B 的标准输入流 26 | # 参照下面 B.py 看 27 | out_bytes = p.stdout.readline() 28 | # 至此,B 的标准输出流的一行 被读到 out_bytes 29 | 30 | print(str(out_bytes, encoding='utf8')) 31 | 32 | import time 33 | 34 | time.sleep(1) 35 | -------------------------------------------------------------------------------- /global_config.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 import os import logging from logging import handlers os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8' class Logger(object): # https://www.cnblogs.com/nancyzhu/p/8551506.html # 日志级别关系映射 level_relations = { 'debug': logging.DEBUG, 'info': logging.INFO, 'warning': logging.WARNING, 'error': logging.ERROR, 'crit': logging.CRITICAL } def __init__(self, filename, level='info', when='D', backCount=3, fmt='%(asctime)s|%(filename)s[line:%(lineno)d]|%(funcName)s|%(levelname)s|%(message)s'): self.logger = logging.getLogger(filename) # 设置日志格式 format_str = logging.Formatter(fmt) # 设置日志级别 self.logger.setLevel(self.level_relations.get(level)) # 往屏幕上输出 sh = logging.StreamHandler() # 设置屏幕上显示的格式 sh.setFormatter(format_str) # 往文件里写入#指定间隔时间自动生成文件的处理器 th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount, encoding='utf-8') # 实例化TimedRotatingFileHandler # interval是时间间隔,backupCount是备份文件的个数,如果超过这个个数,就会自动删除,when是间隔的时间单位,单位有以下几种: # S 秒 # M 分 # H 小时、 # D 天、 # W 每星期(interval==0时代表星期一) # midnight 每天凌晨 # 设置文件里写入的格式 th.setFormatter(format_str) # 把对象加到logger里 self.logger.addHandler(sh) self.logger.addHandler(th) #loginfo = Logger("recommend_articles.log", "info") if __name__ == '__main__': log = Logger('all.log',level='debug') log.logger.debug('debug') log.logger.info('info') log.logger.warning('警告') log.logger.error('报错') log.logger.critical('严重') Logger('error.log', level='error').logger.error('error') -------------------------------------------------------------------------------- /his.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import os 4 | import sys 5 | import subprocess 6 | 7 | 8 | # call mib2c 9 | def callmib(): 10 | exestr = 'python zijc.py' 11 | p = subprocess.Popen(exestr, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) 12 | 13 | while True: 14 | line = p.stdout.readline() 15 | print 16 | line 17 | if not line: 18 | break 19 | if line.startswith('this is opt1'): 20 | opt = 'aaa' 21 | elif line.startswith('this is opt2'): 22 | opt = 'bbb' 23 | 24 | p.stdin.write(opt + os.linesep) 25 | 26 | p.wait() 27 | errout = p.stderr.read() 28 | p.stdout.close() 29 | p.stderr.close() 30 | p.stdin.close() 31 | print 32 | 33 | 34 | # main funciton 35 | # Step 1. make C files first time 36 | callmib() -------------------------------------------------------------------------------- /image/KB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-YULU/KBQA-BERT/7150d7ef00f2ec0a28ffb11fd7d6b30c84a28eb9/image/KB.png -------------------------------------------------------------------------------- /image/NER.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-YULU/KBQA-BERT/7150d7ef00f2ec0a28ffb11fd7d6b30c84a28eb9/image/NER.jpg -------------------------------------------------------------------------------- /kbqa_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import io 4 | import re 5 | import time 6 | import jieba 7 | import numpy as np 8 | import pandas as pd 9 | import urllib.request 10 | import urllib.parse 11 | import tensorflow as tf 12 | from Data.load_dbdata import upload_data 13 | from global_config import Logger 14 | 15 | from run_similarity import BertSim 16 | # 模块导入 https://blog.csdn.net/xiongchengluo1129/article/details/80453599 17 | 18 | loginfo = Logger("recommend_articles.log", "info") 19 | file = "./Data/NER_Data/q_t_a_df_testing.csv" 20 | 21 | bs = BertSim() 22 | bs.set_mode(tf.estimator.ModeKeys.PREDICT) 23 | 24 | 25 | def dataset_test(): 26 | ''' 27 | 用训练问答对中的实体+属性,去知识库中进行问答测试准确率上限 28 | :return: 29 | ''' 30 | with open(file) as f: 31 | total = 0 32 | recall = 0 33 | correct = 0 34 | 35 | for line in f: 36 | question, entity, attribute, answer, ner = line.split("\t") 37 | ner = ner.replace("#", "").replace("[UNK]", "%") 38 | # case1: entity and attribute Exact Match 39 | sql_e1_a1 = "select * from nlpccQA where entity='"+entity+"' and attribute='"+attribute+"' limit 10" 40 | result_e1_a1 = upload_data(sql_e1_a1) 41 | 42 | # case2: entity Fuzzy Match and attribute Exact Match 43 | sql_e0_a1 = "select * from nlpccQA where entity like '%" + entity + "%' and attribute='" + attribute + "' limit 10" 44 | #result_e0_a1 = upload_data(sql_e0_a1, True) 45 | 46 | # case3: entity Exact Match and attribute Fuzzy Match 47 | sql_e1_a0 = "select * from nlpccQA where entity like '" + entity + "' and attribute='%" + attribute + "%' limit 10" 48 | #result_e1_a0 = upload_data(sql_e1_a0) 49 | 50 | if len(result_e1_a1) > 0: 51 | recall += 1 52 | for l in result_e1_a1: 53 | if l[2] == answer: 54 | correct += 1 55 | else: 56 | result_e0_a1 = upload_data(sql_e0_a1) 57 | if len(result_e0_a1) > 0: 58 | recall += 1 59 | for l in result_e0_a1: 60 | if l[2] == answer: 61 | correct += 1 62 | else: 63 | result_e1_a0 = upload_data(sql_e1_a0) 64 | if len(result_e1_a0) > 0: 65 | recall += 1 66 | for l in result_e1_a0: 67 | if l[2] == answer: 68 | correct += 1 69 | else: 70 | loginfo.logger.info(sql_e1_a0) 71 | if total > 100: 72 | break 73 | total += 1 74 | time.sleep(1) 75 | loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%".format(total, recall, correct, correct * 100.0 / recall)) 76 | #loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%".format(total, recall, correct, correct*100.0/recall)) 77 | 78 | 79 | def estimate_answer(candidate, answer): 80 | ''' 81 | :param candidate: 82 | :param answer: 83 | :return: 84 | ''' 85 | candidate = candidate.strip().lower() 86 | answer = answer.strip().lower() 87 | if candidate == answer: 88 | return True 89 | 90 | if not answer.isdigit() and candidate.isdigit(): 91 | candidate_temp = "{:.5E}".format(int(candidate)) 92 | if candidate_temp == answer: 93 | return True 94 | candidate_temp == "{:.4E}".format(int(candidate)) 95 | if candidate_temp == answer: 96 | return True 97 | 98 | return False 99 | 100 | 101 | def kb_fuzzy_classify_test(): 102 | ''' 103 | 进行问答测试: 104 | 1、 实体检索:输入问题,ner得出实体集合,在数据库中检索与输入实体相关的所有三元组 105 | 2、 属性映射——bert分类/文本相似度 106 | + 非语义匹配:如果所得三元组的关系(attribute)属性是 输入问题 字符串的子集,将所得三元组的答案(answer)属性与正确答案匹配,correct +1 107 | + 语义匹配:利用bert计算输入问题(input question)与所得三元组的关系(attribute)属性的相似度,将最相似的三元组的答案作为答案,并与正确 108 | 的答案进行匹配,correct +1 109 | 3、 答案组合 110 | :return: 111 | ''' 112 | with open(file, encoding='utf-8') as f: 113 | total = 0 114 | recall = 0 115 | correct = 0 116 | ambiguity = 0 # 属性匹配正确但是答案不正确 117 | 118 | for line in f: 119 | try: 120 | total += 1 121 | question, entity, attribute, answer, ner = line.split("\t") 122 | ner = ner.replace("#", "").replace("[UNK]", "%").replace("\n", "") 123 | # case: entity Fuzzy Match 124 | # 找出所有包含这些实体的三元组 125 | sql_e0_a1 = "select * from nlpccQA where entity like '%" + ner + "%' order by length(entity) asc limit 20" 126 | # sql查出来的是tuple,要转换成list才不会报错 127 | result_e0_a1 = list(upload_data(sql_e0_a1)) 128 | 129 | if len(result_e0_a1) > 0: 130 | recall += 1 131 | 132 | flag_fuzzy = True 133 | # 非语义匹配,加快速度 134 | # l1[0]: entity 135 | # l1[1]: attribute 136 | # l1[2]: answer 137 | flag_ambiguity = True 138 | for l in result_e0_a1: 139 | if l[1] in question or l[1].lower() in question or l[1].upper() in question: 140 | flag_fuzzy = False 141 | 142 | if estimate_answer(l[2], answer): 143 | correct += 1 144 | flag_ambiguity = False 145 | else: 146 | loginfo.logger.info("\t".join(l)) 147 | 148 | # 非语义匹配成功,继续下一次 149 | if not flag_fuzzy: 150 | 151 | if flag_ambiguity: 152 | ambiguity += 1 153 | 154 | time.sleep(1) 155 | loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%, ambiguity:{}".format(total, recall, correct, correct * 100.0 / recall, ambiguity)) 156 | continue 157 | 158 | # 语义匹配 159 | result_df = pd.DataFrame(result_e0_a1, columns=['entity', 'attribute', 'value']) 160 | # loginfo.logger.info(result_df.head(100)) 161 | 162 | attribute_candicate_sim = [(k, bs.predict(question, k)[0][1]) for k in result_df['attribute'].tolist()] 163 | attribute_candicate_sort = sorted(attribute_candicate_sim, key=lambda candicate: candicate[1], reverse=True) 164 | loginfo.logger.info("\n".join([str(k)+" "+str(v) for (k, v) in attribute_candicate_sort])) 165 | 166 | answer_candicate_df = result_df[result_df["attribute"] == attribute_candicate_sort[0][0]] 167 | for row in answer_candicate_df.index: 168 | if estimate_answer(answer_candicate_df.loc[row, "value"], answer): 169 | correct += 1 170 | else: 171 | loginfo.logger.info("\t".join(answer_candicate_df.loc[row].tolist())) 172 | time.sleep(1) 173 | loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%, ambiguity:{}".format(total, recall, correct, correct * 100.0 / recall, ambiguity)) 174 | except Exception as e: 175 | loginfo.logger.info("the question id % d occur error %s" % (total, repr(e))) 176 | 177 | 178 | if __name__ == '__main__': 179 | kb_fuzzy_classify_test() -------------------------------------------------------------------------------- /kl.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.system('sh terminal_ner.sh') 3 | -------------------------------------------------------------------------------- /lstm_crf_layer.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | bert-blstm-crf layer 5 | @Author:Macan 6 | """ 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib import rnn 10 | from tensorflow.contrib import crf 11 | 12 | class BLSTM_CRF(object): 13 | def __init__(self, embedded_chars, hidden_unit, cell_type, num_layers, dropout_rate, 14 | initializers, num_labels, seq_length, labels, lengths, is_training): 15 | """ 16 | BLSTM-CRF 网络 17 | :param embedded_chars: Fine-tuning embedding input 18 | :param hidden_unit: LSTM的隐含单元个数 19 | :param cell_type: RNN类型(LSTM OR GRU DICNN will be add in feature) 20 | :param num_layers: RNN的层数 21 | :param droupout_rate: droupout rate 22 | :param initializers: variable init class 23 | :param num_labels: 标签数量 24 | :param seq_length: 序列最大长度 25 | :param labels: 真实标签 26 | :param lengths: [batch_size] 每个batch下序列的真实长度 27 | :param is_training: 是否是训练过程 28 | """ 29 | self.hidden_unit = hidden_unit 30 | self.dropout_rate = dropout_rate 31 | self.cell_type = cell_type 32 | self.num_layers = num_layers 33 | self.embedded_chars = embedded_chars 34 | self.initializers = initializers 35 | self.seq_length = seq_length 36 | self.num_labels = num_labels 37 | self.labels = labels 38 | self.lengths = lengths 39 | self.embedding_dims = embedded_chars.shape[-1].value 40 | self.is_training = is_training 41 | 42 | def add_blstm_crf_layer(self, crf_only): 43 | 44 | """ 45 | blstm-crf网络 46 | :return: 47 | """ 48 | if self.is_training: 49 | # lstm input dropout rate i set 0.9 will get best score 50 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.dropout_rate) 51 | 52 | if crf_only: 53 | logits = self.project_crf_layer(self.embedded_chars) 54 | else: 55 | #blstm 56 | lstm_output = self.blstm_layer(self.embedded_chars) 57 | #project 58 | logits = self.project_bilstm_layer(lstm_output) 59 | #crf 60 | loss, trans = self.crf_layer(logits) 61 | # CRF decode, pred_ids 是一条最大概率的标注路径 62 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths) 63 | return ((loss, logits, trans, pred_ids)) 64 | 65 | def _witch_cell(self): 66 | """ 67 | RNN 类型 68 | :return: 69 | """ 70 | cell_tmp = None 71 | if self.cell_type == 'lstm': 72 | cell_tmp = rnn.BasicLSTMCell(self.hidden_unit) 73 | elif self.cell_type == 'gru': 74 | cell_tmp = rnn.GRUCell(self.hidden_unit) 75 | # 是否需要进行dropout 76 | if self.dropout_rate is not None: 77 | cell_tmp = rnn.DropoutWrapper(cell_tmp, output_keep_prob=self.dropout_rate) 78 | return cell_tmp 79 | 80 | def _bi_dir_rnn(self): 81 | """ 82 | 双向RNN 83 | :return: 84 | """ 85 | cell_fw = self._witch_cell() 86 | cell_bw = self._witch_cell() 87 | return cell_fw, cell_bw 88 | 89 | def blstm_layer(self, embedding_chars): 90 | """ 91 | 92 | :return: 93 | """ 94 | with tf.variable_scope('rnn_layer'): 95 | cell_fw, cell_bw = self._bi_dir_rnn() 96 | if self.num_layers > 1: 97 | cell_fw = rnn.MultiRNNCell([cell_fw] * self.num_layers, state_is_tuple=True) 98 | cell_bw = rnn.MultiRNNCell([cell_bw] * self.num_layers, state_is_tuple=True) 99 | 100 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, embedding_chars, 101 | dtype=tf.float32) 102 | outputs = tf.concat(outputs, axis=2) 103 | return outputs 104 | 105 | def project_bilstm_layer(self, lstm_outputs, name=None): 106 | """ 107 | hidden layer between lstm layer and logits 108 | :param lstm_outputs: [batch_size, num_steps, emb_size] 109 | :return: [batch_size, num_steps, num_tags] 110 | """ 111 | with tf.variable_scope("project" if not name else name): 112 | with tf.variable_scope("hidden"): 113 | W = tf.get_variable("W", shape=[self.hidden_unit * 2, self.hidden_unit], 114 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 115 | 116 | b = tf.get_variable("b", shape=[self.hidden_unit], dtype=tf.float32, 117 | initializer=tf.zeros_initializer()) 118 | output = tf.reshape(lstm_outputs, shape=[-1, self.hidden_unit * 2]) 119 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 120 | 121 | # project to score of tags 122 | with tf.variable_scope("logits"): 123 | W = tf.get_variable("W", shape=[self.hidden_unit, self.num_labels], 124 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 125 | 126 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 127 | initializer=tf.zeros_initializer()) 128 | 129 | pred = tf.nn.xw_plus_b(hidden, W, b) 130 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 131 | 132 | def project_crf_layer(self, embedding_chars, name=None): 133 | """ 134 | hidden layer between input layer and logits 135 | :param lstm_outputs: [batch_size, num_steps, emb_size] 136 | :return: [batch_size, num_steps, num_tags] 137 | """ 138 | with tf.variable_scope("project" if not name else name): 139 | with tf.variable_scope("logits"): 140 | W = tf.get_variable("W", shape=[self.embedding_dims, self.num_labels], 141 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 142 | 143 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 144 | initializer=tf.zeros_initializer()) 145 | output = tf.reshape(self.embedded_chars, shape=[-1, self.embedding_dims]) #[batch_size, embedding_dims] 146 | pred = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 147 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 148 | 149 | def crf_layer(self, logits): 150 | """ 151 | calculate crf loss 152 | :param project_logits: [1, num_steps, num_tags] 153 | :return: scalar loss 154 | """ 155 | with tf.variable_scope("crf_loss"): 156 | trans = tf.get_variable( 157 | "transitions", 158 | shape=[self.num_labels, self.num_labels], 159 | initializer=self.initializers.xavier_initializer()) 160 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood( 161 | inputs=logits, 162 | tag_indices=self.labels, 163 | transition_params=trans, 164 | sequence_lengths=self.lengths) 165 | return tf.reduce_mean(-log_likelihood), trans -------------------------------------------------------------------------------- /neo4j_qa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import subprocess 5 | import io 6 | import re 7 | import time 8 | import jieba 9 | import numpy as np 10 | import pandas as pd 11 | import urllib.request 12 | import urllib.parse 13 | import tensorflow as tf 14 | from Data.load_dbdata import upload_data 15 | from global_config import Logger 16 | 17 | from run_similarity import BertSim 18 | # 模块导入 https://blog.csdn.net/xiongchengluo1129/article/details/80453599 19 | 20 | loginfo = Logger("answer.log", "info") 21 | s= subprocess.Popen(['python','prt.py'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 22 | output,err = s.communicate(b'wkdnlkadn') 23 | #print(s.stdout.read()) 24 | 25 | ''' 26 | nextline=s.stdout.readlines() 27 | 28 | for line in nextline: 29 | line = line.decode('gbk') 30 | #if line.strip() == '输入:': 31 | print(line.strip()) 32 | ''' 33 | ''' 34 | a = os.system('sh terminal_ner.sh') 35 | print(a) 36 | ''' -------------------------------------------------------------------------------- /prt.py: -------------------------------------------------------------------------------- 1 | while 1: 2 | try: 3 | a = input('输入: ') 4 | print('这句话的第二个字是:',a[1]) 5 | except: 6 | break -------------------------------------------------------------------------------- /qa_my.sh: -------------------------------------------------------------------------------- 1 | python qa_my.py \ 2 | --task_name=ner \ 3 | --data_dir=./Data/NER_Data \ 4 | --vocab_file=./ModelParams/chinese_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file=./ModelParams/chinese_L-12_H-768_A-12/bert_config.json \ 6 | --output_dir=./Output/NER \ 7 | --init_checkpoint=./ModelParams/chinese_L-12_H-768_A-12/bert_model.ckpt \ 8 | --data_config_path=./Config/NER/ner_data.conf \ 9 | --do_train=True \ 10 | --do_eval=True \ 11 | --max_seq_length=128 \ 12 | --lstm_size=128 \ 13 | --num_layers=1 \ 14 | --train_batch_size=64 \ 15 | --eval_batch_size=8 \ 16 | --predict_batch_size=8 \ 17 | --learning_rate=5e-5 \ 18 | --num_train_epochs=1 \ 19 | --droupout_rate=0.5 \ 20 | --clip=5 \ 21 | --do_predict_online=True -------------------------------------------------------------------------------- /recommend_articles.log.2019-08-06: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-YULU/KBQA-BERT/7150d7ef00f2ec0a28ffb11fd7d6b30c84a28eb9/recommend_articles.log.2019-08-06 -------------------------------------------------------------------------------- /run_ner.sh: -------------------------------------------------------------------------------- 1 | python run_ner.py \ 2 | --task_name=ner \ 3 | --data_dir=./Data/NER_Data \ 4 | --vocab_file=./ModelParams/chinese_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file=./ModelParams/chinese_L-12_H-768_A-12/bert_config.json \ 6 | --output_dir=./Output/NER \ 7 | --init_checkpoint=./ModelParams/chinese_L-12_H-768_A-12/bert_model.ckpt \ 8 | --data_config_path=./Config/NER/ner_data.conf \ 9 | --do_train=True \ 10 | --do_eval=True \ 11 | --max_seq_length=32 \ 12 | --lstm_size=128 \ 13 | --num_layers=1 \ 14 | --train_batch_size=16 \ 15 | --eval_batch_size=8 \ 16 | --predict_batch_size=8 \ 17 | --learning_rate=5e-5 \ 18 | --num_train_epochs=20 \ 19 | --droupout_rate=0.5 \ 20 | --clip=5 21 | -------------------------------------------------------------------------------- /subpcs.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import sys 3 | 4 | while True: 5 | # 这里的 txt,就是 A write 的数据,标准输入 6 | txt = input() 7 | 8 | # 这里是业务处理 9 | v = txt + ' from B' 10 | 11 | # 标准输出,在别的可执行程序里可能是 echo 等,结果都流到了 A 中 12 | print(v) 13 | sys.stdout.flush() 14 | -------------------------------------------------------------------------------- /terminal_ner.sh: -------------------------------------------------------------------------------- 1 | python terminal_predict.py \ 2 | --task_name=ner \ 3 | --data_dir=./Data/NER_Data \ 4 | --vocab_file=./ModelParams/chinese_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file=./ModelParams/chinese_L-12_H-768_A-12/bert_config.json \ 6 | --output_dir=./Output/NER \ 7 | --init_checkpoint=./ModelParams/chinese_L-12_H-768_A-12/bert_model.ckpt \ 8 | --data_config_path=./Config/NER/ner_data.conf \ 9 | --do_train=True \ 10 | --do_eval=True \ 11 | --max_seq_length=128 \ 12 | --lstm_size=128 \ 13 | --num_layers=1 \ 14 | --train_batch_size=64 \ 15 | --eval_batch_size=8 \ 16 | --predict_batch_size=8 \ 17 | --learning_rate=5e-5 \ 18 | --num_train_epochs=1 \ 19 | --droupout_rate=0.5 \ 20 | --clip=5 \ 21 | --do_predict_online=True \ 22 | 23 | -------------------------------------------------------------------------------- /terminal_predict.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | 基于命令行的在线预测方法 5 | @Author: Macan (ma_cancan@163.com) 6 | """ 7 | import pandas as pd 8 | import tensorflow as tf 9 | import numpy as np 10 | import codecs 11 | import pickle 12 | import os 13 | from datetime import time, timedelta, datetime 14 | 15 | from run_ner import create_model, InputFeatures, InputExample 16 | from bert import tokenization 17 | from bert import modeling 18 | 19 | from neo4j import GraphDatabase #yueuu 20 | 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 23 | 24 | flags = tf.flags 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | flags.DEFINE_bool( 29 | "do_predict_outline", False, 30 | "Whether to do predict outline." 31 | ) 32 | flags.DEFINE_bool( 33 | "do_predict_online", False, 34 | "Whether to do predict online." 35 | ) 36 | 37 | # init mode and session 38 | # move something codes outside of function, so that this code will run only once during online prediction when predict_online is invoked. 39 | is_training=False 40 | use_one_hot_embeddings=False 41 | batch_size=1 42 | 43 | gpu_config = tf.ConfigProto() 44 | gpu_config.gpu_options.allow_growth = True 45 | sess=tf.Session(config=gpu_config) 46 | model=None 47 | 48 | global graph 49 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None 50 | print(FLAGS.output_dir) 51 | print('checkpoint path:{}'.format(os.path.join(FLAGS.output_dir, "checkpoint"))) 52 | if not os.path.exists(os.path.join(FLAGS.output_dir, "checkpoint")): 53 | raise Exception("failed to get checkpoint. going to return ") 54 | 55 | # 加载label->id的词典 56 | with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf: 57 | label2id = pickle.load(rf) 58 | id2label = {value: key for key, value in label2id.items()} 59 | 60 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'rb') as rf: 61 | label_list = pickle.load(rf) 62 | num_labels = len(label_list) + 1 63 | 64 | graph = tf.get_default_graph() 65 | with graph.as_default(): 66 | print("going to restore checkpoint") 67 | #sess.run(tf.global_variables_initializer()) 68 | input_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_ids") 69 | input_mask_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_mask") 70 | label_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="label_ids") 71 | segment_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="segment_ids") 72 | 73 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 74 | (total_loss, logits, trans, pred_ids) = create_model( 75 | bert_config, is_training, input_ids_p, input_mask_p, segment_ids_p, 76 | label_ids_p, num_labels, use_one_hot_embeddings) 77 | 78 | saver = tf.train.Saver() 79 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.output_dir)) 80 | 81 | 82 | tokenizer = tokenization.FullTokenizer( 83 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 84 | 85 | 86 | def predict_online(): 87 | """ 88 | do online prediction. each time make prediction for one instance. 89 | you can change to a batch if you want. 90 | 91 | :param line: a list. element is: [dummy_label,text_a,text_b] 92 | :return: 93 | """ 94 | 95 | def convert(line): 96 | feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p') 97 | input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length)) 98 | input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length)) 99 | segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length)) 100 | label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length)) 101 | return input_ids, input_mask, segment_ids, label_ids 102 | 103 | global graph 104 | with graph.as_default(): 105 | print(id2label) 106 | while True: 107 | print('input the test sentence:') 108 | sentence = str(input()) 109 | start = datetime.now() 110 | if len(sentence) < 2: 111 | print(sentence) 112 | continue 113 | sentence = tokenizer.tokenize(sentence) 114 | # print('your input is:{}'.format(sentence)) 115 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 116 | 117 | feed_dict = {input_ids_p: input_ids, 118 | input_mask_p: input_mask, 119 | segment_ids_p:segment_ids, 120 | label_ids_p:label_ids} 121 | # run session get current feed_dict result 122 | pred_ids_result = sess.run([pred_ids], feed_dict) 123 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 124 | print(pred_label_result) 125 | #todo: 组合策略 126 | result = strage_combined_link_org_loc(sentence, pred_label_result[0], True) 127 | print('识别的实体有:{}'.format(' '.join(result))) 128 | print('Time used: {} sec'.format((datetime.now() - start).seconds)) 129 | 130 | # yueuu 131 | driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "Nic180319")) 132 | 133 | 134 | 135 | def predict_outline(): 136 | """ 137 | do online prediction. each time make prediction for one instance. 138 | you can change to a batch if you want. 139 | 140 | :param line: a list. element is: [dummy_label,text_a,text_b] 141 | :return: 142 | """ 143 | def convert(line): 144 | feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p') 145 | input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length)) 146 | input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length)) 147 | segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length)) 148 | label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length)) 149 | return input_ids, input_mask, segment_ids, label_ids 150 | 151 | global graph 152 | with graph.as_default(): 153 | start = datetime.now() 154 | nlpcc_test_data = pd.read_csv("./Data/NER_Data/q_t_a_df_testing.csv") 155 | correct = 0 156 | test_size = nlpcc_test_data.shape[0] 157 | nlpcc_test_result = [] 158 | 159 | for row in nlpcc_test_data.index: 160 | question = nlpcc_test_data.loc[row,"q_str"] 161 | entity = nlpcc_test_data.loc[row,"t_str"].split("|||")[0].split(">")[1].strip() 162 | attribute = nlpcc_test_data.loc[row, "t_str"].split("|||")[1].strip() 163 | answer = nlpcc_test_data.loc[row, "t_str"].split("|||")[2].strip() 164 | 165 | sentence = str(question) 166 | start = datetime.now() 167 | if len(sentence) < 2: 168 | print(sentence) 169 | continue 170 | sentence = tokenizer.tokenize(sentence) 171 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 172 | 173 | feed_dict = {input_ids_p: input_ids, 174 | input_mask_p: input_mask, 175 | segment_ids_p:segment_ids, 176 | label_ids_p:label_ids} 177 | # run session get current feed_dict result 178 | pred_ids_result = sess.run([pred_ids], feed_dict) 179 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 180 | # print(pred_label_result) 181 | #todo: 组合策略 182 | result = strage_combined_link_org_loc(sentence, pred_label_result[0], False) 183 | if entity in result: 184 | correct += 1 185 | nlpcc_test_result.append(question+"\t"+entity+"\t"+attribute+"\t"+answer+"\t"+','.join(result)) 186 | with open("./Data/NER_Data/q_t_a_testing_predict.txt", "w") as f: 187 | f.write("\n".join(nlpcc_test_result)) 188 | print("accuracy: {}%, correct: {}, total: {}".format(correct*100.0/float(test_size), correct, test_size)) 189 | print('Time used: {} sec'.format((datetime.now() - start).seconds)) 190 | 191 | 192 | def convert_id_to_label(pred_ids_result, idx2label): 193 | """ 194 | 将id形式的结果转化为真实序列结果 195 | :param pred_ids_result: 196 | :param idx2label: 197 | :return: 198 | """ 199 | result = [] 200 | for row in range(batch_size): 201 | curr_seq = [] 202 | for ids in pred_ids_result[row][0]: 203 | if ids == 0: 204 | break 205 | curr_label = idx2label[ids] 206 | if curr_label in ['[CLS]', '[SEP]']: 207 | continue 208 | curr_seq.append(curr_label) 209 | result.append(curr_seq) 210 | return result 211 | 212 | 213 | def strage_combined_link_org_loc(tokens, tags, flag): 214 | """ 215 | 组合策略 216 | :param pred_label_result: 217 | :param types: 218 | :return: 219 | """ 220 | def print_output(data, type): 221 | line = [] 222 | for i in data: 223 | line.append(i.word) 224 | print('{}: {}'.format(type, ', '.join(line))) 225 | 226 | def string_output(data): 227 | line = [] 228 | for i in data: 229 | line.append(i.word) 230 | return line 231 | 232 | params = None 233 | eval = Result(params) 234 | if len(tokens) > len(tags): 235 | tokens = tokens[:len(tags)] 236 | person, loc, org = eval.get_result(tokens, tags) 237 | if flag: 238 | if len(loc) != 0: 239 | print_output(loc, 'LOC') 240 | if len(person) != 0: 241 | print_output(person, 'PER') 242 | if len(org) != 0: 243 | print_output(org, 'ORG') 244 | person_list = string_output(person) 245 | person_list.extend(string_output(loc)) 246 | person_list.extend(string_output(org)) 247 | return person_list 248 | 249 | 250 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): 251 | """ 252 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 253 | :param ex_index: index 254 | :param example: 一个样本 255 | :param label_list: 标签列表 256 | :param max_seq_length: 257 | :param tokenizer: 258 | :param mode: 259 | :return: 260 | """ 261 | label_map = {} 262 | # 1表示从1开始对label进行index化 263 | for (i, label) in enumerate(label_list, 1): 264 | label_map[label] = i 265 | # 保存label->index 的map 266 | if not os.path.exists(os.path.join(FLAGS.output_dir, 'label2id.pkl')): 267 | with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'wb') as w: 268 | pickle.dump(label_map, w) 269 | 270 | tokens = example 271 | # tokens = tokenizer.tokenize(example.text) 272 | # 序列截断 273 | if len(tokens) >= max_seq_length - 1: 274 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 275 | ntokens = [] 276 | segment_ids = [] 277 | label_ids = [] 278 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 279 | segment_ids.append(0) 280 | # append("O") or append("[CLS]") not sure! 281 | label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 282 | for i, token in enumerate(tokens): 283 | ntokens.append(token) 284 | segment_ids.append(0) 285 | label_ids.append(0) 286 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志 287 | segment_ids.append(0) 288 | # append("O") or append("[SEP]") not sure! 289 | label_ids.append(label_map["[SEP]"]) 290 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 291 | input_mask = [1] * len(input_ids) 292 | 293 | # padding, 使用 294 | while len(input_ids) < max_seq_length: 295 | input_ids.append(0) 296 | input_mask.append(0) 297 | segment_ids.append(0) 298 | # we don't concerned about it! 299 | label_ids.append(0) 300 | ntokens.append("**NULL**") 301 | # label_mask.append(0) 302 | # print(len(input_ids)) 303 | assert len(input_ids) == max_seq_length 304 | assert len(input_mask) == max_seq_length 305 | assert len(segment_ids) == max_seq_length 306 | assert len(label_ids) == max_seq_length 307 | # assert len(label_mask) == max_seq_length 308 | 309 | # 结构化为一个类 310 | feature = InputFeatures( 311 | input_ids=input_ids, 312 | input_mask=input_mask, 313 | segment_ids=segment_ids, 314 | label_ids=label_ids, 315 | # label_mask = label_mask 316 | ) 317 | return feature 318 | 319 | 320 | class Pair(object): 321 | def __init__(self, word, start, end, type, merge=False): 322 | self.__word = word 323 | self.__start = start 324 | self.__end = end 325 | self.__merge = merge 326 | self.__types = type 327 | 328 | @property 329 | def start(self): 330 | return self.__start 331 | @property 332 | def end(self): 333 | return self.__end 334 | @property 335 | def merge(self): 336 | return self.__merge 337 | @property 338 | def word(self): 339 | return self.__word 340 | 341 | @property 342 | def types(self): 343 | return self.__types 344 | @word.setter 345 | def word(self, word): 346 | self.__word = word 347 | @start.setter 348 | def start(self, start): 349 | self.__start = start 350 | @end.setter 351 | def end(self, end): 352 | self.__end = end 353 | @merge.setter 354 | def merge(self, merge): 355 | self.__merge = merge 356 | 357 | @types.setter 358 | def types(self, type): 359 | self.__types = type 360 | 361 | def __str__(self) -> str: 362 | line = [] 363 | line.append('entity:{}'.format(self.__word)) 364 | line.append('start:{}'.format(self.__start)) 365 | line.append('end:{}'.format(self.__end)) 366 | line.append('merge:{}'.format(self.__merge)) 367 | line.append('types:{}'.format(self.__types)) 368 | return '\t'.join(line) 369 | 370 | 371 | class Result(object): 372 | def __init__(self, config): 373 | self.config = config 374 | self.person = [] 375 | self.loc = [] 376 | self.org = [] 377 | self.others = [] 378 | def get_result(self, tokens, tags, config=None): 379 | # 先获取标注结果 380 | self.result_to_json(tokens, tags) 381 | return self.person, self.loc, self.org 382 | 383 | def result_to_json(self, string, tags): 384 | """ 385 | 将模型标注序列和输入序列结合 转化为结果 386 | :param string: 输入序列 387 | :param tags: 标注结果 388 | :return: 389 | """ 390 | item = {"entities": []} 391 | entity_name = "" 392 | entity_start = 0 393 | idx = 0 394 | last_tag = '' 395 | 396 | for char, tag in zip(string, tags): 397 | if tag[0] == "S": 398 | self.append(char, idx, idx+1, tag[2:]) 399 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]}) 400 | elif tag[0] == "B": 401 | if entity_name != '': 402 | self.append(entity_name, entity_start, idx, last_tag[2:]) 403 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 404 | entity_name = "" 405 | entity_name += char 406 | entity_start = idx 407 | elif tag[0] == "I": 408 | entity_name += char 409 | elif tag[0] == "O": 410 | if entity_name != '': 411 | self.append(entity_name, entity_start, idx, last_tag[2:]) 412 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 413 | entity_name = "" 414 | else: 415 | entity_name = "" 416 | entity_start = idx 417 | idx += 1 418 | last_tag = tag 419 | if entity_name != '': 420 | self.append(entity_name, entity_start, idx, last_tag[2:]) 421 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 422 | return item 423 | 424 | def append(self, word, start, end, tag): 425 | if tag == 'LOC': 426 | self.loc.append(Pair(word, start, end, 'LOC')) 427 | elif tag == 'PER': 428 | self.person.append(Pair(word, start, end, 'PER')) 429 | elif tag == 'ORG': 430 | self.org.append(Pair(word, start, end, 'ORG')) 431 | else: 432 | self.others.append(Pair(word, start, end, tag)) 433 | 434 | 435 | if __name__ == "__main__": 436 | if FLAGS.do_predict_outline: 437 | predict_outline() 438 | if FLAGS.do_predict_online: 439 | predict_online() 440 | 441 | 442 | -------------------------------------------------------------------------------- /tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 多分类评测 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() --------------------------------------------------------------------------------