├── Config ├── NER │ └── README.md └── SIM │ └── README.md ├── Data ├── DB_Data │ └── README.md ├── NER_Data │ └── README.md ├── NLPCC2016KBQA │ ├── nlpcc-iccpol-2016.kbqa.kb │ ├── nlpcc-iccpol-2016.kbqa.testing-data │ └── nlpcc-iccpol-2016.kbqa.training-data ├── Sim_Data │ └── README.md ├── construct_dataset.py ├── construct_dataset_attribute.py ├── load_dbdata.py └── triple_clean.py ├── LICENSE ├── ModelParams └── README.md ├── Output └── README.md ├── README.md ├── args.py ├── bert ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── modeling.cpython-36.pyc │ ├── optimization.cpython-36.pyc │ └── tokenization.cpython-36.pyc ├── create_pretraining_data.py ├── extract_features.py ├── modeling.py ├── modeling_test.py ├── multilingual.md ├── optimization.py ├── optimization_test.py ├── requirements.txt ├── run_classifier.py ├── run_pretraining.py ├── run_squad.py ├── sample_text.txt ├── tokenization.py └── tokenization_test.py ├── conlleval.pl ├── conlleval.py ├── global_config.py ├── image ├── KB.png └── NER.jpg ├── kbqa_test.py ├── lstm_crf_layer.py ├── run_ner.py ├── run_ner.sh ├── run_similarity.py ├── terminal_ner.sh ├── terminal_predict.py └── tf_metrics.py /Config/NER/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Config/SIM/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/DB_Data/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/NER_Data/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/NLPCC2016KBQA/nlpcc-iccpol-2016.kbqa.kb: -------------------------------------------------------------------------------- 1 | 空气干燥 ||| 别名 ||| 空气干燥 2 | 空气干燥 ||| 中文名 ||| 空气干燥 3 | 空气干燥 ||| 外文名 ||| air drying 4 | 空气干燥 ||| 形式 ||| 两个 5 | 空气干燥 ||| 作用 ||| 将空气中的水份去除 6 | 罗育德 ||| 别名 ||| 罗育德 7 | 罗育德 ||| 中文名 ||| 罗育德 8 | 罗育德 ||| 民族 ||| 汉族 9 | 罗育德 ||| 出生地 ||| 河南郑州 10 | 罗育德 ||| 出生日期 ||| 1971年6月 11 | 罗育德 ||| 职业 ||| 中共深圳市南山区委常委、区政府常务副区长 12 | 罗育德 ||| 毕业院校 ||| 深圳大学 13 | 罗育德 ||| 民 族 ||| 汉族 14 | 罗育德 ||| 职 业 ||| 中共深圳市南山区委常委、区政府常务副区长 15 | 鳞 ||| 别名 ||| 鳞 16 | 鳞 ||| 中文名 ||| 鳞 17 | 鳞 ||| 拼音 ||| lín 18 | 鳞 ||| 英文名 ||| Squama 19 | 鳞 ||| 汉字笔划 ||| 20 20 | 鳞 ||| 名称 ||| 鳞 21 | 鳞 ||| 繁体 ||| 鱗 22 | 鳞 ||| 笔画 ||| 20 23 | 鳞 ||| 部首 ||| 鱼 24 | 鳞 ||| 释义 ||| 鳞 (形声。从鱼,粦声。本义:鱼甲) 同本义 鳞,鱼甲也。――《说文》 鳞罗布烈。――扬雄《羽猎赋》 25 | 浙江村 ||| 别名 ||| 浙江村 26 | 浙江村 ||| 中文名 ||| 浙江村 27 | 浙江村 ||| 主要构成 ||| 温州人 28 | 浙江村 ||| 位置 ||| 北京丰台区大红门、木樨园地区 29 | 浙江村 ||| 辐射范围 ||| 周围5公里区域 30 | 浙江村 ||| 中心 ||| 木樨园桥环岛 31 | 浙江村 ||| 从事 ||| 服装生产批发、五金电器等 32 | 于明诠 ||| 别名 ||| 于明诠 33 | 于明诠 ||| 中文名 ||| 于明诠 34 | 于明诠 ||| 国籍 ||| 中国 35 | 于明诠 ||| 出生地 ||| 山东乐陵 36 | 于明诠 ||| 出生日期 ||| 1963年生 37 | 于明诠 ||| 职业 ||| 书法家 38 | 于明诠 ||| 主要成就 ||| 获"全国百名优秀青年文艺家\ 39 | 于明诠 ||| 代表作品 ||| 《书法的传统与传统的书法》 40 | 金刚部 ||| 别名 ||| 金刚部 41 | 金刚部 ||| 中文名 ||| 金刚部 42 | 金刚部 ||| 部主 ||| 阿閦如来 43 | 金刚部 ||| 含义 ||| 能摧破诸怨敌 44 | 金刚部 ||| 部主金刚界 ||| 阿閦佛 45 | 钟飙 ||| 别名 ||| 钟飙 46 | 钟飙 ||| 中文名 ||| 钟飙 47 | 钟飙 ||| 出生地 ||| 中国重庆 48 | 钟飙 ||| 出生日期 ||| 1968年 49 | 钟飙 ||| 主要成就 ||| 以独特的构图方式和切入当代艺术的视角著称 50 | 钟飙 ||| 性别 ||| 男 51 | 钟飙 ||| 国籍 ||| 中国 52 | 钟飙 ||| 出生年月 ||| 1968年 53 | 钟飙 ||| 职业 ||| 艺术家 54 | 钟飙 ||| 毕业院校 ||| 中国美术学院 55 | 葵涌中学 ||| 别名 ||| 葵涌中学 56 | 葵涌中学 ||| 中文名 ||| 葵涌中学 57 | 葵涌中学 ||| 创建时间 ||| 1969年 58 | 葵涌中学 ||| 占地面积 ||| 43547.26㎡ 59 | 葵涌中学 ||| 绿化占地 ||| 15865 ㎡ 60 | 汪尧田 ||| 别名 ||| 汪尧田 61 | 汪尧田 ||| 中文名 ||| 汪尧田 62 | 汪尧田 ||| 国籍 ||| 中国 63 | 汪尧田 ||| 民族 ||| 汉族 64 | 汪尧田 ||| 出生地 ||| 不详 65 | 愈裂霜 ||| 别名 ||| 愈裂霜 66 | 愈裂霜 ||| 中文名 ||| 愈裂霜 67 | 愈裂霜 ||| 药剂类型 ||| 软膏制剂 68 | 愈裂霜 ||| 主治 ||| 真菌引起的手气、脚气等 69 | 盖盖虫 ||| 别名 ||| 盖盖虫 70 | 盖盖虫 ||| 中文名 ||| 盖盖虫 71 | 盖盖虫 ||| 日文名 ||| カブルモ 72 | 盖盖虫 ||| 英文名 ||| Karrablast 73 | 盖盖虫 ||| 全国编号 ||| 588 74 | 盖盖虫 ||| 属性 ||| 虫 75 | 盖盖虫 ||| 种类 ||| 咬住神奇宝贝 76 | 盖盖虫 ||| 身高 ||| 0.5m 77 | 盖盖虫 ||| 体重 ||| 5.9kg 78 | 李冀川(成都电视台《道听途说》节目主持人) ||| 别名 ||| 李冀川 79 | 李冀川(成都电视台《道听途说》节目主持人) ||| 中文名 ||| 李冀川 80 | 李冀川(成都电视台《道听途说》节目主持人) ||| 出生日期 ||| 1981.12 81 | 李冀川(成都电视台《道听途说》节目主持人) ||| 籍贯 ||| 成都市温江区 82 | 李冀川(成都电视台《道听途说》节目主持人) ||| 学历 ||| 本科 83 | 唐本 ||| 别名 ||| 唐本 84 | 唐本 ||| 中文名 ||| 唐本 85 | 唐本 ||| 国籍 ||| 中国 86 | 唐本 ||| 民族 ||| 汉族 87 | 唐本 ||| 出生地 ||| 浙江省兰溪 88 | 诺水河镇 ||| 别名 ||| 诺水河镇 89 | 诺水河镇 ||| 中文名称 ||| 诺水河镇 90 | 诺水河镇 ||| 地理位置 ||| 东经107°27″,北纬32°16″——32°33″ 91 | 诺水河镇 ||| 面积 ||| 308.38平方公里 92 | 诺水河镇 ||| 人口 ||| 21114人 93 | 诺水河镇 ||| 中文名 ||| 诺水河镇 94 | 诺水河镇 ||| 所属国 ||| 中国 95 | 诺水河镇 ||| 所属市 ||| 巴山市 96 | 诺水河镇 ||| 诺水河镇 ||| 中华人民共和国 97 | 诺水河镇 ||| 上级行政区 ||| 通江县 98 | 诺水河镇 ||| 行政区类型 ||| 镇 99 | 诺水河镇 ||| 行政区划代码 ||| 511921113 100 | 诺水河镇 ||| 村级区划单位数 ||| 25 101 | 诺水河镇 ||| - 社区数 ||| 1 102 | 诺水河镇 ||| - 行政村数 ||| 24 103 | 诺水河镇 ||| 地理、人口、经济 ||| 32°23′19″N 107°10′39″E / 32.38857°N 107.17753°E坐标:32°23′19″N 107°10′39″E / 32.38857°N 107.17753°E 104 | 诺水河镇 ||| 电话区号 ||| +86 105 | 水冷 ||| 别名 ||| 水冷 106 | 水冷 ||| 中文名 ||| 水冷 107 | 水冷 ||| 产品 ||| 水冷散热器 108 | 水冷 ||| 分类 ||| 主动式水冷和被动式水冷 109 | 水冷 ||| 类别 ||| 散热器 110 | 水冷 ||| 材料 ||| 聚氯乙烯、聚乙烯和硅树脂三种 111 | 水冷 ||| 类似 ||| 风冷散热 112 | 水冷 ||| 分 类 ||| 主动式水冷和被动式水冷 113 | 水冷 ||| 材 料 ||| 聚氯乙烯、聚乙烯和硅树脂三种 114 | 水冷 ||| 产 品 ||| 水冷散热器 115 | 水冷 ||| 类 别 ||| 散热器 116 | 水冷 ||| 类 似 ||| 风冷散热 117 | 水冷机箱 ||| 别名 ||| 水冷机箱 118 | 水冷机箱 ||| 中文名 ||| 水冷机箱 119 | 水冷机箱 ||| 缺点 ||| 普遍体积过大,操作不够简单 120 | 水冷机箱 ||| 类型 ||| 电脑内部发热部件散热的一种装置 121 | 水冷机箱 ||| 功能 ||| 它包括水冷散热系统和防尘机箱 122 | 水冷机箱 ||| 英文名 ||| Water-cooled chassis 123 | 美少女战士R ||| 别名 ||| 美少女战士R 124 | 美少女战士R ||| 制作 ||| テレビ朝日 125 | 美少女战士R ||| 其他名称 ||| 美少女戦士 セーラームーンR 126 | 美少女战士R ||| 香港TVB首播 ||| 1995年8月26日~1996年1月21日 127 | 美少女战士R ||| 日本播放 ||| 1993年3月6日~1994年3月5日 128 | 美少女战士R ||| 香港TVB重播 ||| 2011年6月3日 129 | 美少女战士R ||| 导演 ||| 佐藤顺一、几原邦彦 130 | 美少女战士R ||| 中文名 ||| 美少女战士R 131 | 美少女战士R ||| 美少女戦士 セーラームーンR ||| 美少女戦士 セーラームーンR 132 | 美少女战士R ||| 日语假名 ||| びしょうじょせんし セーラームーンR 133 | 美少女战士R ||| 罗马字 ||| Bishōjo Senshi Sērā Mūn Aru 134 | 美少女战士R ||| 类型 ||| 魔法少女、爱情 135 | 美少女战士R ||| 作品原名 ||| 美少女戦士セーラームーンR 136 | 美少女战士R ||| 正式译名 ||| 美少女战士R 美少女战士第二部 137 | 美少女战士R ||| 原作 ||| 武内直子 138 | 美少女战士R ||| 企划 ||| 东伊里弥、太田贤司 139 | 美少女战士R ||| 剧本统筹 ||| 富田祐弘 140 | 美少女战士R ||| 编剧 ||| 富田祐弘等10人 141 | 美少女战士R ||| 人物设定 ||| 只野和子 142 | 美少女战士R ||| 动画制作 ||| 东映动画 143 | 美少女战士R ||| 播放电视台 ||| 朝日电视台 无线电视翡翠台 中华电视公司 无线电视J2台 144 | 美少女战士R ||| 播放期间 ||| 1993年3月6日-1994年3月5日 145 | 美少女战士R ||| 话数 ||| 全43话 146 | 美少女战士R ||| 版权信息 ||| ©东映动画 147 | 马丁·泰勒 ||| 别名 ||| 马丁·泰勒 148 | 马丁·泰勒 ||| 外文名 ||| MartinTaylor 149 | 马丁·泰勒 ||| 国籍 ||| 英格兰 150 | 马丁·泰勒 ||| 出生日期 ||| 1979年11月9日 151 | 马丁·泰勒 ||| 身高 ||| 193cm 152 | 马丁·泰勒 ||| 体重 ||| 89Kg 153 | 马丁·泰勒 ||| 运动项目 ||| 足球 154 | 马丁·泰勒 ||| 所属运动队 ||| 沃特福德足球俱乐部 155 | 马丁·泰勒 ||| 中文名 ||| 马丁。泰勒 156 | 马丁·泰勒 ||| 专业特点 ||| 后卫 157 | 马丁·泰勒 ||| 英文名 ||| Martin Taylor 158 | 马丁·泰勒 ||| 籍贯 ||| 英格兰阿兴顿 159 | 马丁·泰勒 ||| 性别 ||| 男 160 | 马丁·泰勒 ||| 出生年月 ||| 1979年11月9日 161 | 马丁·泰勒 ||| 马丁·泰勒(2008年摄于老特拉福德球场) ||| 马丁·泰勒(2008年摄于老特拉福德球场) 162 | 马丁·泰勒 ||| 出生 ||| 1945年09月14日(69岁) 切斯特, 柴郡, 英格兰 163 | 马丁·泰勒 ||| 教育程度 ||| 皇家文理学校(吉尔福德)(英语:Royal Grammar School, Guildford) 东英吉利亚大学 164 | 马丁·泰勒 ||| 职业 ||| 足球评论员 165 | 马丁·泰勒 ||| 雇主 ||| 天空电视台, ESPN, TSN(英语:TSN) 166 | 马丁·泰勒 ||| 个人资料 ||| Martin Taylor[1] 167 | 马丁·泰勒 ||| 出生地点 ||| 英格兰阿兴顿 168 | 马丁·泰勒 ||| 位置 ||| 后卫 169 | 马丁·泰勒 ||| 俱乐部资料 ||| 谢周三 170 | 马丁·泰勒 ||| 球衣号码 ||| 5 171 | 马丁·泰勒 ||| 青年队 ||| 克拉姆灵顿少年队 172 | 马丁·泰勒 ||| 职业俱乐部* ||| 出场 173 | 马丁·泰勒 ||| 1997–2004 ||| 88 174 | 马丁·泰勒 ||| 2000 ||| 4 175 | 马丁·泰勒 ||| 2004–2010 ||| 99 176 | 马丁·泰勒 ||| 2007 ||| 8 177 | 马丁·泰勒 ||| 2010–2012 ||| 90 178 | 马丁·泰勒 ||| 2012– ||| 10 179 | 马丁·泰勒 ||| 国家队 ||| 1 180 | 入射线 ||| 别名 ||| 入射线 181 | 入射线 ||| 中文名 ||| 入射线 182 | 入射线 ||| 学科 ||| 物理光学 183 | 入射线 ||| 类别 ||| 射线 184 | 入射线 ||| 相对 ||| 反射线 185 | 万家灯火(林兆华李六乙导演话剧) ||| 别名 ||| 万家灯火 186 | 万家灯火(林兆华李六乙导演话剧) ||| 中文名 ||| 万家灯火 187 | 万家灯火(林兆华李六乙导演话剧) ||| 导演 ||| 林兆华、李六乙 188 | 万家灯火(林兆华李六乙导演话剧) ||| 编剧 ||| 李云龙 189 | 万家灯火(林兆华李六乙导演话剧) ||| 制作 ||| 北京人民艺术剧院 190 | 万家灯火(林兆华李六乙导演话剧) ||| 上映时间 ||| 2002年 191 | 万家灯火(林兆华李六乙导演话剧) ||| 主演 ||| 宋丹丹、濮存昕、米铁增、何冰 192 | 紫屋魔恋 ||| 别名 ||| 紫屋魔恋 193 | 紫屋魔恋 ||| 中文名 ||| 紫屋魔恋 194 | 紫屋魔恋 ||| 制片人 ||| 彼得.古伯 195 | 紫屋魔恋 ||| 主演 ||| 杰克.尼科尔森 196 | 紫屋魔恋 ||| 片长 ||| 121分钟 197 | 美丽的日子(王心凌演唱专辑) ||| 别名 ||| 美丽的日子 198 | 美丽的日子(王心凌演唱专辑) ||| 中文名 ||| 美丽的日子 199 | 美丽的日子(王心凌演唱专辑) ||| 发行时间 ||| 2009年11月13日 200 | 美丽的日子(王心凌演唱专辑) ||| 地区 ||| 台湾 201 | 美丽的日子(王心凌演唱专辑) ||| 语言 ||| 普通话 202 | 美丽的日子(王心凌演唱专辑) ||| 歌手 ||| 王心凌 203 | 美丽的日子(王心凌演唱专辑) ||| 音乐风格 ||| 流行 204 | 埃尔文·约翰逊 ||| 别名 ||| 埃尔文·约翰逊 205 | 埃尔文·约翰逊 ||| 外文名 ||| EarvinJohnson 206 | 埃尔文·约翰逊 ||| 国籍 ||| 美国 207 | 埃尔文·约翰逊 ||| 出生日期 ||| 1959年08月14日 208 | 埃尔文·约翰逊 ||| 身高 ||| 2.06米/6尺9寸 209 | 埃尔文·约翰逊 ||| 体重 ||| 98公斤/215磅 210 | 埃尔文·约翰逊 ||| 运动项目 ||| 篮球 211 | 埃尔文·约翰逊 ||| 所属运动队 ||| 湖人队 212 | 埃尔文·约翰逊 ||| 中文名 ||| 埃尔文·约翰逊 213 | 埃尔文·约翰逊 ||| 出生地 ||| 密歇根 214 | 埃尔文·约翰逊 ||| 主要奖项 ||| 五届NBA总冠军 三届NBA最有价值球员奖 三届NBA总决赛最有价值球员奖 二次NBA全明星赛最有价值球员奖 215 | 埃尔文·约翰逊 ||| 粤语地区译名 ||| 埃尔文·莊逊 216 | 埃尔文·约翰逊 ||| 英文名 ||| ErvinJohnson 217 | 埃尔文·约翰逊 ||| 性别 ||| 男 218 | 埃尔文·约翰逊 ||| 出生年月 ||| 1959年8月14日 219 | 埃尔文·约翰逊 ||| 职业 ||| 篮球 220 | 埃尔文·约翰逊 ||| 体 重 ||| 98公斤/215磅 221 | 埃尔文·约翰逊 ||| 别 名 ||| Magic Johnson/魔术师-约翰逊 222 | 埃尔文·约翰逊 ||| 国 籍 ||| 美国 223 | 河北外国语职业学院 ||| 别名 ||| 河北外国语职业学院 224 | 河北外国语职业学院 ||| 中文名 ||| 河北外国语职业学院 225 | 河北外国语职业学院 ||| 英文名 ||| 英文:Hebei Vocational College of Foreign Languages韩语:하북외국어전문대학은 226 | 河北外国语职业学院 ||| 简称 ||| 河北外国语职院 227 | 河北外国语职业学院 ||| 创办时间 ||| 1948年 228 | 河北外国语职业学院 ||| 类别 ||| 公立大学 229 | 河北外国语职业学院 ||| 学校类型 ||| 语言 230 | 河北外国语职业学院 ||| 所属地区 ||| 中国秦皇岛 231 | 河北外国语职业学院 ||| 现任校长 ||| 丁国声 232 | 河北外国语职业学院 ||| 主管部门 ||| 河北省教育厅 233 | 河北外国语职业学院 ||| 校训 ||| 德以立校 学以弘业 234 | 河北外国语职业学院 ||| 主要院系 ||| 英语系、教育系、国际商务系、酒店航空系、西语系等 235 | 河北外国语职业学院 ||| 外文名 ||| 英文:Hebei Vocational College of Foreign Languages韩语:하북외국어전문대학은 236 | 河北外国语职业学院 ||| 中文名称 ||| 河北外国语职业学院 237 | 河北外国语职业学院 ||| 校址 ||| 学院地址:河北省秦皇岛市南戴河前进路6号 238 | 河北外国语职业学院 ||| 邮编 ||| 066311 239 | 河北外国语职业学院 ||| 外文名称 ||| Hebei Vocational College of Foreign Languages 240 | 河北外国语职业学院 ||| 学校主页 ||| http://www.hbvcfl.com.cn/ 241 | 徐峥 ||| 别名 ||| 徐峥 242 | 徐峥 ||| 中文名 ||| 徐峥[13] 243 | 徐峥 ||| 外文名 ||| Xú Zhēng 244 | 徐峥 ||| 国籍 ||| 中国 245 | 徐峥 ||| 民族 ||| 汉族 246 | 徐峥 ||| 星座 ||| 白羊座 247 | 徐峥 ||| 血型 ||| A型 248 | 徐峥 ||| 身高 ||| 178cm 249 | 徐峥 ||| 体重 ||| 72kg 250 | 徐峥 ||| 出生地 ||| 上海市闸北区 251 | 徐峥 ||| 出生日期 ||| 1972年4月18日 252 | 徐峥 ||| 职业 ||| 演员、导演 253 | 徐峥 ||| 毕业院校 ||| 上海戏剧学院 254 | 徐峥 ||| 代表作品 ||| 人再囧途之泰囧、无人区、人在囧途、爱情呼叫转移、春光灿烂猪八戒、李卫当官 255 | 徐峥 ||| 主要成就 ||| 中国话剧金狮奖演员奖 北京大学生电影节最受欢迎导演 电影频道传媒大奖最受关注男演员 上海国际电影节最卖座电影 安徽卫视国剧盛典年度最佳角色 展开 256 | 徐峥 ||| 妻子 ||| 陶虹 257 | 徐峥 ||| 女儿 ||| 徐小宝 258 | 徐峥 ||| 籍贯 ||| 上海 259 | 徐峥 ||| 性别 ||| 男 260 | 徐峥 ||| 出生年月 ||| 1972年4月18日 261 | 徐峥 ||| 语言 ||| 普通话、吴语、英语 262 | 徐峥 ||| 祖籍 ||| 浙江省嘉兴市平湖市曹桥街道百寿村沈家门 263 | 徐峥 ||| 导演代表作 ||| 人再囧途之泰囧 264 | 徐峥 ||| 本名 ||| 徐峥 265 | 徐峥 ||| 出生 ||| 1972年4月18日(42岁)  中国上海市 266 | 徐峥 ||| 教育程度 ||| 上海戏剧学院90级表演系本科毕业 267 | 徐峥 ||| 配偶 ||| 陶虹(2003年-) 268 | 徐峥 ||| 儿女 ||| 一女 269 | 徐峥 ||| 活跃年代 ||| 1997年- 270 | 徐峥 ||| 经纪公司 ||| 影响力明星经纪公司 271 | 徐峥 ||| 奖项 ||| 详见下文 272 | 徐峥 ||| 电影 ||| 《爱情呼叫转移》 《疯狂的石头》 《疯狂的赛车》 《人在囧途》 《人再囧途之泰囧》 273 | 徐峥 ||| 电视剧 ||| 《春光灿烂猪八戒》 《李卫当官》 《大男当婚》 274 | 计算机应用基础 ||| 别名 ||| 计算机应用基础 275 | 计算机应用基础 ||| 中文名 ||| 计算机应用基础 276 | 计算机应用基础 ||| 作者 ||| 刘晓斌、魏智荣、刘庆生 277 | 计算机应用基础 ||| 类别 ||| 计算机/网络 > 计算机理论 278 | 计算机应用基础 ||| 字数 ||| 445000 279 | 计算机应用基础 ||| ISBN ||| 9787122177513 280 | 计算机应用基础 ||| 出版社 ||| 化学工业出版社 281 | 计算机应用基础 ||| 页数 ||| 255 282 | 计算机应用基础 ||| 开本 ||| 16开 283 | 计算机应用基础 ||| 出版时间 ||| 2013年 284 | 计算机应用基础 ||| 装帧 ||| 平装 285 | 计算机应用基础 ||| 原作者 ||| 刘升贵,黄敏,庄强兵 286 | 计算机应用基础 ||| 定价 ||| 29.00元 287 | 登山临水 ||| 别名 ||| 登山临水 288 | 登山临水 ||| 名称 ||| 登山临水 289 | 登山临水 ||| 拼音 ||| dēng shān lín shuǐ 290 | 登山临水 ||| 出处 ||| 战国·楚·宋玉《九辩》:“登山临水兮送将归。” 291 | 登山临水 ||| 释义 ||| 形容旅途遥远。也指游山玩水。 292 | 登山临水 ||| 用法 ||| 褒义 谓语 293 | 登山临水 ||| 结构 ||| 联合式 294 | 白驹过隙 ||| 别名 ||| 白驹过隙 295 | 白驹过隙 ||| 名称 ||| 白驹过隙 296 | 石川数正 ||| 别名 ||| 石川数正 297 | 石川数正 ||| 中文名 ||| 石川数正 298 | 石川数正 ||| 外文名 ||| いしかわ かずまさ 299 | 石川数正 ||| 国籍 ||| 日本 300 | 石川数正 ||| 民族 ||| 大和民族 301 | 石川数正 ||| 出生地 ||| 大阪府 302 | 石川数正 ||| 出生日期 ||| 1533年 303 | 石川数正 ||| 逝世日期 ||| 1592年 304 | 石川数正 ||| 职业 ||| 家臣 305 | 石川数正 ||| 信仰 ||| 神道教 306 | 石川数正 ||| 主要成就 ||| 德川二重臣之一 307 | 石川数正 ||| 信 仰 ||| 神道教 308 | 石川数正 ||| 职 业 ||| 家臣 309 | 石川数正 ||| 国 籍 ||| 日本 310 | 石川数正 ||| 别 名 ||| 石川伯耆守 311 | 石川数正 ||| 民 族 ||| 大和民族 312 | 石川数正 ||| 日语写法 ||| 石川 数正 313 | 石川数正 ||| 假名 ||| いしかわ かずまさ 314 | 石川数正 ||| 平文式罗马字 ||| Ishikawa Kazumasa 315 | 净莲妖圣 ||| 别名 ||| 净莲妖圣 316 | 净莲妖圣 ||| 中文名 ||| 净莲妖圣 317 | 净莲妖圣 ||| 职业 ||| 斗圣巅峰 318 | 净莲妖圣 ||| 代表作品 ||| 斗破苍穹 319 | 净莲妖圣 ||| 异火 ||| 净莲妖火 320 | 对句作起法 ||| 别名 ||| 对句作起法 321 | 对句作起法 ||| 中文名 ||| 对句作起法 322 | 对句作起法 ||| 类属 ||| 诗歌技巧之一 323 | 对句作起法 ||| 人物 ||| 刘铁冷 324 | 对句作起法 ||| 著作 ||| 《作诗百法》 325 | 花右京女仆队 ||| 别名 ||| 花右京女仆队 326 | 花右京女仆队 ||| 中文名 ||| 花右京女仆队 327 | 花右京女仆队 ||| 原版名称 ||| 花右京メイド队 328 | 花右京女仆队 ||| 其他名称 ||| 花右京女佣队 329 | 花右京女仆队 ||| 作者 ||| 森繁(もりしげ) 330 | 花右京女仆队 ||| 类型 ||| 恋爱 331 | 花右京女仆队 ||| 地区 ||| 日本 332 | 花右京女仆队 ||| 连载杂志 ||| 少年champion 333 | 花右京女仆队 ||| 连载期间 ||| 1999年4月-2006年9月 334 | 花右京女仆队 ||| 出版社 ||| 秋田书店 335 | 花右京女仆队 ||| 单行本册数 ||| 全14卷 336 | 花右京女仆队 ||| 其他外文名 ||| 花右京メイド队 337 | 花右京女仆队 ||| 其他译名 ||| 花右京女佣队 338 | 花右京女仆队 ||| 出品时间 ||| 1999年4月-2006年9月 339 | 花右京女仆队 ||| 动画话数 ||| 全15话(3话电视没有播放)+12话 340 | 花右京女仆队 ||| 地 区 ||| 日本 341 | 花右京女仆队 ||| 外文名 ||| 花右京メイド队 342 | 花右京女仆队 ||| 播出时间 ||| 2001年(第1季)2004年(第2季) 343 | 花右京女仆队 ||| 作 者 ||| 森繁(もりしげ) 344 | 直销 ||| 别名 ||| 直销 345 | 直销 ||| 中文名 ||| 直销 346 | 直销 ||| 外文名 ||| Direct Selling 347 | 直销 ||| 归类 ||| 销售方式 348 | 直销 ||| 途径 ||| 直接向最终消费者推销产品 349 | 直销 ||| 别称 ||| 厂家直接销售 350 | 直销 ||| 形式 ||| 生产商文化 销售商文化 351 | 无尘衣 ||| 别名 ||| 无尘衣 352 | 无尘衣 ||| 中文名 ||| 无尘衣 353 | 无尘衣 ||| 类别 ||| 服装 354 | 无尘衣 ||| 国家 ||| 中国 355 | 无尘衣 ||| 适用 ||| 男女通用 356 | 异度传说 ||| 别名 ||| 异度传说 357 | 异度传说 ||| 中文名 ||| 异度传说 358 | 异度传说 ||| 游戏类别 ||| 角色扮演 359 | 异度传说 ||| 开发商 ||| MONOLITHSOFT 360 | 异度传说 ||| 发行时间 ||| 2002年4月3日(第一部) 361 | 异度传说 ||| 外文名 ||| XenoSaga 362 | 异度传说 ||| 游戏平台 ||| PS2 363 | 异度传说 ||| 发行商 ||| NAMCO 364 | 异度传说 ||| 主要角色 ||| 卯月紫苑、KOS-MOS、基奇、M.O.M.O.、凯欧斯、卯月仁等 365 | 秋和 ||| 别名 ||| 秋和 366 | 秋和 ||| 中文名 ||| 秋和 367 | 秋和 ||| 国籍 ||| 中国 368 | 秋和 ||| 出生日期 ||| 11月11日 369 | 秋和 ||| 职业 ||| 作家 370 | 秋和 ||| 代表作品 ||| 《尘埃眠于光年》 371 | 哪里 ||| 别名 ||| 哪里 372 | 哪里 ||| 中文名 ||| 哪里 373 | 哪里 ||| 外文名 ||| Where 374 | 哪里 ||| 拼音 ||| nǎ lǐ 375 | 哪里 ||| 注音 ||| ㄣㄚˇ ㄌㄧˇ 376 | 哪里 ||| 同义词 ||| 那边 那里 那处 那儿 何处 377 | 哪里 ||| 英文名 ||| where 378 | 机械能守恒定律 ||| 别名 ||| 机械能守恒定律 379 | 机械能守恒定律 ||| 中文名 ||| 机械能守恒定律 380 | 机械能守恒定律 ||| 外文名 ||| law of conservation of mechanical energy 381 | 机械能守恒定律 ||| 所属领域 ||| 动力学 382 | 机械能守恒定律 ||| 基本公式 ||| △E机=E(末)-E(初)=0 383 | 机械能守恒定律 ||| 条件 ||| 无外界能量损失 384 | 机械能守恒定律 ||| 贡献者 ||| 焦耳、迈尔和亥姆霍兹 385 | 机械能守恒定律 ||| 应用学科 ||| 物理学 386 | 机械能守恒定律 ||| 条 件 ||| 无外界能量损失 387 | 折扣优惠 ||| 别名 ||| 折扣优惠 388 | 折扣优惠 ||| 中文名 ||| 折扣优惠 389 | 折扣优惠 ||| 注音 ||| zhē kòu yōu huì 390 | 折扣优惠 ||| 相关词 ||| 降价、促销 391 | 折扣优惠 ||| 体现形式 ||| 代金券、促销活动 392 | 周易科学观 ||| 别名 ||| 周易科学观 393 | 周易科学观 ||| 书名 ||| 周易科学观 394 | 周易科学观 ||| 作者 ||| 徐道一 395 | 周易科学观 ||| ISBN ||| 9787502804596 396 | 周易科学观 ||| 页数 ||| 32开 397 | 周易科学观 ||| 出版社 ||| 地震出版社北京发行部 398 | 周易科学观 ||| 出版时间 ||| 1992-5-1 399 | 周易科学观 ||| 开本 ||| 32开 400 | 东陵少主 ||| 别名 ||| 东陵少主 401 | 东陵少主 ||| 中文名 ||| 东陵少主 402 | 东陵少主 ||| 配音 ||| 黄文择 403 | 东陵少主 ||| 登场作品 ||| 霹雳布袋戏 404 | 东陵少主 ||| 生日 ||| 5.21-6.22(双子座,会刊第149期) 405 | 东陵少主 ||| 性别 ||| 男 406 | 东陵少主 ||| 身高 ||| 184cm 407 | 东陵少主 ||| 身份 ||| 五方主星之东青龙 408 | 东陵少主 ||| 初登场 ||| 霹雳英雄榜之争王记第1集 409 | 东陵少主 ||| 退场 ||| 霹雳图腾第16集 410 | 东陵少主 ||| 根据地 ||| 怀拥天地七步阶 411 | 东陵少主 ||| 人物创作者 ||| 罗陵 412 | 金匮要略 ||| 别名 ||| 金匮要略 413 | 金匮要略 ||| 书名 ||| 活解金匮要略 414 | 金匮要略 ||| 作者 ||| (汉)张仲景著,(宋)林亿校正,杨鹏举,侯仙明,杨延巍注释 415 | 金匮要略 ||| 类别 ||| 图书 医学 中医学 经典古籍 416 | 金匮要略 ||| 页数 ||| 148页 417 | 金匮要略 ||| 定价 ||| 14.00元 418 | 金匮要略 ||| 出版社 ||| 学苑出版社 419 | 金匮要略 ||| 出版时间 ||| 2008年1月 420 | 金匮要略 ||| 书 名 ||| 活解金匮要略 421 | 金匮要略 ||| 类 别 ||| 图书 医学 中医学 经典古籍 422 | 金匮要略 ||| 定 价 ||| 14.00元 423 | 金匮要略 ||| 作 者 ||| (汉)张仲景著,(宋)林亿校正,杨鹏举,侯仙明,杨延巍注释 424 | 金匮要略 ||| 页 数 ||| 148页 425 | 祖国网 ||| 别名 ||| 祖国网 426 | 祖国网 ||| 中文名 ||| 祖国网 427 | 祖国网 ||| 成立时间 ||| 2009年2月 428 | 祖国网 ||| 口号 ||| 爱国者的精神家园 429 | 祖国网 ||| 因故关闭 ||| 2011年 430 | 祖国网 ||| 覆盖方面 ||| 政、经、军、商、史 431 | 祖国网 ||| 编辑总部 ||| 西柏坡 432 | 村官(大学生村官) ||| 别名 ||| 村官 433 | 村官(大学生村官) ||| 中文名 ||| 村官 434 | 村官(大学生村官) ||| 职务 ||| 村党支部书记助理、村主任助理 435 | 村官(大学生村官) ||| 面向 ||| 应届高校本科及以上学历毕业生 436 | 村官(大学生村官) ||| 开始于 ||| 1995年 437 | 村官(大学生村官) ||| 名称 ||| 村官 438 | 村官(大学生村官) ||| 拼音 ||| cunguan 439 | 村官(大学生村官) ||| 英文 ||| village official 440 | 村官(大学生村官) ||| 笔画 ||| 15 441 | 舜玉路街道 ||| 别名 ||| 舜玉路街道 442 | 舜玉路街道 ||| 中文名称 ||| 舜玉路街道 443 | 舜玉路街道 ||| 行政区类别 ||| 街道 444 | 舜玉路街道 ||| 所属地区 ||| 济南市 445 | 舜玉路街道 ||| 电话区号 ||| 0531 446 | 舜玉路街道 ||| 面积 ||| 4.6平方公里 447 | 舜玉路街道 ||| 人口 ||| 6.1万人 448 | 舜玉路街道 ||| 舜玉路街道 ||| 中华人民共和国 449 | 舜玉路街道 ||| 上级行政区 ||| 市中区 450 | 舜玉路街道 ||| 行政区类型 ||| 街道 451 | 舜玉路街道 ||| 行政区划代码 ||| 370103012 452 | 舜玉路街道 ||| 村级区划单位数 ||| 24 453 | 舜玉路街道 ||| - 社区数 ||| 24 454 | 舜玉路街道 ||| - 行政村数 ||| 0 455 | 舜玉路街道 ||| 地理、人口、经济 ||| 36°37′33″N 117°00′27″E / 36.62588°N 117.00746°E坐标:36°37′33″N 117°00′27″E / 36.62588°N 117.00746°E 456 | 谍影重重之上海 ||| 别名 ||| 谍影重重之上海 457 | 谍影重重之上海 ||| 中文名 ||| 《谍影重重之上海》 458 | 谍影重重之上海 ||| 出品公司 ||| 新力量影视 459 | 谍影重重之上海 ||| 导演 ||| 黄文利 460 | 谍影重重之上海 ||| 主演 ||| 李光洁, 秦海璐, 廖凡 461 | 谍影重重之上海 ||| 类型 ||| 谍战 462 | 谍影重重之上海 ||| 出品时间 ||| 2007年 463 | 谍影重重之上海 ||| 制片地区 ||| 中国大陆 464 | 谍影重重之上海 ||| 编剧 ||| 霍昕 465 | 谍影重重之上海 ||| 集数 ||| 24集 466 | 谍影重重之上海 ||| 上映时间 ||| 2009年9月5日 467 | 游珍珠泉记 ||| 别名 ||| 游珍珠泉记 468 | 游珍珠泉记 ||| 作品名称 ||| 游珍珠泉记 469 | 游珍珠泉记 ||| 创作年代 ||| 清代 470 | 游珍珠泉记 ||| 文学体裁 ||| 散文作品原文 471 | 游珍珠泉记 ||| 作者 ||| 王昶 472 | 日本电视剧 ||| 别名 ||| 日本电视剧 473 | 日本电视剧 ||| 中文名 ||| 日本电视剧 474 | 日本电视剧 ||| 外文名 ||| テレビドラマ 475 | 日本电视剧 ||| 俗称 ||| 日剧、霓虹剧 476 | 日本电视剧 ||| 语言 ||| 日语 477 | 日本电视剧 ||| 奖项 ||| 日剧学院赏、银河赏 等 478 | 日本电视剧 ||| 剧长 ||| 多数10-13集 479 | 日本电视剧 ||| 片长 ||| 40分钟到60分钟不等 480 | 日本电视剧 ||| 发行语言 ||| 日语 481 | 日本电视剧 ||| 拍摄地点 ||| 日本 482 | 日本电视剧 ||| 颜色 ||| 彩色 483 | 日本电视剧 ||| 类型 ||| 冬季剧,春季剧,夏季剧,秋季剧 484 | 日本电视剧 ||| 其它译名 ||| 日剧 霓虹剧 485 | 日本电视剧 ||| 制片地区 ||| 日本 486 | 日本电视剧 ||| 时段 ||| 月九,木十,金十,土九 487 | 平安银行 ||| 别名 ||| 平安银行 488 | 平安银行 ||| 公司名称 ||| 平安银行股份有限公司 489 | 平安银行 ||| 外文名称 ||| Ping An Bank Co., Ltd. 490 | 平安银行 ||| 总部地点 ||| 中国深圳 491 | 平安银行 ||| 成立时间 ||| 1987年 492 | 平安银行 ||| 经营范围 ||| 提供一站式的综合金融产品与服务 493 | 平安银行 ||| 公司性质 ||| 上市公司 494 | 平安银行 ||| 公司口号 ||| 平安银行 真的不一样 495 | 平安银行 ||| 总资产 ||| 18270亿元人民币(2013年6月) 496 | 平安银行 ||| 员工数 ||| 32851人(2013年6月) 497 | 平安银行 ||| 战略思想 ||| 变革 创新 发展 498 | 平安银行 ||| 经营理念 ||| 对外以客户为中心 对内以人为本 499 | 平安银行 ||| 银行特色 ||| 专业化 集约化 综合金融 500 | 平安银行 ||| 董事长 ||| 孙建一 501 | 平安银行 ||| 行长 ||| 邵平 502 | 平安银行 ||| 中文名称 ||| 平安银行股份有限公司 503 | 平安银行 ||| 行 长 ||| 邵平 504 | 平安银行 ||| 公司类型 ||| 商业银行 505 | 平安银行 ||| 股票代号 ||| 深交所:000001 506 | 平安银行 ||| 成立 ||| 1995年6月22日 507 | 平安银行 ||| 代表人物 ||| 董事长:孙建一 副总经理: 行长:邵平 监事长:邱伟 508 | 平安银行 ||| 产业 ||| 银行 509 | 平安银行 ||| 产品 ||| 存款结算、理财、信用卡等 510 | 平安银行 ||| 净利润 ||| 1,636,029,604元人民币 511 | 平安银行 ||| 母公司 ||| 中国平安集团 49.57% 512 | 平安银行 ||| 网址 ||| 平安银行股份有限公司 513 | 墓地风水 ||| 别名 ||| 墓地风水 514 | 墓地风水 ||| 中文名 ||| 墓地风水 515 | 墓地风水 ||| 外文名 ||| The graveyard of feng shui 516 | 墓地风水 ||| 学术分类 ||| 堪舆地理学 517 | 墓地风水 ||| 代表作品 ||| 《葬书》 518 | 墓地风水 ||| 代表人物 ||| 晋代的郭璞 519 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 别名 ||| 揠苗助长 520 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 中文名 ||| 揠苗助长 521 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 制片地区 ||| 美国 522 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 制片人 ||| 米切尔-胡尔维茨 523 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 类型 ||| 喜剧 524 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 主演 ||| 杰森-贝特曼,迈克尔-布拉斯 525 | 揠苗助长(米切尔胡尔维茨导演美国电影) ||| 语言 ||| 英语 526 | 胸神经前支 ||| 别名 ||| 胸神经前支 527 | 胸神经前支 ||| 中文名 ||| 胸神经前支 528 | 胸神经前支 ||| 外文名 ||| anterior branch of thoracic nerves 529 | 胸神经前支 ||| 数量 ||| 12对 530 | 胸神经前支 ||| 分类 ||| 肋间神经,肋下神经 531 | 胸神经前支 ||| 分布特性 ||| 节段性 532 | 胸神经前支 ||| 检查部位 ||| 胸骨角、肋骨、剑突 533 | 闭环 ||| 别名 ||| 闭环 534 | 闭环 ||| 中文名 ||| 闭环 535 | 闭环 ||| 又叫 ||| 反馈控制系统 536 | 闭环 ||| 全称 ||| 闭环结构 537 | 闭环 ||| 原理 ||| 输出值尽量接近于期望值 538 | 闭环 ||| 应用 ||| 系统元件参数存在无法预计的变化 539 | 闭环 ||| 举例 ||| 调节水龙头、油路不畅流量下降 540 | HTC myTouch 4G Slide ||| 别名 ||| HTC myTouch 4G Slide 541 | HTC myTouch 4G Slide ||| 外文名 ||| HTC myTouch 4G Slide 542 | HTC myTouch 4G Slide ||| 网络模式 ||| GSM,WCDMA 543 | HTC myTouch 4G Slide ||| ROM容量 ||| 4GB 544 | HTC myTouch 4G Slide ||| RAM容量 ||| 768MB 545 | 三级甲等医院 ||| 别名 ||| 三级甲等医院 546 | 三级甲等医院 ||| 中文名 ||| 三级甲等医院 547 | 三级甲等医院 ||| 简称 ||| 三甲医院 548 | 三级甲等医院 ||| 类别 ||| 医院等级之一 549 | 三级甲等医院 ||| 三甲标准 ||| 病床数在501张以上、面向多地区 550 | 三级甲等医院 ||| 评选标准 ||| 按分等评分标准获得超过900分 551 | 三级甲等医院 ||| 级别 ||| 三甲医院为中国医院的最高级别 552 | 三级甲等医院 ||| 甲等标准 ||| 按分等评分标准获得超过900分 553 | 三级甲等医院 ||| 简 称 ||| 三甲医院 554 | 三级甲等医院 ||| 三级医院标准 ||| 病床数在501张以上、面向多地区 555 | 三级甲等医院 ||| 类 别 ||| 医院等级之一 556 | 不列颠尼亚(罗马帝国行省) ||| 别名 ||| 不列颠尼亚 557 | 不列颠尼亚(罗马帝国行省) ||| 中文名 ||| 不列颠尼亚 558 | 不列颠尼亚(罗马帝国行省) ||| 外文名 ||| Britannia 559 | 不列颠尼亚(罗马帝国行省) ||| 定义 ||| 不列颠岛古称;罗马帝国行省 560 | 白领 ||| 别名 ||| 白领 561 | 白领 ||| 中文名 ||| 白领 562 | 白领 ||| 英文名 ||| White-collar worker 563 | 归去来兮辞 ||| 别名 ||| 归去来兮辞 564 | 归去来兮辞 ||| 作品名称 ||| 归去来兮辞 -------------------------------------------------------------------------------- /Data/Sim_Data/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Data/construct_dataset.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | import os 4 | import pandas as pd 5 | 6 | 7 | ''' 8 | 构造NER训练集,实体序列标注,训练BERT+BiLSTM+CRF 9 | ''' 10 | # [training, testing] 11 | data_type = "training" 12 | file = "./NLPCC2016KBQA/nlpcc-iccpol-2016.kbqa."+data_type+"-data" 13 | question_str = "")[1].strip() 38 | q_str = q_str.split(">")[1].replace(" ","").strip() 39 | if entities in q_str: 40 | q_list = list(q_str) 41 | seq_q_list.extend(q_list) 42 | seq_q_list.extend([" "]) 43 | tag_list = ["O" for i in range(len(q_list))] 44 | tag_start_index = q_str.find(entities) 45 | for i in range(tag_start_index, tag_start_index+len(entities)): 46 | if tag_start_index == i: 47 | tag_list[i] = "B-LOC" 48 | else: 49 | tag_list[i] = "I-LOC" 50 | seq_tag_list.extend(tag_list) 51 | seq_tag_list.extend([" "]) 52 | else: 53 | pass 54 | q_t_a_list.append([q_str, t_str, a_str]) 55 | 56 | print('\t'.join(seq_tag_list[0:50])) 57 | print('\t'.join(seq_q_list[0:50])) 58 | seq_result = [str(q)+" "+tag for q, tag in zip(seq_q_list, seq_tag_list)] 59 | with open("./NER_Data/"+data_type+".txt", "w", encoding='utf-8') as f: 60 | f.write("\n".join(seq_result)) 61 | 62 | df = pd.DataFrame(q_t_a_list, columns=["q_str", "t_str", "a_str"]) 63 | df.to_csv("./NER_Data/q_t_a_df_"+data_type+".csv", encoding='utf-8', index=False) -------------------------------------------------------------------------------- /Data/construct_dataset_attribute.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | import os 4 | import random 5 | import pandas as pd 6 | 7 | 8 | ''' 9 | 构造属性关联训练集,分类问题,训练BERT分类模型 10 | 1 11 | ''' 12 | # [training, testing] 13 | data_type = "training" 14 | file = "nlpcc-iccpol-2016.kbqa."+data_type+"-data" 15 | target = "./NER_Data/q_t_a_df_"+data_type+".csv" 16 | 17 | attribute_classify_sample = [] 18 | 19 | # count the number of attribute 20 | testing_df = pd.read_csv(target, encoding='utf-8') 21 | testing_df['attribute'] = testing_df['t_str'].apply(lambda x: x.split('|||')[1].strip()) 22 | attribute_list = testing_df['attribute'].tolist() 23 | print(len(set(attribute_list))) 24 | print(testing_df.head()) 25 | 26 | 27 | # construct sample 28 | for row in testing_df.index: 29 | question, pos_att = testing_df.loc[row][['q_str', 'attribute']] 30 | question = question.strip() 31 | pos_att = pos_att.strip() 32 | # random.shuffle(attribute_list) the complex is big 33 | # neg_att_list = attribute_list[0:5] 34 | neg_att_list = random.sample(attribute_list, 5) 35 | attribute_classify_sample.append([question, pos_att, '1']) 36 | neg_att_sample = [[question, neg_att, '0'] for neg_att in neg_att_list if neg_att != pos_att] 37 | attribute_classify_sample.extend(neg_att_sample) 38 | 39 | seq_result = [str(lineno) + '\t' + '\t'.join(line) for (lineno, line) in enumerate(attribute_classify_sample)] 40 | 41 | if data_type == 'testing': 42 | with open("./Sim_Data/"+data_type+".txt", "w", encoding='utf-8') as f: 43 | f.write("\n".join(seq_result)) 44 | else: 45 | val_seq_result = seq_result[0:12000] 46 | with open("./Sim_Data/"+"val"+".txt", "w", encoding='utf-8') as f: 47 | f.write("\n".join(val_seq_result)) 48 | 49 | training_seq_result = seq_result[12000:] 50 | with open("./Sim_Data/"+data_type+".txt", "w", encoding='utf-8') as f: 51 | f.write("\n".join(training_seq_result)) 52 | -------------------------------------------------------------------------------- /Data/load_dbdata.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/4/18 20:47 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : load_dbdata.py 6 | # @Software: PyCharm 7 | 8 | 9 | import pymysql 10 | import pandas as pd 11 | from sqlalchemy import create_engine 12 | 13 | 14 | def create_db(): 15 | connect = pymysql.connect( # 连接数据库服务器 16 | user="root", 17 | password="123456", 18 | host="127.0.0.1", 19 | port=3306, 20 | db="KB_QA", 21 | charset="utf8" 22 | ) 23 | conn = connect.cursor() # 创建操作游标 24 | # 你需要一个游标 来实现对数据库的操作相当于一条线索 25 | 26 | # 创建表 27 | conn.execute("drop database if exists KB_QA") # 如果new_database数据库存在则删除 28 | conn.execute("create database KB_QA") # 新创建一个数据库 29 | conn.execute("use KB_QA") # 选择new_database这个数据库 30 | 31 | # sql 中的内容为创建一个名为new_table的表 32 | sql = """create table nlpccQA(entity VARCHAR(20) character set utf8 collate utf8_unicode_ci, 33 | attribute VARCHAR(20) character set utf8 collate utf8_unicode_ci, answer VARCHAR(20) character set utf8 34 | collate utf8_unicode_ci)""" # ()中的参数可以自行设置 35 | conn.execute("drop table if exists nlpccQA") # 如果表存在则删除 36 | conn.execute(sql) # 创建表 37 | 38 | # 删除 39 | # conn.execute("drop table new_table") 40 | 41 | conn.close() # 关闭游标连接 42 | connect.close() # 关闭数据库服务器连接 释放内存 43 | 44 | 45 | def loaddata(): 46 | # 初始化数据库连接,使用pymysql模块 47 | db_info = {'user': 'root', 48 | 'password': '123456', 49 | 'host': '127.0.0.1', 50 | 'port': 3306, 51 | 'database': 'KB_QA' 52 | } 53 | 54 | engine = create_engine( 55 | 'mysql+pymysql://%(user)s:%(password)s@%(host)s:%(port)d/%(database)s?charset=utf8' % db_info, encoding='utf-8') 56 | # 直接使用下一种形式也可以 57 | # engine = create_engine('mysql+pymysql://root:123456@localhost:3306/test') 58 | 59 | # 读取本地CSV文件 60 | df = pd.read_csv("./DB_Data/clean_triple.csv", sep=',', encoding='utf-8') 61 | print(df) 62 | # 将新建的DataFrame储存为MySQL中的数据表,不储存index列(index=False) 63 | # if_exists: 64 | # 1.fail:如果表存在,啥也不做 65 | # 2.replace:如果表存在,删了表,再建立一个新表,把数据插入 66 | # 3.append:如果表存在,把数据插入,如果表不存在创建一个表!! 67 | pd.io.sql.to_sql(df, 'nlpccQA', con=engine, index=False, if_exists='append', chunksize=10000) 68 | # df.to_sql('example', con=engine, if_exists='replace')这种形式也可以 69 | print("Write to MySQL successfully!") 70 | 71 | 72 | def upload_data(sql): 73 | connect = pymysql.connect( # 连接数据库服务器 74 | user="root", 75 | password="123456", 76 | host="127.0.0.1", 77 | port=3306, 78 | db="KB_QA", 79 | charset="utf8" 80 | ) 81 | cursor = connect.cursor() # 创建操作游标 82 | try: 83 | # 执行SQL语句 84 | cursor.execute(sql) 85 | # 获取所有记录列表 86 | results = cursor.fetchall() 87 | except Exception as e: 88 | print("Error: unable to fecth data: %s ,%s" % (repr(e), sql)) 89 | finally: 90 | # 关闭数据库连接 91 | cursor.close() 92 | connect.close() 93 | return results 94 | 95 | 96 | if __name__ == '__main__': 97 | create_db() 98 | loaddata() -------------------------------------------------------------------------------- /Data/triple_clean.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/4/18 20:16 3 | # @Author : Alan 4 | # @Email : xiezhengwen2013@163.com 5 | # @File : triple_clean.py 6 | # @Software: PyCharm 7 | 8 | 9 | import pandas as pd 10 | 11 | 12 | ''' 13 | 构造NER训练集,实体序列标注,训练BERT+BiLSTM+CRF 14 | ''' 15 | 16 | question_str = "")[1].strip() 38 | q_str = q_str.split(">")[1].replace(" ","").strip() 39 | if ''.join(entities.split(' ')) in q_str: 40 | clean_triple = t_str.split(">")[1].replace('\t','').replace(" ","").strip().split("|||") 41 | triple_list.append(clean_triple) 42 | else: 43 | print(entities) 44 | print(q_str) 45 | print('------------------------') 46 | 47 | df = pd.DataFrame(triple_list, columns=["entity", "attribute", "answer"]) 48 | print(df) 49 | print(df.info()) 50 | df.to_csv("./DB_Data/clean_triple.csv", encoding='utf-8', index=False) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 WenRichard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ModelParams/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /Output/README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KBQA-BERT 2 | ## 基于知识图谱的问答系统,BERT做命名实体识别和句子相似度,分为online和outline模式 3 | 4 | ## Introduction 5 | 本项目主要由两个重要的点组成,一是**基于BERT的命名实体识别**,二是**基于BERT的句子相似度计算**,本项目将这两个模块进行融合,构建基于BERT的KBQA问答系统,在命名实体识别上分为online predict和outline predict;在句子相似度上,也分为online predict和outline predict,2个模块互不干扰,做到了高内聚低耦合的效果,最后的kbqa相当于融合这2个模块进行outline predict,具体介绍请见[我的知乎专栏](https://zhuanlan.zhihu.com/p/62946533)! 6 | 7 | ### ------------------------------------------- 2019/6/15 更新 ---------------------------------------- 8 | **把过去一段时间同学们遇到的主要问题汇总一下,下面是一些FAQ:** 9 | 10 | **Q:** 运行run_ner.py时未找到dev.txt,请问这个文件是怎么生成的呢? 11 | **A:** 这一部分我记得当初是没有足够多的数据,我把生成的test.txt copy, 改成dev.txt了。 12 | 13 | **Q:** 你好,我下载了你的项目,但在运行run_ner的时候总是会卡在Saving checkpoint 0 to....这里,请问是什么原因呢? 14 | **A:** ner部分是存在一些问题,我也没有解决,但是我没有遇到这种情况。微调bert大概需要12GB左右的显存,大家可以把batch_size和max_length调小一点,说不定会解决这个问题!。 15 | 16 | **Q:** 该项目有没有相应的论文呢? 17 | **A:** 回答是肯定的,有的,送上 [**论文传送门**!](http://www.cnki.com.cn/Article/CJFDTotal-DLXZ201705041.htm) 18 | 19 | **Q:** 数据下载失败,不满足现有数据? 20 | **A:** 数据在Data中,更多的数据在[**NLPCC2016**](http://tcci.ccf.org.cn/conference/2016/pages/page05_evadata.html) 和[**NLPCC2017**](http://tcci.ccf.org.cn/conference/2017/taskdata.php)。 21 | 22 | **PS:这个项目有很多需要提高的地方,如果大家有好点子,欢迎pull,感谢!这段时间发论文找工作比较忙,邮件和issue没有及时回复望见谅!** 23 | ### ------------------------------------------- 2019/6/15 更新 ---------------------------------------- 24 | ### 环境配置 25 | 26 | Python版本为3.6 27 | tensorflow版本为1.13 28 | XAMPP版本为3.3.2 29 | Navicat Premium12 30 | 31 | ### 目录说明 32 | 33 | bert文件夹是google官方下载的 34 | Data文件夹存放原始数据和处理好的数据 35 | construct_dataset.py 生成NER_Data的数据 36 | construct_dataset_attribute.py 生成Sim_Data的数据 37 | triple_clean.py 生成三元组数据 38 | load_dbdata.py 将数据导入mysql db 39 | ModelParams文件夹需要下载BERT的中文配置文件:chinese_L-12_H-768_A-12 40 | Output文件夹存放输出的数据 41 | 42 | 基于BERT的命名实体识别模块 43 | - lstm_crf_layer.py 44 | - run_ner.py 45 | - tf_metrics.py 46 | - conlleval.py 47 | - conlleval.pl 48 | - run_ner.sh 49 | 50 | 基于BERT的句子相似度计算模块 51 | - args.py 52 | - run_similarity.py 53 | 54 | KBQA模块 55 | - terminal_predict.py 56 | - terminal_ner.sh 57 | - kbqa_test.py 58 | 59 | ### 使用说明 60 | 61 | - run_ner.sh 62 | NER训练和调参 63 | 64 | - terminal_ner.sh 65 | do_predict_online=True NER线上预测 66 | do_predict_outline=True NER线下预测 67 | 68 | - args.py 69 | train = True 预训练模型 70 | test = True SIM线上测试 71 | 72 | - run_similarity.py 73 | python run一下就可以啦 74 | 75 | - kbqa_test.py 76 | 基于KB的问答测试 77 | 78 | ### 实验分析 79 | ![NER图]( https://github.com/WenRichard/KBQA-BERT/raw/master/image/NER.jpg "分析图") 80 | 81 | ![kb图]( https://github.com/WenRichard/KBQA-BERT/raw/master/image/KB.png "分析图") 82 | 83 | **Cite** 84 | 如果你在研究中使用了KBQA-BERT,请按如下格式引用: 85 | 86 | ``` 87 | @software{KBQA-BERT, 88 | author = {ZhengWen Xie}, 89 | title = {KBQA-BERT: Knowledge Base Question Answering based BERT}, 90 | year = {2019}, 91 | url = {https://github.com/WenRichard/KBQA-BERT}, 92 | } 93 | ``` 94 | 95 | -------------------------------------------------------------- 96 | **如果觉得我的工作对您有帮助,请不要吝啬右上角的小星星哦!欢迎Fork和Star!也欢迎一起建设这个项目!** 97 | **有时间就会更新问答相关项目,有兴趣的同学可以follow一下** 98 | **留言请在Issues或者email richardxie1205@gmail.com** 99 | 100 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | tf.logging.set_verbosity(tf.logging.INFO) 5 | 6 | file_path = os.path.dirname(__file__) 7 | 8 | model_dir = os.path.join(file_path, 'ModelParams/chinese_L-12_H-768_A-12/') 9 | config_name = os.path.join(model_dir, 'bert_config.json') 10 | ckpt_name = os.path.join(model_dir, 'bert_model.ckpt') 11 | output_dir = os.path.join(file_path, 'Output/SIM/result/') 12 | vocab_file = os.path.join(model_dir, 'vocab.txt') 13 | data_dir = os.path.join(file_path, 'Data/Sim_Data/') 14 | 15 | num_train_epochs = 2 16 | batch_size = 128 17 | learning_rate = 0.00005 18 | 19 | # gpu使用率 20 | gpu_memory_fraction = 0.8 21 | 22 | # 默认取倒数第二层的输出值作为句向量 23 | layer_indexes = [-2] 24 | 25 | # 序列的最大程度,单文本建议把该值调小 26 | max_seq_len = 32 27 | 28 | # 预训练模型 29 | train = True 30 | 31 | # 测试模型 32 | test = False 33 | -------------------------------------------------------------------------------- /bert/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /bert/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/bert/__init__.py -------------------------------------------------------------------------------- /bert/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/bert/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /bert/__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/bert/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /bert/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/bert/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /bert/__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/bert/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /bert/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | 24 | import tokenization 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("input_file", None, 32 | "Input raw text file (or comma-separated list of files).") 33 | 34 | flags.DEFINE_string( 35 | "output_file", None, 36 | "Output TF example file (or comma-separated list of files).") 37 | 38 | flags.DEFINE_string("vocab_file", None, 39 | "The vocabulary file that the BERT model was trained on.") 40 | 41 | flags.DEFINE_bool( 42 | "do_lower_case", True, 43 | "Whether to lower case the input text. Should be True for uncased " 44 | "models and False for cased models.") 45 | 46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 47 | 48 | flags.DEFINE_integer("max_predictions_per_seq", 20, 49 | "Maximum number of masked LM predictions per sequence.") 50 | 51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 52 | 53 | flags.DEFINE_integer( 54 | "dupe_factor", 10, 55 | "Number of times to duplicate the input data (with different masks).") 56 | 57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 58 | 59 | flags.DEFINE_float( 60 | "short_seq_prob", 0.1, 61 | "Probability of creating sequences which are shorter than the " 62 | "maximum length.") 63 | 64 | 65 | class TrainingInstance(object): 66 | """A single training instance (sentence pair).""" 67 | 68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 69 | is_random_next): 70 | self.tokens = tokens 71 | self.segment_ids = segment_ids 72 | self.is_random_next = is_random_next 73 | self.masked_lm_positions = masked_lm_positions 74 | self.masked_lm_labels = masked_lm_labels 75 | 76 | def __str__(self): 77 | s = "" 78 | s += "tokens: %s\n" % (" ".join( 79 | [tokenization.printable_text(x) for x in self.tokens])) 80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 81 | s += "is_random_next: %s\n" % self.is_random_next 82 | s += "masked_lm_positions: %s\n" % (" ".join( 83 | [str(x) for x in self.masked_lm_positions])) 84 | s += "masked_lm_labels: %s\n" % (" ".join( 85 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 86 | s += "\n" 87 | return s 88 | 89 | def __repr__(self): 90 | return self.__str__() 91 | 92 | 93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 94 | max_predictions_per_seq, output_files): 95 | """Create TF example files from `TrainingInstance`s.""" 96 | writers = [] 97 | for output_file in output_files: 98 | writers.append(tf.python_io.TFRecordWriter(output_file)) 99 | 100 | writer_index = 0 101 | 102 | total_written = 0 103 | for (inst_index, instance) in enumerate(instances): 104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 105 | input_mask = [1] * len(input_ids) 106 | segment_ids = list(instance.segment_ids) 107 | assert len(input_ids) <= max_seq_length 108 | 109 | while len(input_ids) < max_seq_length: 110 | input_ids.append(0) 111 | input_mask.append(0) 112 | segment_ids.append(0) 113 | 114 | assert len(input_ids) == max_seq_length 115 | assert len(input_mask) == max_seq_length 116 | assert len(segment_ids) == max_seq_length 117 | 118 | masked_lm_positions = list(instance.masked_lm_positions) 119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 120 | masked_lm_weights = [1.0] * len(masked_lm_ids) 121 | 122 | while len(masked_lm_positions) < max_predictions_per_seq: 123 | masked_lm_positions.append(0) 124 | masked_lm_ids.append(0) 125 | masked_lm_weights.append(0.0) 126 | 127 | next_sentence_label = 1 if instance.is_random_next else 0 128 | 129 | features = collections.OrderedDict() 130 | features["input_ids"] = create_int_feature(input_ids) 131 | features["input_mask"] = create_int_feature(input_mask) 132 | features["segment_ids"] = create_int_feature(segment_ids) 133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 137 | 138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 139 | 140 | writers[writer_index].write(tf_example.SerializeToString()) 141 | writer_index = (writer_index + 1) % len(writers) 142 | 143 | total_written += 1 144 | 145 | if inst_index < 20: 146 | tf.logging.info("*** Example ***") 147 | tf.logging.info("tokens: %s" % " ".join( 148 | [tokenization.printable_text(x) for x in instance.tokens])) 149 | 150 | for feature_name in features.keys(): 151 | feature = features[feature_name] 152 | values = [] 153 | if feature.int64_list.value: 154 | values = feature.int64_list.value 155 | elif feature.float_list.value: 156 | values = feature.float_list.value 157 | tf.logging.info( 158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 159 | 160 | for writer in writers: 161 | writer.close() 162 | 163 | tf.logging.info("Wrote %d total instances", total_written) 164 | 165 | 166 | def create_int_feature(values): 167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 168 | return feature 169 | 170 | 171 | def create_float_feature(values): 172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 173 | return feature 174 | 175 | 176 | def create_training_instances(input_files, tokenizer, max_seq_length, 177 | dupe_factor, short_seq_prob, masked_lm_prob, 178 | max_predictions_per_seq, rng): 179 | """Create `TrainingInstance`s from raw text.""" 180 | all_documents = [[]] 181 | 182 | # Input file format: 183 | # (1) One sentence per line. These should ideally be actual sentences, not 184 | # entire paragraphs or arbitrary spans of text. (Because we use the 185 | # sentence boundaries for the "next sentence prediction" task). 186 | # (2) Blank lines between documents. Document boundaries are needed so 187 | # that the "next sentence prediction" task doesn't span between documents. 188 | for input_file in input_files: 189 | with tf.gfile.GFile(input_file, "r") as reader: 190 | while True: 191 | line = tokenization.convert_to_unicode(reader.readline()) 192 | if not line: 193 | break 194 | line = line.strip() 195 | 196 | # Empty lines are used as document delimiters 197 | if not line: 198 | all_documents.append([]) 199 | tokens = tokenizer.tokenize(line) 200 | if tokens: 201 | all_documents[-1].append(tokens) 202 | 203 | # Remove empty documents 204 | all_documents = [x for x in all_documents if x] 205 | rng.shuffle(all_documents) 206 | 207 | vocab_words = list(tokenizer.vocab.keys()) 208 | instances = [] 209 | for _ in range(dupe_factor): 210 | for document_index in range(len(all_documents)): 211 | instances.extend( 212 | create_instances_from_document( 213 | all_documents, document_index, max_seq_length, short_seq_prob, 214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 215 | 216 | rng.shuffle(instances) 217 | return instances 218 | 219 | 220 | def create_instances_from_document( 221 | all_documents, document_index, max_seq_length, short_seq_prob, 222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 223 | """Creates `TrainingInstance`s for a single document.""" 224 | document = all_documents[document_index] 225 | 226 | # Account for [CLS], [SEP], [SEP] 227 | max_num_tokens = max_seq_length - 3 228 | 229 | # We *usually* want to fill up the entire sequence since we are padding 230 | # to `max_seq_length` anyways, so short sequences are generally wasted 231 | # computation. However, we *sometimes* 232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 233 | # sequences to minimize the mismatch between pre-training and fine-tuning. 234 | # The `target_seq_length` is just a rough target however, whereas 235 | # `max_seq_length` is a hard limit. 236 | target_seq_length = max_num_tokens 237 | if rng.random() < short_seq_prob: 238 | target_seq_length = rng.randint(2, max_num_tokens) 239 | 240 | # We DON'T just concatenate all of the tokens from a document into a long 241 | # sequence and choose an arbitrary split point because this would make the 242 | # next sentence prediction task too easy. Instead, we split the input into 243 | # segments "A" and "B" based on the actual "sentences" provided by the user 244 | # input. 245 | instances = [] 246 | current_chunk = [] 247 | current_length = 0 248 | i = 0 249 | while i < len(document): 250 | segment = document[i] 251 | current_chunk.append(segment) 252 | current_length += len(segment) 253 | if i == len(document) - 1 or current_length >= target_seq_length: 254 | if current_chunk: 255 | # `a_end` is how many segments from `current_chunk` go into the `A` 256 | # (first) sentence. 257 | a_end = 1 258 | if len(current_chunk) >= 2: 259 | a_end = rng.randint(1, len(current_chunk) - 1) 260 | 261 | tokens_a = [] 262 | for j in range(a_end): 263 | tokens_a.extend(current_chunk[j]) 264 | 265 | tokens_b = [] 266 | # Random next 267 | is_random_next = False 268 | if len(current_chunk) == 1 or rng.random() < 0.5: 269 | is_random_next = True 270 | target_b_length = target_seq_length - len(tokens_a) 271 | 272 | # This should rarely go for more than one iteration for large 273 | # corpora. However, just to be careful, we try to make sure that 274 | # the random document is not the same as the document 275 | # we're processing. 276 | for _ in range(10): 277 | random_document_index = rng.randint(0, len(all_documents) - 1) 278 | if random_document_index != document_index: 279 | break 280 | 281 | random_document = all_documents[random_document_index] 282 | random_start = rng.randint(0, len(random_document) - 1) 283 | for j in range(random_start, len(random_document)): 284 | tokens_b.extend(random_document[j]) 285 | if len(tokens_b) >= target_b_length: 286 | break 287 | # We didn't actually use these segments so we "put them back" so 288 | # they don't go to waste. 289 | num_unused_segments = len(current_chunk) - a_end 290 | i -= num_unused_segments 291 | # Actual next 292 | else: 293 | is_random_next = False 294 | for j in range(a_end, len(current_chunk)): 295 | tokens_b.extend(current_chunk[j]) 296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 297 | 298 | assert len(tokens_a) >= 1 299 | assert len(tokens_b) >= 1 300 | 301 | tokens = [] 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in tokens_a: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | 309 | tokens.append("[SEP]") 310 | segment_ids.append(0) 311 | 312 | for token in tokens_b: 313 | tokens.append(token) 314 | segment_ids.append(1) 315 | tokens.append("[SEP]") 316 | segment_ids.append(1) 317 | 318 | (tokens, masked_lm_positions, 319 | masked_lm_labels) = create_masked_lm_predictions( 320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 321 | instance = TrainingInstance( 322 | tokens=tokens, 323 | segment_ids=segment_ids, 324 | is_random_next=is_random_next, 325 | masked_lm_positions=masked_lm_positions, 326 | masked_lm_labels=masked_lm_labels) 327 | instances.append(instance) 328 | current_chunk = [] 329 | current_length = 0 330 | i += 1 331 | 332 | return instances 333 | 334 | 335 | def create_masked_lm_predictions(tokens, masked_lm_prob, 336 | max_predictions_per_seq, vocab_words, rng): 337 | """Creates the predictions for the masked LM objective.""" 338 | 339 | cand_indexes = [] 340 | for (i, token) in enumerate(tokens): 341 | if token == "[CLS]" or token == "[SEP]": 342 | continue 343 | cand_indexes.append(i) 344 | 345 | rng.shuffle(cand_indexes) 346 | 347 | output_tokens = list(tokens) 348 | 349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name 350 | 351 | num_to_predict = min(max_predictions_per_seq, 352 | max(1, int(round(len(tokens) * masked_lm_prob)))) 353 | 354 | masked_lms = [] 355 | covered_indexes = set() 356 | for index in cand_indexes: 357 | if len(masked_lms) >= num_to_predict: 358 | break 359 | if index in covered_indexes: 360 | continue 361 | covered_indexes.add(index) 362 | 363 | masked_token = None 364 | # 80% of the time, replace with [MASK] 365 | if rng.random() < 0.8: 366 | masked_token = "[MASK]" 367 | else: 368 | # 10% of the time, keep original 369 | if rng.random() < 0.5: 370 | masked_token = tokens[index] 371 | # 10% of the time, replace with random word 372 | else: 373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 374 | 375 | output_tokens[index] = masked_token 376 | 377 | masked_lms.append(masked_lm(index=index, label=tokens[index])) 378 | 379 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 380 | 381 | masked_lm_positions = [] 382 | masked_lm_labels = [] 383 | for p in masked_lms: 384 | masked_lm_positions.append(p.index) 385 | masked_lm_labels.append(p.label) 386 | 387 | return (output_tokens, masked_lm_positions, masked_lm_labels) 388 | 389 | 390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 391 | """Truncates a pair of sequences to a maximum sequence length.""" 392 | while True: 393 | total_length = len(tokens_a) + len(tokens_b) 394 | if total_length <= max_num_tokens: 395 | break 396 | 397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 398 | assert len(trunc_tokens) >= 1 399 | 400 | # We want to sometimes truncate from the front and sometimes from the 401 | # back to add more randomness and avoid biases. 402 | if rng.random() < 0.5: 403 | del trunc_tokens[0] 404 | else: 405 | trunc_tokens.pop() 406 | 407 | 408 | def main(_): 409 | tf.logging.set_verbosity(tf.logging.INFO) 410 | 411 | tokenizer = tokenization.FullTokenizer( 412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 413 | 414 | input_files = [] 415 | for input_pattern in FLAGS.input_file.split(","): 416 | input_files.extend(tf.gfile.Glob(input_pattern)) 417 | 418 | tf.logging.info("*** Reading from input files ***") 419 | for input_file in input_files: 420 | tf.logging.info(" %s", input_file) 421 | 422 | rng = random.Random(FLAGS.random_seed) 423 | instances = create_training_instances( 424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 426 | rng) 427 | 428 | output_files = FLAGS.output_file.split(",") 429 | tf.logging.info("*** Writing to output files ***") 430 | for output_file in output_files: 431 | tf.logging.info(" %s", output_file) 432 | 433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 434 | FLAGS.max_predictions_per_seq, output_files) 435 | 436 | 437 | if __name__ == "__main__": 438 | flags.mark_flag_as_required("input_file") 439 | flags.mark_flag_as_required("output_file") 440 | flags.mark_flag_as_required("vocab_file") 441 | tf.app.run() 442 | -------------------------------------------------------------------------------- /bert/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import modeling 27 | import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, 174 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 175 | tvars, init_checkpoint) 176 | if use_tpu: 177 | 178 | def tpu_scaffold(): 179 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 180 | return tf.train.Scaffold() 181 | 182 | scaffold_fn = tpu_scaffold 183 | else: 184 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 185 | 186 | tf.logging.info("**** Trainable Variables ****") 187 | for var in tvars: 188 | init_string = "" 189 | if var.name in initialized_variable_names: 190 | init_string = ", *INIT_FROM_CKPT*" 191 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 192 | init_string) 193 | 194 | all_layers = model.get_all_encoder_layers() 195 | 196 | predictions = { 197 | "unique_id": unique_ids, 198 | } 199 | 200 | for (i, layer_index) in enumerate(layer_indexes): 201 | predictions["layer_output_%d" % i] = all_layers[layer_index] 202 | 203 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 204 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 205 | return output_spec 206 | 207 | return model_fn 208 | 209 | 210 | def convert_examples_to_features(examples, seq_length, tokenizer): 211 | """Loads a data file into a list of `InputBatch`s.""" 212 | 213 | features = [] 214 | for (ex_index, example) in enumerate(examples): 215 | tokens_a = tokenizer.tokenize(example.text_a) 216 | 217 | tokens_b = None 218 | if example.text_b: 219 | tokens_b = tokenizer.tokenize(example.text_b) 220 | 221 | if tokens_b: 222 | # Modifies `tokens_a` and `tokens_b` in place so that the total 223 | # length is less than the specified length. 224 | # Account for [CLS], [SEP], [SEP] with "- 3" 225 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 226 | else: 227 | # Account for [CLS] and [SEP] with "- 2" 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | if tokens_b: 260 | for token in tokens_b: 261 | tokens.append(token) 262 | input_type_ids.append(1) 263 | tokens.append("[SEP]") 264 | input_type_ids.append(1) 265 | 266 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 267 | 268 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 269 | # tokens are attended to. 270 | input_mask = [1] * len(input_ids) 271 | 272 | # Zero-pad up to the sequence length. 273 | while len(input_ids) < seq_length: 274 | input_ids.append(0) 275 | input_mask.append(0) 276 | input_type_ids.append(0) 277 | 278 | assert len(input_ids) == seq_length 279 | assert len(input_mask) == seq_length 280 | assert len(input_type_ids) == seq_length 281 | 282 | if ex_index < 5: 283 | tf.logging.info("*** Example ***") 284 | tf.logging.info("unique_id: %s" % (example.unique_id)) 285 | tf.logging.info("tokens: %s" % " ".join( 286 | [tokenization.printable_text(x) for x in tokens])) 287 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 288 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 289 | tf.logging.info( 290 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 291 | 292 | features.append( 293 | InputFeatures( 294 | unique_id=example.unique_id, 295 | tokens=tokens, 296 | input_ids=input_ids, 297 | input_mask=input_mask, 298 | input_type_ids=input_type_ids)) 299 | return features 300 | 301 | 302 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 303 | """Truncates a sequence pair in place to the maximum length.""" 304 | 305 | # This is a simple heuristic which will always truncate the longer sequence 306 | # one token at a time. This makes more sense than truncating an equal percent 307 | # of tokens from each, since if one sequence is very short then each token 308 | # that's truncated likely contains more information than a longer sequence. 309 | while True: 310 | total_length = len(tokens_a) + len(tokens_b) 311 | if total_length <= max_length: 312 | break 313 | if len(tokens_a) > len(tokens_b): 314 | tokens_a.pop() 315 | else: 316 | tokens_b.pop() 317 | 318 | 319 | def read_examples(input_file): 320 | """Read a list of `InputExample`s from an input file.""" 321 | examples = [] 322 | unique_id = 0 323 | with tf.gfile.GFile(input_file, "r") as reader: 324 | while True: 325 | line = tokenization.convert_to_unicode(reader.readline()) 326 | if not line: 327 | break 328 | line = line.strip() 329 | text_a = None 330 | text_b = None 331 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 332 | if m is None: 333 | text_a = line 334 | else: 335 | text_a = m.group(1) 336 | text_b = m.group(2) 337 | examples.append( 338 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 339 | unique_id += 1 340 | return examples 341 | 342 | 343 | def main(_): 344 | tf.logging.set_verbosity(tf.logging.INFO) 345 | 346 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 347 | 348 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 349 | 350 | tokenizer = tokenization.FullTokenizer( 351 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 354 | run_config = tf.contrib.tpu.RunConfig( 355 | master=FLAGS.master, 356 | tpu_config=tf.contrib.tpu.TPUConfig( 357 | num_shards=FLAGS.num_tpu_cores, 358 | per_host_input_for_training=is_per_host)) 359 | 360 | examples = read_examples(FLAGS.input_file) 361 | 362 | features = convert_examples_to_features( 363 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 364 | 365 | unique_id_to_feature = {} 366 | for feature in features: 367 | unique_id_to_feature[feature.unique_id] = feature 368 | 369 | model_fn = model_fn_builder( 370 | bert_config=bert_config, 371 | init_checkpoint=FLAGS.init_checkpoint, 372 | layer_indexes=layer_indexes, 373 | use_tpu=FLAGS.use_tpu, 374 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 375 | 376 | # If TPU is not available, this will fall back to normal Estimator on CPU 377 | # or GPU. 378 | estimator = tf.contrib.tpu.TPUEstimator( 379 | use_tpu=FLAGS.use_tpu, 380 | model_fn=model_fn, 381 | config=run_config, 382 | predict_batch_size=FLAGS.batch_size) 383 | 384 | input_fn = input_fn_builder( 385 | features=features, seq_length=FLAGS.max_seq_length) 386 | 387 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 388 | "w")) as writer: 389 | for result in estimator.predict(input_fn, yield_single_examples=True): 390 | unique_id = int(result["unique_id"]) 391 | feature = unique_id_to_feature[unique_id] 392 | output_json = collections.OrderedDict() 393 | output_json["linex_index"] = unique_id 394 | all_features = [] 395 | for (i, token) in enumerate(feature.tokens): 396 | all_layers = [] 397 | for (j, layer_index) in enumerate(layer_indexes): 398 | layer_output = result["layer_output_%d" % j] 399 | layers = collections.OrderedDict() 400 | layers["index"] = layer_index 401 | layers["values"] = [ 402 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 403 | ] 404 | all_layers.append(layers) 405 | features = collections.OrderedDict() 406 | features["token"] = token 407 | features["layers"] = all_layers 408 | all_features.append(features) 409 | output_json["features"] = all_features 410 | writer.write(json.dumps(output_json) + "\n") 411 | 412 | 413 | if __name__ == "__main__": 414 | flags.mark_flag_as_required("input_file") 415 | flags.mark_flag_as_required("vocab_file") 416 | flags.mark_flag_as_required("bert_config_file") 417 | flags.mark_flag_as_required("init_checkpoint") 418 | flags.mark_flag_as_required("output_file") 419 | tf.app.run() 420 | -------------------------------------------------------------------------------- /bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/assert_less_equal/.*$", 168 | "^.*/dilation_rate$", 169 | "^.*/Tensordot/concat$", 170 | "^.*/Tensordot/concat/axis$", 171 | "^testing/.*$", 172 | ] 173 | 174 | ignore_regexes = [re.compile(x) for x in ignore_strings] 175 | 176 | unreachable = self.get_unreachable_ops(graph, outputs) 177 | filtered_unreachable = [] 178 | for x in unreachable: 179 | do_ignore = False 180 | for r in ignore_regexes: 181 | m = r.match(x.name) 182 | if m is not None: 183 | do_ignore = True 184 | if do_ignore: 185 | continue 186 | filtered_unreachable.append(x) 187 | unreachable = filtered_unreachable 188 | 189 | self.assertEqual( 190 | len(unreachable), 0, "The following ops are unreachable: %s" % 191 | (" ".join([x.name for x in unreachable]))) 192 | 193 | @classmethod 194 | def get_unreachable_ops(cls, graph, outputs): 195 | """Finds all of the tensors in graph that are unreachable from outputs.""" 196 | outputs = cls.flatten_recursive(outputs) 197 | output_to_op = collections.defaultdict(list) 198 | op_to_all = collections.defaultdict(list) 199 | assign_out_to_in = collections.defaultdict(list) 200 | 201 | for op in graph.get_operations(): 202 | for x in op.inputs: 203 | op_to_all[op.name].append(x.name) 204 | for y in op.outputs: 205 | output_to_op[y.name].append(op.name) 206 | op_to_all[op.name].append(y.name) 207 | if str(op.type) == "Assign": 208 | for y in op.outputs: 209 | for x in op.inputs: 210 | assign_out_to_in[y.name].append(x.name) 211 | 212 | assign_groups = collections.defaultdict(list) 213 | for out_name in assign_out_to_in.keys(): 214 | name_group = assign_out_to_in[out_name] 215 | for n1 in name_group: 216 | assign_groups[n1].append(out_name) 217 | for n2 in name_group: 218 | if n1 != n2: 219 | assign_groups[n1].append(n2) 220 | 221 | seen_tensors = {} 222 | stack = [x.name for x in outputs] 223 | while stack: 224 | name = stack.pop() 225 | if name in seen_tensors: 226 | continue 227 | seen_tensors[name] = True 228 | 229 | if name in output_to_op: 230 | for op_name in output_to_op[name]: 231 | if op_name in op_to_all: 232 | for input_name in op_to_all[op_name]: 233 | if input_name not in stack: 234 | stack.append(input_name) 235 | 236 | expanded_names = [] 237 | if name in assign_groups: 238 | for assign_name in assign_groups[name]: 239 | expanded_names.append(assign_name) 240 | 241 | for expanded_name in expanded_names: 242 | if expanded_name not in stack: 243 | stack.append(expanded_name) 244 | 245 | unreachable_ops = [] 246 | for op in graph.get_operations(): 247 | is_unreachable = False 248 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 249 | for name in all_names: 250 | if name not in seen_tensors: 251 | is_unreachable = True 252 | if is_unreachable: 253 | unreachable_ops.append(op) 254 | return unreachable_ops 255 | 256 | @classmethod 257 | def flatten_recursive(cls, item): 258 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 259 | output = [] 260 | if isinstance(item, list): 261 | output.extend(item) 262 | elif isinstance(item, tuple): 263 | output.extend(list(item)) 264 | elif isinstance(item, dict): 265 | for (_, v) in six.iteritems(item): 266 | output.append(v) 267 | else: 268 | return [item] 269 | 270 | flat_output = [] 271 | for x in output: 272 | flat_output.extend(cls.flatten_recursive(x)) 273 | return flat_output 274 | 275 | 276 | if __name__ == "__main__": 277 | tf.test.main() 278 | -------------------------------------------------------------------------------- /bert/multilingual.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | There are two multilingual models currently available. We do not plan to release 4 | more single-language models, but we may release `BERT-Large` versions of these 5 | two in the future: 6 | 7 | * **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**: 8 | 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters 9 | * **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: 10 | Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M 11 | parameters 12 | 13 | See the [list of languages](#list-of-languages) that the Multilingual model 14 | supports. The Multilingual model does include Chinese (and English), but if your 15 | fine-tuning data is Chinese-only, then the Chinese model will likely produce 16 | better results. 17 | 18 | ## Results 19 | 20 | To evaluate these systems, we use the 21 | [XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a 22 | version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the 23 | dev and test sets have been translated (by humans) into 15 languages. Note that 24 | the training set was *machine* translated (we used the translations provided by 25 | XNLI, not Google NMT). For clarity, we only report on 6 languages below: 26 | 27 | 28 | 29 | | System | English | Chinese | Spanish | German | Arabic | Urdu | 30 | | ------------------------------- | -------- | -------- | -------- | -------- | -------- | -------- | 31 | | XNLI Baseline - Translate Train | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 | 32 | | XNLI Baseline - Translate Test | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 | 33 | | BERT -Translate Train | **81.4** | **74.2** | **77.3** | **75.2** | **70.5** | 61.7 | 34 | | BERT - Translate Test | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** | 35 | | BERT - Zero Shot | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 | 36 | 37 | 38 | 39 | The first two rows are baselines from the XNLI paper and the last three rows are 40 | our results with BERT. 41 | 42 | **Translate Train** means that the MultiNLI training set was machine translated 43 | from English into the foreign language. So training and evaluation were both 44 | done in the foreign language. Unfortunately, training was done on 45 | machine-translated data, so it is impossible to quantify how much of the lower 46 | accuracy (compared to English) is due to the quality of the machine translation 47 | vs. the quality of the pre-trained model. 48 | 49 | **Translate Test** means that the XNLI test set was machine translated from the 50 | foreign language into English. So training and evaluation were both done on 51 | English. However, test evaluation was done on machine-translated English, so the 52 | accuracy depends on the quality of the machine translation system. 53 | 54 | **Zero Shot** means that the Multilingual BERT system was fine-tuned on English 55 | MultiNLI, and then evaluated on the foreign language XNLI test. In this case, 56 | machine translation was not involved at all in either the pre-training or 57 | fine-tuning. 58 | 59 | Note that the English result is worse than the 84.2 MultiNLI baseline because 60 | this training used Multilingual BERT rather than English-only BERT. This implies 61 | that for high-resource languages, the Multilingual model is somewhat worse than 62 | a single-language model. However, it is not feasible for us to train and 63 | maintain dozens of single-language model. Therefore, if your goal is to maximize 64 | performance with a language other than English or Chinese, you might find it 65 | beneficial to run pre-training for additional steps starting from our 66 | Multilingual model on data from your language of interest. 67 | 68 | Here is a comparison of training Chinese models with the Multilingual 69 | `BERT-Base` and Chinese-only `BERT-Base`: 70 | 71 | System | Chinese 72 | ----------------------- | ------- 73 | XNLI Baseline | 67.0 74 | BERT Multilingual Model | 74.2 75 | BERT Chinese-only Model | 77.2 76 | 77 | Similar to English, the single-language model does 3% better than the 78 | Multilingual model. 79 | 80 | ## Fine-tuning Example 81 | 82 | The multilingual model does **not** require any special consideration or API 83 | changes. We did update the implementation of `BasicTokenizer` in 84 | `tokenization.py` to support Chinese character tokenization, so please update if 85 | you forked it. However, we did not change the tokenization API. 86 | 87 | To test the new models, we did modify `run_classifier.py` to add support for the 88 | [XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language 89 | version of MultiNLI where the dev/test sets have been human-translated, and the 90 | training set has been machine-translated. 91 | 92 | To run the fine-tuning code, please download the 93 | [XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the 94 | [XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip) 95 | and then unpack both .zip files into some directory `$XNLI_DIR`. 96 | 97 | To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py` 98 | (Chinese by default), so please modify `XnliProcessor` if you want to run on 99 | another language. 100 | 101 | This is a large dataset, so this will training will take a few hours on a GPU 102 | (or about 30 minutes on a Cloud TPU). To run an experiment quickly for 103 | debugging, just set `num_train_epochs` to a small value like `0.1`. 104 | 105 | ```shell 106 | export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12 107 | export XNLI_DIR=/path/to/xnli 108 | 109 | python run_classifier.py \ 110 | --task_name=XNLI \ 111 | --do_train=true \ 112 | --do_eval=true \ 113 | --data_dir=$XNLI_DIR \ 114 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 115 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 116 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 117 | --max_seq_length=128 \ 118 | --train_batch_size=32 \ 119 | --learning_rate=5e-5 \ 120 | --num_train_epochs=2.0 \ 121 | --output_dir=/tmp/xnli_output/ 122 | ``` 123 | 124 | With the Chinese-only model, the results should look something like this: 125 | 126 | ``` 127 | ***** Eval results ***** 128 | eval_accuracy = 0.774116 129 | eval_loss = 0.83554 130 | global_step = 24543 131 | loss = 0.74603 132 | ``` 133 | 134 | ## Details 135 | 136 | ### Data Source and Sampling 137 | 138 | The languages chosen were the 139 | [top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias). 140 | The entire Wikipedia dump for each language (excluding user and talk pages) was 141 | taken as the training data for each language 142 | 143 | However, the size of the Wikipedia for a given language varies greatly, and 144 | therefore low-resource languages may be "under-represented" in terms of the 145 | neural network model (under the assumption that languages are "competing" for 146 | limited model capacity to some extent). 147 | 148 | However, the size of a Wikipedia also correlates with the number of speakers of 149 | a language, and we also don't want to overfit the model by performing thousands 150 | of epochs over a tiny Wikipedia for a particular language. 151 | 152 | To balance these two factors, we performed exponentially smoothed weighting of 153 | the data during pre-training data creation (and WordPiece vocab creation). In 154 | other words, let's say that the probability of a language is *P(L)*, e.g., 155 | *P(English) = 0.21* means that after concatenating all of the Wikipedias 156 | together, 21% of our data is English. We exponentiate each probability by some 157 | factor *S* and then re-normalize, and sample from that distribution. In our case 158 | we use *S=0.7*. So, high-resource languages like English will be under-sampled, 159 | and low-resource languages like Icelandic will be over-sampled. E.g., in the 160 | original distribution English would be sampled 1000x more than Icelandic, but 161 | after smoothing it's only sampled 100x more. 162 | 163 | ### Tokenization 164 | 165 | For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are 166 | weighted the same way as the data, so low-resource languages are upweighted by 167 | some factor. We intentionally do *not* use any marker to denote the input 168 | language (so that zero-shot training can work). 169 | 170 | Because Chinese does not have whitespace characters, we add spaces around every 171 | character in the 172 | [CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\)) 173 | before applying WordPiece. This means that Chinese is effectively 174 | character-tokenized. Note that the CJK Unicode block only includes 175 | Chinese-origin characters and does *not* include Hangul Korean or 176 | Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like 177 | all other languages. 178 | 179 | For all other languages, we apply the 180 | [same recipe as English](https://github.com/google-research/bert#tokenization): 181 | (a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace 182 | tokenization. We understand that accent markers have substantial meaning in some 183 | languages, but felt that the benefits of reducing the effective vocabulary make 184 | up for this. Generally the strong contextual models of BERT should make up for 185 | any ambiguity introduced by stripping accent markers. 186 | 187 | ### List of Languages 188 | 189 | The multilingual model supports the following languages. These languages were 190 | chosen because they are the top 100 languages with the largest Wikipedias: 191 | 192 | * Afrikaans 193 | * Albanian 194 | * Arabic 195 | * Aragonese 196 | * Armenian 197 | * Asturian 198 | * Azerbaijani 199 | * Bashkir 200 | * Basque 201 | * Bavarian 202 | * Belarusian 203 | * Bengali 204 | * Bishnupriya Manipuri 205 | * Bosnian 206 | * Breton 207 | * Bulgarian 208 | * Burmese 209 | * Catalan 210 | * Cebuano 211 | * Chechen 212 | * Chinese (Simplified) 213 | * Chinese (Traditional) 214 | * Chuvash 215 | * Croatian 216 | * Czech 217 | * Danish 218 | * Dutch 219 | * English 220 | * Estonian 221 | * Finnish 222 | * French 223 | * Galician 224 | * Georgian 225 | * German 226 | * Greek 227 | * Gujarati 228 | * Haitian 229 | * Hebrew 230 | * Hindi 231 | * Hungarian 232 | * Icelandic 233 | * Ido 234 | * Indonesian 235 | * Irish 236 | * Italian 237 | * Japanese 238 | * Javanese 239 | * Kannada 240 | * Kazakh 241 | * Kirghiz 242 | * Korean 243 | * Latin 244 | * Latvian 245 | * Lithuanian 246 | * Lombard 247 | * Low Saxon 248 | * Luxembourgish 249 | * Macedonian 250 | * Malagasy 251 | * Malay 252 | * Malayalam 253 | * Marathi 254 | * Minangkabau 255 | * Nepali 256 | * Newar 257 | * Norwegian (Bokmal) 258 | * Norwegian (Nynorsk) 259 | * Occitan 260 | * Persian (Farsi) 261 | * Piedmontese 262 | * Polish 263 | * Portuguese 264 | * Punjabi 265 | * Romanian 266 | * Russian 267 | * Scots 268 | * Serbian 269 | * Serbo-Croatian 270 | * Sicilian 271 | * Slovak 272 | * Slovenian 273 | * South Azerbaijani 274 | * Spanish 275 | * Sundanese 276 | * Swahili 277 | * Swedish 278 | * Tagalog 279 | * Tajik 280 | * Tamil 281 | * Tatar 282 | * Telugu 283 | * Turkish 284 | * Ukrainian 285 | * Urdu 286 | * Uzbek 287 | * Vietnamese 288 | * Volapük 289 | * Waray-Waray 290 | * Welsh 291 | * West 292 | * Western Punjabi 293 | * Yoruba 294 | 295 | The only language which we had to unfortunately exclude was Thai, since it is 296 | the only language (other than Chinese) that does not use whitespace to delimit 297 | words, and it has too many characters-per-word to use character-based 298 | tokenization. Our WordPiece algorithm is quadratic with respect to the size of 299 | the input token so very long character strings do not work with it. 300 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /bert/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /bert/run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | scaffold_fn = None 154 | if init_checkpoint: 155 | (assignment_map, initialized_variable_names 156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 157 | if use_tpu: 158 | 159 | def tpu_scaffold(): 160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 161 | return tf.train.Scaffold() 162 | 163 | scaffold_fn = tpu_scaffold 164 | else: 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | 167 | tf.logging.info("**** Trainable Variables ****") 168 | for var in tvars: 169 | init_string = "" 170 | if var.name in initialized_variable_names: 171 | init_string = ", *INIT_FROM_CKPT*" 172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 173 | init_string) 174 | 175 | output_spec = None 176 | if mode == tf.estimator.ModeKeys.TRAIN: 177 | train_op = optimization.create_optimizer( 178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 179 | 180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | train_op=train_op, 184 | scaffold_fn=scaffold_fn) 185 | elif mode == tf.estimator.ModeKeys.EVAL: 186 | 187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 188 | masked_lm_weights, next_sentence_example_loss, 189 | next_sentence_log_probs, next_sentence_labels): 190 | """Computes the loss and accuracy of the model.""" 191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 192 | [-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax( 194 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 198 | masked_lm_accuracy = tf.metrics.accuracy( 199 | labels=masked_lm_ids, 200 | predictions=masked_lm_predictions, 201 | weights=masked_lm_weights) 202 | masked_lm_mean_loss = tf.metrics.mean( 203 | values=masked_lm_example_loss, weights=masked_lm_weights) 204 | 205 | next_sentence_log_probs = tf.reshape( 206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 207 | next_sentence_predictions = tf.argmax( 208 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 210 | next_sentence_accuracy = tf.metrics.accuracy( 211 | labels=next_sentence_labels, predictions=next_sentence_predictions) 212 | next_sentence_mean_loss = tf.metrics.mean( 213 | values=next_sentence_example_loss) 214 | 215 | return { 216 | "masked_lm_accuracy": masked_lm_accuracy, 217 | "masked_lm_loss": masked_lm_mean_loss, 218 | "next_sentence_accuracy": next_sentence_accuracy, 219 | "next_sentence_loss": next_sentence_mean_loss, 220 | } 221 | 222 | eval_metrics = (metric_fn, [ 223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 224 | masked_lm_weights, next_sentence_example_loss, 225 | next_sentence_log_probs, next_sentence_labels 226 | ]) 227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 228 | mode=mode, 229 | loss=total_loss, 230 | eval_metrics=eval_metrics, 231 | scaffold_fn=scaffold_fn) 232 | else: 233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 234 | 235 | return output_spec 236 | 237 | return model_fn 238 | 239 | 240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 241 | label_ids, label_weights): 242 | """Get loss and log probs for the masked LM.""" 243 | input_tensor = gather_indexes(input_tensor, positions) 244 | 245 | with tf.variable_scope("cls/predictions"): 246 | # We apply one more non-linear transformation before the output layer. 247 | # This matrix is not used after pre-training. 248 | with tf.variable_scope("transform"): 249 | input_tensor = tf.layers.dense( 250 | input_tensor, 251 | units=bert_config.hidden_size, 252 | activation=modeling.get_activation(bert_config.hidden_act), 253 | kernel_initializer=modeling.create_initializer( 254 | bert_config.initializer_range)) 255 | input_tensor = modeling.layer_norm(input_tensor) 256 | 257 | # The output weights are the same as the input embeddings, but there is 258 | # an output-only bias for each token. 259 | output_bias = tf.get_variable( 260 | "output_bias", 261 | shape=[bert_config.vocab_size], 262 | initializer=tf.zeros_initializer()) 263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 264 | logits = tf.nn.bias_add(logits, output_bias) 265 | log_probs = tf.nn.log_softmax(logits, axis=-1) 266 | 267 | label_ids = tf.reshape(label_ids, [-1]) 268 | label_weights = tf.reshape(label_weights, [-1]) 269 | 270 | one_hot_labels = tf.one_hot( 271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 272 | 273 | # The `positions` tensor might be zero-padded (if the sequence is too 274 | # short to have the maximum number of predictions). The `label_weights` 275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 276 | # padding predictions. 277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 278 | numerator = tf.reduce_sum(label_weights * per_example_loss) 279 | denominator = tf.reduce_sum(label_weights) + 1e-5 280 | loss = numerator / denominator 281 | 282 | return (loss, per_example_loss, log_probs) 283 | 284 | 285 | def get_next_sentence_output(bert_config, input_tensor, labels): 286 | """Get loss and log probs for the next sentence prediction.""" 287 | 288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 289 | # "random sentence". This weight matrix is not used after pre-training. 290 | with tf.variable_scope("cls/seq_relationship"): 291 | output_weights = tf.get_variable( 292 | "output_weights", 293 | shape=[2, bert_config.hidden_size], 294 | initializer=modeling.create_initializer(bert_config.initializer_range)) 295 | output_bias = tf.get_variable( 296 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 297 | 298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 299 | logits = tf.nn.bias_add(logits, output_bias) 300 | log_probs = tf.nn.log_softmax(logits, axis=-1) 301 | labels = tf.reshape(labels, [-1]) 302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 304 | loss = tf.reduce_mean(per_example_loss) 305 | return (loss, per_example_loss, log_probs) 306 | 307 | 308 | def gather_indexes(sequence_tensor, positions): 309 | """Gathers the vectors at the specific positions over a minibatch.""" 310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 311 | batch_size = sequence_shape[0] 312 | seq_length = sequence_shape[1] 313 | width = sequence_shape[2] 314 | 315 | flat_offsets = tf.reshape( 316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 317 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 318 | flat_sequence_tensor = tf.reshape(sequence_tensor, 319 | [batch_size * seq_length, width]) 320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 321 | return output_tensor 322 | 323 | 324 | def input_fn_builder(input_files, 325 | max_seq_length, 326 | max_predictions_per_seq, 327 | is_training, 328 | num_cpu_threads=4): 329 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | batch_size = params["batch_size"] 334 | 335 | name_to_features = { 336 | "input_ids": 337 | tf.FixedLenFeature([max_seq_length], tf.int64), 338 | "input_mask": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "segment_ids": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "masked_lm_positions": 343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 344 | "masked_lm_ids": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_weights": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 348 | "next_sentence_labels": 349 | tf.FixedLenFeature([1], tf.int64), 350 | } 351 | 352 | # For training, we want a lot of parallel reading and shuffling. 353 | # For eval, we want no shuffling and parallel reading doesn't matter. 354 | if is_training: 355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 356 | d = d.repeat() 357 | d = d.shuffle(buffer_size=len(input_files)) 358 | 359 | # `cycle_length` is the number of parallel files that get read. 360 | cycle_length = min(num_cpu_threads, len(input_files)) 361 | 362 | # `sloppy` mode means that the interleaving is not exact. This adds 363 | # even more randomness to the training pipeline. 364 | d = d.apply( 365 | tf.contrib.data.parallel_interleave( 366 | tf.data.TFRecordDataset, 367 | sloppy=is_training, 368 | cycle_length=cycle_length)) 369 | d = d.shuffle(buffer_size=100) 370 | else: 371 | d = tf.data.TFRecordDataset(input_files) 372 | # Since we evaluate for a fixed number of steps we don't want to encounter 373 | # out-of-range exceptions. 374 | d = d.repeat() 375 | 376 | # We must `drop_remainder` on training because the TPU requires fixed 377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 378 | # and we *don't* want to drop the remainder, otherwise we wont cover 379 | # every sample. 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size, 384 | num_parallel_batches=num_cpu_threads, 385 | drop_remainder=True)) 386 | return d 387 | 388 | return input_fn 389 | 390 | 391 | def _decode_record(record, name_to_features): 392 | """Decodes a record to a TensorFlow example.""" 393 | example = tf.parse_single_example(record, name_to_features) 394 | 395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 396 | # So cast all int64 to int32. 397 | for name in list(example.keys()): 398 | t = example[name] 399 | if t.dtype == tf.int64: 400 | t = tf.to_int32(t) 401 | example[name] = t 402 | 403 | return example 404 | 405 | 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | if not FLAGS.do_train and not FLAGS.do_eval: 410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 411 | 412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 413 | 414 | tf.gfile.MakeDirs(FLAGS.output_dir) 415 | 416 | input_files = [] 417 | for input_pattern in FLAGS.input_file.split(","): 418 | input_files.extend(tf.gfile.Glob(input_pattern)) 419 | 420 | tf.logging.info("*** Input Files ***") 421 | for input_file in input_files: 422 | tf.logging.info(" %s" % input_file) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | model_fn = model_fn_builder( 441 | bert_config=bert_config, 442 | init_checkpoint=FLAGS.init_checkpoint, 443 | learning_rate=FLAGS.learning_rate, 444 | num_train_steps=FLAGS.num_train_steps, 445 | num_warmup_steps=FLAGS.num_warmup_steps, 446 | use_tpu=FLAGS.use_tpu, 447 | use_one_hot_embeddings=FLAGS.use_tpu) 448 | 449 | # If TPU is not available, this will fall back to normal Estimator on CPU 450 | # or GPU. 451 | estimator = tf.contrib.tpu.TPUEstimator( 452 | use_tpu=FLAGS.use_tpu, 453 | model_fn=model_fn, 454 | config=run_config, 455 | train_batch_size=FLAGS.train_batch_size, 456 | eval_batch_size=FLAGS.eval_batch_size) 457 | 458 | if FLAGS.do_train: 459 | tf.logging.info("***** Running training *****") 460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 461 | train_input_fn = input_fn_builder( 462 | input_files=input_files, 463 | max_seq_length=FLAGS.max_seq_length, 464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 465 | is_training=True) 466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 467 | 468 | if FLAGS.do_eval: 469 | tf.logging.info("***** Running evaluation *****") 470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 471 | 472 | eval_input_fn = input_fn_builder( 473 | input_files=input_files, 474 | max_seq_length=FLAGS.max_seq_length, 475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 476 | is_training=False) 477 | 478 | result = estimator.evaluate( 479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 480 | 481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 482 | with tf.gfile.GFile(output_eval_file, "w") as writer: 483 | tf.logging.info("***** Eval results *****") 484 | for key in sorted(result.keys()): 485 | tf.logging.info(" %s = %s", key, str(result[key])) 486 | writer.write("%s = %s\n" % (key, str(result[key]))) 487 | 488 | 489 | if __name__ == "__main__": 490 | flags.mark_flag_as_required("input_file") 491 | flags.mark_flag_as_required("bert_config_file") 492 | flags.mark_flag_as_required("output_dir") 493 | tf.app.run() 494 | -------------------------------------------------------------------------------- /bert/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | 140 | Args: 141 | do_lower_case: Whether to lower case the input. 142 | """ 143 | self.do_lower_case = do_lower_case 144 | 145 | def tokenize(self, text): 146 | """Tokenizes a piece of text.""" 147 | text = convert_to_unicode(text) 148 | text = self._clean_text(text) 149 | 150 | # This was added on November 1st, 2018 for the multilingual and Chinese 151 | # models. This is also applied to the English models now, but it doesn't 152 | # matter since the English models were not trained on any Chinese data 153 | # and generally don't have any Chinese data in them (there are Chinese 154 | # characters in the vocabulary because Wikipedia does have some Chinese 155 | # words in the English Wikipedia.). 156 | text = self._tokenize_chinese_chars(text) 157 | 158 | orig_tokens = whitespace_tokenize(text) 159 | split_tokens = [] 160 | for token in orig_tokens: 161 | if self.do_lower_case: 162 | token = token.lower() 163 | token = self._run_strip_accents(token) 164 | split_tokens.extend(self._run_split_on_punc(token)) 165 | 166 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 167 | return output_tokens 168 | 169 | def _run_strip_accents(self, text): 170 | """Strips accents from a piece of text.""" 171 | text = unicodedata.normalize("NFD", text) 172 | output = [] 173 | for char in text: 174 | cat = unicodedata.category(char) 175 | if cat == "Mn": 176 | continue 177 | output.append(char) 178 | return "".join(output) 179 | 180 | def _run_split_on_punc(self, text): 181 | """Splits punctuation on a piece of text.""" 182 | chars = list(text) 183 | i = 0 184 | start_new_word = True 185 | output = [] 186 | while i < len(chars): 187 | char = chars[i] 188 | if _is_punctuation(char): 189 | output.append([char]) 190 | start_new_word = True 191 | else: 192 | if start_new_word: 193 | output.append([]) 194 | start_new_word = False 195 | output[-1].append(char) 196 | i += 1 197 | 198 | return ["".join(x) for x in output] 199 | 200 | def _tokenize_chinese_chars(self, text): 201 | """Adds whitespace around any CJK character.""" 202 | output = [] 203 | for char in text: 204 | cp = ord(char) 205 | if self._is_chinese_char(cp): 206 | output.append(" ") 207 | output.append(char) 208 | output.append(" ") 209 | else: 210 | output.append(char) 211 | return "".join(output) 212 | 213 | def _is_chinese_char(self, cp): 214 | """Checks whether CP is the codepoint of a CJK character.""" 215 | # This defines a "chinese character" as anything in the CJK Unicode block: 216 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 217 | # 218 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 219 | # despite its name. The modern Korean Hangul alphabet is a different block, 220 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 221 | # space-separated words, so they are not treated specially and handled 222 | # like the all of the other languages. 223 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 224 | (cp >= 0x3400 and cp <= 0x4DBF) or # 225 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 226 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 227 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 228 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 229 | (cp >= 0xF900 and cp <= 0xFAFF) or # 230 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 231 | return True 232 | 233 | return False 234 | 235 | def _clean_text(self, text): 236 | """Performs invalid character removal and whitespace cleanup on text.""" 237 | output = [] 238 | for char in text: 239 | cp = ord(char) 240 | if cp == 0 or cp == 0xfffd or _is_control(char): 241 | continue 242 | if _is_whitespace(char): 243 | output.append(" ") 244 | else: 245 | output.append(char) 246 | return "".join(output) 247 | 248 | 249 | class WordpieceTokenizer(object): 250 | """Runs WordPiece tokenziation.""" 251 | 252 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 253 | self.vocab = vocab 254 | self.unk_token = unk_token 255 | self.max_input_chars_per_word = max_input_chars_per_word 256 | 257 | def tokenize(self, text): 258 | """Tokenizes a piece of text into its word pieces. 259 | 260 | This uses a greedy longest-match-first algorithm to perform tokenization 261 | using the given vocabulary. 262 | 263 | For example: 264 | input = "unaffable" 265 | output = ["un", "##aff", "##able"] 266 | 267 | Args: 268 | text: A single token or whitespace separated tokens. This should have 269 | already been passed through `BasicTokenizer. 270 | 271 | Returns: 272 | A list of wordpiece tokens. 273 | """ 274 | 275 | text = convert_to_unicode(text) 276 | 277 | output_tokens = [] 278 | for token in whitespace_tokenize(text): 279 | chars = list(token) 280 | if len(chars) > self.max_input_chars_per_word: 281 | output_tokens.append(self.unk_token) 282 | continue 283 | 284 | is_bad = False 285 | start = 0 286 | sub_tokens = [] 287 | while start < len(chars): 288 | end = len(chars) 289 | cur_substr = None 290 | while start < end: 291 | substr = "".join(chars[start:end]) 292 | if start > 0: 293 | substr = "##" + substr 294 | if substr in self.vocab: 295 | cur_substr = substr 296 | break 297 | end -= 1 298 | if cur_substr is None: 299 | is_bad = True 300 | break 301 | sub_tokens.append(cur_substr) 302 | start = end 303 | 304 | if is_bad: 305 | output_tokens.append(self.unk_token) 306 | else: 307 | output_tokens.extend(sub_tokens) 308 | return output_tokens 309 | 310 | 311 | def _is_whitespace(char): 312 | """Checks whether `chars` is a whitespace character.""" 313 | # \t, \n, and \r are technically contorl characters but we treat them 314 | # as whitespace since they are generally considered as such. 315 | if char == " " or char == "\t" or char == "\n" or char == "\r": 316 | return True 317 | cat = unicodedata.category(char) 318 | if cat == "Zs": 319 | return True 320 | return False 321 | 322 | 323 | def _is_control(char): 324 | """Checks whether `chars` is a control character.""" 325 | # These are technically control characters but we count them as whitespace 326 | # characters. 327 | if char == "\t" or char == "\n" or char == "\r": 328 | return False 329 | cat = unicodedata.category(char) 330 | if cat.startswith("C"): 331 | return True 332 | return False 333 | 334 | 335 | def _is_punctuation(char): 336 | """Checks whether `chars` is a punctuation character.""" 337 | cp = ord(char) 338 | # We treat all non-letter/number ASCII as punctuation. 339 | # Characters such as "^", "$", and "`" are not in the Unicode 340 | # Punctuation class but we treat them as punctuation anyways, for 341 | # consistency. 342 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 343 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 344 | return True 345 | cat = unicodedata.category(char) 346 | if cat.startswith("P"): 347 | return True 348 | return False 349 | -------------------------------------------------------------------------------- /bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | 22 | import tokenization 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | 36 | vocab_file = vocab_writer.name 37 | 38 | tokenizer = tokenization.FullTokenizer(vocab_file) 39 | os.unlink(vocab_file) 40 | 41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 42 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 43 | 44 | self.assertAllEqual( 45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 46 | 47 | def test_chinese(self): 48 | tokenizer = tokenization.BasicTokenizer() 49 | 50 | self.assertAllEqual( 51 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 52 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 53 | 54 | def test_basic_tokenizer_lower(self): 55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 56 | 57 | self.assertAllEqual( 58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 59 | ["hello", "!", "how", "are", "you", "?"]) 60 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 61 | 62 | def test_basic_tokenizer_no_lower(self): 63 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 64 | 65 | self.assertAllEqual( 66 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 67 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 68 | 69 | def test_wordpiece_tokenizer(self): 70 | vocab_tokens = [ 71 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 72 | "##ing" 73 | ] 74 | 75 | vocab = {} 76 | for (i, token) in enumerate(vocab_tokens): 77 | vocab[token] = i 78 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 79 | 80 | self.assertAllEqual(tokenizer.tokenize(""), []) 81 | 82 | self.assertAllEqual( 83 | tokenizer.tokenize("unwanted running"), 84 | ["un", "##want", "##ed", "runn", "##ing"]) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 88 | 89 | def test_convert_tokens_to_ids(self): 90 | vocab_tokens = [ 91 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 92 | "##ing" 93 | ] 94 | 95 | vocab = {} 96 | for (i, token) in enumerate(vocab_tokens): 97 | vocab[token] = i 98 | 99 | self.assertAllEqual( 100 | tokenization.convert_tokens_to_ids( 101 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 102 | 103 | def test_is_whitespace(self): 104 | self.assertTrue(tokenization._is_whitespace(u" ")) 105 | self.assertTrue(tokenization._is_whitespace(u"\t")) 106 | self.assertTrue(tokenization._is_whitespace(u"\r")) 107 | self.assertTrue(tokenization._is_whitespace(u"\n")) 108 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 109 | 110 | self.assertFalse(tokenization._is_whitespace(u"A")) 111 | self.assertFalse(tokenization._is_whitespace(u"-")) 112 | 113 | def test_is_control(self): 114 | self.assertTrue(tokenization._is_control(u"\u0005")) 115 | 116 | self.assertFalse(tokenization._is_control(u"A")) 117 | self.assertFalse(tokenization._is_control(u" ")) 118 | self.assertFalse(tokenization._is_control(u"\t")) 119 | self.assertFalse(tokenization._is_control(u"\r")) 120 | 121 | def test_is_punctuation(self): 122 | self.assertTrue(tokenization._is_punctuation(u"-")) 123 | self.assertTrue(tokenization._is_punctuation(u"$")) 124 | self.assertTrue(tokenization._is_punctuation(u"`")) 125 | self.assertTrue(tokenization._is_punctuation(u".")) 126 | 127 | self.assertFalse(tokenization._is_punctuation(u"A")) 128 | self.assertFalse(tokenization._is_punctuation(u" ")) 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.test.main() 133 | -------------------------------------------------------------------------------- /conlleval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | # conlleval: evaluate result of processing CoNLL-2000 shared task 3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file 4 | # README: http://cnts.uia.ac.be/conll2000/chunking/output.html 5 | # options: l: generate LaTeX output for tables like in 6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 7 | # r: accept raw result tags (without B- and I- prefix; 8 | # assumes one word per chunk) 9 | # d: alternative delimiter tag (default is single space) 10 | # o: alternative outside tag (default is O) 11 | # note: the file should contain lines with items separated 12 | # by $delimiter characters (default space). The final 13 | # two items should contain the correct tag and the 14 | # guessed tag in that order. Sentences should be 15 | # separated from each other by empty lines or lines 16 | # with $boundary fields (default -X-). 17 | # url: http://lcg-www.uia.ac.be/conll2000/chunking/ 18 | # started: 1998-09-25 19 | # version: 2004-01-26 20 | # author: Erik Tjong Kim Sang 21 | 22 | use strict; 23 | 24 | my $false = 0; 25 | my $true = 42; 26 | 27 | my $boundary = "-X-"; # sentence boundary 28 | my $correct; # current corpus chunk tag (I,O,B) 29 | my $correctChunk = 0; # number of correctly identified chunks 30 | my $correctTags = 0; # number of correct chunk tags 31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.) 32 | my $delimiter = " "; # field delimiter 33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979) 34 | my $firstItem; # first feature (for sentence boundary checks) 35 | my $foundCorrect = 0; # number of chunks in corpus 36 | my $foundGuessed = 0; # number of identified chunks 37 | my $guessed; # current guessed chunk tag 38 | my $guessedType; # type of current guessed chunk tag 39 | my $i; # miscellaneous counter 40 | my $inCorrect = $false; # currently processed chunk is correct until now 41 | my $lastCorrect = "O"; # previous chunk tag in corpus 42 | my $latex = 0; # generate LaTeX formatted output 43 | my $lastCorrectType = ""; # type of previously identified chunk tag 44 | my $lastGuessed = "O"; # previously identified chunk tag 45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus 46 | my $lastType; # temporary storage for detecting duplicates 47 | my $line; # line 48 | my $nbrOfFeatures = -1; # number of features per line 49 | my $precision = 0.0; # precision score 50 | my $oTag = "O"; # outside tag, default O 51 | my $raw = 0; # raw input: add B to every token 52 | my $recall = 0.0; # recall score 53 | my $tokenCounter = 0; # token counter (ignores sentence breaks) 54 | 55 | my %correctChunk = (); # number of correctly identified chunks per type 56 | my %foundCorrect = (); # number of chunks in corpus per type 57 | my %foundGuessed = (); # number of identified chunks per type 58 | 59 | my @features; # features on line 60 | my @sortedTypes; # sorted list of chunk type names 61 | 62 | # sanity check 63 | while (@ARGV and $ARGV[0] =~ /^-/) { 64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); } 65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); } 66 | elsif ($ARGV[0] eq "-d") { 67 | shift(@ARGV); 68 | if (not defined $ARGV[0]) { 69 | die "conlleval: -d requires delimiter character"; 70 | } 71 | $delimiter = shift(@ARGV); 72 | } elsif ($ARGV[0] eq "-o") { 73 | shift(@ARGV); 74 | if (not defined $ARGV[0]) { 75 | die "conlleval: -o requires delimiter character"; 76 | } 77 | $oTag = shift(@ARGV); 78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; } 79 | } 80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; } 81 | # process input 82 | while () { 83 | chomp($line = $_); 84 | @features = split(/$delimiter/,$line); 85 | # @features = split(/\t/,$line); 86 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; } 87 | elsif ($nbrOfFeatures != $#features and @features != 0) { 88 | printf STDERR "unexpected number of features: %d (%d)\n", 89 | $#features+1,$nbrOfFeatures+1; 90 | exit(1); 91 | } 92 | if (@features == 0 or 93 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); } 94 | if (@features < 2) { 95 | printf STDERR "feature length is %d. \n", @features; 96 | die "conlleval: unexpected number of features in line $line\n"; 97 | } 98 | if ($raw) { 99 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; } 100 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; } 101 | if ($features[$#features] ne "O") { 102 | $features[$#features] = "B-$features[$#features]"; 103 | } 104 | if ($features[$#features-1] ne "O") { 105 | $features[$#features-1] = "B-$features[$#features-1]"; 106 | } 107 | } 108 | # 20040126 ET code which allows hyphens in the types 109 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 110 | $guessed = $1; 111 | $guessedType = $2; 112 | } else { 113 | $guessed = $features[$#features]; 114 | $guessedType = ""; 115 | } 116 | pop(@features); 117 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) { 118 | $correct = $1; 119 | $correctType = $2; 120 | } else { 121 | $correct = $features[$#features]; 122 | $correctType = ""; 123 | } 124 | pop(@features); 125 | # ($guessed,$guessedType) = split(/-/,pop(@features)); 126 | # ($correct,$correctType) = split(/-/,pop(@features)); 127 | $guessedType = $guessedType ? $guessedType : ""; 128 | $correctType = $correctType ? $correctType : ""; 129 | $firstItem = shift(@features); 130 | 131 | # 1999-06-26 sentence breaks should always be counted as out of chunk 132 | if ( $firstItem eq $boundary ) { $guessed = "O"; } 133 | 134 | if ($inCorrect) { 135 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 136 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 137 | $lastGuessedType eq $lastCorrectType) { 138 | $inCorrect=$false; 139 | $correctChunk++; 140 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 141 | $correctChunk{$lastCorrectType}+1 : 1; 142 | } elsif ( 143 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) != 144 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or 145 | $guessedType ne $correctType ) { 146 | $inCorrect=$false; 147 | } 148 | } 149 | 150 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and 151 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and 152 | $guessedType eq $correctType) { $inCorrect = $true; } 153 | 154 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) { 155 | $foundCorrect++; 156 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ? 157 | $foundCorrect{$correctType}+1 : 1; 158 | } 159 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) { 160 | $foundGuessed++; 161 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ? 162 | $foundGuessed{$guessedType}+1 : 1; 163 | } 164 | if ( $firstItem ne $boundary ) { 165 | if ( $correct eq $guessed and $guessedType eq $correctType ) { 166 | $correctTags++; 167 | } 168 | $tokenCounter++; 169 | } 170 | 171 | $lastGuessed = $guessed; 172 | $lastCorrect = $correct; 173 | $lastGuessedType = $guessedType; 174 | $lastCorrectType = $correctType; 175 | } 176 | if ($inCorrect) { 177 | $correctChunk++; 178 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ? 179 | $correctChunk{$lastCorrectType}+1 : 1; 180 | } 181 | 182 | if (not $latex) { 183 | # compute overall precision, recall and FB1 (default values are 0.0) 184 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 185 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 186 | $FB1 = 2*$precision*$recall/($precision+$recall) 187 | if ($precision+$recall > 0); 188 | 189 | # print overall performance 190 | printf "processed $tokenCounter tokens with $foundCorrect phrases; "; 191 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n"; 192 | if ($tokenCounter>0) { 193 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter; 194 | printf "precision: %6.2f%%; ",$precision; 195 | printf "recall: %6.2f%%; ",$recall; 196 | printf "FB1: %6.2f\n",$FB1; 197 | } 198 | } 199 | 200 | # sort chunk type names 201 | undef($lastType); 202 | @sortedTypes = (); 203 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) { 204 | if (not($lastType) or $lastType ne $i) { 205 | push(@sortedTypes,($i)); 206 | } 207 | $lastType = $i; 208 | } 209 | # print performance per chunk type 210 | if (not $latex) { 211 | for $i (@sortedTypes) { 212 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 213 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; } 214 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 215 | if (not($foundCorrect{$i})) { $recall = 0.0; } 216 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 217 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 218 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 219 | printf "%17s: ",$i; 220 | printf "precision: %6.2f%%; ",$precision; 221 | printf "recall: %6.2f%%; ",$recall; 222 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i}; 223 | } 224 | } else { 225 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline"; 226 | for $i (@sortedTypes) { 227 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0; 228 | if (not($foundGuessed{$i})) { $precision = 0.0; } 229 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; } 230 | if (not($foundCorrect{$i})) { $recall = 0.0; } 231 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; } 232 | if ($precision+$recall == 0.0) { $FB1 = 0.0; } 233 | else { $FB1 = 2*$precision*$recall/($precision+$recall); } 234 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\", 235 | $i,$precision,$recall,$FB1; 236 | } 237 | print "\\hline\n"; 238 | $precision = 0.0; 239 | $recall = 0; 240 | $FB1 = 0.0; 241 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0); 242 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0); 243 | $FB1 = 2*$precision*$recall/($precision+$recall) 244 | if ($precision+$recall > 0); 245 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n", 246 | $precision,$recall,$FB1; 247 | } 248 | 249 | exit 0; 250 | 251 | # endOfChunk: checks if a chunk ended between the previous and current word 252 | # arguments: previous and current chunk tags, previous and current types 253 | # note: this code is capable of handling other chunk representations 254 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 255 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 256 | 257 | sub endOfChunk { 258 | my $prevTag = shift(@_); 259 | my $tag = shift(@_); 260 | my $prevType = shift(@_); 261 | my $type = shift(@_); 262 | my $chunkEnd = $false; 263 | 264 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; } 265 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; } 266 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; } 267 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 268 | 269 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; } 270 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; } 271 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; } 272 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; } 273 | 274 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) { 275 | $chunkEnd = $true; 276 | } 277 | 278 | # corrected 1998-12-22: these chunks are assumed to have length 1 279 | if ( $prevTag eq "]" ) { $chunkEnd = $true; } 280 | if ( $prevTag eq "[" ) { $chunkEnd = $true; } 281 | 282 | return($chunkEnd); 283 | } 284 | 285 | # startOfChunk: checks if a chunk started between the previous and current word 286 | # arguments: previous and current chunk tags, previous and current types 287 | # note: this code is capable of handling other chunk representations 288 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong 289 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006 290 | 291 | sub startOfChunk { 292 | my $prevTag = shift(@_); 293 | my $tag = shift(@_); 294 | my $prevType = shift(@_); 295 | my $type = shift(@_); 296 | my $chunkStart = $false; 297 | 298 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; } 299 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; } 300 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; } 301 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 302 | 303 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; } 304 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; } 305 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; } 306 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; } 307 | 308 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) { 309 | $chunkStart = $true; 310 | } 311 | 312 | # corrected 1998-12-22: these chunks are assumed to have length 1 313 | if ( $tag eq "[" ) { $chunkStart = $true; } 314 | if ( $tag eq "]" ) { $chunkStart = $true; } 315 | 316 | return($chunkStart); 317 | } -------------------------------------------------------------------------------- /conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | # add function :evaluate(predicted_label, ori_label): which will not read from file 13 | 14 | import sys 15 | import re 16 | import codecs 17 | from collections import defaultdict, namedtuple 18 | 19 | ANY_SPACE = '' 20 | 21 | 22 | class FormatError(Exception): 23 | pass 24 | 25 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 26 | 27 | 28 | class EvalCounts(object): 29 | def __init__(self): 30 | self.correct_chunk = 0 # number of correctly identified chunks 31 | self.correct_tags = 0 # number of correct chunk tags 32 | self.found_correct = 0 # number of chunks in corpus 33 | self.found_guessed = 0 # number of identified chunks 34 | self.token_counter = 0 # token counter (ignores sentence breaks) 35 | 36 | # counts by type 37 | self.t_correct_chunk = defaultdict(int) 38 | self.t_found_correct = defaultdict(int) 39 | self.t_found_guessed = defaultdict(int) 40 | 41 | 42 | def parse_args(argv): 43 | import argparse 44 | parser = argparse.ArgumentParser( 45 | description='evaluate tagging results using CoNLL criteria', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 47 | ) 48 | arg = parser.add_argument 49 | arg('-b', '--boundary', metavar='STR', default='-X-', 50 | help='sentence boundary') 51 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 52 | help='character delimiting items in input') 53 | arg('-o', '--otag', metavar='CHAR', default='O', 54 | help='alternative outside tag') 55 | arg('file', nargs='?', default=None) 56 | return parser.parse_args(argv) 57 | 58 | 59 | def parse_tag(t): 60 | m = re.match(r'^([^-]*)-(.*)$', t) 61 | return m.groups() if m else (t, '') 62 | 63 | 64 | def evaluate(iterable, options=None): 65 | if options is None: 66 | options = parse_args([]) # use defaults 67 | 68 | counts = EvalCounts() 69 | num_features = None # number of features per line 70 | in_correct = False # currently processed chunks is correct until now 71 | last_correct = 'O' # previous chunk tag in corpus 72 | last_correct_type = '' # type of previously identified chunk tag 73 | last_guessed = 'O' # previously identified chunk tag 74 | last_guessed_type = '' # type of previous chunk tag in corpus 75 | 76 | for line in iterable: 77 | line = line.rstrip('\r\n') 78 | 79 | if options.delimiter == ANY_SPACE: 80 | features = line.split() 81 | else: 82 | features = line.split(options.delimiter) 83 | 84 | if num_features is None: 85 | num_features = len(features) 86 | elif num_features != len(features) and len(features) != 0: 87 | raise FormatError('unexpected number of features: %d (%d)' % 88 | (len(features), num_features)) 89 | 90 | if len(features) == 0 or features[0] == options.boundary: 91 | features = [options.boundary, 'O', 'O'] 92 | if len(features) < 3: 93 | raise FormatError('unexpected number of features in line %s' % line) 94 | 95 | guessed, guessed_type = parse_tag(features.pop()) 96 | correct, correct_type = parse_tag(features.pop()) 97 | first_item = features.pop(0) 98 | 99 | if first_item == options.boundary: 100 | guessed = 'O' 101 | 102 | end_correct = end_of_chunk(last_correct, correct, 103 | last_correct_type, correct_type) 104 | end_guessed = end_of_chunk(last_guessed, guessed, 105 | last_guessed_type, guessed_type) 106 | start_correct = start_of_chunk(last_correct, correct, 107 | last_correct_type, correct_type) 108 | start_guessed = start_of_chunk(last_guessed, guessed, 109 | last_guessed_type, guessed_type) 110 | 111 | if in_correct: 112 | if (end_correct and end_guessed and 113 | last_guessed_type == last_correct_type): 114 | in_correct = False 115 | counts.correct_chunk += 1 116 | counts.t_correct_chunk[last_correct_type] += 1 117 | elif (end_correct != end_guessed or guessed_type != correct_type): 118 | in_correct = False 119 | 120 | if start_correct and start_guessed and guessed_type == correct_type: 121 | in_correct = True 122 | 123 | if start_correct: 124 | counts.found_correct += 1 125 | counts.t_found_correct[correct_type] += 1 126 | if start_guessed: 127 | counts.found_guessed += 1 128 | counts.t_found_guessed[guessed_type] += 1 129 | if first_item != options.boundary: 130 | if correct == guessed and guessed_type == correct_type: 131 | counts.correct_tags += 1 132 | counts.token_counter += 1 133 | 134 | last_guessed = guessed 135 | last_correct = correct 136 | last_guessed_type = guessed_type 137 | last_correct_type = correct_type 138 | 139 | if in_correct: 140 | counts.correct_chunk += 1 141 | counts.t_correct_chunk[last_correct_type] += 1 142 | 143 | return counts 144 | 145 | 146 | 147 | def uniq(iterable): 148 | seen = set() 149 | return [i for i in iterable if not (i in seen or seen.add(i))] 150 | 151 | 152 | def calculate_metrics(correct, guessed, total): 153 | tp, fp, fn = correct, guessed-correct, total-correct 154 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 155 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 156 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 157 | return Metrics(tp, fp, fn, p, r, f) 158 | 159 | 160 | def metrics(counts): 161 | c = counts 162 | overall = calculate_metrics( 163 | c.correct_chunk, c.found_guessed, c.found_correct 164 | ) 165 | by_type = {} 166 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 167 | by_type[t] = calculate_metrics( 168 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 169 | ) 170 | return overall, by_type 171 | 172 | 173 | def report(counts, out=None): 174 | if out is None: 175 | out = sys.stdout 176 | 177 | overall, by_type = metrics(counts) 178 | 179 | c = counts 180 | out.write('processed %d tokens with %d phrases; ' % 181 | (c.token_counter, c.found_correct)) 182 | out.write('found: %d phrases; correct: %d.\n' % 183 | (c.found_guessed, c.correct_chunk)) 184 | 185 | if c.token_counter > 0: 186 | out.write('accuracy: %6.2f%%; ' % 187 | (100.*c.correct_tags/c.token_counter)) 188 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 189 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 190 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 191 | 192 | for i, m in sorted(by_type.items()): 193 | out.write('%17s: ' % i) 194 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 195 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 196 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 197 | 198 | 199 | def report_notprint(counts, out=None): 200 | if out is None: 201 | out = sys.stdout 202 | 203 | overall, by_type = metrics(counts) 204 | 205 | c = counts 206 | final_report = [] 207 | line = [] 208 | line.append('processed %d tokens with %d phrases; ' % 209 | (c.token_counter, c.found_correct)) 210 | line.append('found: %d phrases; correct: %d.\n' % 211 | (c.found_guessed, c.correct_chunk)) 212 | final_report.append("".join(line)) 213 | 214 | if c.token_counter > 0: 215 | line = [] 216 | line.append('accuracy: %6.2f%%; ' % 217 | (100.*c.correct_tags/c.token_counter)) 218 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 219 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 220 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 221 | final_report.append("".join(line)) 222 | 223 | for i, m in sorted(by_type.items()): 224 | line = [] 225 | line.append('%17s: ' % i) 226 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 227 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 228 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 229 | final_report.append("".join(line)) 230 | return final_report 231 | 232 | 233 | def end_of_chunk(prev_tag, tag, prev_type, type_): 234 | # check if a chunk ended between the previous and current word 235 | # arguments: previous and current chunk tags, previous and current types 236 | chunk_end = False 237 | 238 | if prev_tag == 'E': chunk_end = True 239 | if prev_tag == 'S': chunk_end = True 240 | 241 | if prev_tag == 'B' and tag == 'B': chunk_end = True 242 | if prev_tag == 'B' and tag == 'S': chunk_end = True 243 | if prev_tag == 'B' and tag == 'O': chunk_end = True 244 | if prev_tag == 'I' and tag == 'B': chunk_end = True 245 | if prev_tag == 'I' and tag == 'S': chunk_end = True 246 | if prev_tag == 'I' and tag == 'O': chunk_end = True 247 | 248 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 249 | chunk_end = True 250 | 251 | # these chunks are assumed to have length 1 252 | if prev_tag == ']': chunk_end = True 253 | if prev_tag == '[': chunk_end = True 254 | 255 | return chunk_end 256 | 257 | 258 | def start_of_chunk(prev_tag, tag, prev_type, type_): 259 | # check if a chunk started between the previous and current word 260 | # arguments: previous and current chunk tags, previous and current types 261 | chunk_start = False 262 | 263 | if tag == 'B': chunk_start = True 264 | if tag == 'S': chunk_start = True 265 | 266 | if prev_tag == 'E' and tag == 'E': chunk_start = True 267 | if prev_tag == 'E' and tag == 'I': chunk_start = True 268 | if prev_tag == 'S' and tag == 'E': chunk_start = True 269 | if prev_tag == 'S' and tag == 'I': chunk_start = True 270 | if prev_tag == 'O' and tag == 'E': chunk_start = True 271 | if prev_tag == 'O' and tag == 'I': chunk_start = True 272 | 273 | if tag != 'O' and tag != '.' and prev_type != type_: 274 | chunk_start = True 275 | 276 | # these chunks are assumed to have length 1 277 | if tag == '[': chunk_start = True 278 | if tag == ']': chunk_start = True 279 | 280 | return chunk_start 281 | 282 | 283 | def return_report(input_file): 284 | with codecs.open(input_file, "r", "utf8") as f: 285 | counts = evaluate(f) 286 | return report_notprint(counts) 287 | 288 | 289 | def main(argv): 290 | args = parse_args(argv[1:]) 291 | 292 | if args.file is None: 293 | counts = evaluate(sys.stdin, args) 294 | else: 295 | with open(args.file) as f: 296 | counts = evaluate(f, args) 297 | report(counts) 298 | 299 | if __name__ == '__main__': 300 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /global_config.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 import os import logging from logging import handlers os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8' class Logger(object): # https://www.cnblogs.com/nancyzhu/p/8551506.html # 日志级别关系映射 level_relations = { 'debug': logging.DEBUG, 'info': logging.INFO, 'warning': logging.WARNING, 'error': logging.ERROR, 'crit': logging.CRITICAL } def __init__(self, filename, level='info', when='D', backCount=3, fmt='%(asctime)s|%(filename)s[line:%(lineno)d]|%(funcName)s|%(levelname)s|%(message)s'): self.logger = logging.getLogger(filename) # 设置日志格式 format_str = logging.Formatter(fmt) # 设置日志级别 self.logger.setLevel(self.level_relations.get(level)) # 往屏幕上输出 sh = logging.StreamHandler() # 设置屏幕上显示的格式 sh.setFormatter(format_str) # 往文件里写入#指定间隔时间自动生成文件的处理器 th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount, encoding='utf-8') # 实例化TimedRotatingFileHandler # interval是时间间隔,backupCount是备份文件的个数,如果超过这个个数,就会自动删除,when是间隔的时间单位,单位有以下几种: # S 秒 # M 分 # H 小时、 # D 天、 # W 每星期(interval==0时代表星期一) # midnight 每天凌晨 # 设置文件里写入的格式 th.setFormatter(format_str) # 把对象加到logger里 self.logger.addHandler(sh) self.logger.addHandler(th) #loginfo = Logger("recommend_articles.log", "info") if __name__ == '__main__': log = Logger('all.log',level='debug') log.logger.debug('debug') log.logger.info('info') log.logger.warning('警告') log.logger.error('报错') log.logger.critical('严重') Logger('error.log', level='error').logger.error('error') -------------------------------------------------------------------------------- /image/KB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/image/KB.png -------------------------------------------------------------------------------- /image/NER.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenRichard/KBQA-BERT/6e7079ede8979b1179600564ed9ecc58c7a4f877/image/NER.jpg -------------------------------------------------------------------------------- /kbqa_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import io 4 | import re 5 | import time 6 | import jieba 7 | import numpy as np 8 | import pandas as pd 9 | import urllib.request 10 | import urllib.parse 11 | import tensorflow as tf 12 | from Data.load_dbdata import upload_data 13 | from global_config import Logger 14 | 15 | from run_similarity import BertSim 16 | # 模块导入 https://blog.csdn.net/xiongchengluo1129/article/details/80453599 17 | 18 | loginfo = Logger("recommend_articles.log", "info") 19 | file = "./Data/NER_Data/q_t_a_testing_predict.txt" 20 | 21 | bs = BertSim() 22 | bs.set_mode(tf.estimator.ModeKeys.PREDICT) 23 | 24 | 25 | def dataset_test(): 26 | ''' 27 | 用训练问答对中的实体+属性,去知识库中进行问答测试准确率上限 28 | :return: 29 | ''' 30 | with open(file) as f: 31 | total = 0 32 | recall = 0 33 | correct = 0 34 | 35 | for line in f: 36 | question, entity, attribute, answer, ner = line.split("\t") 37 | ner = ner.replace("#", "").replace("[UNK]", "%") 38 | # case1: entity and attribute Exact Match 39 | sql_e1_a1 = "select * from nlpccQA where entity='"+entity+"' and attribute='"+attribute+"' limit 10" 40 | result_e1_a1 = upload_data(sql_e1_a1) 41 | 42 | # case2: entity Fuzzy Match and attribute Exact Match 43 | sql_e0_a1 = "select * from nlpccQA where entity like '%" + entity + "%' and attribute='" + attribute + "' limit 10" 44 | #result_e0_a1 = upload_data(sql_e0_a1, True) 45 | 46 | # case3: entity Exact Match and attribute Fuzzy Match 47 | sql_e1_a0 = "select * from nlpccQA where entity like '" + entity + "' and attribute='%" + attribute + "%' limit 10" 48 | #result_e1_a0 = upload_data(sql_e1_a0) 49 | 50 | if len(result_e1_a1) > 0: 51 | recall += 1 52 | for l in result_e1_a1: 53 | if l[2] == answer: 54 | correct += 1 55 | else: 56 | result_e0_a1 = upload_data(sql_e0_a1) 57 | if len(result_e0_a1) > 0: 58 | recall += 1 59 | for l in result_e0_a1: 60 | if l[2] == answer: 61 | correct += 1 62 | else: 63 | result_e1_a0 = upload_data(sql_e1_a0) 64 | if len(result_e1_a0) > 0: 65 | recall += 1 66 | for l in result_e1_a0: 67 | if l[2] == answer: 68 | correct += 1 69 | else: 70 | loginfo.logger.info(sql_e1_a0) 71 | if total > 100: 72 | break 73 | total += 1 74 | time.sleep(1) 75 | loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%".format(total, recall, correct, correct * 100.0 / recall)) 76 | #loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%".format(total, recall, correct, correct*100.0/recall)) 77 | 78 | 79 | def estimate_answer(candidate, answer): 80 | ''' 81 | :param candidate: 82 | :param answer: 83 | :return: 84 | ''' 85 | candidate = candidate.strip().lower() 86 | answer = answer.strip().lower() 87 | if candidate == answer: 88 | return True 89 | 90 | if not answer.isdigit() and candidate.isdigit(): 91 | candidate_temp = "{:.5E}".format(int(candidate)) 92 | if candidate_temp == answer: 93 | return True 94 | candidate_temp == "{:.4E}".format(int(candidate)) 95 | if candidate_temp == answer: 96 | return True 97 | 98 | return False 99 | 100 | 101 | def kb_fuzzy_classify_test(): 102 | ''' 103 | 进行问答测试: 104 | 1、 实体检索:输入问题,ner得出实体集合,在数据库中检索与输入实体相关的所有三元组 105 | 2、 属性映射——bert分类/文本相似度 106 | + 非语义匹配:如果所得三元组的关系(attribute)属性是 输入问题 字符串的子集,将所得三元组的答案(answer)属性与正确答案匹配,correct +1 107 | + 语义匹配:利用bert计算输入问题(input question)与所得三元组的关系(attribute)属性的相似度,将最相似的三元组的答案作为答案,并与正确 108 | 的答案进行匹配,correct +1 109 | 3、 答案组合 110 | :return: 111 | ''' 112 | with open(file, encoding='utf-8') as f: 113 | total = 0 114 | recall = 0 115 | correct = 0 116 | ambiguity = 0 # 属性匹配正确但是答案不正确 117 | 118 | for line in f: 119 | try: 120 | total += 1 121 | question, entity, attribute, answer, ner = line.split("\t") 122 | ner = ner.replace("#", "").replace("[UNK]", "%").replace("\n", "") 123 | # case: entity Fuzzy Match 124 | # 找出所有包含这些实体的三元组 125 | sql_e0_a1 = "select * from nlpccQA where entity like '%" + ner + "%' order by length(entity) asc limit 20" 126 | # sql查出来的是tuple,要转换成list才不会报错 127 | result_e0_a1 = list(upload_data(sql_e0_a1)) 128 | 129 | if len(result_e0_a1) > 0: 130 | recall += 1 131 | 132 | flag_fuzzy = True 133 | # 非语义匹配,加快速度 134 | # l1[0]: entity 135 | # l1[1]: attribute 136 | # l1[2]: answer 137 | flag_ambiguity = True 138 | for l in result_e0_a1: 139 | if l[1] in question or l[1].lower() in question or l[1].upper() in question: 140 | flag_fuzzy = False 141 | 142 | if estimate_answer(l[2], answer): 143 | correct += 1 144 | flag_ambiguity = False 145 | else: 146 | loginfo.logger.info("\t".join(l)) 147 | 148 | # 非语义匹配成功,继续下一次 149 | if not flag_fuzzy: 150 | 151 | if flag_ambiguity: 152 | ambiguity += 1 153 | 154 | time.sleep(1) 155 | loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%, ambiguity:{}".format(total, recall, correct, correct * 100.0 / recall, ambiguity)) 156 | continue 157 | 158 | # 语义匹配 159 | result_df = pd.DataFrame(result_e0_a1, columns=['entity', 'attribute', 'value']) 160 | # loginfo.logger.info(result_df.head(100)) 161 | 162 | attribute_candicate_sim = [(k, bs.predict(question, k)[0][1]) for k in result_df['attribute'].tolist()] 163 | attribute_candicate_sort = sorted(attribute_candicate_sim, key=lambda candicate: candicate[1], reverse=True) 164 | loginfo.logger.info("\n".join([str(k)+" "+str(v) for (k, v) in attribute_candicate_sort])) 165 | 166 | answer_candicate_df = result_df[result_df["attribute"] == attribute_candicate_sort[0][0]] 167 | for row in answer_candicate_df.index: 168 | if estimate_answer(answer_candicate_df.loc[row, "value"], answer): 169 | correct += 1 170 | else: 171 | loginfo.logger.info("\t".join(answer_candicate_df.loc[row].tolist())) 172 | time.sleep(1) 173 | loginfo.logger.info("total: {}, recall: {}, correct:{}, accuracy: {}%, ambiguity:{}".format(total, recall, correct, correct * 100.0 / recall, ambiguity)) 174 | except Exception as e: 175 | loginfo.logger.info("the question id % d occur error %s" % (total, repr(e))) 176 | 177 | 178 | if __name__ == '__main__': 179 | kb_fuzzy_classify_test() -------------------------------------------------------------------------------- /lstm_crf_layer.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | bert-blstm-crf layer 5 | @Author:Macan 6 | """ 7 | 8 | import tensorflow as tf 9 | from tensorflow.contrib import rnn 10 | from tensorflow.contrib import crf 11 | 12 | class BLSTM_CRF(object): 13 | def __init__(self, embedded_chars, hidden_unit, cell_type, num_layers, dropout_rate, 14 | initializers, num_labels, seq_length, labels, lengths, is_training): 15 | """ 16 | BLSTM-CRF 网络 17 | :param embedded_chars: Fine-tuning embedding input 18 | :param hidden_unit: LSTM的隐含单元个数 19 | :param cell_type: RNN类型(LSTM OR GRU DICNN will be add in feature) 20 | :param num_layers: RNN的层数 21 | :param droupout_rate: droupout rate 22 | :param initializers: variable init class 23 | :param num_labels: 标签数量 24 | :param seq_length: 序列最大长度 25 | :param labels: 真实标签 26 | :param lengths: [batch_size] 每个batch下序列的真实长度 27 | :param is_training: 是否是训练过程 28 | """ 29 | self.hidden_unit = hidden_unit 30 | self.dropout_rate = dropout_rate 31 | self.cell_type = cell_type 32 | self.num_layers = num_layers 33 | self.embedded_chars = embedded_chars 34 | self.initializers = initializers 35 | self.seq_length = seq_length 36 | self.num_labels = num_labels 37 | self.labels = labels 38 | self.lengths = lengths 39 | self.embedding_dims = embedded_chars.shape[-1].value 40 | self.is_training = is_training 41 | 42 | def add_blstm_crf_layer(self, crf_only): 43 | """ 44 | blstm-crf网络 45 | :return: 46 | """ 47 | if self.is_training: 48 | # lstm input dropout rate i set 0.9 will get best score 49 | self.embedded_chars = tf.nn.dropout(self.embedded_chars, self.dropout_rate) 50 | 51 | if crf_only: 52 | logits = self.project_crf_layer(self.embedded_chars) 53 | else: 54 | #blstm 55 | lstm_output = self.blstm_layer(self.embedded_chars) 56 | #project 57 | logits = self.project_bilstm_layer(lstm_output) 58 | #crf 59 | loss, trans = self.crf_layer(logits) 60 | # CRF decode, pred_ids 是一条最大概率的标注路径 61 | pred_ids, _ = crf.crf_decode(potentials=logits, transition_params=trans, sequence_length=self.lengths) 62 | return ((loss, logits, trans, pred_ids)) 63 | 64 | def _witch_cell(self): 65 | """ 66 | RNN 类型 67 | :return: 68 | """ 69 | cell_tmp = None 70 | if self.cell_type == 'lstm': 71 | cell_tmp = rnn.BasicLSTMCell(self.hidden_unit) 72 | elif self.cell_type == 'gru': 73 | cell_tmp = rnn.GRUCell(self.hidden_unit) 74 | # 是否需要进行dropout 75 | if self.dropout_rate is not None: 76 | cell_tmp = rnn.DropoutWrapper(cell_tmp, output_keep_prob=self.dropout_rate) 77 | return cell_tmp 78 | 79 | def _bi_dir_rnn(self): 80 | """ 81 | 双向RNN 82 | :return: 83 | """ 84 | cell_fw = self._witch_cell() 85 | cell_bw = self._witch_cell() 86 | return cell_fw, cell_bw 87 | 88 | def blstm_layer(self, embedding_chars): 89 | """ 90 | 91 | :return: 92 | """ 93 | with tf.variable_scope('rnn_layer'): 94 | cell_fw, cell_bw = self._bi_dir_rnn() 95 | if self.num_layers > 1: 96 | cell_fw = rnn.MultiRNNCell([cell_fw] * self.num_layers, state_is_tuple=True) 97 | cell_bw = rnn.MultiRNNCell([cell_bw] * self.num_layers, state_is_tuple=True) 98 | 99 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, embedding_chars, 100 | dtype=tf.float32) 101 | outputs = tf.concat(outputs, axis=2) 102 | return outputs 103 | 104 | def project_bilstm_layer(self, lstm_outputs, name=None): 105 | """ 106 | hidden layer between lstm layer and logits 107 | :param lstm_outputs: [batch_size, num_steps, emb_size] 108 | :return: [batch_size, num_steps, num_tags] 109 | """ 110 | with tf.variable_scope("project" if not name else name): 111 | with tf.variable_scope("hidden"): 112 | W = tf.get_variable("W", shape=[self.hidden_unit * 2, self.hidden_unit], 113 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 114 | 115 | b = tf.get_variable("b", shape=[self.hidden_unit], dtype=tf.float32, 116 | initializer=tf.zeros_initializer()) 117 | output = tf.reshape(lstm_outputs, shape=[-1, self.hidden_unit * 2]) 118 | hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 119 | 120 | # project to score of tags 121 | with tf.variable_scope("logits"): 122 | W = tf.get_variable("W", shape=[self.hidden_unit, self.num_labels], 123 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 124 | 125 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 126 | initializer=tf.zeros_initializer()) 127 | 128 | pred = tf.nn.xw_plus_b(hidden, W, b) 129 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 130 | 131 | def project_crf_layer(self, embedding_chars, name=None): 132 | """ 133 | hidden layer between input layer and logits 134 | :param lstm_outputs: [batch_size, num_steps, emb_size] 135 | :return: [batch_size, num_steps, num_tags] 136 | """ 137 | with tf.variable_scope("project" if not name else name): 138 | with tf.variable_scope("logits"): 139 | W = tf.get_variable("W", shape=[self.embedding_dims, self.num_labels], 140 | dtype=tf.float32, initializer=self.initializers.xavier_initializer()) 141 | 142 | b = tf.get_variable("b", shape=[self.num_labels], dtype=tf.float32, 143 | initializer=tf.zeros_initializer()) 144 | output = tf.reshape(self.embedded_chars, shape=[-1, self.embedding_dims]) #[batch_size, embedding_dims] 145 | pred = tf.tanh(tf.nn.xw_plus_b(output, W, b)) 146 | return tf.reshape(pred, [-1, self.seq_length, self.num_labels]) 147 | 148 | def crf_layer(self, logits): 149 | """ 150 | calculate crf loss 151 | :param project_logits: [1, num_steps, num_tags] 152 | :return: scalar loss 153 | """ 154 | with tf.variable_scope("crf_loss"): 155 | trans = tf.get_variable( 156 | "transitions", 157 | shape=[self.num_labels, self.num_labels], 158 | initializer=self.initializers.xavier_initializer()) 159 | log_likelihood, trans = tf.contrib.crf.crf_log_likelihood( 160 | inputs=logits, 161 | tag_indices=self.labels, 162 | transition_params=trans, 163 | sequence_lengths=self.lengths) 164 | return tf.reduce_mean(-log_likelihood), trans -------------------------------------------------------------------------------- /run_ner.sh: -------------------------------------------------------------------------------- 1 | python run_ner.py \ 2 | --task_name=ner \ 3 | --data_dir=./Data/NER_Data \ 4 | --vocab_file=./ModelParams/chinese_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file=./ModelParams/chinese_L-12_H-768_A-12/bert_config.json \ 6 | --output_dir=./Output/NER \ 7 | --init_checkpoint=./ModelParams/chinese_L-12_H-768_A-12/bert_model.ckpt \ 8 | --data_config_path=./Config/NER/ner_data.conf \ 9 | --do_train=True \ 10 | --do_eval=True \ 11 | --max_seq_length=128 \ 12 | --lstm_size=128 \ 13 | --num_layers=1 \ 14 | --train_batch_size=64 \ 15 | --eval_batch_size=8 \ 16 | --predict_batch_size=8 \ 17 | --learning_rate=5e-5 \ 18 | --num_train_epochs=1 \ 19 | --droupout_rate=0.5 \ 20 | --clip=5 21 | -------------------------------------------------------------------------------- /terminal_ner.sh: -------------------------------------------------------------------------------- 1 | python terminal_predict.py \ 2 | --task_name=ner \ 3 | --data_dir=./Data/NER_Data \ 4 | --vocab_file=./ModelParams/chinese_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file=./ModelParams/chinese_L-12_H-768_A-12/bert_config.json \ 6 | --output_dir=./Output/NER \ 7 | --init_checkpoint=./ModelParams/chinese_L-12_H-768_A-12/bert_model.ckpt \ 8 | --data_config_path=./Config/NER/ner_data.conf \ 9 | --do_train=True \ 10 | --do_eval=True \ 11 | --max_seq_length=128 \ 12 | --lstm_size=128 \ 13 | --num_layers=1 \ 14 | --train_batch_size=64 \ 15 | --eval_batch_size=8 \ 16 | --predict_batch_size=8 \ 17 | --learning_rate=5e-5 \ 18 | --num_train_epochs=1 \ 19 | --droupout_rate=0.5 \ 20 | --clip=5 \ 21 | --do_predict_online=True \ 22 | 23 | -------------------------------------------------------------------------------- /terminal_predict.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | """ 4 | 基于命令行的在线预测方法 5 | @Author: Macan (ma_cancan@163.com) 6 | """ 7 | import pandas as pd 8 | import tensorflow as tf 9 | import numpy as np 10 | import codecs 11 | import pickle 12 | import os 13 | from datetime import time, timedelta, datetime 14 | 15 | from run_ner import create_model, InputFeatures, InputExample 16 | from bert import tokenization 17 | from bert import modeling 18 | 19 | 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 21 | 22 | flags = tf.flags 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_bool( 27 | "do_predict_outline", False, 28 | "Whether to do predict outline." 29 | ) 30 | flags.DEFINE_bool( 31 | "do_predict_online", False, 32 | "Whether to do predict online." 33 | ) 34 | 35 | # init mode and session 36 | # move something codes outside of function, so that this code will run only once during online prediction when predict_online is invoked. 37 | is_training=False 38 | use_one_hot_embeddings=False 39 | batch_size=1 40 | 41 | gpu_config = tf.ConfigProto() 42 | gpu_config.gpu_options.allow_growth = True 43 | sess=tf.Session(config=gpu_config) 44 | model=None 45 | 46 | global graph 47 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None 48 | print(FLAGS.output_dir) 49 | print('checkpoint path:{}'.format(os.path.join(FLAGS.output_dir, "checkpoint"))) 50 | if not os.path.exists(os.path.join(FLAGS.output_dir, "checkpoint")): 51 | raise Exception("failed to get checkpoint. going to return ") 52 | 53 | # 加载label->id的词典 54 | with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf: 55 | label2id = pickle.load(rf) 56 | id2label = {value: key for key, value in label2id.items()} 57 | 58 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'rb') as rf: 59 | label_list = pickle.load(rf) 60 | num_labels = len(label_list) + 1 61 | 62 | graph = tf.get_default_graph() 63 | with graph.as_default(): 64 | print("going to restore checkpoint") 65 | #sess.run(tf.global_variables_initializer()) 66 | input_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_ids") 67 | input_mask_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_mask") 68 | label_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="label_ids") 69 | segment_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="segment_ids") 70 | 71 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 72 | (total_loss, logits, trans, pred_ids) = create_model( 73 | bert_config, is_training, input_ids_p, input_mask_p, segment_ids_p, 74 | label_ids_p, num_labels, use_one_hot_embeddings) 75 | 76 | saver = tf.train.Saver() 77 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.output_dir)) 78 | 79 | 80 | tokenizer = tokenization.FullTokenizer( 81 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 82 | 83 | 84 | def predict_online(): 85 | """ 86 | do online prediction. each time make prediction for one instance. 87 | you can change to a batch if you want. 88 | 89 | :param line: a list. element is: [dummy_label,text_a,text_b] 90 | :return: 91 | """ 92 | def convert(line): 93 | feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p') 94 | input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length)) 95 | input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length)) 96 | segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length)) 97 | label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length)) 98 | return input_ids, input_mask, segment_ids, label_ids 99 | 100 | global graph 101 | with graph.as_default(): 102 | print(id2label) 103 | while True: 104 | print('input the test sentence:') 105 | sentence = str(input()) 106 | start = datetime.now() 107 | if len(sentence) < 2: 108 | print(sentence) 109 | continue 110 | sentence = tokenizer.tokenize(sentence) 111 | # print('your input is:{}'.format(sentence)) 112 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 113 | 114 | feed_dict = {input_ids_p: input_ids, 115 | input_mask_p: input_mask, 116 | segment_ids_p:segment_ids, 117 | label_ids_p:label_ids} 118 | # run session get current feed_dict result 119 | pred_ids_result = sess.run([pred_ids], feed_dict) 120 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 121 | print(pred_label_result) 122 | #todo: 组合策略 123 | result = strage_combined_link_org_loc(sentence, pred_label_result[0], True) 124 | print('识别的实体有:{}'.format(''.join(result))) 125 | print('Time used: {} sec'.format((datetime.now() - start).seconds)) 126 | 127 | 128 | def predict_outline(): 129 | """ 130 | do online prediction. each time make prediction for one instance. 131 | you can change to a batch if you want. 132 | 133 | :param line: a list. element is: [dummy_label,text_a,text_b] 134 | :return: 135 | """ 136 | def convert(line): 137 | feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p') 138 | input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length)) 139 | input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length)) 140 | segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length)) 141 | label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length)) 142 | return input_ids, input_mask, segment_ids, label_ids 143 | 144 | global graph 145 | with graph.as_default(): 146 | start = datetime.now() 147 | nlpcc_test_data = pd.read_csv("./Data/NER_Data/q_t_a_df_testing.csv") 148 | correct = 0 149 | test_size = nlpcc_test_data.shape[0] 150 | nlpcc_test_result = [] 151 | 152 | for row in nlpcc_test_data.index: 153 | question = nlpcc_test_data.loc[row,"q_str"] 154 | entity = nlpcc_test_data.loc[row,"t_str"].split("|||")[0].split(">")[1].strip() 155 | attribute = nlpcc_test_data.loc[row, "t_str"].split("|||")[1].strip() 156 | answer = nlpcc_test_data.loc[row, "t_str"].split("|||")[2].strip() 157 | 158 | sentence = str(question) 159 | start = datetime.now() 160 | if len(sentence) < 2: 161 | print(sentence) 162 | continue 163 | sentence = tokenizer.tokenize(sentence) 164 | input_ids, input_mask, segment_ids, label_ids = convert(sentence) 165 | 166 | feed_dict = {input_ids_p: input_ids, 167 | input_mask_p: input_mask, 168 | segment_ids_p:segment_ids, 169 | label_ids_p:label_ids} 170 | # run session get current feed_dict result 171 | pred_ids_result = sess.run([pred_ids], feed_dict) 172 | pred_label_result = convert_id_to_label(pred_ids_result, id2label) 173 | # print(pred_label_result) 174 | #todo: 组合策略 175 | result = strage_combined_link_org_loc(sentence, pred_label_result[0], False) 176 | if entity in result: 177 | correct += 1 178 | nlpcc_test_result.append(question+"\t"+entity+"\t"+attribute+"\t"+answer+"\t"+','.join(result)) 179 | with open("./Data/NER_Data/q_t_a_testing_predict.txt", "w") as f: 180 | f.write("\n".join(nlpcc_test_result)) 181 | print("accuracy: {}%, correct: {}, total: {}".format(correct*100.0/float(test_size), correct, test_size)) 182 | print('Time used: {} sec'.format((datetime.now() - start).seconds)) 183 | 184 | 185 | def convert_id_to_label(pred_ids_result, idx2label): 186 | """ 187 | 将id形式的结果转化为真实序列结果 188 | :param pred_ids_result: 189 | :param idx2label: 190 | :return: 191 | """ 192 | result = [] 193 | for row in range(batch_size): 194 | curr_seq = [] 195 | for ids in pred_ids_result[row][0]: 196 | if ids == 0: 197 | break 198 | curr_label = idx2label[ids] 199 | if curr_label in ['[CLS]', '[SEP]']: 200 | continue 201 | curr_seq.append(curr_label) 202 | result.append(curr_seq) 203 | return result 204 | 205 | 206 | def strage_combined_link_org_loc(tokens, tags, flag): 207 | """ 208 | 组合策略 209 | :param pred_label_result: 210 | :param types: 211 | :return: 212 | """ 213 | def print_output(data, type): 214 | line = [] 215 | for i in data: 216 | line.append(i.word) 217 | print('{}: {}'.format(type, ', '.join(line))) 218 | 219 | def string_output(data): 220 | line = [] 221 | for i in data: 222 | line.append(i.word) 223 | return line 224 | 225 | params = None 226 | eval = Result(params) 227 | if len(tokens) > len(tags): 228 | tokens = tokens[:len(tags)] 229 | person, loc, org = eval.get_result(tokens, tags) 230 | if flag: 231 | if len(loc) != 0: 232 | print_output(loc, 'LOC') 233 | if len(person) != 0: 234 | print_output(person, 'PER') 235 | if len(org) != 0: 236 | print_output(org, 'ORG') 237 | person_list = string_output(person) 238 | person_list.extend(string_output(loc)) 239 | person_list.extend(string_output(org)) 240 | return person_list 241 | 242 | 243 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, mode): 244 | """ 245 | 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 246 | :param ex_index: index 247 | :param example: 一个样本 248 | :param label_list: 标签列表 249 | :param max_seq_length: 250 | :param tokenizer: 251 | :param mode: 252 | :return: 253 | """ 254 | label_map = {} 255 | # 1表示从1开始对label进行index化 256 | for (i, label) in enumerate(label_list, 1): 257 | label_map[label] = i 258 | # 保存label->index 的map 259 | if not os.path.exists(os.path.join(FLAGS.output_dir, 'label2id.pkl')): 260 | with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'wb') as w: 261 | pickle.dump(label_map, w) 262 | 263 | tokens = example 264 | # tokens = tokenizer.tokenize(example.text) 265 | # 序列截断 266 | if len(tokens) >= max_seq_length - 1: 267 | tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 268 | ntokens = [] 269 | segment_ids = [] 270 | label_ids = [] 271 | ntokens.append("[CLS]") # 句子开始设置CLS 标志 272 | segment_ids.append(0) 273 | # append("O") or append("[CLS]") not sure! 274 | label_ids.append(label_map["[CLS]"]) # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 275 | for i, token in enumerate(tokens): 276 | ntokens.append(token) 277 | segment_ids.append(0) 278 | label_ids.append(0) 279 | ntokens.append("[SEP]") # 句尾添加[SEP] 标志 280 | segment_ids.append(0) 281 | # append("O") or append("[SEP]") not sure! 282 | label_ids.append(label_map["[SEP]"]) 283 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 284 | input_mask = [1] * len(input_ids) 285 | 286 | # padding, 使用 287 | while len(input_ids) < max_seq_length: 288 | input_ids.append(0) 289 | input_mask.append(0) 290 | segment_ids.append(0) 291 | # we don't concerned about it! 292 | label_ids.append(0) 293 | ntokens.append("**NULL**") 294 | # label_mask.append(0) 295 | # print(len(input_ids)) 296 | assert len(input_ids) == max_seq_length 297 | assert len(input_mask) == max_seq_length 298 | assert len(segment_ids) == max_seq_length 299 | assert len(label_ids) == max_seq_length 300 | # assert len(label_mask) == max_seq_length 301 | 302 | # 结构化为一个类 303 | feature = InputFeatures( 304 | input_ids=input_ids, 305 | input_mask=input_mask, 306 | segment_ids=segment_ids, 307 | label_ids=label_ids, 308 | # label_mask = label_mask 309 | ) 310 | return feature 311 | 312 | 313 | class Pair(object): 314 | def __init__(self, word, start, end, type, merge=False): 315 | self.__word = word 316 | self.__start = start 317 | self.__end = end 318 | self.__merge = merge 319 | self.__types = type 320 | 321 | @property 322 | def start(self): 323 | return self.__start 324 | @property 325 | def end(self): 326 | return self.__end 327 | @property 328 | def merge(self): 329 | return self.__merge 330 | @property 331 | def word(self): 332 | return self.__word 333 | 334 | @property 335 | def types(self): 336 | return self.__types 337 | @word.setter 338 | def word(self, word): 339 | self.__word = word 340 | @start.setter 341 | def start(self, start): 342 | self.__start = start 343 | @end.setter 344 | def end(self, end): 345 | self.__end = end 346 | @merge.setter 347 | def merge(self, merge): 348 | self.__merge = merge 349 | 350 | @types.setter 351 | def types(self, type): 352 | self.__types = type 353 | 354 | def __str__(self) -> str: 355 | line = [] 356 | line.append('entity:{}'.format(self.__word)) 357 | line.append('start:{}'.format(self.__start)) 358 | line.append('end:{}'.format(self.__end)) 359 | line.append('merge:{}'.format(self.__merge)) 360 | line.append('types:{}'.format(self.__types)) 361 | return '\t'.join(line) 362 | 363 | 364 | class Result(object): 365 | def __init__(self, config): 366 | self.config = config 367 | self.person = [] 368 | self.loc = [] 369 | self.org = [] 370 | self.others = [] 371 | def get_result(self, tokens, tags, config=None): 372 | # 先获取标注结果 373 | self.result_to_json(tokens, tags) 374 | return self.person, self.loc, self.org 375 | 376 | def result_to_json(self, string, tags): 377 | """ 378 | 将模型标注序列和输入序列结合 转化为结果 379 | :param string: 输入序列 380 | :param tags: 标注结果 381 | :return: 382 | """ 383 | item = {"entities": []} 384 | entity_name = "" 385 | entity_start = 0 386 | idx = 0 387 | last_tag = '' 388 | 389 | for char, tag in zip(string, tags): 390 | if tag[0] == "S": 391 | self.append(char, idx, idx+1, tag[2:]) 392 | item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]}) 393 | elif tag[0] == "B": 394 | if entity_name != '': 395 | self.append(entity_name, entity_start, idx, last_tag[2:]) 396 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 397 | entity_name = "" 398 | entity_name += char 399 | entity_start = idx 400 | elif tag[0] == "I": 401 | entity_name += char 402 | elif tag[0] == "O": 403 | if entity_name != '': 404 | self.append(entity_name, entity_start, idx, last_tag[2:]) 405 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 406 | entity_name = "" 407 | else: 408 | entity_name = "" 409 | entity_start = idx 410 | idx += 1 411 | last_tag = tag 412 | if entity_name != '': 413 | self.append(entity_name, entity_start, idx, last_tag[2:]) 414 | item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]}) 415 | return item 416 | 417 | def append(self, word, start, end, tag): 418 | if tag == 'LOC': 419 | self.loc.append(Pair(word, start, end, 'LOC')) 420 | elif tag == 'PER': 421 | self.person.append(Pair(word, start, end, 'PER')) 422 | elif tag == 'ORG': 423 | self.org.append(Pair(word, start, end, 'ORG')) 424 | else: 425 | self.others.append(Pair(word, start, end, tag)) 426 | 427 | 428 | if __name__ == "__main__": 429 | if FLAGS.do_predict_outline: 430 | predict_outline() 431 | if FLAGS.do_predict_online: 432 | predict_online() 433 | 434 | 435 | -------------------------------------------------------------------------------- /tf_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiclass 3 | from: 4 | https://github.com/guillaumegenthial/tf_metrics/blob/master/tf_metrics/__init__.py 5 | 6 | """ 7 | 8 | __author__ = "Guillaume Genthial" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 13 | 14 | 15 | def precision(labels, predictions, num_classes, pos_indices=None, 16 | weights=None, average='micro'): 17 | """Multi-class precision metric for Tensorflow 18 | Parameters 19 | ---------- 20 | labels : Tensor of tf.int32 or tf.int64 21 | The true labels 22 | predictions : Tensor of tf.int32 or tf.int64 23 | The predictions, same shape as labels 24 | num_classes : int 25 | The number of classes 26 | pos_indices : list of int, optional 27 | The indices of the positive classes, default is all 28 | weights : Tensor of tf.int32, optional 29 | Mask, must be of compatible shape with labels 30 | average : str, optional 31 | 'micro': counts the total number of true positives, false 32 | positives, and false negatives for the classes in 33 | `pos_indices` and infer the metric from it. 34 | 'macro': will compute the metric separately for each class in 35 | `pos_indices` and average. Will not account for class 36 | imbalance. 37 | 'weighted': will compute the metric separately for each class in 38 | `pos_indices` and perform a weighted average by the total 39 | number of true labels for each class. 40 | Returns 41 | ------- 42 | tuple of (scalar float Tensor, update_op) 43 | """ 44 | cm, op = _streaming_confusion_matrix( 45 | labels, predictions, num_classes, weights) 46 | pr, _, _ = metrics_from_confusion_matrix( 47 | cm, pos_indices, average=average) 48 | op, _, _ = metrics_from_confusion_matrix( 49 | op, pos_indices, average=average) 50 | return (pr, op) 51 | 52 | 53 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 54 | average='micro'): 55 | """Multi-class recall metric for Tensorflow 56 | Parameters 57 | ---------- 58 | labels : Tensor of tf.int32 or tf.int64 59 | The true labels 60 | predictions : Tensor of tf.int32 or tf.int64 61 | The predictions, same shape as labels 62 | num_classes : int 63 | The number of classes 64 | pos_indices : list of int, optional 65 | The indices of the positive classes, default is all 66 | weights : Tensor of tf.int32, optional 67 | Mask, must be of compatible shape with labels 68 | average : str, optional 69 | 'micro': counts the total number of true positives, false 70 | positives, and false negatives for the classes in 71 | `pos_indices` and infer the metric from it. 72 | 'macro': will compute the metric separately for each class in 73 | `pos_indices` and average. Will not account for class 74 | imbalance. 75 | 'weighted': will compute the metric separately for each class in 76 | `pos_indices` and perform a weighted average by the total 77 | number of true labels for each class. 78 | Returns 79 | ------- 80 | tuple of (scalar float Tensor, update_op) 81 | """ 82 | cm, op = _streaming_confusion_matrix( 83 | labels, predictions, num_classes, weights) 84 | _, re, _ = metrics_from_confusion_matrix( 85 | cm, pos_indices, average=average) 86 | _, op, _ = metrics_from_confusion_matrix( 87 | op, pos_indices, average=average) 88 | return (re, op) 89 | 90 | 91 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 92 | average='micro'): 93 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 94 | average) 95 | 96 | 97 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 98 | average='micro', beta=1): 99 | """Multi-class fbeta metric for Tensorflow 100 | Parameters 101 | ---------- 102 | labels : Tensor of tf.int32 or tf.int64 103 | The true labels 104 | predictions : Tensor of tf.int32 or tf.int64 105 | The predictions, same shape as labels 106 | num_classes : int 107 | The number of classes 108 | pos_indices : list of int, optional 109 | The indices of the positive classes, default is all 110 | weights : Tensor of tf.int32, optional 111 | Mask, must be of compatible shape with labels 112 | average : str, optional 113 | 'micro': counts the total number of true positives, false 114 | positives, and false negatives for the classes in 115 | `pos_indices` and infer the metric from it. 116 | 'macro': will compute the metric separately for each class in 117 | `pos_indices` and average. Will not account for class 118 | imbalance. 119 | 'weighted': will compute the metric separately for each class in 120 | `pos_indices` and perform a weighted average by the total 121 | number of true labels for each class. 122 | beta : int, optional 123 | Weight of precision in harmonic mean 124 | Returns 125 | ------- 126 | tuple of (scalar float Tensor, update_op) 127 | """ 128 | cm, op = _streaming_confusion_matrix( 129 | labels, predictions, num_classes, weights) 130 | _, _, fbeta = metrics_from_confusion_matrix( 131 | cm, pos_indices, average=average, beta=beta) 132 | _, _, op = metrics_from_confusion_matrix( 133 | op, pos_indices, average=average, beta=beta) 134 | return (fbeta, op) 135 | 136 | 137 | def safe_div(numerator, denominator): 138 | """Safe division, return 0 if denominator is 0""" 139 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 140 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 141 | denominator_is_zero = tf.equal(denominator, zeros) 142 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 143 | 144 | 145 | def pr_re_fbeta(cm, pos_indices, beta=1): 146 | """Uses a confusion matrix to compute precision, recall and fbeta""" 147 | num_classes = cm.shape[0] 148 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 149 | cm_mask = np.ones([num_classes, num_classes]) 150 | cm_mask[neg_indices, neg_indices] = 0 151 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 152 | 153 | cm_mask = np.ones([num_classes, num_classes]) 154 | cm_mask[:, neg_indices] = 0 155 | tot_pred = tf.reduce_sum(cm * cm_mask) 156 | 157 | cm_mask = np.ones([num_classes, num_classes]) 158 | cm_mask[neg_indices, :] = 0 159 | tot_gold = tf.reduce_sum(cm * cm_mask) 160 | 161 | pr = safe_div(diag_sum, tot_pred) 162 | re = safe_div(diag_sum, tot_gold) 163 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 164 | 165 | return pr, re, fbeta 166 | 167 | 168 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 169 | beta=1): 170 | """Precision, Recall and F1 from the confusion matrix 171 | Parameters 172 | ---------- 173 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 174 | The streaming confusion matrix. 175 | pos_indices : list of int, optional 176 | The indices of the positive classes 177 | beta : int, optional 178 | Weight of precision in harmonic mean 179 | average : str, optional 180 | 'micro', 'macro' or 'weighted' 181 | """ 182 | num_classes = cm.shape[0] 183 | if pos_indices is None: 184 | pos_indices = [i for i in range(num_classes)] 185 | 186 | if average == 'micro': 187 | return pr_re_fbeta(cm, pos_indices, beta) 188 | elif average in {'macro', 'weighted'}: 189 | precisions, recalls, fbetas, n_golds = [], [], [], [] 190 | for idx in pos_indices: 191 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 192 | precisions.append(pr) 193 | recalls.append(re) 194 | fbetas.append(fbeta) 195 | cm_mask = np.zeros([num_classes, num_classes]) 196 | cm_mask[idx, :] = 1 197 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 198 | 199 | if average == 'macro': 200 | pr = tf.reduce_mean(precisions) 201 | re = tf.reduce_mean(recalls) 202 | fbeta = tf.reduce_mean(fbetas) 203 | return pr, re, fbeta 204 | if average == 'weighted': 205 | n_gold = tf.reduce_sum(n_golds) 206 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 207 | pr = safe_div(pr_sum, n_gold) 208 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 209 | re = safe_div(re_sum, n_gold) 210 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 211 | fbeta = safe_div(fbeta_sum, n_gold) 212 | return pr, re, fbeta 213 | 214 | else: 215 | raise NotImplementedError() --------------------------------------------------------------------------------