├── .gitignore ├── EDA.ipynb ├── README.md ├── ckbqa ├── __init__.py ├── dao │ ├── __init__.py │ ├── db.py │ ├── db_tools.py │ ├── mongo_models.py │ ├── mongo_utils.py │ ├── sqlite_models.py │ └── sqlite_utils.py ├── dataset │ ├── __init__.py │ ├── data_prepare.py │ └── kb_data_prepare.py ├── layers │ ├── __init__.py │ ├── losses.py │ └── modules.py ├── models │ ├── __init__.py │ ├── base_trainer.py │ ├── data_helper.py │ ├── entity_score │ │ ├── __init__.py │ │ └── model.py │ ├── evaluation_matrics.py │ ├── ner │ │ ├── __init__.py │ │ ├── crf.py │ │ └── model.py │ └── relation_score │ │ ├── __init__.py │ │ ├── model.py │ │ ├── predictor.py │ │ └── trainer.py ├── qa │ ├── __init__.py │ ├── algorithms.py │ ├── cache.py │ ├── el.py │ ├── lac_tools.py │ ├── neo4j_graph.py │ ├── qa.py │ └── relation_extractor.py └── utils │ ├── __init__.py │ ├── async_tools.py │ ├── decorators.py │ ├── gpu_selector.py │ ├── logger.py │ ├── saver.py │ ├── sequence.py │ └── tools.py ├── config.py ├── data.py ├── docs ├── CCKS 2019 知识图谱评测技术报告:实体、关系、事件及问答 .pdf └── bad_case.md ├── evaluate.py ├── examples ├── __init__.py ├── answer_format.py ├── bad_case.py ├── del_method.py ├── kb_data.py ├── lac_test.py ├── single_example.py └── top_path.py ├── manage.py ├── port_map.sh ├── qa.py └── tests ├── __init__.py ├── test_.py └── test_ceg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #pycharm 132 | .idea 133 | 134 | # 135 | local 136 | data/raw_data/* 137 | data/ 138 | output 139 | supports 140 | 141 | # 142 | *.txt 143 | *.csv 144 | *.tgz 145 | *.json 146 | *.old 147 | -------------------------------------------------------------------------------- /EDA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 数据分析" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "from ckbqa.dataset.data_prepare import get_raw_data_df" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "data_df = get_raw_data_df()" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 10, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/html": [ 37 | "
\n", 38 | "\n", 51 | "\n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | "
questionanswersparql
count400040004000
unique398829303965
top[科学家牛顿的英文名是?]<疾病>select ?x where { <大华股份> <市盈率> ?x. }
freq32533
\n", 87 | "
" 88 | ], 89 | "text/plain": [ 90 | " question answer sparql\n", 91 | "count 4000 4000 4000\n", 92 | "unique 3988 2930 3965\n", 93 | "top [科学家牛顿的英文名是?] <疾病> select ?x where { <大华股份> <市盈率> ?x. }\n", 94 | "freq 3 253 3" 95 | ] 96 | }, 97 | "execution_count": 10, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "data_df.describe()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 11, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/html": [ 114 | "
\n", 115 | "\n", 128 | "\n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | "
questionanswersparql
0[莫妮卡·贝鲁奇的代表作?]<西西里的美丽传说>select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. }
1[《湖上草》是谁的诗?]<柳如是_(明末“秦淮八艳”之一)>select ?x where { ?x <主要作品> <湖上草>. }
2[龙卷风的英文名是什么?]\"Tornado\"select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. }
3[新加坡的水域率是多少?]\"1.444%\"select ?x where { <新加坡> <水域率> ?x. }
4[商朝在哪场战役中走向覆灭?]<牧野之战>select ?x where { <商朝> <灭亡> ?x. }
\n", 170 | "
" 171 | ], 172 | "text/plain": [ 173 | " question answer \\\n", 174 | "0 [莫妮卡·贝鲁奇的代表作?] <西西里的美丽传说> \n", 175 | "1 [《湖上草》是谁的诗?] <柳如是_(明末“秦淮八艳”之一)> \n", 176 | "2 [龙卷风的英文名是什么?] \"Tornado\" \n", 177 | "3 [新加坡的水域率是多少?] \"1.444%\" \n", 178 | "4 [商朝在哪场战役中走向覆灭?] <牧野之战> \n", 179 | "\n", 180 | " sparql \n", 181 | "0 select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. } \n", 182 | "1 select ?x where { ?x <主要作品> <湖上草>. } \n", 183 | "2 select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. } \n", 184 | "3 select ?x where { <新加坡> <水域率> ?x. } \n", 185 | "4 select ?x where { <商朝> <灭亡> ?x. } " 186 | ] 187 | }, 188 | "execution_count": 11, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "data_df.head()" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 5, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "data_df['q_len'] = data_df['question'].apply(len)\n", 204 | "data_df['a_len'] = data_df['answer'].apply(len)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 7, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/html": [ 215 | "
\n", 216 | "\n", 229 | "\n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | "
questionanswersparqlq_lena_len
0[莫妮卡·贝鲁奇的代表作?]<西西里的美丽传说>select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. }110
1[《湖上草》是谁的诗?]<柳如是_(明末“秦淮八艳”之一)>select ?x where { ?x <主要作品> <湖上草>. }118
2[龙卷风的英文名是什么?]\"Tornado\"select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. }19
3[新加坡的水域率是多少?]\"1.444%\"select ?x where { <新加坡> <水域率> ?x. }18
4[商朝在哪场战役中走向覆灭?]<牧野之战>select ?x where { <商朝> <灭亡> ?x. }16
\n", 283 | "
" 284 | ], 285 | "text/plain": [ 286 | " question answer \\\n", 287 | "0 [莫妮卡·贝鲁奇的代表作?] <西西里的美丽传说> \n", 288 | "1 [《湖上草》是谁的诗?] <柳如是_(明末“秦淮八艳”之一)> \n", 289 | "2 [龙卷风的英文名是什么?] \"Tornado\" \n", 290 | "3 [新加坡的水域率是多少?] \"1.444%\" \n", 291 | "4 [商朝在哪场战役中走向覆灭?] <牧野之战> \n", 292 | "\n", 293 | " sparql q_len a_len \n", 294 | "0 select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. } 1 10 \n", 295 | "1 select ?x where { ?x <主要作品> <湖上草>. } 1 18 \n", 296 | "2 select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. } 1 9 \n", 297 | "3 select ?x where { <新加坡> <水域率> ?x. } 1 8 \n", 298 | "4 select ?x where { <商朝> <灭亡> ?x. } 1 6 " 299 | ] 300 | }, 301 | "execution_count": 7, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "data_df.head()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [] 316 | } 317 | ], 318 | "metadata": { 319 | "kernelspec": { 320 | "display_name": "Python 3", 321 | "language": "python", 322 | "name": "python3" 323 | }, 324 | "language_info": { 325 | "codemirror_mode": { 326 | "name": "ipython", 327 | "version": 3 328 | }, 329 | "file_extension": ".py", 330 | "mimetype": "text/x-python", 331 | "name": "python", 332 | "nbconvert_exporter": "python", 333 | "pygments_lexer": "ipython3", 334 | "version": "3.7.1" 335 | } 336 | }, 337 | "nbformat": 4, 338 | "nbformat_minor": 2 339 | } 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCKS-2020 2 | 3 | - [CCKS 2020:新冠知识图谱构建与问答评测(四)新冠知识图谱问答评测 4 | ](https://biendata.com/competition/ccks_2020_7_4/) 5 | 6 | ### 0. 代码结构 7 | 8 | - 主要有三大模块: 9 | 1. dataset模块:数据预处理等与数据相关的 10 | 1. config.py 所有文件路径写这里,方便查找 11 | 2. models模块:各种深度学习模型 12 | 3. qa模块:将训练好的模型和策略组合应用,完成问答 13 | 14 | **代码运行必须从几个接口模块开始;小脚本测试,可模仿接口模块,新写一个接口模块,在其中导入函数运行** 15 | 16 | ├── dao //数据库接口,准备做辅助缓存 17 | ├── dataset //数据相关 18 | │ ├── data_prepare.py //train数据预处理,将给定训练数据做各种处理,构造字典等 19 | │ └── kb_data_prepare.py //图谱数据预处理,将给定图谱做各种转换,构造字典等 20 | ├── layers //损失函数等 21 | ├── models //模型 22 | │ ├── entity_score // 对识别出的主实体打分模型 23 | │ │ └── model.py // 模型定义 24 | │ ├── ner //主实体识别模型 25 | │ │ └── model.py // 模型定义 26 | │ ├── relation_score //对实体关联的关系打分的模型 27 | │ │ ├── model.py // 模型定义 28 | │ │ ├── predictor.py // 模型封装用作后续预测 29 | │ │ └── trainer.py // 模型训练 30 | │ ├── base_trainer.py //训练模块;模型初始化到训练 31 | │ ├── data_helper.py //train数据处理成合适的格式,feed给模型 32 | │ └── evaluation_matrics.py // 指标计算 33 | ├── qa //问答模块 34 | │ ├── algorithms.py // 后处理算法 35 | │ ├── cache.py // 大文件,在单例模式缓存;避免多次载入内存;ent2id等放在这里,提供给其他模块公共使用 36 | │ ├── el.py //entity link,实体链接(主实体识别模块) 37 | │ ├── entity_score.py // 对识别出的主实体打分模型 38 | │ ├── lac_tools.py // 分词模块自定义优化等 39 | │ ├── neo4j_graph.py //图数据库查询缓存等 40 | │ ├── qa.py //问答接口,将其他模块组装到这里完成问答 41 | │ └── relation_extractor.py //实体关联关系识别 42 | ├── utils //通用工具 43 | ├── docs //文档 44 | ├── examples //临时任务,模块试验等,单个脚本 45 | ├── tests //测试 46 | ├── config.py //所有数据路径和少量全局配置 47 | ├── data.py //所有数据处理的入口文件 48 | ├── evaluate.py //模块评测入口文件 49 | ├── manage.py //所有模型训练的入口文件 50 | ├── qa.py //问答入口文件 51 | ├── README.md //说明文档 52 | └── requirements.txt //依赖包 53 | 54 | 55 | ### 1. 图数据库 56 | - 知识库管理系统 57 | - [gStore](http://gstore-pku.com/pcsite/index.html) 58 | - [gStore - github](https://github.com/pkumod/gStore/blob/master/docs/DOCKER_DEPLOY_CN.md) 59 | - SPARQL 60 | - 国际化资源标识符(Internationalized Resource Identifiers,简称IRI),与其相提并论的是URI(Uniform Resource Identifier,统一资源标志符)。 61 | 使用来表示一个IRI 62 | - Literal用于表示三元组中客体(Object),表示非IRI的数据,例如字符串(String),数字(xsd:integer),日期(xsd:date)等。 63 | 普通字符串等 "chat" 64 | - [RDF查询语言SPARQL - SimmerChan的文章 - 知乎 65 | ](https://zhuanlan.zhihu.com/p/32703794) 66 | 67 | neo4j数据库 68 | - 安装 69 | https://segmentfault.com/a/1190000015389941 70 | 71 | - 运行 72 | cd /home/wangshengguang/neo4j-community-3.4.5/bin 73 | ./neo4j start 74 | ./neo4j stop 75 | 76 | - 数据导入 77 | cd /home/wangshengguang/neo4j-community-3.4.5/bin 78 | ./neo4j-admin import --database=graph.db --nodes /home/wangshengguang/ccks-2020/data/graph_entity.csv --relationships /home/wangshengguang/ccks-2020/data/graph_relation.csv --ignore-duplicate-nodes=true --id-type INTEGER --ignore-missing-nodes=true 79 | 80 | - 创建索引 81 | CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.id IS UNIQUE; 82 | CREATE INDEX ON :Entity(name) 83 | CREATE INDEX ON :Relation(name) 84 | 85 | 86 | - [neo4j学习笔记(三)——python接口-创建删除结点和关系](https://blog.csdn.net/qq_36591505/article/details/100987105) 87 | - [neo4j︱与python结合的py2neo使用教程(四)](https://blog.csdn.net/sinat_26917383/article/details/79901207) 88 | - [neo4j中文文档](http://neo4j.com.cn/public/docs/index.html) 89 | 90 | 91 | 92 | ## 2. 历年方案 93 | 94 | - [CCKS 2019 | 开放域中文KBQA系统 - 最AI的小PAI的文章 - 知乎 95 | ](https://zhuanlan.zhihu.com/p/92317079) 96 | 1. 首先在句子中找到主题实体。在这里,我们使用了比赛组织者提供的Entity-mention文件和一些外部工具,例如paddle-paddle。 97 | **2. 然后在关系识别模块中,通过提取知识图中的主题实体的子图来找到问题(也称为谓词)的关系。通过相似性评分模型获得所有关系的排名。** 98 | 3. 最后,在答案选择模块中,根据简单复杂问题分类器和一些规则,得出最终答案。 99 | 100 | - [百度智珠夺冠:在知识图谱领域百度持续领先 ](https://www.sohu.com/a/339187520_630344) 101 | 1. 实体链接组件把问题中提及的实体链接到了知识库,并识别问题的核心实体。为了提高链接的精度,链接组件综合考虑了实体的子图与问题的匹配度、实体的流行度、指称正确度等多种特征,最后利用 LambdaRank 算法对实体进行排序,得到得分最高的实体。 102 | 2. 子图排序组件目标是从多种角度计算问题与各个子图的匹配度,最后综合多个匹配度的得分,得到出得分最高的答案子图。 103 | 3. 针对千万级的图谱,百度智珠团队采用了自主研发的策略来进行子图生成时的剪枝,综合考虑了召回率、精确率和时间代价等因素,从而提高子图排序的效率和效果。 104 | 针对开放领域的子图匹配,采用字面匹配函数计算符号化的语义相似,应用 word2vec 框架计算浅层的语义匹配,最后应用 BERT 算法做深度语义对齐。 105 | 除此之外,方案还针对具体的特征类型的问题进行一系列的意图判断,进一步提升模型在真实的问答场景中的效果和精度,更好地控制返回的答案类型,更符合真实的问答产品的需要。 106 | 107 | - http://nlpprogress.com/ 108 | 109 | ## 3. 数据分析及预处理 110 | 以下提到所有路径都是/home/wangshengguang/下的相对路径 111 | ### 3.1 原始数据 112 | 原始数据存放在data/roaw_data 目录下 113 | 1. 问答数据 data/raw_data/ccks_2020_7_4_Data下 114 | 1. 有标注训练集15999:data/raw_data/ccks_2020_7_4_Data/task1-4_train_2020.txt 115 | 2. 无标注验证集(做提交)1529:data/raw_data/ccks_2020_7_4_Data/task1-4_valid_2020.questions 116 | 2. 图谱数据 data/raw_data/PKUBASE下 117 | 1. 所有三元组66499745:data/raw_data/PKUBASE/pkubase-complete.txt 118 | 119 | ### 3.2 数据分析 120 | 三元组中实体分两类:<实体>和"属性"; 121 | 122 | 123 | ### 3.3 预处理tips: 124 | 三元组中数据分两类,<实体>和"属性"; 125 | 预处理时将属性的双引号去掉(包括构建的字典和导入neo4j的数据全部双引号都被去掉了,主要考虑双引号作为json的key不方便保存),方便使用; 126 | kb_data_prepare.py->iter_triples 127 | 在最后提交时需要恢复 128 | 129 | 130 | ## 4. 目前方案 131 | 1. 先做NER识别 主实体 132 | 2. 查找实体的关系,做分类,挑选出top 路径 133 | 3. 生成sparql查询结果 134 | 135 | 136 | ### 4.1 主实体识别模块 137 | EL 138 | 139 | 140 | ### 4.2 关系打分模块 141 | RelationExtractor 142 | 143 | 144 | ### 4.3 后处理模块 145 | Algorithms 146 | 147 | 148 | 149 | ## Others 150 | 151 | - 本地端口映射后在本机访问远程图数据库 152 | ssh -f wangshengguang@remotehost -N -L 7474:localhost:7474 153 | ssh -f wangshengguang@remotehost -N -L 7687:localhost:7687 154 | 155 | 访问:[http://localhost:7474/browser/](http://localhost:7474/browser/) 156 | 157 | 158 | 159 | - [Jupyter Notebook 有哪些奇技淫巧? - z.defying的回答 - 知乎 160 | ](https://www.zhihu.com/question/266988943/answer/1154607853) 161 | - [linux下查看文件编码及修改编码](https://blog.csdn.net/jnbbwyth/article/details/6991425) 162 | 163 | 164 | ## 执行顺序 165 | 1. 数据准备 data.py + neo4j 数据库安装及数据导入 166 | 2. 模型训练 manage.py 167 | 3. 组合以上模型得到问答结果 qa.py 168 | -------------------------------------------------------------------------------- /ckbqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/__init__.py -------------------------------------------------------------------------------- /ckbqa/dao/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/dao/__init__.py -------------------------------------------------------------------------------- /ckbqa/dao/db.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -*- 2 | """ 3 | 在此管理数据库连接 4 | 在连接池启用的情况下,session 只是对连接池中的连接进行操作: 5 | session = Session() 将 Session 实例化的过程只是从连接池中取一个连接,在用完之后 6 | 使用 session.close(),session.commit(),session.rooback()还回到连接池中,而不是真正的关闭连接。 7 | 连接池禁用之后,sqlalchemy 会在查询之前创建连接,在查询结束,调用 session.close() 的时候关闭数据库连接。 8 | http://qinfei.glrsmart.com/2017/11/17/python-sqlalchemy-shu-ju-ku-lian-jie-chi/ 9 | """ 10 | import logging 11 | from mongoengine import connect, register_connection 12 | 13 | from sqlalchemy import create_engine 14 | from sqlalchemy.orm import sessionmaker 15 | 16 | from config import data_dir 17 | 18 | # db_pool_config = { 19 | # "pool_pre_ping": True, # 检查连接,无效则重置 20 | # "pool_size": 5, # 连接数, 默认5 21 | # "max_overflow": 10, # 超出pool_size后最多达到几个连接,超过部分使用后直接关闭,不放回连接池,默认为10,-1不限制 22 | # "pool_recycle": 7200, # 连接重置周期,默认-1,推荐7200,表示连接在给定时间之后会被回收,勿超过mysql 默认8小时 23 | # "pool_timeout": 30 # 等待 pool_timeout 秒,没有获取到连接之后,放弃从池中获取连接 24 | # "echo": False 显示执行的SQL语句 #想查看语句直接print str(session.query(UserProfile).filter_by(user_id='xxxx')) 25 | # } 26 | 27 | 28 | # sqlalchemy 一次只允许操作一个数据库,必须指定database 29 | sqlite_db_engine = create_engine(f'sqlite:///{data_dir}/data.sqlite?charset=utf8mb4') 30 | 31 | 32 | class DB(object): 33 | def __init__(self): 34 | session_factory = sessionmaker(bind=sqlite_db_engine) # autocommit=True若无开启Transaction,会自动commit 35 | self.session = session_factory() 36 | # self.session = scoped_session(session_factory) # 多线程安全,每个线程使用单独数据库连接 37 | 38 | def select(self, sql, params=None): 39 | """ 40 | :param sql: type str 41 | :param params: Optional dictionary, or list of dictionaries 42 | :return: result = session.execute( 43 | "SELECT * FROM user WHERE id=:user_id", 44 | {"user_id":5} 45 | ) 46 | """ 47 | cursor = self.session.execute(sql, params) 48 | res = cursor.fetchall() 49 | self.session.commit() 50 | return res 51 | 52 | def execute(self, sql, params=None): 53 | """ 54 | :param sql: type str 55 | :param params: Optional dictionary or list of dictionaries 56 | :return: result = session.execute( 57 | "SELECT * FROM user WHERE id=:user_id", 58 | {"user_id":5} 59 | ) 60 | """ 61 | self.session.execute(sql, params) 62 | self.session.commit() 63 | 64 | def __enter__(self): 65 | return self 66 | 67 | def __exit__(self, exc_type, exc_val, exc_tb): 68 | try: 69 | self.session.commit() 70 | except Exception: 71 | self.session.rollback() 72 | logging.info('save stage final,save error, session rollback ') 73 | finally: 74 | self.session.close() 75 | if exc_type is not None or exc_val is not None or exc_tb is not None: 76 | logging.info('exc_type: {}, exc_val: {}, exc_tb: {}'.format(exc_type, exc_val, exc_tb)) 77 | 78 | def __del__(self): 79 | self.session.close() 80 | 81 | 82 | # mongodb 83 | class Mongo(object): 84 | mongodb = { 85 | "host": '127.0.0.1', 86 | "port": 27017, 87 | "username": None, 88 | "password": None, 89 | "db": 'kb' 90 | } 91 | connect(**mongodb) 92 | # # 连接数据库,不存在则会自动创建 93 | register_connection("whiski_db", "whiski", host=mongodb['host']) 94 | register_connection("recommend_db", "recommend", host=mongodb['host']) 95 | -------------------------------------------------------------------------------- /ckbqa/dao/db_tools.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | from collections import Iterable 4 | from functools import wraps 5 | 6 | from sqlalchemy import exc 7 | 8 | 9 | def try_commit_rollback(expunge=None): 10 | """ 装饰继承自DB的类的对象方法 11 | :param expunge: bool,将SQLAlchemy的数据对象实例转为一个简单的对象(切断与数据库会话的联系) 12 | 只在且在查询时尽量使用,做查询且返回sqlalchemy对象时使用,不知道就是用不到 13 | """ 14 | 15 | def out_wrapper(func): 16 | @wraps(func) 17 | def wrapper(self, *args, **kwargs): 18 | res = None 19 | db_session = self.session 20 | try: 21 | res = func(self, *args, **kwargs) 22 | if expunge and res is not None: 23 | if isinstance(res, Iterable): 24 | [db_session.expunge(_data) for _data in res if isinstance(_data, BaseModel)] 25 | elif isinstance(res, BaseModel): 26 | db_session.expunge(res) 27 | else: 28 | raise TypeError("Please ensure object(s) is instance(s) of BaseModel") 29 | db_session.commit() 30 | except exc.IntegrityError as e: # duplicate key 31 | db_session.rollback() 32 | logging.error("sqlalchemy.exc.IntegrityError: {}".format(e)) 33 | except exc.DataError as e: # invalid input syntax for uuid: "7ae16-ake" 34 | db_session.rollback() 35 | logging.error("sqlalchemy.exc.DataError: {}".format(e)) 36 | except Exception: 37 | db_session.rollback() 38 | logging.error(traceback.format_exc()) 39 | return res 40 | 41 | return wrapper 42 | 43 | return out_wrapper if isinstance(expunge, bool) else out_wrapper(expunge) 44 | 45 | 46 | try_commit_rollback_expunge = try_commit_rollback(expunge=True) 47 | 48 | 49 | def try_commit_rollback_dbsession(func): 50 | """装饰传入db_session对象的方法""" 51 | 52 | @wraps(func) 53 | def wrapper(self, db_session, *args, **kwargs): 54 | res = None 55 | try: 56 | res = func(self, db_session, *args, **kwargs) 57 | db_session.commit() 58 | except exc.IntegrityError as e: # duplicate key 59 | db_session.rollback() 60 | logging.error("sqlalchemy.exc.IntegrityError: {}".format(e)) 61 | except exc.DataError as e: # invalid input syntax for uuid: "7ae16-ake" 62 | db_session.rollback() 63 | logging.error("sqlalchemy.exc.DataError: {}".format(e)) 64 | except Exception: 65 | db_session.rollback() 66 | logging.error(traceback.format_exc()) 67 | return res 68 | 69 | return wrapper 70 | -------------------------------------------------------------------------------- /ckbqa/dao/mongo_models.py: -------------------------------------------------------------------------------- 1 | from mongoengine import Document, StringField, IntField, ListField 2 | 3 | 4 | class Entity2id(Document): 5 | meta = { 6 | 'db_alias': 'entity2id', 7 | 'collection': 'graph', 8 | 'strict': False, 9 | 'indexes': [ 10 | 'name', 11 | 'id' 12 | ] 13 | } 14 | 15 | name = StringField(required=True) 16 | id = IntField(required=True) 17 | 18 | 19 | class Relation2id(Document): 20 | meta = { 21 | 'db_alias': 'relation2id', 22 | 'collection': 'graph', 23 | 'strict': False, 24 | 'indexes': [ 25 | 'name', 26 | 'id' 27 | ] 28 | } 29 | 30 | name = StringField(required=True) 31 | id = IntField(required=True) 32 | 33 | 34 | class Graph(Document): 35 | meta = { 36 | 'db_alias': 'sub_graph', 37 | 'collection': 'graph', 38 | 'strict': False, 39 | 'indexes': [ 40 | 'entity_name', 41 | 'entity_id' 42 | ] 43 | } 44 | 45 | entity_name = StringField(required=True) 46 | entity_id = IntField(required=True) 47 | ins = ListField(required=True) 48 | outs = ListField(required=True) 49 | 50 | 51 | class SubGraph(Document): 52 | meta = { 53 | 'db_alias': 'sub_graph', 54 | 'collection': 'graph', 55 | 'strict': False, 56 | 'indexes': [ 57 | 'entity_name', 58 | 'entity_id' 59 | ] 60 | } 61 | 62 | entity_name = StringField(required=True) 63 | entity_id = IntField(required=True) 64 | sub_graph = ListField(required=True) 65 | hop = IntField(required=True) # 几跳 66 | -------------------------------------------------------------------------------- /ckbqa/dao/mongo_utils.py: -------------------------------------------------------------------------------- 1 | from .db import Mongo 2 | from .mongo_models import Graph, Entity2id 3 | 4 | 5 | class MongoDB(Mongo): 6 | def entity2id(self, entity_name): 7 | return Entity2id.objects(name_eq=entity_name).all() 8 | 9 | def save_graph(self, graph_node: Graph): 10 | return Graph.objects(graph_node).all() 11 | -------------------------------------------------------------------------------- /ckbqa/dao/sqlite_models.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -*- 2 | 3 | from sqlalchemy import Column 4 | from sqlalchemy.dialects.sqlite import CHAR, INTEGER, TEXT, SMALLINT 5 | from sqlalchemy.ext.declarative import declarative_base 6 | 7 | BaseModel = declarative_base() # 创建一个类,这个类的子类可以自动与一个表关联 8 | 9 | 10 | class Graph(BaseModel): 11 | __tablename__ = 'graph' 12 | 13 | id = Column(INTEGER, primary_key=True, comment="entity id") 14 | name = Column(CHAR(256), default=None, index=True, unique=True, comment="entity name") 15 | pure_name = Column(CHAR(256), default=None, index=True, comment="entity name") 16 | in_ids = Column(TEXT, default=[], comment="入度") 17 | out_ids = Column(TEXT, default=[], comment="出度") 18 | type = Column(SMALLINT, default=0, comment="entity:1, relation:2") 19 | 20 | 21 | class SubGraph(BaseModel): 22 | __tablename__ = 'sub_graph' 23 | 24 | id = Column(INTEGER, primary_key=True, comment="entity id") 25 | name = Column(CHAR(256), default=None, index=True, unique=True, comment="entity name") 26 | pure_name = Column(CHAR(256), default=None, index=True, comment="entity name") 27 | sub_graph_ids = Column(TEXT, default={}, comment="子图") 28 | sub_graph_names = Column(TEXT, default={}, comment="子图") 29 | type = Column(SMALLINT, index=True, nullable=False, comment="entity:1, relation:2") 30 | 31 | 32 | class Entity2id(BaseModel): 33 | __tablename__ = 'entity2id' 34 | 35 | entity_id = Column(INTEGER, default=None, primary_key=True, comment="entity id") 36 | entity_name = Column(CHAR(256), nullable=False, index=True, unique=True, comment="entity name") 37 | pure_name = Column(CHAR(256), nullable=False, index=True, comment="entity name") 38 | -------------------------------------------------------------------------------- /ckbqa/dao/sqlite_utils.py: -------------------------------------------------------------------------------- 1 | from .db import DB 2 | from .db_tools import try_commit_rollback, try_commit_rollback_expunge 3 | from .sqlite_models import Entity2id, SubGraph 4 | from ..utils.decorators import singleton 5 | 6 | 7 | @singleton 8 | class SqliteDB(DB): 9 | 10 | @try_commit_rollback 11 | def get_id_by_entity_name(self, entity_name, pure_name): 12 | sql = 'select entity_name from ' 13 | ws_video = self.select(sql) 14 | return ws_video 15 | 16 | @try_commit_rollback_expunge 17 | def get_subGraph_by_entity_ids(self, entity_ids): 18 | # sub_graph = self.session.query(SubGraph).filter_by(id=entity_id).first() 19 | sub_graph = self.session.query(SubGraph).filter(SubGraph.id.in_(entity_ids)) 20 | return sub_graph 21 | -------------------------------------------------------------------------------- /ckbqa/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/dataset/__init__.py -------------------------------------------------------------------------------- /ckbqa/dataset/data_prepare.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | from ckbqa.utils.tools import json_dump 7 | from config import DataConfig, raw_train_txt 8 | 9 | entity_pattern = re.compile(r'(<.*?>)') # 保留<> 10 | attr_pattern = re.compile(r'"(.*?)"') # 不保留"", "在json key中不便保存 11 | question_patten = re.compile(r'q\d{1,4}:(.*)') 12 | 13 | PAD = 0 14 | UNK = 1 15 | 16 | 17 | def load_data(tqdm_prefix=''): 18 | with open(raw_train_txt, 'r', encoding='utf-8') as f: 19 | lines = [line.strip() for line in f.readlines() if line.strip()] 20 | for q, sparql, a in tqdm([lines[i:i + 3] for i in range(0, len(lines), 3)], 21 | desc=tqdm_prefix): 22 | yield q, sparql, a 23 | 24 | 25 | def fit_on_texts(): 26 | q_entities_set = set() 27 | a_entities_set = set() 28 | words_set = set() 29 | for q, sparql, a in load_data(): 30 | a_entities = entity_pattern.findall(a) 31 | q_entities = entity_pattern.findall(sparql) 32 | q_text = question_patten.findall(q) 33 | entities = a_entities + q_entities 34 | words = list(q_text[0]) + [e for ent in entities for e in ent] 35 | 36 | q_entities_set.update(q_entities) 37 | a_entities_set.update(a_entities) 38 | words_set.update(words) 39 | word2id = {"PAD": PAD, "UNK": UNK} 40 | for w in sorted(words_set): 41 | word2id[w] = len(word2id) 42 | json_dump(word2id, DataConfig.word2id_json) 43 | # question entity 44 | q_entity2id = {"PAD": PAD, "UNK": UNK} 45 | for e in sorted(q_entities_set): 46 | q_entity2id[e] = len(q_entity2id) 47 | json_dump(q_entity2id, DataConfig.q_entity2id_json) 48 | # answer entity 49 | a_entity2id = {"PAD": PAD, "UNK": UNK} 50 | for e in sorted(a_entities_set): 51 | a_entity2id[e] = len(a_entity2id) 52 | json_dump(a_entity2id, DataConfig.a_entity2id_json) 53 | 54 | 55 | def data_convert(): 56 | """转化原始数据格式, 方便采样entity""" 57 | data = {'question': [], 'q_entities': [], 'q_strs': [], 58 | 'a_entities': [], 'a_strs': []} 59 | for q, sparql, a in load_data(): 60 | q_text = question_patten.findall(q)[0] 61 | q_entities = entity_pattern.findall(sparql) 62 | q_strs = attr_pattern.findall(sparql) 63 | a_entities = entity_pattern.findall(a) 64 | a_strs = attr_pattern.findall(a) 65 | data['question'].append(q_text) 66 | data['q_entities'].append(q_entities) 67 | data['q_strs'].append(q_strs) 68 | data['a_entities'].append(a_entities) 69 | data['a_strs'].append(a_strs) 70 | pd.DataFrame(data).to_csv(DataConfig.data_csv, encoding='utf_8_sig', index=False) 71 | -------------------------------------------------------------------------------- /ckbqa/dataset/kb_data_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | 与知识图谱相关的数据与处理 3 | """ 4 | import gc 5 | import logging 6 | import os 7 | import re 8 | from collections import defaultdict, Counter 9 | from pathlib import Path 10 | 11 | import pandas as pd 12 | from tqdm import tqdm 13 | 14 | from ckbqa.utils.tools import json_dump, get_file_linenums, pkl_dump, json_load, tqdm_iter_file 15 | from config import Config, mention2ent_txt, kb_triples_txt 16 | 17 | 18 | def iter_triples(tqdm_prefix=''): 19 | """迭代图谱的三元组;""" 20 | rdf_patten = re.compile(r'(["<].*?[>"])') 21 | ent_partten = re.compile('<(.*?)>') 22 | attr_partten = re.compile('"(.*?)"') 23 | 24 | def parse_entities(string): 25 | ents = ent_partten.findall(string) 26 | if ents: # 加上括号,使用上面会去掉括号 27 | ents = [f'<{_ent.strip()}>' for _ent in ents if _ent] 28 | else: # 去掉引号 29 | ents = [_ent.strip() for _ent in attr_partten.findall(string) if _ent] 30 | return ents 31 | 32 | line_num = get_file_linenums(kb_triples_txt) 33 | # f_record = open('fail.txt', 'w', encoding='utf-8') 34 | with open(kb_triples_txt, 'r', encoding='utf-8') as f: 35 | desc = Path(kb_triples_txt).name + '-' + 'iter_triples' 36 | for line in tqdm(f, total=line_num, desc=desc): 37 | # for idx, line in enumerate(f): 38 | # if idx < 6618182: 39 | # continue 40 | # 有些数据不以\t分割;有些数据连续两个<><>作为头实体; 41 | triples = line.rstrip('.\n').split('\t') 42 | if len(triples) != 3: 43 | triples = line.rstrip('.\n').split() 44 | if len(triples) != 3: 45 | triples = rdf_patten.findall(line) 46 | if len(triples) == 6: 47 | head_ent, rel, tail_ent = triples[:3] 48 | yield parse_entities(head_ent)[0], parse_entities(rel)[0], parse_entities(tail_ent)[0] 49 | head_ent, rel, tail_ent = triples[3:] 50 | yield parse_entities(head_ent)[0], parse_entities(rel)[0], parse_entities(tail_ent)[0] 51 | continue 52 | head_ent, rel, tail_ent = triples 53 | # 54 | head_ents = parse_entities(head_ent) 55 | tail_ents = parse_entities(tail_ent) 56 | rels = parse_entities(rel) 57 | for head_ent in head_ents: 58 | for rel in rels: 59 | for tail_ent in tail_ents: 60 | yield head_ent, rel, tail_ent 61 | 62 | 63 | def fit_triples(): 64 | """entity,relation map to id; 65 | 实体到id的映射词典 66 | """ 67 | logging.info('fit_triples start ...') 68 | entities = set() 69 | relations = set() 70 | for head_ent, rel, tail_ent in iter_triples(tqdm_prefix='fit_triples '): 71 | entities.add(head_ent) 72 | entities.add(tail_ent) 73 | relations.add(rel) 74 | logging.info('entity2id start ...') 75 | entity2id = {ent: idx + 1 for idx, ent in enumerate(sorted(entities))} 76 | json_dump(entity2id, Config.entity2id) 77 | id2entity = {id: ent for ent, id in entity2id.items()} 78 | pkl_dump(id2entity, Config.id2entity_pkl) 79 | logging.info('relation2id start ...') 80 | relation2id = {rel: idx + 1 for idx, rel in enumerate(sorted(relations))} 81 | json_dump(relation2id, Config.relation2id) 82 | id2relation = {id: rel for rel, id in relation2id.items()} 83 | pkl_dump(id2relation, Config.id2relation_pkl) 84 | 85 | 86 | def map_mention_entity(): 87 | """mention2ent; 4 min; mention到实体的映射""" 88 | logging.info('map_mention_entity start ...') 89 | ent2mention = defaultdict(set) 90 | mention2ent = defaultdict(set) 91 | # count = 0 92 | for line in tqdm_iter_file(mention2ent_txt, prefix='iter mention2ent.txt '): 93 | mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串 94 | mention2ent[mention].add(ent) 95 | ent2mention[ent].add(mention) 96 | # if count == 13930117: 97 | # print(line) 98 | # import ipdb 99 | # ipdb.set_trace() 100 | # count += 1 101 | # 102 | mention2ent = {k: sorted(v) for k, v in mention2ent.items()} 103 | json_dump(mention2ent, Config.mention2ent_json) 104 | ent2mention = {k: sorted(v) for k, v in ent2mention.items()} 105 | json_dump(ent2mention, Config.ent2mention_json) 106 | 107 | 108 | def candidate_words(): 109 | """实体的属性和类型映射字典;作为尾实体候选; 7min""" 110 | logging.info('candidate_words gen start ...') 111 | ent2types_dict = defaultdict(set) 112 | ent2attrs_dict = defaultdict(set) 113 | all_attrs = set() 114 | for head_ent, rel, tail_ent in iter_triples(tqdm_prefix='candidate_words '): 115 | if rel == '<类型>': 116 | ent2types_dict[head_ent].add(tail_ent) 117 | elif not head_ent.startswith('<'): 118 | ent2attrs_dict[tail_ent].add(head_ent) 119 | all_attrs.add(head_ent) 120 | elif not tail_ent.startswith('<'): 121 | ent2attrs_dict[head_ent].add(tail_ent) 122 | all_attrs.add(tail_ent) 123 | json_dump({ent: sorted(_types) for ent, _types in ent2types_dict.items()}, 124 | Config.entity2types_json) 125 | json_dump({ent: sorted(attrs) for ent, attrs in ent2attrs_dict.items()}, 126 | Config.entity2attrs_json) 127 | json_dump(list(all_attrs), Config.all_attrs_json) 128 | 129 | 130 | def _get_top_counter(): 131 | """26G,高频实体和mention,作为后期筛选和lac字典; 132 | 统计实体和mention出现的次数; 方便取top作为最终自定义分词的词典; 133 | """ 134 | logging.info('kb_count_top_dict start ...') 135 | if not (os.path.isfile(Config.entity2count_json) and 136 | os.path.isfile(Config.relation2count_json)): 137 | entities = [] 138 | relations = [] 139 | for head_ent, rel, tail_ent in iter_triples(tqdm_prefix='kb_top_count_dict '): 140 | entities.extend([head_ent, tail_ent]) 141 | relations.append(rel) 142 | ent_counter = Counter(entities) 143 | json_dump(dict(ent_counter), Config.entity2count_json) 144 | del entities 145 | rel_counter = Counter(relations) 146 | del relations 147 | json_dump(dict(rel_counter), Config.relation2count_json) 148 | else: 149 | ent_counter = Counter(json_load(Config.entity2count_json)) 150 | rel_counter = Counter(json_load(Config.relation2count_json)) 151 | # 152 | if not os.path.isfile(Config.mention2count_json): 153 | mentions = [] 154 | for line in tqdm_iter_file(mention2ent_txt, prefix='count_top_dict iter mention2ent.txt '): 155 | mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串 156 | mentions.append(mention) 157 | mention_counter = Counter(mentions) 158 | del mentions 159 | json_dump(dict(mention_counter), Config.mention2count_json) 160 | else: 161 | mention_counter = Counter(json_load(Config.mention2count_json)) 162 | return ent_counter, rel_counter, mention_counter 163 | 164 | 165 | def create_lac_custom_dict(): 166 | """生成自定义分词词典""" 167 | logging.info('create_lac_custom_dict start...') 168 | ent_counter, rel_counter, mention_counter = _get_top_counter() 169 | mention_count = mention_counter.most_common(50 * 10000) # 170 | customization_dict = {mention: 'MENTION' for mention, count in mention_count 171 | if 2 <= len(mention) <= 8} 172 | del mention_count, mention_counter 173 | logging.info('create ent&rel customization_dict ...') 174 | _ent_pattrn = re.compile(r'["<](.*?)[>"]') 175 | # customization_dict.update({' '.join(_ent_pattrn.findall(rel)): 'REL' 176 | # for rel, count in rel_counter.most_common(10000) # 10万 177 | # if 2 <= len(rel) <= 8}) 178 | # del rel_counter 179 | customization_dict.update({' '.join(_ent_pattrn.findall(ent)): 'ENT' 180 | for ent, count in ent_counter.most_common(50 * 10000) # 100万 181 | if 2 <= len(ent) <= 8}) 182 | q_entity2id = json_load(Config.q_entity2id_json) 183 | q_entity2id.update(json_load(Config.a_entity2id_json)) 184 | customization_dict.update({' '.join(_ent_pattrn.findall(ent)): 'ENT' 185 | for ent, _id in q_entity2id.items()}) 186 | del ent_counter 187 | with open(Config.lac_custom_dict_txt, 'w') as f: 188 | for e, t in customization_dict.items(): 189 | if len(e) >= 3: 190 | f.write(f"{e}/{t}\n") 191 | logging.info('attr_custom_dict gen start ...') 192 | entity2attrs = json_load(Config.entity2attrs_json) 193 | all_attrs = set() 194 | for attrs in entity2attrs.values(): 195 | all_attrs.update(attrs) 196 | name_patten = re.compile('"(.*?)"') 197 | with open(Config.lac_attr_custom_dict_txt, 'w') as f: 198 | for _attr in all_attrs: 199 | attr = ' '.join(name_patten.findall(_attr)) 200 | if len(attr) >= 2: 201 | f.write(f"{attr}/ATTR\n") 202 | 203 | 204 | def create_graph_csv(): 205 | """ 生成数据库导入文件 206 | cd /home/wangshengguang/neo4j-community-3.4.5/bin/ 207 | ./neo4j-admin import --database=graph.db --nodes /home/wangshengguang/ccks-2020/data/graph_entity.csv --relationships /home/wangshengguang/ccks-2020/data/graph_relation2.csv --ignore-duplicate-nodes=true --id-type INTEGER --ignore-missing-nodes=true 208 | CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.id IS UNIQUE; 209 | CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.name IS UNIQUE; 210 | CREATE CONSTRAINT ON (r:Relation) ASSERT r.name IS UNIQUE; 211 | """ 212 | logging.info('start load Config.id2entity ..') 213 | entity2id = json_load(Config.entity2id) 214 | pd.DataFrame.from_records( 215 | [(id, ent, 'Entity') for ent, id in entity2id.items()], 216 | columns=["id:ID(Entity)", "name:String", ":LABEL"]).to_csv( 217 | Config.graph_entity_csv, index=False, encoding='utf_8_sig') 218 | # 219 | records = [[entity2id[head_name], entity2id[tail_name], 'Relation', rel_name] 220 | for head_name, rel_name, tail_name in iter_triples(tqdm_prefix='gen relation csv')] 221 | del entity2id 222 | gc.collect() 223 | pd.DataFrame.from_records( 224 | records, columns=[":START_ID(Entity)", ":END_ID(Entity)", ":TYPE", "name:String"]).to_csv( 225 | Config.graph_relation_csv, index=False, encoding='utf_8_sig') 226 | 227 | # def kge_data(): 228 | # """17.3G,生成KGE训练数据""" 229 | # logging.info('kge_data start ...') 230 | # kge_data_dir = Path(Config.entity2id).parent.joinpath('ccks-2020') 231 | # kge_data_dir.mkdir(exist_ok=True) 232 | # 233 | # def data_df2file(data_df, save_path): 234 | # with open(save_path, 'w') as f: 235 | # f.write(f'{data_df.shape[0]}\n') 236 | # data_df.to_csv(save_path, mode='a', header=False, index=False, sep='\t') 237 | # 238 | # # 239 | # ent_most = set([ent for ent, count in Counter(json_load(Config.entity2count_json)).most_common(1000000)]) 240 | # entity2id = {ent: id + 1 for id, ent in enumerate(sorted(ent_most))} 241 | # data_df2file(data_df=pd.DataFrame([(ent, idx) for ent, idx in entity2id.items()]), 242 | # save_path=kge_data_dir.joinpath('entity2id.txt')) 243 | # # 244 | # rel_most = set([rel for rel, _count in Counter(json_load(Config.relation2count_json)).most_common(1 * 200000)]) 245 | # relation2id = {rel: id + 1 for id, rel in enumerate(sorted(rel_most))} 246 | # data_df2file(data_df=pd.DataFrame([(rel, idx) for rel, idx in relation2id.items()]), 247 | # save_path=kge_data_dir.joinpath('relation2id.txt')) 248 | # # 训练集、验证集、测试集 249 | # logging.info('kge train data start ...') 250 | # # head,rel,tail -> head,tail,rel #HACK 交换了顺序,为了配合KGE模型的训练数据需要 251 | # data_df = pd.DataFrame([(entity2id.get(head_ent, 0), entity2id.get(tail_ent, 0), 252 | # relation2id.get(rel, 0)) 253 | # for head_ent, rel, tail_ent in iter_triples()]) 254 | # del entity2id, relation2id 255 | # train_df, test_df = train_test_split(data_df, test_size=0.01) 256 | # data_df2file(test_df, save_path=kge_data_dir.joinpath('test2id.txt')) 257 | # del test_df 258 | # train_df, valid_df = train_test_split(train_df, test_size=0.01) 259 | # data_df2file(valid_df, save_path=kge_data_dir.joinpath('valid2id.txt')) 260 | # del valid_df 261 | # data_df2file(train_df, save_path=kge_data_dir.joinpath('train2id.txt')) 262 | # del train_df 263 | # # 264 | # src_dir = str(kge_data_dir) 265 | # dst_dir = str( 266 | # Path(Config.entity2id).parent.parent.parent.joinpath('KE').joinpath('benchmarks').joinpath('ccks-2020')) 267 | # logging.info(f'move file: {src_dir} to {dst_dir}') 268 | # shutil.copytree(src=src_dir, dst=dst_dir) # directory 269 | -------------------------------------------------------------------------------- /ckbqa/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/layers/__init__.py -------------------------------------------------------------------------------- /ckbqa/layers/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ContrastiveLoss(torch.nn.Module): 5 | """ 6 | Contrastive loss function. 7 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 8 | """ 9 | 10 | def __init__(self, margin=1.0): 11 | super(ContrastiveLoss, self).__init__() 12 | self.margin = margin 13 | 14 | def forward(self, euclidean_distance, label): 15 | loss_contrastive = torch.mean(label * torch.pow(euclidean_distance, 2) + 16 | (1 - label) * torch.pow(torch.relu(self.margin - euclidean_distance), 2)) / 2 17 | 18 | return loss_contrastive 19 | 20 | -------------------------------------------------------------------------------- /ckbqa/layers/modules.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import torch 4 | 5 | 6 | class TextCNN(nn.Module): 7 | def __init__(self, embed_dim, out_dim): 8 | super(TextCNN).__init__() 9 | input_chanel = 1 10 | kernel_num = 1 11 | filter_sizes = [2, 3, 4] 12 | self.convs = [nn.Conv2d(input_chanel, kernel_num, (kernel_size, embed_dim)) 13 | for kernel_size in filter_sizes] 14 | # self.drop_out = nn.Dropout(0.8) 15 | self.fc = nn.Linear(len(filter_sizes) * kernel_num, out_dim) 16 | 17 | def forward(self, x): 18 | xs = [] 19 | for i in range(len(self.convs)): 20 | x = self.convs[i](x) 21 | xs.append(x) 22 | x = torch.cat(xs, dim=1) 23 | x = self.fc(x) 24 | return x 25 | -------------------------------------------------------------------------------- /ckbqa/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/models/__init__.py -------------------------------------------------------------------------------- /ckbqa/models/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | from config import Config 7 | 8 | 9 | class BaseTrainer(object): 10 | def __init__(self, model_name): 11 | self.global_step = 0 12 | self.device = Config.device 13 | self.model_name = model_name 14 | 15 | def backfoward(self, loss, model): 16 | if Config.gpu_nums > 1 and Config.multi_gpu: 17 | loss = loss.mean() # mean() to average on multi-gpu 18 | if Config.gradient_accumulation_steps > 1: 19 | loss = loss / Config.gradient_accumulation_steps 20 | # https://zhuanlan.zhihu.com/p/79887894 21 | loss.backward() 22 | self.global_step += 1 23 | if self.global_step % Config.gradient_accumulation_steps == 0: 24 | nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=Config.clip_grad) 25 | self.optimizer.step() 26 | self.optimizer.zero_grad() 27 | return loss 28 | 29 | def init_model(self, model): 30 | model.to(self.device) # without this there is no error, but it runs in CPU (instead of GPU). 31 | if Config.gpu_nums > 1 and Config.multi_gpu: 32 | model = torch.nn.DataParallel(model) 33 | self.optimizer = Adam(model.parameters(), lr=Config.learning_rate) 34 | self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)) 35 | return model 36 | -------------------------------------------------------------------------------- /ckbqa/models/data_helper.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from pytorch_transformers import BertTokenizer 5 | 6 | from ckbqa.utils.decorators import singleton 7 | from ckbqa.utils.sequence import pad_sequences 8 | from config import Config 9 | 10 | 11 | @singleton 12 | class DataHelper(object): 13 | def __init__(self, load_tokenizer=True): 14 | self.device = Config.device 15 | if load_tokenizer: 16 | self.load_tokenizer() 17 | 18 | def load_tokenizer(self): 19 | self.tokenizer = BertTokenizer.from_pretrained( 20 | Config.pretrained_model_name_or_path, do_lower_case=True) 21 | 22 | def batch_sent2tensor(self, sent_texts: List[str], pad=False): 23 | batch_token_ids = [self.sent2ids(sent_text) for sent_text in sent_texts] 24 | if pad: 25 | batch_token_ids = pad_sequences(batch_token_ids, maxlen=Config.max_len, padding='post') 26 | batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long).to(self.device) 27 | return batch_token_ids 28 | 29 | def sent2ids(self, sent_text): 30 | sent_tokens = ['[CLS]'] + self.tokenizer.tokenize(sent_text) + ["[SEP]"] 31 | token_ids = self.tokenizer.convert_tokens_to_ids(sent_tokens) 32 | return token_ids 33 | 34 | def data2tensor(self, batch_token_ids, pad=True, maxlen=Config.max_len): 35 | if pad: 36 | batch_token_ids = pad_sequences(batch_token_ids, maxlen=maxlen, padding='post') 37 | batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long).to(self.device) 38 | return batch_token_ids 39 | -------------------------------------------------------------------------------- /ckbqa/models/entity_score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/models/entity_score/__init__.py -------------------------------------------------------------------------------- /ckbqa/models/entity_score/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 从所有候选实体中筛选出topn个候选实体 3 | ''' 4 | 5 | import os 6 | 7 | import numpy as np 8 | from sklearn import preprocessing 9 | from sklearn.linear_model import LogisticRegression 10 | 11 | from ckbqa.dataset.data_prepare import question_patten, entity_pattern, attr_pattern, load_data 12 | from ckbqa.utils.tools import pkl_dump, pkl_load 13 | from config import Config 14 | 15 | 16 | class EntityScore(object): 17 | """ 18 | 主实体打分排序; 19 | """ 20 | 21 | def __init__(self, load_pretrain_model=False): 22 | if load_pretrain_model: 23 | self.model: LogisticRegression = pkl_load(Config.entity_score_model_pkl) 24 | 25 | def gen_train_data(self): 26 | X_train = [] 27 | Y_label = [] 28 | from ckbqa.qa.el import EL # 避免循环导入 29 | el = EL() 30 | for q, sparql, a in load_data(tqdm_prefix='EntityScore train data '): 31 | # a_entities = entity_pattern.findall(a) 32 | q_entities = set(entity_pattern.findall(sparql) + attr_pattern.findall(sparql)) # attr 33 | q_text = question_patten.findall(q)[0] 34 | candidate_entities = el.el(q_text) 35 | for ent_name, feature_dict in candidate_entities.items(): 36 | feature = feature_dict['feature'] 37 | label = 1 if ent_name in q_entities else 0 # 候选实体有的不在答案中 38 | X_train.append(feature) 39 | Y_label.append(label) 40 | pkl_dump({'x_data': X_train, 'y_label': Y_label}, Config.entity_score_data_pkl) 41 | 42 | def train(self): 43 | if not os.path.isfile(Config.entity_score_data_pkl): 44 | self.gen_train_data() 45 | data = pkl_load(Config.entity_score_data_pkl) 46 | X_train = preprocessing.scale(np.array(data['x_data'])) 47 | # Y_label = np.eye(2)[data['y_label']] # class_num==2 48 | Y_label = np.array(data['y_label']) # class_num==2 49 | print(f"X_train : {X_train.shape}, Y_label: {Y_label.shape}") 50 | print(f"sum(Y_label): {sum(Y_label)}") 51 | model = LogisticRegression(C=1e5, verbose=1) 52 | model.fit(X_train, Y_label) 53 | # model.predict() 54 | accuracy_score = model.score(X_train, Y_label) 55 | print(f"accuracy_score: {accuracy_score:.4f}") 56 | pkl_dump(model, Config.entity_score_model_pkl) 57 | 58 | def predict(self, features): 59 | preds = self.model.predict(np.asarray(features)) 60 | return preds 61 | -------------------------------------------------------------------------------- /ckbqa/models/evaluation_matrics.py: -------------------------------------------------------------------------------- 1 | def get_metrics(real_entities, pred_entities): 2 | if len(pred_entities) == 0: 3 | return 0, 0, 0 4 | real_entities_set = set(real_entities) 5 | pred_entities_set = set(pred_entities) 6 | TP = len(real_entities_set & pred_entities_set) # 负例 TP 7 | # FN = len(pred_entities - real_entities) # 预测出的负例, FN 8 | # FP = len(real_entities - pred_entities) # 没有预测出的正例 9 | # TP+FP = real_entities_set 10 | # TP+FN = pred_entities_set 11 | precision = TP / len(pred_entities) 12 | if TP: 13 | recall = 1 14 | else: 15 | recall = TP / len(real_entities_set) 16 | if (precision + recall): 17 | f1 = 2 * precision * recall / (precision + recall) 18 | else: 19 | f1 = 0 20 | return precision, recall, f1 21 | -------------------------------------------------------------------------------- /ckbqa/models/ner/__init__.py: -------------------------------------------------------------------------------- 1 | """主实体识别""" 2 | -------------------------------------------------------------------------------- /ckbqa/models/ner/crf.py: -------------------------------------------------------------------------------- 1 | """ 2 | CRF Module 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class CRF(nn.Module): 12 | """Conditional random field. 13 | 14 | This module implements a conditional random field [LMP01]_. The forward computation 15 | of this class computes the log likelihood of the given sequence of tags and 16 | emission score tensor. This class also has `~CRF.decode` method which finds 17 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_. 18 | 19 | Args: 20 | num_tags: Number of tags. 21 | batch_first: Whether the first dimension corresponds to the size of a minibatch. 22 | 23 | Attributes: 24 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size 25 | ``(num_tags,)``. 26 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size 27 | ``(num_tags,)``. 28 | transitions (`~torch.nn.Parameter`): Transition score tensor of size 29 | ``(num_tags, num_tags)``. 30 | 31 | 32 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). 33 | "Conditional random fields: Probabilistic models for segmenting and 34 | labeling sequence data". *Proc. 18th International Conf. on Machine 35 | Learning*. Morgan Kaufmann. pp. 282–289. 36 | 37 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm 38 | """ 39 | 40 | def __init__(self, num_tags: int, batch_first: bool = False) -> None: 41 | if num_tags <= 0: 42 | raise ValueError(f'invalid number of tags: {num_tags}') 43 | super().__init__() 44 | self.num_tags = num_tags 45 | self.batch_first = batch_first 46 | self.start_transitions = nn.Parameter(torch.empty(num_tags)) 47 | self.end_transitions = nn.Parameter(torch.empty(num_tags)) 48 | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) 49 | 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self) -> None: 53 | """Initialize the transition parameters. 54 | 55 | The parameters will be initialized randomly from a uniform distribution 56 | between -0.1 and 0.1. 57 | """ 58 | nn.init.uniform_(self.start_transitions, -0.1, 0.1) 59 | nn.init.uniform_(self.end_transitions, -0.1, 0.1) 60 | nn.init.uniform_(self.transitions, -0.1, 0.1) 61 | 62 | def __repr__(self) -> str: 63 | return f'{self.__class__.__name__}(num_tags={self.num_tags})' 64 | 65 | def forward( 66 | self, 67 | emissions: torch.Tensor, 68 | tags: torch.LongTensor, 69 | mask: Optional[torch.ByteTensor] = None, 70 | reduction: str = 'sum', 71 | ) -> torch.Tensor: 72 | """Compute the conditional log likelihood of a sequence of tags given emission scores. 73 | 74 | Args: 75 | emissions (`~torch.Tensor`): Emission score tensor of size 76 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 77 | ``(batch_size, seq_length, num_tags)`` otherwise. 78 | tags (`~torch.LongTensor`): Sequence of tags tensor of size 79 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, 80 | ``(batch_size, seq_length)`` otherwise. 81 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 82 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 83 | reduction: Specifies the reduction to apply to the output: 84 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. 85 | ``sum``: the output will be summed over batches. ``mean``: the output will be 86 | averaged over batches. ``token_mean``: the output will be averaged over tokens. 87 | 88 | Returns: 89 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if 90 | reduction is ``none``, ``()`` otherwise. 91 | """ 92 | self._validate(emissions, tags=tags, mask=mask) 93 | if reduction not in ('none', 'sum', 'mean', 'token_mean'): 94 | raise ValueError(f'invalid reduction: {reduction}') 95 | if mask is None: 96 | mask = torch.ones_like(tags, dtype=torch.uint8) 97 | 98 | if self.batch_first: 99 | emissions = emissions.transpose(0, 1) 100 | tags = tags.transpose(0, 1) 101 | mask = mask.transpose(0, 1) 102 | 103 | # shape: (batch_size,) 104 | numerator = self._compute_score(emissions, tags, mask) 105 | # shape: (batch_size,) 106 | denominator = self._compute_normalizer(emissions, mask) 107 | # shape: (batch_size,) 108 | llh = numerator - denominator 109 | 110 | if reduction == 'none': 111 | return llh 112 | if reduction == 'sum': 113 | return llh.sum() 114 | if reduction == 'mean': 115 | return llh.mean() 116 | assert reduction == 'token_mean' 117 | # return llh.sum() / mask.half().sum() 118 | return llh.sum() / mask.float().sum() 119 | 120 | def decode(self, emissions: torch.Tensor, 121 | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: 122 | """Find the most likely tag sequence using Viterbi algorithm. 123 | 124 | Args: 125 | emissions (`~torch.Tensor`): Emission score tensor of size 126 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 127 | ``(batch_size, seq_length, num_tags)`` otherwise. 128 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 129 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 130 | 131 | Returns: 132 | List of list containing the best tag sequence for each batch. 133 | """ 134 | self._validate(emissions, mask=mask) 135 | if mask is None: 136 | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) 137 | 138 | if self.batch_first: 139 | emissions = emissions.transpose(0, 1) 140 | mask = mask.transpose(0, 1) 141 | 142 | return self._viterbi_decode(emissions, mask) 143 | 144 | def _validate( 145 | self, 146 | emissions: torch.Tensor, 147 | tags: Optional[torch.LongTensor] = None, 148 | mask: Optional[torch.ByteTensor] = None) -> None: 149 | if emissions.dim() != 3: 150 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') 151 | if emissions.size(2) != self.num_tags: 152 | raise ValueError( 153 | f'expected last dimension of emissions is {self.num_tags}, ' 154 | f'got {emissions.size(2)}') 155 | 156 | if tags is not None: 157 | if emissions.shape[:2] != tags.shape: 158 | raise ValueError( 159 | 'the first two dimensions of emissions and tags must match, ' 160 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') 161 | 162 | if mask is not None: 163 | if emissions.shape[:2] != mask.shape: 164 | raise ValueError( 165 | 'the first two dimensions of emissions and mask must match, ' 166 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') 167 | no_empty_seq = not self.batch_first and mask[0].all() 168 | no_empty_seq_bf = self.batch_first and mask[:, 0].all() 169 | if not no_empty_seq and not no_empty_seq_bf: 170 | raise ValueError('mask of the first timestep must all be on') 171 | 172 | def _compute_score( 173 | self, emissions: torch.Tensor, tags: torch.LongTensor, 174 | mask: torch.ByteTensor) -> torch.Tensor: 175 | # emissions: (seq_length, batch_size, num_tags) 176 | # tags: (seq_length, batch_size) 177 | # mask: (seq_length, batch_size) 178 | assert emissions.dim() == 3 and tags.dim() == 2 179 | assert emissions.shape[:2] == tags.shape 180 | assert emissions.size(2) == self.num_tags 181 | assert mask.shape == tags.shape 182 | assert mask[0].all() 183 | 184 | seq_length, batch_size = tags.shape 185 | # mask = mask.half() 186 | mask = mask.float() 187 | 188 | # Start transition score and first emission 189 | # shape: (batch_size,) 190 | score = self.start_transitions[tags[0]] 191 | score += emissions[0, torch.arange(batch_size), tags[0]] 192 | 193 | for i in range(1, seq_length): 194 | # Transition score to next tag, only added if next timestep is valid (mask == 1) 195 | # shape: (batch_size,) 196 | score += self.transitions[tags[i - 1], tags[i]] * mask[i] 197 | # Emission score for next tag, only added if next timestep is valid (mask == 1) 198 | # shape: (batch_size,) 199 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] 200 | 201 | # End transition score 202 | # shape: (batch_size,) 203 | seq_ends = mask.long().sum(dim=0) - 1 204 | # shape: (batch_size,) 205 | last_tags = tags[seq_ends, torch.arange(batch_size)] 206 | # shape: (batch_size,) 207 | score += self.end_transitions[last_tags] 208 | 209 | return score 210 | 211 | def _compute_normalizer( 212 | self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: 213 | # emissions: (seq_length, batch_size, num_tags) 214 | # mask: (seq_length, batch_size) 215 | assert emissions.dim() == 3 and mask.dim() == 2 216 | assert emissions.shape[:2] == mask.shape 217 | assert emissions.size(2) == self.num_tags 218 | assert mask[0].all() 219 | 220 | seq_length = emissions.size(0) 221 | 222 | # Start transition score and first emission; score has size of 223 | # (batch_size, num_tags) where for each batch, the j-th column stores 224 | # the score that the first timestep has tag j 225 | # shape: (batch_size, num_tags) 226 | score = self.start_transitions + emissions[0] 227 | 228 | for i in range(1, seq_length): 229 | # Broadcast score for every possible next tag 230 | # shape: (batch_size, num_tags, 1) 231 | broadcast_score = score.unsqueeze(2) 232 | 233 | # Broadcast emission score for every possible current tag 234 | # shape: (batch_size, 1, num_tags) 235 | broadcast_emissions = emissions[i].unsqueeze(1) 236 | 237 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 238 | # for each sample, entry at row i and column j stores the sum of scores of all 239 | # possible tag sequences so far that end with transitioning from tag i to tag j 240 | # and emitting 241 | # shape: (batch_size, num_tags, num_tags) 242 | next_score = broadcast_score + self.transitions + broadcast_emissions 243 | 244 | # Sum over all possible current tags, but we're in score space, so a sum 245 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of 246 | # all possible tag sequences so far, that end in tag i 247 | # shape: (batch_size, num_tags) 248 | next_score = torch.logsumexp(next_score, dim=1) 249 | 250 | # Set score to the next score if this timestep is valid (mask == 1) 251 | # shape: (batch_size, num_tags) 252 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 253 | 254 | # End transition score 255 | # shape: (batch_size, num_tags) 256 | score += self.end_transitions 257 | 258 | # Sum (log-sum-exp) over all possible tags 259 | # shape: (batch_size,) 260 | return torch.logsumexp(score, dim=1) 261 | 262 | def _viterbi_decode(self, emissions: torch.Tensor, 263 | mask: torch.ByteTensor) -> List[List[int]]: 264 | # emissions: (seq_length, batch_size, num_tags) 265 | # mask: (seq_length, batch_size) 266 | assert emissions.dim() == 3 and mask.dim() == 2 267 | assert emissions.shape[:2] == mask.shape 268 | assert emissions.size(2) == self.num_tags 269 | assert mask[0].all() 270 | 271 | seq_length, batch_size = mask.shape 272 | 273 | # Start transition and first emission 274 | # shape: (batch_size, num_tags) 275 | score = self.start_transitions + emissions[0] 276 | 277 | # shape: (seq_length-1, batch_size, num_tags) 278 | history = torch.empty([seq_length - 1, batch_size, self.num_tags], dtype=torch.long, 279 | device=torch.device('cuda')) 280 | 281 | # score is a tensor of size (batch_size, num_tags) where for every batch, 282 | # value at column j stores the score of the best tag sequence so far that ends 283 | # with tag j 284 | # history saves where the best tags candidate transitioned from; this is used 285 | # when we trace back the best tag sequence 286 | 287 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 288 | # for every possible next tag 289 | for i in range(1, seq_length): 290 | # Broadcast viterbi score for every possible next tag 291 | # shape: (batch_size, num_tags, 1) 292 | broadcast_score = score.unsqueeze(2) 293 | 294 | # Broadcast emission score for every possible current tag 295 | # shape: (batch_size, 1, num_tags) 296 | broadcast_emission = emissions[i].unsqueeze(1) 297 | 298 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 299 | # for each sample, entry at row i and column j stores the score of the best 300 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting 301 | # shape: (batch_size, num_tags, num_tags) 302 | next_score = broadcast_score + self.transitions + broadcast_emission 303 | 304 | # Find the maximum score over all possible current tag 305 | # shape: (batch_size, num_tags), (batch_size, num_tags) 306 | next_score, indices = next_score.max(dim=1) 307 | 308 | # Set score to the next score if this timestep is valid (mask == 1) 309 | # and save the index that produces the next score 310 | # shape: (batch_size, num_tags) 311 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 312 | # history.append(indices) 313 | history[i - 1] = indices 314 | # End transition score 315 | # shape: (batch_size, num_tags) 316 | score += self.end_transitions 317 | 318 | # Now, compute the best path for each sample 319 | 320 | # shape: (batch_size,) 321 | seq_ends = mask.long().sum(dim=0) - 1 322 | # best_tags_list = [] 323 | best_tags_all = torch.empty(batch_size, seq_length, dtype=torch.long, device=torch.device('cuda')) 324 | 325 | best_tags = torch.empty(seq_length, dtype=torch.long, device=torch.device('cuda')) 326 | for idx in range(batch_size): 327 | # Find the tag which maximizes the score at the last timestep; this is our best tag 328 | # for the last timestep 329 | _, best_last_tag = score[idx].max(dim=0) 330 | best_tags[0] = best_last_tag 331 | 332 | # We trace back where the best last tag comes from, append that to our best tag 333 | # sequence, and trace it back again, and so on 334 | for i, hist in enumerate(history[:seq_ends[idx]].flip([0])): 335 | best_last_tag = hist[idx][best_tags[i]] 336 | best_tags[i + 1] = best_last_tag 337 | 338 | # Reverse the order because we start from the last timestep 339 | best_tags = best_tags.flip([0]) 340 | best_tags_all[idx] = best_tags 341 | 342 | return best_tags_all 343 | -------------------------------------------------------------------------------- /ckbqa/models/ner/model.py: -------------------------------------------------------------------------------- 1 | """BERT + CRF for named entity recognition""" 2 | 3 | import torch.nn as nn 4 | from pytorch_transformers import BertPreTrainedModel, BertModel 5 | 6 | from .crf import CRF 7 | 8 | 9 | class BERTCRF(BertPreTrainedModel): 10 | """BERT model for token-level classification. 11 | This module is composed of the BERT model with a linear layer on top of 12 | the full hidden state of the last layer. 13 | 14 | Args: 15 | config: a BertConfig class instance with the configuration to build a new model. 16 | num_labels: the number of classes for the classifier. Default = 2. 17 | 18 | Inputs: 19 | input_ids: a torch.LongTensor of shape (batch_size, seq_length) 20 | with the word token indices in the vocabulary. 21 | token_type_ids: an optional torch.LongTensor of shape (batch_size, seq_length) with the token 22 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 23 | a `sentence B` token. 24 | attention_mask: an optional torch.LongTensor of shape (batch_size, seq_length) with indices 25 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 26 | input sequence length in the current batch. It's the mask that we typically use for attention when 27 | a batch has varying length sentences. 28 | labels: labels for the classification output: torch.LongTensor of shape (batch_size, seq_length) 29 | with indices selected in [0, ..., num_labels]. 30 | 31 | Returns: 32 | if labels is not None: 33 | Outputs the CrossEntropy classification loss of the output with the labels. 34 | if labels is None: 35 | Outputs the classification tag indices of shape (batch_size, seq_length). 36 | """ 37 | 38 | def __init__(self, config, num_labels): 39 | super(BERTCRF, self).__init__(config) 40 | self.num_labels = num_labels # ent label num 41 | self.bert = BertModel(config) 42 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 43 | self.fc = nn.Linear(config.hidden_size, num_labels) 44 | self.crf = CRF(num_labels, batch_first=True) 45 | self.apply(self.init_bert_weights) 46 | # self.fc.weight.data.normal_(mean=0.0, std=0.02) 47 | # self.fc.bias.data.zero_() 48 | 49 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 50 | # shape: (batch_size, seq_length, hidden_size) 51 | outputs = self.bert(input_ids, token_type_ids, attention_mask) 52 | seq_output = outputs[0] 53 | # shape: (batch_size, seq_length, hidden_size) 54 | seq_output = self.dropout(seq_output) 55 | # shape: (batch_size, seq_length, num_labels) 56 | fc_output = self.fc(seq_output) 57 | 58 | if attention_mask is None: 59 | pred = self.crf.decode(fc_output) 60 | else: 61 | pred = self.crf.decode(fc_output, attention_mask) 62 | 63 | if labels is None: 64 | return pred 65 | else: 66 | if attention_mask is not None: 67 | # shape: (1,) 68 | loss = -self.crf(fc_output, labels, attention_mask, reduction='token_mean') 69 | else: 70 | loss = -self.crf(fc_output, labels, reduction='token_mean') 71 | return pred, loss 72 | -------------------------------------------------------------------------------- /ckbqa/models/relation_score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/models/relation_score/__init__.py -------------------------------------------------------------------------------- /ckbqa/models/relation_score/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | 主实体关联的关系打分,得到topn 3 | """ 4 | import torch 5 | from pytorch_transformers import BertModel 6 | from torch import nn 7 | 8 | from config import Config 9 | 10 | 11 | class BertMatch(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | self.bert = BertModel.from_pretrained(Config.pretrained_model_name_or_path) 15 | self.soft_max = nn.Softmax(dim=-1) 16 | self.fc = nn.Linear(768 * 3, 2) 17 | self.dropout = nn.Dropout(0.8) 18 | self.loss = nn.CrossEntropyLoss() 19 | 20 | def encode(self, x): 21 | outputs = self.bert(x) 22 | pooled_out = outputs[1] 23 | return pooled_out 24 | 25 | def forward(self, input1_ids, input2_ids, labels=None): 26 | # shape: (batch_size, seq_length, hidden_size) 27 | seq1_output = self.encode(input1_ids) 28 | seq2_output = self.encode(input2_ids) 29 | feature = torch.cat([seq1_output, seq2_output, seq1_output - seq2_output], dim=-1) 30 | logistic = self.fc(feature) 31 | if labels is None: 32 | output = self.soft_max(logistic) 33 | # pred = torch.argmax(output, dim=-1) 34 | return output 35 | else: 36 | loss = self.loss(logistic, labels) 37 | return logistic, loss 38 | 39 | 40 | class BertMatch2(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | self.bert = BertModel.from_pretrained(Config.pretrained_model_name_or_path) 44 | self.soft_max = nn.Softmax(dim=-1) 45 | self.fc1 = nn.Linear(768 * 3, 128) 46 | self.fc2 = nn.Linear(128, 2) 47 | self.dropout1 = nn.Dropout(0.8) 48 | self.dropout2 = nn.Dropout(0.8) 49 | self.loss = nn.CrossEntropyLoss() 50 | 51 | def encode(self, x): 52 | outputs = self.bert(x) 53 | pooled_out = outputs[1] 54 | return pooled_out 55 | 56 | def forward(self, input1_ids, input2_ids, labels=None): 57 | # shape: (batch_size, seq_length, hidden_size) 58 | seq1_output = self.encode(input1_ids) 59 | seq2_output = self.encode(input2_ids) 60 | feature = torch.cat([seq1_output, seq2_output, seq1_output - seq2_output], dim=-1) 61 | feature1 = self.fc1(feature) 62 | logistic = self.fc2(feature1) 63 | if labels is None: 64 | output = self.soft_max(logistic) 65 | # pred = torch.argmax(output, dim=-1) 66 | return output 67 | else: 68 | loss = self.loss(logistic, labels) 69 | return logistic, loss 70 | -------------------------------------------------------------------------------- /ckbqa/models/relation_score/predictor.py: -------------------------------------------------------------------------------- 1 | """现有模型封装;提供给预测使用""" 2 | import logging 3 | from typing import List 4 | 5 | from ckbqa.models.data_helper import DataHelper 6 | from ckbqa.models.relation_score.model import BertMatch, BertMatch2 7 | from ckbqa.utils.saver import Saver 8 | 9 | 10 | class RelationScorePredictor(object): 11 | def __init__(self, model_name): 12 | logging.info(f'BertMatchPredictor loaded sim_model init ...') 13 | self.data_helper = DataHelper(load_tokenizer=True) 14 | self.model = self.load_sim_model(model_name) 15 | logging.info(f'loaded sim_model: {model_name}') 16 | 17 | def load_sim_model(self, model_name: str): 18 | """文本相似度模型""" 19 | assert model_name in ['bert_match', 'bert_match2'] 20 | saver = Saver(model_name=model_name) 21 | if model_name == 'bert_match': 22 | model = BertMatch() 23 | elif model_name == 'bert_match2': 24 | model = BertMatch2() 25 | else: 26 | raise ValueError() 27 | saver.load_model(model) 28 | return model 29 | 30 | def iter_sample(self, q_text, sim_texts): 31 | batch_size = 32 32 | q_ids = self.data_helper.sent2ids(q_text) 33 | batch_q_sent_token_ids = [q_ids for _ in range(min(batch_size, len(sim_texts)))] 34 | q_tensors = self.data_helper.data2tensor(batch_q_sent_token_ids) 35 | batch_sim_texts = [sim_texts[i:i + batch_size] for i in range(0, len(sim_texts), batch_size)] 36 | for batch_sim_text in batch_sim_texts: 37 | sim_sent_token_ids = [self.data_helper.sent2ids(sim_text) for sim_text in batch_sim_text] 38 | sim_tensors = self.data_helper.data2tensor(sim_sent_token_ids) 39 | yield q_tensors[:len(sim_tensors)], sim_tensors 40 | 41 | def predict(self, q_text: str, sim_texts: List[str]): 42 | all_preds = [] 43 | for q_tensors, sim_tensors in self.iter_sample(q_text, sim_texts): 44 | preds = self.model(q_tensors, sim_tensors) 45 | preds = preds[:, 1] 46 | all_preds.extend(preds.tolist()) 47 | return all_preds 48 | -------------------------------------------------------------------------------- /ckbqa/models/relation_score/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import re 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.model_selection import train_test_split 9 | 10 | from ckbqa.dataset.data_prepare import load_data, question_patten 11 | from ckbqa.models.base_trainer import BaseTrainer 12 | from ckbqa.models.data_helper import DataHelper 13 | from ckbqa.models.relation_score.model import BertMatch, BertMatch2 14 | from ckbqa.utils.saver import Saver 15 | from ckbqa.utils.tools import json_load 16 | from config import Config 17 | 18 | 19 | class RelationScoreTrainer(BaseTrainer): 20 | """关系打分模块训练""" 21 | 22 | def __init__(self, model_name): 23 | super().__init__(model_name) 24 | self.model_name = model_name 25 | self.data_helper = DataHelper() 26 | self.saver = Saver(model_name=model_name) 27 | 28 | def data2samples(self, neg_rate=3, test_size=0.1): 29 | if os.path.isfile(Config.get_relation_score_sample_csv('train', neg_rate)): 30 | return 31 | questions = [] 32 | sim_questions = [] 33 | labels = [] 34 | all_relations = list(json_load(Config.relation2id)) 35 | _entity_pattern = re.compile(r'["<](.*?)[>"]') 36 | for q, sparql, a in load_data(tqdm_prefix='data2samples '): 37 | q_text = question_patten.findall(q)[0] 38 | q_entities = _entity_pattern.findall(sparql) 39 | questions.append(q_text) 40 | sim_questions.append('的'.join(q_entities)) 41 | labels.append(1) 42 | # 43 | for neg_relation in random.sample(all_relations, neg_rate): 44 | questions.append(q_text) 45 | neg_question = '的'.join(q_entities[:-1] + [neg_relation]) # 随机替换 <关系> 46 | sim_questions.append(neg_question) 47 | labels.append(0) 48 | data_df = pd.DataFrame({'question': questions, 'sim_question': sim_questions, 'label': labels}) 49 | data_df.to_csv(Config.relation_score_sample_csv, encoding='utf_8_sig', index=False) 50 | train_df, test_df = train_test_split(data_df, test_size=test_size) 51 | test_df.to_csv(Config.get_relation_score_sample_csv('test', neg_rate), encoding='utf_8_sig', index=False) 52 | train_df.to_csv(Config.get_relation_score_sample_csv('train', neg_rate), encoding='utf_8_sig', index=False) 53 | 54 | def batch_iter(self, data_type, batch_size, _shuffle=True, fixed_seq_len=None): 55 | file_path = Config.get_relation_score_sample_csv(data_type=data_type, neg_rate=3) 56 | logging.info(f'* load data from {file_path}') 57 | data_df = pd.read_csv(file_path) 58 | x_data = data_df['question'].apply(self.data_helper.sent2ids) 59 | y_data = data_df['sim_question'].apply(self.data_helper.sent2ids) 60 | labels = data_df['label'] 61 | data_size = len(x_data) 62 | order = list(range(data_size)) 63 | if _shuffle: 64 | np.random.shuffle(order) 65 | for batch_step in range(data_size // batch_size + 1): 66 | batch_idxs = order[batch_step * batch_size:(batch_step + 1) * batch_size] 67 | if len(batch_idxs) != batch_size: # batch size 不可过大; 不足batch_size的数据丢弃(最后一batch) 68 | continue 69 | q_sents = self.data_helper.data2tensor([x_data[idx] for idx in batch_idxs]) 70 | a_sents = self.data_helper.data2tensor([y_data[idx] for idx in batch_idxs]) 71 | batch_labels = self.data_helper.data2tensor([labels[idx] for idx in batch_idxs], pad=False) 72 | yield q_sents, a_sents, batch_labels 73 | 74 | def train_match_model(self, mode='train'): 75 | """ 76 | :mode: train,test 77 | """ 78 | self.data2samples(neg_rate=3, test_size=0.1) 79 | if self.model_name == 'bert_match': # 单层 80 | model = BertMatch() 81 | elif self.model_name == 'bert_match2': # 双层 82 | model = BertMatch2() 83 | else: 84 | raise ValueError() 85 | for para in model.bert.parameters(): 86 | para.requires_grad = False # bert参数不训练 87 | model = self.init_model(model) 88 | model_path, epoch, step = self.saver.load_model(model, fail_ok=True) 89 | self.global_step = step 90 | for q_sents, a_sents, batch_labels in self.batch_iter( 91 | data_type='train', batch_size=32): 92 | pred, loss = model(q_sents, a_sents, batch_labels) 93 | # print(f' {str(datetime.datetime.now())[:19]} global_step: {self.global_step}, loss:{loss.item():.04f}') 94 | logging.info(f' global_step: {self.global_step}, loss:{loss.item():.04f}') 95 | self.backfoward(loss, model) 96 | if self.global_step % 100 == 0: 97 | model_path = self.saver.save(model, epoch=1, step=self.global_step) 98 | logging.info(f'save to {model_path}') 99 | # if mode == 'test': 100 | # output = torch.softmax(pred, dim=-1) 101 | # pred = torch.argmax(output, dim=-1) 102 | # print('labels: ', batch_labels) 103 | # print('pred: ', pred.tolist()) 104 | # print('--' * 10 + '\n\n') 105 | -------------------------------------------------------------------------------- /ckbqa/qa/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 不做模型训练 3 | 4 | 将其他模型训练完后组合应用到这里,提取答案 5 | """ 6 | -------------------------------------------------------------------------------- /ckbqa/qa/algorithms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Set 4 | 5 | legal_char_partten = re.compile(u"([^\u4e00-\u9fa5\u0030-\u0039\u0041-\u005a\u0061-\u007a])") # 中英文和数字 6 | 7 | 8 | def sequences_set_similar(s1: Set, s2: Set): 9 | overlap = len(s1 & s2) 10 | jaccard = overlap / len(s1 | s2) # 集合距离 11 | # 词向量相似度 12 | # wordvecsim = model.similarity(''.join(s1),''.join(s2)) 13 | return overlap, jaccard 14 | 15 | 16 | class Algorithms(object): 17 | def __init__(self): 18 | self.invalid_names = {'是什么', ''} # 汽车的动力是什么 19 | 20 | def get_most_overlap_path(self, q_text, paths): 21 | # 从排名前几的tuples里选择与问题overlap最多的 22 | max_score = 0 23 | top_path = [] 24 | q_words = set(q_text) 25 | pure_qtext = legal_char_partten.sub('', q_text) 26 | for path in paths: 27 | path_words = set() 28 | score5 = 0 29 | for ent_rel in path: 30 | if ent_rel.startswith('<'): 31 | ent_rel = ent_rel[1:-1].split('_')[0] 32 | path_words.update(set(ent_rel)) 33 | pure_ent = legal_char_partten.sub('', ent_rel) 34 | if pure_ent not in self.invalid_names and pure_ent in pure_qtext: 35 | score5 += 1 36 | # print(pure_ent) 37 | common_words = path_words & q_words # 交集 38 | score1 = len(common_words) # 主score,交集长度 ;权重排名:1>2>3>4 39 | score2 = 1 / len(path_words) # 候选词长度越短越好,权重4 40 | score3 = 1 / len(path) # 跳数越少越好,权重2 41 | score4 = 0.5 if path[-1].startswith('<') else 0 # 实体大于属性,实体加分,权重3;会造成越长分数越高 42 | score = score1 + score2 + score3 + score4 + score5 43 | # log_str = (f'score : {score:.3f}, score1 : {score1}, score2 : {score2:.3f}, ' 44 | # f'score3 : {score3:.3f}, score4 : {score4}, score5: {score5}, ' 45 | # f'path: {path}') 46 | # logging.info(log_str) 47 | # print(log_str) 48 | if score > max_score: 49 | top_path = path 50 | max_score = score 51 | # logging.info(f'score: {max_score}, top_path: {top_path}') 52 | return top_path, max_score 53 | 54 | 55 | if __name__ == '__main__': 56 | # q_text = '怎样的检查项目能对小儿多源性房性心动过速、急性肾>功能不全以及动静脉血管瘤做出检测?' 57 | # out_paths = [['<肾功能不全>', '<>来源>']] 58 | # in_paths = [['肾功能', '<检查项目>', '<标签>']] 59 | # q_text = '演员梅艳芳有多高?' 60 | # out_paths = [['<梅艳芳>', '<身高>']] # 6.2 61 | # in_paths = [['<梅艳芳>', '<演员>']] # 8.2 62 | q_text = '叶文洁毕业于哪个大学?' 63 | out_paths = [['<叶文洁>', '<毕业院校>', '<学校代码>']] # 7.9333 64 | in_paths = [['<大学>', '<毕业于>', '<类型>']] # 7.9762 65 | algo = Algorithms() 66 | algo.get_most_overlap_path(q_text, out_paths) 67 | algo.get_most_overlap_path(q_text, in_paths) 68 | -------------------------------------------------------------------------------- /ckbqa/qa/cache.py: -------------------------------------------------------------------------------- 1 | from ckbqa.utils.decorators import singleton 2 | from ckbqa.utils.tools import json_load 3 | from config import Config 4 | 5 | 6 | @singleton 7 | class Memory(object): 8 | def __init__(self): 9 | self.entity2id = json_load(Config.entity2id) 10 | self.mention2entity = json_load(Config.mention2ent_json) 11 | # self.all_attrs = set(json_load(Config.all_attrs_json)) 12 | 13 | def get_entity_id(self, ent_name): 14 | if ent_name in self.entity2id: 15 | ent_id = self.entity2id[ent_name] 16 | else: 17 | # ent_name = ''.join(ent_patten.findall(ent_name)) 18 | ent_id = self.entity2id.get(ent_name[1:-1], 0) 19 | return ent_id 20 | -------------------------------------------------------------------------------- /ckbqa/qa/el.py: -------------------------------------------------------------------------------- 1 | """实体链接模块""" 2 | from collections import defaultdict 3 | from typing import List, Dict 4 | 5 | import numpy as np 6 | 7 | from ckbqa.models.entity_score.model import EntityScore 8 | from ckbqa.qa.algorithms import sequences_set_similar 9 | from ckbqa.qa.cache import Memory 10 | from ckbqa.qa.lac_tools import Ngram, BaiduLac 11 | from ckbqa.qa.neo4j_graph import GraphDB 12 | from ckbqa.utils.decorators import singleton 13 | 14 | 15 | @singleton 16 | class CEG(object): 17 | """ Candidate Entity Generation: NER 18 | 从mention出发,找到KB中所有可能的实体,组成候选实体集 (candidate entities); 19 | 最主流也最有效的方法就是Name Dictionary,说白了就是配别名 20 | """ 21 | 22 | def __init__(self): 23 | # self.jieba = JiebaLac(load_custom_dict=True) 24 | self.lac = BaiduLac(mode='lac', _load_customization=True) 25 | self.memory = Memory() 26 | self.ngram = Ngram() 27 | self.mention_stop_words = {'是什么', '在哪里', '哪里', '什么', '提出的', '有什么', '国家', '哪个', '所在', 28 | '培养出', '为什么', '什么时候', '人', '你知道', '都包括', '是谁', '告诉我', 29 | '又叫做'} 30 | self.entity_stop_words = {'有', '是', '《', '》', '率', '·'} 31 | 32 | self.main_tags = {'n', 'an', 'nw', 'nz', 'PER', 'LOC', 'ORG', 'TIME', 33 | 'vn', 'v'} 34 | 35 | def get_ent2mention(self, q_text: str) -> dict: 36 | ''' 37 | 标签 含义 标签 含义 标签 含义 标签 含义 38 | n 普通名词 f 方位名词 s 处所名词 nw 作品名 39 | nz 其他专名 v 普通动词 vd 动副词 vn 名动词 40 | a 形容词 ad 副形词 an 名形词 d 副词 41 | m 数量词 q 量词 r 代词 p 介词 42 | c 连词 u 助词 xc 其他虚词 w 标点符号 43 | PER 人名 LOC 地名 ORG 机构名 TIME 时间 44 | ''' 45 | entity2mention = {} 46 | words, tags = self.lac.run(q_text) 47 | main_text = ''.join([word for word, tag in zip(words, tags) if tag in self.main_tags]) 48 | main_text_ngrams = list(self.ngram.get_all_grams(main_text)) 49 | q_text_grams = list(self.ngram.get_all_grams(q_text)) 50 | for gram in set(q_text_grams + main_text_ngrams): 51 | if f"<{gram}>" in self.memory.entity2id: 52 | entity2mention[f'<{gram}>'] = gram 53 | if gram in self.memory.entity2id: 54 | entity2mention[gram] = gram 55 | if gram in self.memory.mention2entity and gram not in self.mention_stop_words: 56 | for ent in self.memory.mention2entity[gram]: 57 | if f"<{ent}>" in self.memory.entity2id: 58 | entity2mention[f'<{ent}>'] = gram 59 | if ent in self.memory.entity2id: 60 | entity2mention[ent] = gram 61 | # print(gram) 62 | return entity2mention 63 | 64 | def seg_text(self, text: str) -> List[str]: 65 | return self.lac.run(text)[0] 66 | 67 | # def _get_ent2mention(self, text: str) -> Dict: 68 | # entity2mention = {} 69 | # lac_words, tags = self.lac.run(text) 70 | # jieba_words = [w for w in self.jieba.cut_for_search(text)] 71 | # for word in set(lac_words + jieba_words): 72 | # # if word in self.entity_stop_words: 73 | # # continue 74 | # if word in self.memory.mention2entity and word not in self.mention_stop_words: 75 | # for ent in self.memory.mention2entity[word]: 76 | # if f"<{ent}>" in self.memory.entity2id: 77 | # entity2mention[f'<{ent}>'] = word 78 | # if ent in self.memory.entity2id: 79 | # entity2mention[ent] = word 80 | # if f"<{word}>" in self.memory.entity2id: 81 | # entity2mention[f'<{word}>'] = word 82 | # if word in self.memory.entity2id: 83 | # entity2mention[word] = word 84 | # return entity2mention 85 | 86 | 87 | @singleton 88 | class ED(object): 89 | """ Entity Disambiguation: 90 | 从candidate entities中,选择最可能的实体作为预测实体。 91 | """ 92 | 93 | def __init__(self): 94 | self.memory = Memory() 95 | self.graph_db = GraphDB() 96 | self.ceg = CEG() 97 | self.mention_stop_words = {'是什么', '在哪里', '哪里', '什么', '提出的', '有什么', '国家', '哪个', '所在', 98 | '培养出', '为什么', '什么时候', '人', '你知道', '都包括', '是谁', '告诉我', 99 | '又叫做'} 100 | self.entity_stop_words = {'有', '是', '《', '》', '率', '·'} 101 | # self.entity_score = EntityScore() 102 | self.entity_score = EntityScore(load_pretrain_model=True) 103 | 104 | def ent_rel_similar(self, question: str, entity: str, relations: List) -> List: 105 | ''' 106 | 抽取每个实体或属性值2hop内的所有关系,来跟问题计算各种相似度特征 107 | input: 108 | question: python-str 109 | entity: python-str 110 | relations: python-dic key: 111 | output: 112 | [word_overlap,char_overlap,word_embedding_similarity,char_overlap_ratio] 113 | ''' 114 | # 得到主语-谓词的tokens及chars 115 | rel_tokens = set() 116 | for rel_name in set(relations): 117 | rel_tokens.update(self.ceg.seg_text(rel_name[1:-1])) 118 | rel_chars = set([char for rel_name in relations for char in rel_name]) 119 | # 120 | question_tokens = set(self.ceg.seg_text(question)) 121 | question_chars = set(question) 122 | # 123 | entity_tokens = set(self.ceg.seg_text(entity[1:-1])) 124 | entity_chars = set(entity[1:-1]) 125 | # 126 | qestion_ent_sim = (sequences_set_similar(question_tokens, entity_tokens) + 127 | sequences_set_similar(question_chars, entity_chars)) 128 | qestion_rel_sim = (sequences_set_similar(question_tokens, rel_tokens) + 129 | sequences_set_similar(question_chars, rel_chars)) 130 | # 实体名和问题的overlap除以实体名长度的比例 131 | return qestion_ent_sim + qestion_rel_sim 132 | 133 | def get_entity_feature(self, q_text: str, ent_name: str): 134 | # 得到实体两跳内的所有关系 135 | one_hop_out_rel_names = self.graph_db.get_onehop_relations_by_entName(ent_name, direction='out') 136 | two_hop_out_rel_names = self.graph_db.get_twohop_relations_by_entName(ent_name, direction='out') 137 | out_rel_names = [rel for rels in two_hop_out_rel_names for rel in rels] + one_hop_out_rel_names 138 | # 139 | one_hop_in_rel_names = self.graph_db.get_onehop_relations_by_entName(ent_name, direction='in') 140 | two_hop_in_rel_names = self.graph_db.get_twohop_relations_by_entName(ent_name, direction='in') 141 | in_rel_names = [rel for rels in two_hop_in_rel_names for rel in rels] + one_hop_in_rel_names 142 | # 计算问题和主语实体及其两跳内关系间的相似度 143 | similar_feature = self.ent_rel_similar(q_text, ent_name, out_rel_names + in_rel_names) 144 | # 实体的流行度特征 145 | popular_feature = self.graph_db.get_onehop_relCount_by_entName(ent_name) 146 | return similar_feature, popular_feature 147 | 148 | def get_candidate_entities(self, q_text: str, entity2mention: dict) -> Dict: 149 | """ 150 | :param q_text: 151 | :return: candidate_subject: { ent_name: {'mention':mention_txt, 152 | 'feature': [feature1, feature2, ...] 153 | }, 154 | ... 155 | } 156 | """ 157 | candidate_subject = defaultdict(dict) # {ent:[mention, feature1, feature2, ...]} 158 | # 159 | for entity, mention in entity2mention.items(): 160 | similar_feature, popular_feature = self.get_entity_feature(q_text, entity) 161 | candidate_subject[entity]['mention'] = mention 162 | candidate_subject[entity]['feature'] = [sim for sim in similar_feature] + [popular_feature ** 0.5] 163 | 164 | # subject_score_topn打分做筛选 165 | top_subjects = self.subject_score_topn(candidate_subject) 166 | return top_subjects 167 | 168 | def subject_score_topn(self, candidate_entities: dict, top_k=10): 169 | ''' 170 | :candidate_entities {ent: {'mention':mention}, 171 | 'feature':[...], 172 | ent2:..., 173 | ... 174 | } 175 | 输入候选主语和对应的特征,使用训练好的模型进行打分,排序后返回前topn个候选主语 176 | ''' 177 | top_k = len(candidate_entities) # TODO 后需删除;不做筛选,保留所有实体;目前筛选效果不好; 178 | if top_k >= len(candidate_entities): # 太少则不作筛选 179 | return candidate_entities 180 | entities = [] 181 | features = [] 182 | for ent, feature_dic in candidate_entities.items(): 183 | entities.append(ent) 184 | features.append(feature_dic['feature']) 185 | scores = self.entity_score.predict(features) 186 | # print(f'scores: {scores}') 187 | top_k_indexs = np.asarray(scores).argsort()[-top_k:][::-1] 188 | top_k_entities = {entities[i]: candidate_entities[entities[i]] 189 | for i in top_k_indexs} 190 | return top_k_entities 191 | 192 | 193 | @singleton 194 | class EL(object): 195 | """实体链接""" 196 | 197 | def __init__(self): 198 | self.ceg = CEG() # Candidate Entity Generation 199 | self.ed = ED() # Entity Disambiguation 200 | 201 | def el(self, q_text: str) -> Dict: 202 | ent2mention = self.ceg.get_ent2mention(q_text) 203 | candidate_entities = self.ed.get_candidate_entities(q_text, ent2mention) 204 | return candidate_entities 205 | 206 | # try 207 | # except: 208 | # import traceback,ipdb 209 | # traceback.print_exc() 210 | # ipdb.set_trace() 211 | -------------------------------------------------------------------------------- /ckbqa/qa/lac_tools.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import os 4 | 5 | import jieba 6 | from LAC import LAC 7 | from LAC.ahocorasick import Ahocorasick 8 | from tqdm import tqdm 9 | 10 | from ckbqa.utils.decorators import singleton 11 | from ckbqa.utils.tools import tqdm_iter_file, pkl_dump, json_load, json_dump 12 | from config import Config 13 | 14 | 15 | class Ngram(object): 16 | def __init__(self): 17 | pass 18 | 19 | def ngram(self, text, n): 20 | ngrams = [text[i:i + n] for i in range(len(text) - n + 1)] 21 | return ngrams 22 | 23 | def get_all_grams(self, text): 24 | for n in range(2, len(text)): 25 | for ngram in self.ngram(text, n): 26 | yield ngram 27 | # for gram in ngrams: 28 | # yield gram 29 | 30 | 31 | @singleton 32 | class JiebaLac(object): 33 | def __init__(self, load_custom_dict=True): 34 | # jieba.enable_paddle() # 启动paddle模式。 0.40版之后开始支持,早期版本不支持 35 | if load_custom_dict: 36 | self.load_custom_dict() 37 | 38 | def load_custom_dict(self): 39 | if not os.path.isfile(Config.jieba_custom_dict): 40 | all_ent_mention_words = set(json_load(Config.mention2count_json).keys()) 41 | entity_set = set(json_load(Config.entity2id).keys()) 42 | for ent in tqdm(entity_set, total=len(entity_set), desc='create jieba words '): 43 | if ent.startswith('<'): 44 | ent = ent[1:-1] 45 | all_ent_mention_words.add(ent) 46 | # FREQ,total, 47 | # 模仿jieba.add_word,重写逻辑,加速 48 | # jieba.dt.FREQ = {} 49 | # jieba.dt.total = 0 50 | for word in tqdm(all_ent_mention_words, desc='jieba custom create '): 51 | freq = len(word) * 3 52 | jieba.dt.FREQ[word] = freq 53 | jieba.dt.total += freq 54 | for ch in range(len(word)): 55 | wfrag = word[:ch + 1] 56 | if wfrag not in jieba.dt.FREQ: 57 | jieba.dt.FREQ[wfrag] = 0 58 | del all_ent_mention_words 59 | gc.collect() 60 | json_dump({'dt.FREQ': jieba.dt.FREQ, 'dt.total': jieba.dt.total}, 61 | Config.jieba_custom_dict) 62 | logging.info('create jieba custom dict done ...') 63 | # load 64 | jieba_custom = json_load(Config.jieba_custom_dict) 65 | jieba.dt.check_initialized() 66 | jieba.dt.FREQ = jieba_custom['dt.FREQ'] 67 | jieba.dt.total = jieba_custom['dt.total'] 68 | logging.info('load jieba custom dict done ...') 69 | 70 | def cut_for_search(self, text): 71 | return jieba.cut_for_search(text) 72 | 73 | def cut(self, text): 74 | return jieba.cut(text) 75 | 76 | 77 | @singleton 78 | class BaiduLac(LAC): 79 | def __init__(self, model_path=None, mode='lac', use_cuda=False, 80 | _load_customization=True, reload=False, 81 | customization_path=Config.lac_custom_dict_txt): 82 | super().__init__(model_path=model_path, mode=mode, use_cuda=use_cuda) 83 | self.mode = mode # lac,seg 84 | self.reload = reload 85 | self.customization_path = customization_path 86 | if _load_customization: 87 | self._load_customization() 88 | 89 | def _save_customization(self): 90 | # lac = LAC(mode=self.mode) # 装载LAC模型,LAC(mode='seg') 91 | # lac.load_customization(self.customization_path) # 35万2min;100万>20min; 92 | parms_dict = {'dictitem': self.custom.dictitem, 93 | # 'ac': self.custom.ac, 94 | 'mode': self.mode} 95 | pkl_dump(parms_dict, Config.lac_model_pkl) 96 | return Config.lac_model_pkl 97 | 98 | def _load_customization(self): 99 | logging.info('暂停载入BaiduLac自定词典 ...') 100 | return 101 | # self.custom = Customization() 102 | # if self.reload or not os.path.isfile(Config.lac_model_pkl): 103 | # self.custom.load_customization(self.customization_path) # 35万2min;100万,20min; 104 | # self._save_customization() 105 | # else: 106 | # logging.info('load lac customization start ...') 107 | # parms_dict = pkl_load(Config.lac_model_pkl) 108 | # self.custom.dictitem = parms_dict['dictitem'] # dict 109 | # self.custom.ac = Ahocorasick() 110 | # for phrase in tqdm(parms_dict['dictitem'], desc='load baidu lac '): 111 | # self.custom.ac.add_word(phrase) 112 | # self.custom.ac.make() 113 | # logging.info('loaded lac customization Done ...') 114 | 115 | 116 | class Customization(object): 117 | """ 118 | 基于AC自动机实现用户干预的功能 119 | """ 120 | 121 | def __init__(self): 122 | self.dictitem = {} 123 | self.ac = None 124 | pass 125 | 126 | def load_customization(self, filename): 127 | """装载人工干预词典""" 128 | self.ac = Ahocorasick() 129 | 130 | for line in tqdm_iter_file(filename, prefix='load_customization '): 131 | words = line.strip().split() 132 | if len(words) == 0: 133 | continue 134 | 135 | phrase = "" 136 | tags = [] 137 | offset = [] 138 | for word in words: 139 | if word.rfind('/') < 1: 140 | phrase += word 141 | tags.append('') 142 | else: 143 | phrase += word[:word.rfind('/')] 144 | tags.append(word[word.rfind('/') + 1:]) 145 | offset.append(len(phrase)) 146 | 147 | if len(phrase) < 2 and tags[0] == '': 148 | continue 149 | 150 | self.dictitem[phrase] = (tags, offset) 151 | self.ac.add_word(phrase) 152 | self.ac.make() 153 | 154 | def parse_customization(self, query, lac_tags): 155 | """使用人工干预词典修正lac模型的输出""" 156 | 157 | def ac_postpress(ac_res): 158 | ac_res.sort() 159 | i = 1 160 | while i < len(ac_res): 161 | if ac_res[i - 1][0] < ac_res[i][0] and ac_res[i][0] <= ac_res[i - 1][1]: 162 | ac_res.pop(i) 163 | continue 164 | i += 1 165 | return ac_res 166 | 167 | if not self.ac: 168 | logging.warning("customization dict is not load") 169 | return 170 | 171 | ac_res = self.ac.search(query) 172 | 173 | ac_res = ac_postpress(ac_res) 174 | 175 | for begin, end in ac_res: 176 | phrase = query[begin:end + 1] 177 | index = begin 178 | 179 | tags, offsets = self.dictitem[phrase] 180 | for tag, offset in zip(tags, offsets): 181 | while index < begin + offset: 182 | if len(tag) == 0: 183 | lac_tags[index] = lac_tags[index][:-1] + 'I' 184 | else: 185 | lac_tags[index] = tag + "-I" 186 | index += 1 187 | 188 | lac_tags[begin] = lac_tags[begin][:-1] + 'B' 189 | for offset in offsets: 190 | index = begin + offset 191 | if index < len(lac_tags): 192 | lac_tags[index] = lac_tags[index][:-1] + 'B' 193 | -------------------------------------------------------------------------------- /ckbqa/qa/neo4j_graph.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import logging 3 | import os 4 | import time 5 | 6 | from py2neo import Graph 7 | 8 | from ckbqa.qa.cache import Memory 9 | from ckbqa.utils.async_tools import apply_async 10 | from ckbqa.utils.decorators import singleton, synchronized 11 | from ckbqa.utils.tools import json_load, json_dump 12 | from config import Config 13 | 14 | graph = Graph('http://localhost:7474', username='neo4j', password='password') # bolt://localhost:7687 15 | 16 | 17 | @singleton 18 | class GraphDB(object): 19 | def __init__(self): 20 | self._one_hop_relNames_map = {'in': {}, 'out': {}} 21 | self._two_hop_relNames_map = {'in': {}, 'out': {}} 22 | self.all_directions = {'in', 'out'} 23 | self.load_cache() 24 | self.memory = Memory() 25 | self.async_cache() 26 | 27 | def __del__(self): 28 | # time.sleep(30) # 等待另一个线程写入完成,避免交叉影响;使用线程锁替代 29 | self.cache() # 保证程序结束前会正确保存文件 30 | 31 | def get_total_entity_count(self): 32 | count = (len(self._one_hop_relNames_map['out']) + 33 | len(self._one_hop_relNames_map['in']) + 34 | len(self._two_hop_relNames_map['out']) + 35 | len(self._two_hop_relNames_map['in'])) 36 | return count 37 | 38 | def load_cache(self): 39 | # self.queue = Queue(maxsize=1) 40 | if os.path.isfile(Config.neo4j_query_cache): 41 | data_dict = json_load(Config.neo4j_query_cache) 42 | self._one_hop_relNames_map = data_dict['_one_hop_relNames_map'] 43 | self._two_hop_relNames_map = data_dict['_two_hop_relNames_map'] 44 | self.total_count = self.get_total_entity_count() 45 | logging.info(f'load neo4j_query_cache, total: {self.total_count}') 46 | else: 47 | logging.info(f'not found neo4j_query_cache: {Config.neo4j_query_cache}') 48 | self.total_count = 0 49 | 50 | # def put_cache_sign(self): 51 | # self.queue.put(1) # 每次结束后检查 52 | 53 | def async_cache(self): 54 | def loop_check_cache(): 55 | while True: # self.queue 56 | self.cache() 57 | time.sleep(10 * 60) 58 | 59 | thr = apply_async(loop_check_cache, daemon=True) 60 | return thr 61 | 62 | @synchronized # 避免多线程交叉影响 63 | def cache(self): 64 | total = self.get_total_entity_count() 65 | if total > self.total_count: 66 | parms = {'_one_hop_relNames_map': self._one_hop_relNames_map, 67 | '_two_hop_relNames_map': self._two_hop_relNames_map} 68 | json_dump(parms, Config.neo4j_query_cache) 69 | logging.info(f'neo4j_query_cache , increase: {total - self.total_count}, total: {total}') 70 | self.total_count = total 71 | 72 | def get_onehop_relations_by_entName(self, ent_name: str, direction='out'): 73 | assert direction in self.all_directions 74 | if ent_name not in self._one_hop_relNames_map[direction]: 75 | ent_id = self.memory.get_entity_id(ent_name) 76 | start, end = ('ent:Entity', '') if direction == 'out' else ('', 'ent:Entity') 77 | cpl = f"MATCH p=({start})-[r1:Relation]->({end}) WHERE ent.id={ent_id} RETURN DISTINCT r1.name" 78 | _one_hop_relNames = [rel.data()['r1.name'] for rel in graph.run(cpl)] 79 | self._one_hop_relNames_map[direction][ent_name] = _one_hop_relNames 80 | rel_names = self._one_hop_relNames_map[direction][ent_name] 81 | return rel_names 82 | 83 | def get_twohop_relations_by_entName(self, ent_name: str, direction='out'): 84 | assert direction in self.all_directions 85 | if ent_name not in self._two_hop_relNames_map[direction]: 86 | ent_id = self.memory.get_entity_id(ent_name) 87 | start, end = ('ent:Entity', '') if direction == 'out' else ('', 'ent:Entity') 88 | cql = f"MATCH p=({start})-[r1:Relation]->()-[r2:Relation]->({end}) WHERE ent.id={ent_id} RETURN DISTINCT r1.name,r2.name" 89 | rels = [rel.data() for rel in graph.run(cql)] 90 | rel_names = [[d['r1.name'], d['r2.name']] for d in rels] 91 | self._two_hop_relNames_map[direction][ent_name] = rel_names 92 | rel_names = self._two_hop_relNames_map[direction][ent_name] 93 | return rel_names 94 | 95 | def get_onehop_relCount_by_entName(self, ent_name: str): 96 | '''根据实体名,得到与之相连的关系数量,代表实体在知识库中的流行度''' 97 | if ent_name not in self._one_hop_relNames_map['out']: 98 | rel_names = self.get_onehop_relations_by_entName(ent_name, direction='out') 99 | self._one_hop_relNames_map['out'][ent_name] = rel_names 100 | if ent_name not in self._one_hop_relNames_map['in']: 101 | rel_names = self.get_onehop_relations_by_entName(ent_name, direction='in') 102 | self._one_hop_relNames_map['in'][ent_name] = rel_names 103 | count = len(self._one_hop_relNames_map['out'][ent_name]) + len(self._one_hop_relNames_map['in'][ent_name]) 104 | return count 105 | 106 | def search_by_2path(self, ent_name, rel_name, direction='out'): 107 | ent_id = self.memory.get_entity_id(ent_name) 108 | assert direction in self.all_directions 109 | start, end = ('ent:Entity', 'target') if direction == 'out' else ('target', 'ent:Entity') 110 | cql = (f"match ({start})-[r1:Relation]-({end}) where ent.id={ent_id} " 111 | f" and r1.name='{rel_name}' return DISTINCT target.name") 112 | logging.info(f"{cql}; *ent_name:{ent_name}") 113 | res = graph.run(cql) 114 | target_name = [ent['target.name'] for ent in res] 115 | return target_name 116 | 117 | def search_by_3path(self, ent_name, rel1_name, rel2_name, direction='out'): 118 | ent_id = self.memory.get_entity_id(ent_name) 119 | assert direction in self.all_directions 120 | start, end = ('ent:Entity', 'target') if direction == 'out' else ('target', 'ent:Entity') 121 | cql = (f"match ({start})-[r1:Relation]-()-[r2:Relation]-({end}) where ent.id={ent_id} " 122 | f" and r1.name='{rel1_name}' and r2.name='{rel2_name}' return DISTINCT target.name") 123 | logging.info(f"{cql}; *ent_name:{ent_name}") 124 | res = graph.run(cql) 125 | target_name = [ent['target.name'] for ent in res] 126 | return target_name 127 | 128 | def search_by_4path(self, ent1_name, rel1_name, rel2_name, ent2_name, direction='out'): 129 | ent1_id = self.get_entity_id(ent1_name) 130 | ent2_id = self.get_entity_id(ent2_name) 131 | assert direction in self.all_directions 132 | start, end = ('ent:Entity', 'target') if direction == 'out' else ('target', 'ent:Entity') 133 | cql = (f"match ({start})-[r1:Relation]-({end})-[r2:Relation]-(ent2:Entity) " 134 | f" where ent.id={ent1_id} and r1.name='{rel1_name}' and r2.name='{rel2_name}' and ent2.id={ent2_id}" 135 | f"return DISTINCT target.name") 136 | logging.info(f"{cql}; *ent_name:{ent1_name},{ent2_name}") 137 | res = graph.run(cql) 138 | target_name = [ent['target.name'] for ent in res] 139 | return target_name 140 | -------------------------------------------------------------------------------- /ckbqa/qa/qa.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ckbqa.qa.algorithms import Algorithms 4 | from ckbqa.qa.cache import Memory 5 | from ckbqa.qa.el import EL 6 | from ckbqa.qa.neo4j_graph import GraphDB 7 | from ckbqa.qa.relation_extractor import RelationExtractor 8 | from ckbqa.utils.async_tools import async_init_singleton_class 9 | 10 | 11 | class QA(object): 12 | def __init__(self): 13 | async_init_singleton_class([Memory, GraphDB, RelationExtractor, EL]) 14 | logging.info('QA init start ...') 15 | self.graph_db = GraphDB() 16 | self.memory = Memory() 17 | self.el = EL() 18 | self.relation_extractor = RelationExtractor() 19 | self.algo = Algorithms() 20 | logging.info('QA init Done ...') 21 | 22 | def query_path(self, path, direction='out'): 23 | if len(path) == 2: 24 | ent_name, rel_name = path 25 | answer = self.graph_db.search_by_2path(ent_name, rel_name, direction=direction) 26 | elif len(path) == 3: 27 | ent_name, rel1_name, rel2_name = path 28 | answer = self.graph_db.search_by_3path(ent_name, rel1_name, rel2_name, direction=direction) 29 | elif len(path) == 4: 30 | ent1_name, rel1_name, rel2_name, ent2_name = path 31 | answer = self.graph_db.search_by_4path(ent1_name, rel1_name, rel2_name, ent2_name, direction=direction) 32 | else: 33 | logging.info(f'这个查询路径不规范: {path}') 34 | answer = [] 35 | return answer 36 | 37 | def run(self, q_text, return_candidates=False): 38 | logging.info(f"* q_text: {q_text}") 39 | candidate_entities = self.el.el(q_text) 40 | logging.info(f'* get_candidate_entities {len(candidate_entities)} ...') 41 | candidate_out_paths, candidate_in_paths = self.relation_extractor.get_ent_relations(q_text, candidate_entities) 42 | logging.info(f'* get candidate_out_paths: {len(candidate_out_paths)},' 43 | f'candidate_in_paths: {len(candidate_in_paths)} ...') 44 | # 生成cypher语句并查询 45 | top_out_path, max_out_score = self.algo.get_most_overlap_path(q_text, candidate_out_paths) 46 | logging.info(f'* get_most_overlap out path:{max_out_score:.4f}, {top_out_path} ...') 47 | top_in_path, max_in_score = self.algo.get_most_overlap_path(q_text, candidate_in_paths) 48 | logging.info(f'* get_most_overlap in path:{max_in_score:.4f}, {top_in_path} ...') 49 | # self.graph_db.cache() # 缓存neo4j查询结果 50 | if max_out_score >= max_in_score: # 分数相同时取out 51 | direction = 'out' 52 | top_path = top_out_path 53 | else: 54 | direction = 'in' 55 | top_path = top_in_path 56 | logging.info(f'* get_most_overlap_path: {top_path} ...') 57 | result_ents = self.query_path(top_path, direction=direction) 58 | if not result_ents and len(top_path) > 2: 59 | if direction == 'out': 60 | top_path = top_path[:2] 61 | else: 62 | top_path = top_path[-2:] 63 | result_ents = self.query_path(top_path) 64 | if not result_ents: 65 | top_path = top_path[0] + top_path[-1] 66 | result_ents = self.query_path(top_path, direction=direction) 67 | logging.info(f"* cypher result_ents: {result_ents}" + '\n' + '--' * 10 + '\n\n') 68 | if return_candidates: 69 | return result_ents, candidate_entities, candidate_out_paths, candidate_in_paths 70 | return result_ents 71 | 72 | # def evaluate(self, question, subject_entities, result_ents_entities): 73 | # q_text = question_patten.findall(question)[0] 74 | # candidate_entities = self.recognizer.get_candidate_entities(q_text) 75 | # precision, recall, f1 = get_metrics(subject_entities, candidate_entities) 76 | # logging.info(f'get_candidate_entities, precision: {precision}, recall: {recall}, f1: {f1}') 77 | -------------------------------------------------------------------------------- /ckbqa/qa/relation_extractor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from typing import Tuple, Dict, List 4 | 5 | import numpy as np 6 | 7 | from ckbqa.models.relation_score.predictor import RelationScorePredictor 8 | from ckbqa.qa.neo4j_graph import GraphDB 9 | from ckbqa.utils.decorators import singleton 10 | 11 | 12 | @singleton 13 | class RelationExtractor(object): 14 | 15 | def __init__(self): 16 | self.graph_db = GraphDB() 17 | self.sim_predictor = RelationScorePredictor(model_name='bert_match') # bert_match,bert_match2 18 | 19 | def get_relations(self, candidate_entities, ent_name, direction='out') -> Tuple: 20 | onehop_relations = self.graph_db.get_onehop_relations_by_entName(ent_name, direction=direction) 21 | twohop_relations = self.graph_db.get_twohop_relations_by_entName(ent_name, direction=direction) 22 | mention = candidate_entities[ent_name]['mention'] 23 | candidate_paths, candidate_sents = [], [] 24 | for rel_name in onehop_relations: 25 | _candidate_path = '的'.join([mention, rel_name[1:-1]]) # 查询的关系有<>,需要去掉 26 | candidate_sents.append(_candidate_path) 27 | candidate_paths.append([ent_name, rel_name]) 28 | for rels in twohop_relations: 29 | pure_rel_names = [rel_name[1:-1] for rel_name in rels] # python-list 关系名列表 30 | _candidate_path = '的'.join([mention] + pure_rel_names) 31 | candidate_sents.append(_candidate_path) 32 | candidate_paths.append([ent_name] + rels) 33 | return candidate_paths, candidate_sents 34 | 35 | def get_ent_relations(self, q_text: str, candidate_entities: Dict) -> Tuple: 36 | """ 37 | :param q_text: 38 | :top_k 10 ; out 10, in 10 39 | :param candidate_entities: {ent:[mention, feature1, feature2, ...]} 40 | :return: 41 | """ 42 | candidate_out_sents, candidate_out_paths = [], [] 43 | candidate_in_sents, candidate_in_paths = [], [] 44 | for entity in candidate_entities: 45 | candidate_paths, candidate_sents = self.get_relations(candidate_entities, entity, direction='out') 46 | candidate_out_sents.extend(candidate_sents) 47 | candidate_out_paths.extend(candidate_paths) 48 | candidate_paths, candidate_sents = self.get_relations(candidate_entities, entity, direction='in') 49 | candidate_in_sents.extend(candidate_sents) 50 | candidate_in_paths.extend(candidate_paths) 51 | if not candidate_out_sents and not candidate_in_sents: 52 | logging.info('* candidate_out_paths Empty ...') 53 | return candidate_out_paths, candidate_in_paths 54 | # 模型打分排序 55 | _out_paths = self.relation_score_topn(q_text, candidate_out_paths, candidate_out_sents) 56 | _in_paths = self.relation_score_topn(q_text, candidate_out_paths, candidate_out_sents) 57 | return _out_paths, _in_paths 58 | 59 | def relation_score_topn(self, q_text: str, candidate_paths, candidate_sents, top_k=10) -> List: 60 | top_k = len(candidate_sents) # TODO 后需删除;不做筛选,保留所有路径;目前筛选效果不好 61 | if top_k >= len(candidate_sents): # 太少则不作筛选 62 | return candidate_paths 63 | if candidate_sents: 64 | sim_scores = self.sim_predictor.predict(q_text, candidate_sents) # 目前算法不好,score==1 65 | sim_indexs = np.array(sim_scores).argsort()[-top_k:][::-1] 66 | _paths = [candidate_paths[i] for i in sim_indexs] 67 | else: 68 | _paths = [] 69 | return _paths 70 | # for i in in_sim_indexs: 71 | # try: 72 | # candidate_in_paths[i] 73 | # except: 74 | # import traceback 75 | # logging.info(traceback.format_exc()) 76 | # print(f'i: {i}, candidate_in_paths: {len(candidate_in_paths)}') 77 | # import ipdb 78 | # ipdb.set_trace() 79 | -------------------------------------------------------------------------------- /ckbqa/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/utils/__init__.py -------------------------------------------------------------------------------- /ckbqa/utils/async_tools.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | 4 | from ckbqa.utils.decorators import synchronized 5 | 6 | 7 | def apply_async(func, *args, daemon=False, **kwargs): 8 | thr = threading.Thread(target=func, args=args, kwargs=kwargs) 9 | if daemon: 10 | thr.setDaemon(True) # 随主线程一起结束 11 | thr.start() 12 | if not daemon: 13 | thr.join() 14 | return thr 15 | 16 | 17 | # for sc in A.__subclasses__():#所有子类 18 | # print(sc.__name__) 19 | @synchronized 20 | def async_init_singleton_class(classes=()): 21 | logging.info('async_init_singleton_class start ...') 22 | thrs = [threading.Thread(target=_singleton_class) 23 | for _singleton_class in classes] 24 | for t in thrs: 25 | t.start() 26 | for t in thrs: 27 | t.join() # 等待所有线程结束再往下继续 28 | logging.info('async_init_singleton_class done ...') 29 | -------------------------------------------------------------------------------- /ckbqa/utils/decorators.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import threading 4 | import traceback 5 | from functools import wraps 6 | 7 | 8 | def synchronized(func): 9 | func.__lock__ = threading.Lock() 10 | 11 | @wraps(func) 12 | def lock_func(*args, **kwargs): 13 | with func.__lock__: 14 | res = func(*args, **kwargs) 15 | return res 16 | 17 | return lock_func 18 | 19 | 20 | # singleton_classes = set() 21 | 22 | 23 | def singleton(cls): 24 | """ 25 | 只初始化一次;new和init都只调用一次 26 | """ 27 | instances = {} # 当类创建完成之后才有内容 28 | # singleton_classes.add(cls) 29 | 30 | @synchronized 31 | @wraps(cls) 32 | def get_instance(*args, **kw): 33 | if cls not in instances: # 保证只初始化一次 34 | logging.info(f"{cls.__name__} init start ...") 35 | instances[cls] = cls(*args, **kw) 36 | logging.info(f"{cls.__name__} init done ...") 37 | gc.collect() 38 | return instances[cls] 39 | 40 | return get_instance 41 | 42 | 43 | class Singleton(object): 44 | """ 45 | 保证只new一次;但会初始化多次,每个子类的init方法都会被调用 46 | 会造成对象虽然是同一个,但因为会不断地调用init方法,对象属性被不断的修改 47 | """ 48 | instance = None 49 | 50 | @synchronized 51 | def __new__(cls, *args, **kwargs): 52 | """ 53 | :type kwargs: object 54 | """ 55 | if cls.instance is None: # 保证只new一次;但会初始化多次,每个子类的init方法都会被调用,造成对象虽然是同一个但会不断地被修改 56 | cls.instance = super().__new__(cls) 57 | return cls.instance 58 | 59 | 60 | def try_catch_with_logging(default_response=None): 61 | def out_wrapper(func): 62 | @wraps(func) 63 | def wrapper(*args, **kwargs): 64 | try: 65 | res = func(*args, **kwargs) 66 | except Exception: 67 | res = default_response 68 | logging.error(traceback.format_exc()) 69 | return res 70 | 71 | return wrapper 72 | 73 | return out_wrapper 74 | -------------------------------------------------------------------------------- /ckbqa/utils/gpu_selector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import traceback 4 | 5 | 6 | def get_available_gpu(num_gpu=1, min_memory=1000, try_times=3, allow_gpus="0,1,2,3", verbose=False): 7 | ''' 8 | :param num_gpu: number of GPU you want to use 9 | :param min_memory: minimum memory MiB 10 | :param sample: try_times 11 | :param all_gpu_nums: accessible gpu nums 允许挑选的GPU编号列表 12 | :param verbose: verbose mode 13 | :return: str of best choices, e.x. '1, 2' 14 | ''' 15 | selected = "" 16 | while try_times: 17 | try_times -= 1 18 | info_text = os.popen('nvidia-smi --query-gpu=utilization.gpu,memory.free --format=csv').read() 19 | try: 20 | gpu_info = [(str(gpu_id), int(memory.replace('%', '').replace('MiB', '').split(',')[1].strip())) 21 | for gpu_id, memory in enumerate(info_text.split('\n')[1:-1])] 22 | except: 23 | if verbose: 24 | print(traceback.format_exc()) 25 | return "Not found gpu info ..." 26 | gpu_info.sort(key=lambda info: info[1], reverse=True) # 内存从高到低排序 27 | avilable_gpu = [] 28 | for gpu_id, memory in gpu_info: 29 | if gpu_id in allow_gpus: 30 | if memory >= min_memory: 31 | avilable_gpu.append(gpu_id) 32 | if avilable_gpu: 33 | selected = ",".join(avilable_gpu[:num_gpu]) 34 | else: 35 | print('No GPU available, will retry after 2.0 seconds ...') 36 | time.sleep(2) 37 | continue 38 | if verbose: 39 | print('Available GPU List') 40 | first_line = [['id', 'utilization.gpu(%)', 'memory.free(MiB)']] 41 | matrix = first_line + avilable_gpu 42 | s = [[str(e) for e in row] for row in matrix] 43 | lens = [max(map(len, col)) for col in zip(*s)] 44 | fmt = '\t'.join('{{:{}}}'.format(x) for x in lens) 45 | table = [fmt.format(*row) for row in s] 46 | print('\n'.join(table)) 47 | print('Select id #' + selected + ' for you.') 48 | return selected 49 | -------------------------------------------------------------------------------- /ckbqa/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging.handlers 2 | 3 | 4 | def logging_config(logging_name='./run.log', stream_log=True, log_level="info"): 5 | """ 6 | :param logging_name: log名 7 | :param stream_log: 是否把log信息输出到屏幕,标准输出 8 | :param level: fatal,error,warn,info,debug 9 | :return: None 10 | """ 11 | log_handles = [logging.handlers.RotatingFileHandler( 12 | logging_name, maxBytes=20 * 1024 * 1024, backupCount=5, encoding='utf-8')] 13 | if stream_log: 14 | log_handles.append(logging.StreamHandler()) 15 | logging_level = {"fatal": logging.FATAL, "error": logging.ERROR, "warn": logging.WARN, 16 | "info": logging.INFO, "debug": logging.DEBUG}[log_level] 17 | logging.basicConfig( 18 | handlers=log_handles, 19 | level=logging_level, 20 | format="%(asctime)s - %(levelname)s %(filename)s %(funcName)s %(lineno)s - %(message)s" 21 | ) 22 | 23 | # if __name__ == "__main__": 24 | # logging_config("./test.log", stream_log=True) # ../../log/test.log 25 | # logging.info("标准输出 log ...") 26 | # logging.debug("hello") 27 | -------------------------------------------------------------------------------- /ckbqa/utils/saver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | 6 | from config import Config, ckpt_dir 7 | 8 | 9 | class Saver(object): 10 | def __init__(self, model_name): 11 | # self.dataset = dataset 12 | self.model_name = model_name 13 | self.model_dir = os.path.join(ckpt_dir, model_name) 14 | 15 | def load_model(self, model: torch.nn.Module, mode="max_step", fail_ok=False, map_location=Config.device): 16 | model_path = os.path.join(self.model_dir, mode, f"{self.model_name}.bin") 17 | if os.path.isfile(model_path): 18 | ckpt = torch.load(model_path, map_location=map_location) 19 | model.load_state_dict(ckpt["net"]) 20 | step = ckpt["step"] 21 | epoch = ckpt["epoch"] 22 | logging.info("* load model from file: {},epoch:{}, step:{}".format(model_path, epoch, step)) 23 | else: 24 | if fail_ok: 25 | epoch = 0 26 | step = 0 27 | else: 28 | raise ValueError(f'model path : {model_path} is not exist') 29 | logging.info("* Fail load model from file: {}".format(model_path)) 30 | return model_path, epoch, step 31 | 32 | def save(self, model, epoch, step=-1, mode="max_step", parms_dic: dict = None): 33 | model_path = os.path.join(self.model_dir, mode, f"{self.model_name}.bin") 34 | os.makedirs(os.path.dirname(model_path), exist_ok=True) 35 | state = {"net": model.state_dict(), 'epoch': epoch, "step": step} 36 | if isinstance(parms_dic, dict): 37 | state.update(parms_dic) 38 | torch.save(state, model_path) 39 | return model_path 40 | -------------------------------------------------------------------------------- /ckbqa/utils/sequence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import six 3 | 4 | 5 | def pad_sequences(sequences, maxlen=None, dtype='int32', 6 | padding='pre', truncating='pre', value=0.): 7 | """Pads sequences to the same length. 8 | 9 | This function transforms a list of 10 | `num_samples` sequences (lists of integers) 11 | into a 2D Numpy array of shape `(num_samples, num_timesteps)`. 12 | `num_timesteps` is either the `maxlen` argument if provided, 13 | or the length of the longest sequence otherwise. 14 | 15 | Sequences that are shorter than `num_timesteps` 16 | are padded with `value` at the end. 17 | 18 | Sequences longer than `num_timesteps` are truncated 19 | so that they fit the desired length. 20 | The position where padding or truncation happens is determined by 21 | the arguments `padding` and `truncating`, respectively. 22 | 23 | Pre-padding is the default. 24 | 25 | # Arguments 26 | sequences: List of lists, where each element is a sequence. 27 | maxlen: Int, maximum length of all sequences. 28 | dtype: Type of the output sequences. 29 | To pad sequences with variable length strings, you can use `object`. 30 | padding: String, 'pre' or 'post': 31 | pad either before or after each sequence. 32 | truncating: String, 'pre' or 'post': 33 | remove values from sequences larger than 34 | `maxlen`, either at the beginning or at the end of the sequences. 35 | value: Float or String, padding value. 36 | 37 | # Returns 38 | x: Numpy array with shape `(len(sequences), maxlen)` 39 | 40 | # Raises 41 | ValueError: In case of invalid values for `truncating` or `padding`, 42 | or in case of invalid shape for a `sequences` entry. 43 | """ 44 | if not hasattr(sequences, '__len__'): 45 | raise ValueError('`sequences` must be iterable.') 46 | num_samples = len(sequences) 47 | lengths = [] 48 | for x in sequences: 49 | try: 50 | lengths.append(len(x)) 51 | except TypeError: 52 | raise ValueError('`sequences` must be a list of iterables. ' 53 | 'Found non-iterable: ' + str(x)) 54 | 55 | if maxlen is None: 56 | maxlen = np.max(lengths) 57 | 58 | # take the sample shape from the first non empty sequence 59 | # checking for consistency in the main loop below. 60 | sample_shape = tuple() 61 | for s in sequences: 62 | if len(s) > 0: 63 | sample_shape = np.asarray(s).shape[1:] 64 | break 65 | 66 | is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_) 67 | if isinstance(value, six.string_types) and dtype != object and not is_dtype_str: 68 | raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n" 69 | "You should set `dtype=object` for variable length strings." 70 | .format(dtype, type(value))) 71 | 72 | x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) 73 | for idx, s in enumerate(sequences): 74 | if not len(s): 75 | continue # empty list/array was found 76 | if truncating == 'pre': 77 | trunc = s[-maxlen:] 78 | elif truncating == 'post': 79 | trunc = s[:maxlen] 80 | else: 81 | raise ValueError('Truncating type "%s" ' 82 | 'not understood' % truncating) 83 | 84 | # check `trunc` has expected shape 85 | trunc = np.asarray(trunc, dtype=dtype) 86 | if trunc.shape[1:] != sample_shape: 87 | raise ValueError('Shape of sample %s of sequence at position %s ' 88 | 'is different from expected shape %s' % 89 | (trunc.shape[1:], idx, sample_shape)) 90 | 91 | if padding == 'post': 92 | x[idx, :len(trunc)] = trunc 93 | elif padding == 'pre': 94 | x[idx, -len(trunc):] = trunc 95 | else: 96 | raise ValueError('Padding type "%s" not understood' % padding) 97 | return x 98 | -------------------------------------------------------------------------------- /ckbqa/utils/tools.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import os 4 | import pickle 5 | import platform 6 | import re 7 | import sys 8 | import threading 9 | import time 10 | import traceback 11 | 12 | import orjson # faster than json,ujson 13 | import psutil 14 | from tqdm import tqdm 15 | 16 | 17 | def pkl_load(file_path: str): 18 | with open(file_path, 'rb') as f: 19 | gc.disable() 20 | obj = pickle.load(f) 21 | gc.enable() 22 | return obj 23 | 24 | 25 | def pkl_dump(obj: object, file_path: str): 26 | # limit = {'default': sys.getrecursionlimit(), 27 | # 'common': 10 * 10000, 28 | # 'max': resource.getrlimit(resource.RLIMIT_STACK)[0] 29 | # } 30 | # sys.setrecursionlimit(limit[recursionlimit]) 31 | with open(file_path, 'wb') as f: 32 | gc.disable() 33 | pickle.dump(obj, f) 34 | gc.enable() 35 | 36 | 37 | def json_load(path): 38 | with open(path, 'rb') as f: 39 | obj = orjson.loads(f.read()) 40 | gc.collect() 41 | return obj 42 | 43 | 44 | def json_dump(dict_data, save_path, override_exist=True): 45 | if override_exist or not os.path.isfile(save_path): 46 | strs = orjson.dumps(dict_data, option=orjson.OPT_INDENT_2) 47 | with open(save_path, "wb") as f: 48 | f.write(strs) 49 | # if save_memory: 50 | # json.dump(dict_data, f, ensure_ascii=False) 51 | # else: 52 | # json.dump(dict_data, f, ensure_ascii=False, indent=indent, sort_keys=sort_keys) 53 | 54 | 55 | def get_file_linenums(file_name): 56 | if platform.system() in ['Linux', 'Darwin']: # Linux,Mac 57 | num_str = os.popen(f'wc -l {file_name}').read() 58 | line_num = int(re.findall('\d+', num_str)[0]) 59 | else: # Windows 60 | line_num = sum([1 for _ in open(file_name, encoding='utf-8')]) 61 | return line_num 62 | 63 | 64 | def tqdm_iter_file(file_path, prefix=''): 65 | line_num = get_file_linenums(file_path) 66 | with open(file_path, 'r', encoding='utf-8') as f: 67 | for line in tqdm(f, total=line_num, desc=prefix): 68 | yield line 69 | 70 | 71 | def byte2human(bytes, unit='B', precision=2): 72 | unit_mount_map = {'B': 1, 'KB': 1024, 'MB': 1024 * 1024, 'GB': 1024 * 1024 * 1024} 73 | memo = bytes / unit_mount_map[unit] 74 | memo = round(memo, precision) 75 | return memo 76 | 77 | 78 | def get_var_size(var, unit='B'): 79 | size = sys.getsizeof(var) 80 | readable_size = f"{byte2human(size):.2f} {unit}" 81 | return readable_size 82 | 83 | 84 | class ShowTime(object): 85 | ''' 86 | 用上下文管理器计时 87 | ''' 88 | 89 | def __init__(self, prefix=""): 90 | self.prefix = prefix 91 | 92 | def __enter__(self): 93 | self.start_timestamp = time.time() 94 | return self 95 | 96 | def __exit__(self, exc_type, exc_val, exc_tb): 97 | self.runtime = time.time() - self.start_timestamp 98 | print("{} take time: {:.2f} s".format(self.prefix, self.runtime)) 99 | if exc_type is not None: 100 | print(exc_type, exc_val, exc_tb) 101 | print(traceback.format_exc()) 102 | return self 103 | 104 | 105 | class ProcessManager(object): 106 | def __init__(self, check_secends=20, memo_unit='GB', precision=2): 107 | self.pid = os.getpid() 108 | self.p = psutil.Process(self.pid) 109 | self.check_secends = check_secends 110 | self.memo_unit = memo_unit 111 | self.precision = precision 112 | self.start_time = time.time() 113 | 114 | def kill(self): 115 | child_poll = self.p.children(recursive=True) 116 | for p in child_poll: 117 | if not 'SYSTEM' in p.username: 118 | print(f'kill sub process: PID: {p.pid} user: {p.username()} name: {p.name()}') 119 | p.kill() 120 | self.p.kill() 121 | print(f'kill {self.pid}') 122 | 123 | def get_memory_info(self): 124 | memo = byte2human(self.p.memory_info().rss, self.memo_unit) 125 | info = psutil.virtual_memory() 126 | total_memo = byte2human(info.total, self.memo_unit) 127 | used = byte2human(info.used, self.memo_unit) 128 | free = byte2human(info.free, self.memo_unit) 129 | available = byte2human(info.available, self.memo_unit) 130 | cur_pid_percent = info.percent 131 | return memo, used, free, available, total_memo, cur_pid_percent 132 | 133 | def task(self): 134 | while True: 135 | memo, used, free, available, total_memo, cur_pid_percent = self.get_memory_info() 136 | print('--' * 20) 137 | print(f'PID: {self.pid} name: {self.p.name()}') 138 | print(f'当前进程内存占用 :\t {memo:.2f} {self.memo_unit}') 139 | print(f'used :\t {used:.2f} {self.memo_unit}') 140 | print(f'free :\t {free:.2f} {self.memo_unit}') 141 | print(f'total :\t {total_memo} {self.memo_unit}') 142 | print(f'内存占比 :\t {cur_pid_percent} %') 143 | print(f'运行时间 :\t {(time.time() - self.start_time) / 60:.2f} min') 144 | # print('cpu个数:', psutil.cpu_count()) 145 | if cur_pid_percent > 95: 146 | logging.info(f'内存占比过高: {cur_pid_percent}%, kill {self.pid}') 147 | self.kill() # 停止进程 148 | time.sleep(self.check_secends) 149 | 150 | def run(self): 151 | thr = threading.Thread(target=self.task) 152 | thr.setDaemon(True) # 跟随主线程结束 153 | thr.start() 154 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import arrow 6 | import torch 7 | 8 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 9 | root_dir = cur_dir 10 | data_dir = os.path.join(root_dir, "data") 11 | raw_data_dir = os.path.join(data_dir, 'raw_data') 12 | 13 | output_dir = os.path.join(root_dir, "output") 14 | ckpt_dir = os.path.join(output_dir, "ckpt") 15 | result_dir = os.path.join(output_dir, 'result') 16 | 17 | # 原始数据 18 | mention2ent_txt = os.path.join(raw_data_dir, 'PKUBASE', 'pkubase-mention2ent.txt') 19 | kb_triples_txt = os.path.join(raw_data_dir, 'PKUBASE', 'pkubase-complete2.txt') 20 | # 问答原始数据 21 | raw_train_txt = os.path.join(raw_data_dir, 'ccks_2020_7_4_Data', 'task1-4_train_2020.txt') 22 | valid_question_txt = os.path.join(raw_data_dir, 'ccks_2020_7_4_Data', 'task1-4_valid_2020.questions') 23 | 24 | 25 | class DataConfig(object): 26 | """ 27 | 原始数据经过处理后生成的数据 28 | """ 29 | word2id_json = os.path.join(data_dir, 'word2id.json') 30 | q_entity2id_json = os.path.join(data_dir, 'q_entity2id.json') 31 | a_entity2id_json = os.path.join(data_dir, 'a_entity2id.json') 32 | # 33 | data_csv = os.path.join(data_dir, 'data.csv') # 训练数据做了一点格式转换 34 | # 35 | mention2ent_json = os.path.join(data_dir, 'mention2ent.json') 36 | ent2mention_json = os.path.join(data_dir, 'ent2mention.json') 37 | entity2id = os.path.join(data_dir, 'entity2id.json') 38 | id2entity_pkl = os.path.join(data_dir, 'id2entity.pkl') 39 | relation2id = os.path.join(data_dir, 'relation2id.json') 40 | id2relation_pkl = os.path.join(data_dir, 'id2relation.pkl') 41 | # count 42 | entity2count_json = os.path.join(data_dir, 'entity2count.json') 43 | relation2count_json = os.path.join(data_dir, 'relation2count.json') 44 | mention2count_json = os.path.join(data_dir, 'mention2count.json') 45 | # 46 | lac_custom_dict_txt = os.path.join(data_dir, 'lac_custom_dict.txt') 47 | lac_attr_custom_dict_txt = os.path.join(data_dir, 'lac_attr_custom_dict.txt') 48 | jieba_custom_dict = os.path.join(data_dir, 'jieba_custom_dict.json') 49 | # graph_pkl = os.path.join(data_dir, 'graph.pkl') 50 | graph_entity_csv = os.path.join(data_dir, 'graph_entity.csv') # 图谱导入 51 | graph_relation_csv = os.path.join(data_dir, 'graph_relation.csv') # 图谱导入 52 | entity2types_json = os.path.join(data_dir, 'entity2type.json') 53 | entity2attrs_json = os.path.join(data_dir, 'entity2attr.json') 54 | all_attrs_json = os.path.join(data_dir, 'all_attrs.json') # 所有属性 55 | # 56 | lac_model_pkl = os.path.join(data_dir, 'lac_model.pkl') 57 | # EntityScore model 58 | entity_score_model_pkl = os.path.join(data_dir, 'entity_score_model.pkl') 59 | entity_score_data_pkl = os.path.join(data_dir, 'entity_score_data.pkl') 60 | # 61 | neo4j_query_cache = os.path.join(data_dir, 'neo4j_query_cache.json') 62 | 63 | # 64 | relation_score_sample_csv = os.path.join(data_dir, 'sample.csv') 65 | 66 | @staticmethod 67 | def get_relation_score_sample_csv(data_type, neg_rate): 68 | if data_type == 'train': 69 | file_path = Path(DataConfig.relation_score_sample_csv).with_name(f'train.1_{neg_rate}.csv') 70 | else: 71 | file_path = Path(DataConfig.relation_score_sample_csv).with_name('test.csv') 72 | return str(file_path) 73 | 74 | 75 | class TorchConfig(object): 76 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 77 | device = torch.device('cpu') 78 | # device = "cpu" 79 | gpu_nums = torch.cuda.device_count() 80 | multi_gpu = True 81 | gradient_accumulation_steps = 1 82 | clip_grad = 2 83 | 84 | 85 | class Parms(object): 86 | # 87 | learning_rate = 0.001 88 | # # 89 | # min_epoch_nums = 1 90 | # max_epoch_nums = 10 91 | # # 92 | # embedding_dim = 128 # entity enbedding dim, relation enbedding dim , word enbedding dim 93 | max_len = 50 # max sentence length 94 | # batch_size = 32 95 | # # subtask = 'general' 96 | # test_batch_size = 128 97 | 98 | 99 | class Config(TorchConfig, DataConfig, Parms): 100 | pretrained_model_name_or_path = os.path.join(data_dir, 'bert-base-chinese-pytorch') # 'bert-base-chinese' 101 | # load_pretrain = True 102 | # rand_seed = 1234 103 | # load_model_mode = "min_loss" 104 | # load_model_mode = "max_step" 105 | # load_model_mode = "max_acc" # mrr 106 | # 107 | # train_count = 1000 # TODO for debug 108 | # test_count = 10 # 10*2*13589 109 | 110 | 111 | class ResultSaver(object): 112 | """输出文件管理;自动生成新文件名;避免覆盖 113 | 自动查找已存在的文件 114 | """ 115 | 116 | def __init__(self, find_exist_path=False): 117 | os.makedirs(result_dir, exist_ok=True) 118 | self.find_exist_path = find_exist_path 119 | 120 | def _get_new_path(self, file_name): 121 | date_str = arrow.now().format("YYYYMMDD") 122 | # date_str = '20200609' #临时修改 123 | num = 1 124 | path = os.path.join(result_dir, f"{date_str}-{num}-{file_name}") 125 | while os.path.isfile(path): 126 | path = os.path.join(result_dir, f"{date_str}-{num}-{file_name}") 127 | num += 1 128 | return path 129 | 130 | def _find_paths(self, file_name): 131 | paths = [str(_path) for _path in 132 | Path(result_dir).rglob(f'*{file_name}')] 133 | _paths = sorted(paths, reverse=True) 134 | return _paths 135 | 136 | def get_path(self, file_name): 137 | if self.find_exist_path: 138 | path = self._find_paths(file_name) 139 | else: 140 | path = self._get_new_path(file_name) 141 | logging.info(f'* get path: {path}') 142 | return path 143 | 144 | @property 145 | def train_result_csv(self): 146 | file_name = 'train_answer_result.csv' 147 | path = self.get_path(file_name) 148 | return path 149 | 150 | @property 151 | def valid_result_csv(self): 152 | file_name = 'valid_result.csv' 153 | path = self.get_path(file_name) 154 | return path 155 | 156 | @property 157 | def submit_result_txt(self): 158 | file_name = 'submit_result.txt' 159 | path = self.get_path(file_name) 160 | return path 161 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ckbqa.utils.logger import logging_config 4 | 5 | 6 | def create_db_tabels(): 7 | from sqlalchemy_utils import database_exists, create_database 8 | from ckbqa.dao.db import sqlite_db_engine 9 | if not database_exists(sqlite_db_engine.url): 10 | create_database(sqlite_db_engine.url) 11 | from ckbqa.dao.sqlite_models import BaseModel 12 | BaseModel.metadata.create_all(sqlite_db_engine) # 创建表 13 | 14 | 15 | def data_prepare(): 16 | logging_config('data_prepare.log', stream_log=True) 17 | from ckbqa.dataset.data_prepare import fit_on_texts, data_convert 18 | # map_mention_entity() 19 | # data2samples(neg_rate=3) 20 | data_convert() 21 | fit_on_texts() 22 | create_db_tabels() 23 | 24 | 25 | def kb_data_prepare(): 26 | logging_config('kb_data_prepare.log', stream_log=True) 27 | from ckbqa.dataset.kb_data_prepare import (candidate_words, fit_triples) 28 | from ckbqa.dataset.kb_data_prepare import create_graph_csv 29 | fit_triples() # 生成字典 30 | candidate_words() # 属性 31 | # create_lac_custom_dict() # 自定义分词词典 32 | 33 | create_graph_csv() # 生成数据库导入文件 34 | # from examples.lac_test import lac_model 35 | # lac_model() 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser(description="基础,通用parser") 40 | # logging config 日志配置 41 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕 42 | # 43 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个 44 | 45 | group.add_argument('--data_prepare', action="store_true", help="训练集数据预处理") 46 | group.add_argument('--kb_data_prepare', action="store_true", help="知识库数据预处理") 47 | group.add_argument('--task', action="store_true", help="临时组织的任务,调用多个函数") 48 | # parse args 49 | args = parser.parse_args() 50 | # 51 | # from ckbqa.utils.tools import ProcessManager 52 | # ProcessManager().run() 53 | if args.data_prepare: 54 | data_prepare() 55 | elif args.kb_data_prepare: 56 | kb_data_prepare() 57 | elif args.task: 58 | task() 59 | 60 | 61 | def task(): 62 | pass 63 | 64 | 65 | if __name__ == '__main__': 66 | """ 代码执行入口 67 | examples: 68 | python manage.py --data_prepare 69 | nohup python manage.py --kb_data_prepare &>kb_data_prepare.out & 70 | """ 71 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况 72 | # ProcessManager().run() 73 | main() 74 | -------------------------------------------------------------------------------- /docs/CCKS 2019 知识图谱评测技术报告:实体、关系、事件及问答 .pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/docs/CCKS 2019 知识图谱评测技术报告:实体、关系、事件及问答 .pdf -------------------------------------------------------------------------------- /docs/bad_case.md: -------------------------------------------------------------------------------- 1 | # bad case 2 | 3 | ## 1. 主实体识别错误 4 | 5 | question | 错误描述 | 解决办法|是否解决| 6 | ---|---|---|---| 7 | 被誉为万岛之国的是哪个国家? | 切分错误 | 增加词典or NER | 8 | LOF的全称是什么?|切分对了,mention:但LOF无对应实体|| 9 | 恒铭达的全称是什么?|库里面就不存在这个实体||| 10 | 列举唐朝的代表诗人?|库里不存在这个实体||| 11 | 迪梵是哪个公司旗下的品牌 ?|已近正确,但是训练集实体和给定图谱不一致(训练集:,知识库:)||| 12 | K2属于哪个山脉?|mention2ent里面无提及||| 13 | 哪位作家的原名叫张煐?|需要内容理解,已经识别出张煐等,但不是标准实体||| 14 | 谁下令修建了万里长城?|需要内容理解,已经识别出万里长城等,但不是标准实体||| 15 | 谁废除了分封制?|提供的标准划分,根本不存在||| 16 | 被称为冰上沙皇的是哪位运动员?|需要ner标注,但mention里没有,需要增加mention||| 17 | Hermès的总部在哪里?|||| 18 | 《水浒传》所属文学体裁的三要素是什么?|已经识别出水浒传,但标注为<水浒传(竖版)>||| 19 | 20 | - 改进点: 21 | - 增加mention2entity的内容(例子看k2没有指向;IBM没有指向), 22 | 原名->别名 23 | 别号->别名 24 | 25 | 26 | ## 2. 关系分类错误 27 | question | 错误描述 | 解决办法|是否解决| 28 | ---|---|---|---| 29 | 超级碗是什么体育项目联赛的年度冠军赛? | 分类后主实体不见了 | 分类算法 | 30 | 31 | 32 | 33 | 34 | ## 3. 后处理错误 35 | 36 | question | 错误描述 | 解决办法|是否解决| 37 | ---|---|---|---| 38 | 叶文洁毕业于哪个大学? | 主实体不见了,使用了最长匹配 | 增加词典or NER | 39 | 40 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from ckbqa.utils.logger import logging_config 8 | from config import ResultSaver 9 | 10 | 11 | def train_data(): 12 | logging_config('train_evaluate.log', stream_log=True) 13 | from ckbqa.models.evaluation_matrics import get_metrics 14 | # 15 | partten = re.compile(r'["<](.*?)[>"]') 16 | # 17 | _paths = ResultSaver(find_exist_path=True).train_result_csv 18 | print(_paths) 19 | train_df = pd.read_csv(_paths[0]) 20 | ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], [] 21 | answer_precisions, answer_recalls, answer_f1_scores = [], [], [] 22 | for index, row in tqdm(train_df.iterrows(), total=train_df.shape[0], desc='evaluate '): 23 | subject_entities = partten.findall(row['standard_subject_entities']) # 匹配文字 24 | if not subject_entities: 25 | subject_entities = eval(row['standard_subject_entities']) 26 | # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉; 27 | # 所以需要匹配文字,不匹配 <>, "" 28 | # CEG Candidate Entity Generation 29 | candidate_entities = eval(row['candidate_entities']) + partten.findall(row['candidate_entities']) 30 | precision, recall, f1 = get_metrics(subject_entities, candidate_entities) 31 | ceg_precisions.append(precision) 32 | ceg_recalls.append(recall) 33 | ceg_f1_scores.append(f1) 34 | # Answer 35 | standard_entities = eval(row['standard_answer_entities']) 36 | result_entities = eval(row['result_entities']) 37 | precision, recall, f1 = get_metrics(standard_entities, result_entities) 38 | answer_precisions.append(precision) 39 | answer_recalls.append(recall) 40 | answer_f1_scores.append(f1) 41 | # 42 | # print(f"question: {row['question']}\n" 43 | # f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}" 44 | # f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n") 45 | # import time 46 | # time.sleep(2) 47 | ave_ceg_precision = sum(ceg_precisions) / len(ceg_precisions) 48 | ave_ceg_recall = sum(ceg_recalls) / len(ceg_recalls) 49 | ave_ceg_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores) 50 | print(f"ave_ceg_precision: {ave_ceg_precision:.3f}, " 51 | f"ave_ceg_recall: {ave_ceg_recall:.3f}, " 52 | f"ave_ceg_f1_score:{ave_ceg_f1_score:.3f}") 53 | # 54 | ave_answer_precision = sum(answer_precisions) / len(answer_precisions) 55 | ave_answer_recall = sum(answer_recalls) / len(answer_recalls) 56 | ave_answer_f1_score = sum(answer_f1_scores) / len(answer_f1_scores) 57 | print(f"ave_result_precision: {ave_answer_precision:.3f}, " 58 | f"ave_result_recall: {ave_answer_recall:.3f}, " 59 | f"ave_result_f1_score:{ave_answer_f1_score:.3f}") 60 | 61 | 62 | def ceg(): 63 | logging_config('train_evaluate.log', stream_log=True) 64 | from ckbqa.models.evaluation_matrics import get_metrics 65 | from ckbqa.qa.el import CEG 66 | from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern # 67 | ceg = CEG() # Candidate Entity Generation 68 | ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], [] 69 | ceg_csv = "./ceg.csv" 70 | data = [] 71 | for q, sparql, a in load_data(tqdm_prefix='ceg evaluate '): 72 | q_entities = entity_pattern.findall(sparql) + attr_pattern.findall(sparql) 73 | 74 | q_text = ''.join(question_patten.findall(q)) 75 | # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉; 76 | # 所以需要匹配文字,不匹配 <>, "" 77 | ent2mention = ceg.get_ent2mention(q_text) 78 | # CEG Candidate Entity Generation 79 | precision, recall, f1 = get_metrics(q_entities, ent2mention) 80 | ceg_precisions.append(precision) 81 | ceg_recalls.append(recall) 82 | ceg_f1_scores.append(f1) 83 | # 84 | data.append([q, q_entities, list(ent2mention.keys())]) 85 | if recall == 0: 86 | # ceg.memory.entity2id 87 | # ceg.memory.mention2entity 88 | print(f"question: {q}\n" 89 | f"subject_entities: {q_entities}, candidate_entities: {ent2mention}" 90 | f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n") 91 | # import ipdb 92 | # ipdb.set_trace() 93 | print('\n\n') 94 | # import time 95 | # time.sleep(2) 96 | pd.DataFrame(data, columns=['question', 'q_entities', 'ceg']).to_csv( 97 | ceg_csv, index=False, encoding='utf_8_sig') 98 | ave_precision = sum(ceg_precisions) / len(ceg_precisions) 99 | ave_recall = sum(ceg_recalls) / len(ceg_recalls) 100 | ave_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores) 101 | print(f"ave_precision: {ave_precision:.3f}, " 102 | f"ave_recall: {ave_recall:.3f}, " 103 | f"ave_f1_score:{ave_f1_score:.3f}") 104 | 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser(description="基础,通用parser") 108 | # logging config 日志配置 109 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕 110 | # 111 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个 112 | 113 | group.add_argument('--ceg', action="store_true", help="ceg Candidate Entity Generation评价") 114 | group.add_argument('--train_data', action="store_true", help="train_answer_data评价") 115 | # parse args 116 | args = parser.parse_args() 117 | # 118 | # from ckbqa.utils.tools import ProcessManager 119 | # ProcessManager().run() 120 | if args.ceg: 121 | ceg() 122 | elif args.train_data: 123 | train_data() 124 | elif args.task: 125 | task() 126 | 127 | 128 | def task(): 129 | logging_config('ceg.log', stream_log=True) 130 | ceg() 131 | 132 | 133 | if __name__ == '__main__': 134 | """ 135 | example: 136 | nohup python qa.py --train_data &>train_data.out & 137 | nohup python qa.py --ceg &>ceg.out & 138 | """ 139 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况 140 | # ProcessManager().run() 141 | main() 142 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | def add_root_path(): 3 | import sys 4 | from pathlib import Path 5 | cur_dir = Path(__file__).absolute().parent 6 | times = 10 7 | while not str(cur_dir).endswith('ccks-2020') and times: 8 | cur_dir = cur_dir.parent 9 | times -= 1 10 | print(cur_dir) 11 | sys.path.append(str(cur_dir)) 12 | 13 | 14 | add_root_path() 15 | -------------------------------------------------------------------------------- /examples/answer_format.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def format(): 5 | data_df = pd.read_csv('./answer.txt') 6 | 7 | def func(text: str): 8 | if isinstance(text, str): 9 | ents = [] 10 | for ent in set(text.split('\t')): 11 | if ent.startswith('<') or ent.startswith('"'): 12 | ents.append(ent) 13 | else: 14 | ents.append('"' + ent + '"') 15 | print('"' + ent + '"') 16 | return '\t'.join(ents) 17 | else: 18 | print(text) 19 | return "" 20 | 21 | with open('./result.txt', 'w', encoding='utf-8') as f: 22 | for answer in data_df['answer']: 23 | line = func(answer) 24 | f.write(line + '\n') 25 | # data_df[['commit']].to_csv('./result.txt', encoding='utf_8_sig', index=False) 26 | 27 | 28 | if __name__ == '__main__': 29 | format() 30 | -------------------------------------------------------------------------------- /examples/bad_case.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | 2020-06-09 00:32:53,712 - INFO qa.py run 59 - * get_most_overlap_path: ['<吉林大学>', '<校歌>'] ... 4 | 2020-06-09 00:32:53,712 - INFO neo4j_graph.py search_by_2path 83 - match (ent:Entity)-[r1:Relation]-(target) where ent.id=5199304 and r1.name='<校歌>' return DISTINCT target.name 5 | 2020-06-09 00:32:53,892 - INFO qa.py run 72 - * cypher answer: ['<吉林大学校歌>'] 6 | ''' 7 | 8 | ''' 9 | 没有执行到 qa.run内部,搞不懂 10 | 11 | 12 | question : q185:林清玄有哪些别名? 13 | sparql : select ?x where { <林清玄> <别名> ?x. } 14 | standard answer : "秦情、林漓、林大悲等" 15 | **************************************************************************************************** 16 | 17 | question : q186:北京大学的发展定位是什么 18 | sparql : select ?x where { <北京大学> <发展定位> ?x . } 19 | standard answer : <世界一流大学> 20 | **************************************************************************************************** 21 | 22 | question : q187:中华人民共和国的国家领导人是谁? 23 | sparql : select ?x where { <中华人民共和国> <国家领袖> ?x . } 24 | standard answer : <习近平> <中华人民共和国主席> <中华人民共和国国务院总理> <李克强_(中共中 央政治局常委,国务院总理)> <张德江_(中央政治局常委)> <中华人民共和国全国人民代表大会常务委员会委员长> 25 | 26 | **************************************************************************************************** 27 | question : q188:墨冰仙是哪个门派的? 28 | sparql : select ?x where { <墨冰仙> <门派> ?x. } 29 | standard answer : <蜀山派_(游戏《仙剑奇侠传》中的武林门派)> 30 | 31 | 32 | 33 | **************************************************************************************************** 34 | question : q189:在金庸小说《天龙八部》中,斗转星移的修习者是谁? 35 | sparql : select ?x where { <斗转星移_(金庸小说武功心法)> <主要修习者> ?x. } 36 | standard answer : "慕容博" "慕容龙城" "慕容复" 37 | 38 | 39 | 40 | **************************************************************************************************** 41 | question : q190:《基督山伯爵》的作者是谁? 42 | sparql : select ?x where { <基督山伯爵_(大仲马代表作之一)> <作者> ?x. } 43 | standard answer : <亚历山大·仲马_(法国作家、大仲马)> 44 | 45 | 46 | 47 | **************************************************************************************************** 48 | question : q191:奥运会多久举办一次? 49 | sparql : select ?x where { <奥林匹克运动会> <举办时间> ?x. } 50 | standard answer : "每四年一届" 51 | 52 | 53 | 54 | **************************************************************************************************** 55 | question : q192:音乐剧歌剧魅影的主要演员? 56 | sparql : select ?x where { <歌剧魅影_(安德鲁·劳埃德·韦伯创作的音乐剧)> <主演> ?x. } 57 | standard answer : <莎拉·布莱曼> 58 | 59 | 60 | 61 | **************************************************************************************************** 62 | question : q193:索尼公司的经营范围有哪些? 63 | sparql : select ?x where { <索尼_(世界著名企业)> <经营范围> ?x. } 64 | standard answer : <信息技术_(用于管理和处理信息所采用各种技术总称)> <金融_(专业词语汉语词语 )> <电子_(基本粒子之一)> <娱乐_(汉语词语)> 65 | 66 | 67 | 68 | **************************************************************************************************** 69 | question : q194:2014年世界杯的冠军是哪只队? 70 | sparql : select ?x where { <2014年巴西世界杯> <冠军> ?x. } 71 | standard answer : <德国国家男子足球队> 72 | 73 | 74 | 75 | **************************************************************************************************** 76 | question : q195:泰铢的通胀率是多少? 77 | sparql : select ?x where { <泰铢> <通胀率> ?x. } 78 | standard answer : "5.1%" 79 | 80 | 81 | 82 | **************************************************************************************************** 83 | question : q196:小龙虾的中文学名? 84 | sparql : select ?x where { <小龙虾_(甲壳纲螯虾科水生动物)> <中文学名> ?x. } 85 | standard answer : "克氏原螯虾" 86 | 87 | 88 | 89 | **************************************************************************************************** 90 | question : q197:纽约的气候条件是怎样的? 91 | sparql : select ?x where { <纽约_(美国第一大城市)> <气候条件> ?x. } 92 | standard answer : <温带大陆性气候> 93 | 94 | 95 | 96 | **************************************************************************************************** 97 | question : q198:海贼王动画里黑胡子的声优是谁? 98 | sparql : select ?x where { <黑胡子_(动漫《海贼王》中人物)> <配音> ?x. } 99 | standard answer : <大冢明夫> 100 | 101 | 102 | 103 | **************************************************************************************************** 104 | question : q199:大连在哪个省份? 105 | sparql : select ?x where { <大连_(辽宁省辖市)> <所属地区> ?x. } 106 | standard answer : <辽宁> 107 | 108 | 109 | 110 | **************************************************************************************************** 111 | question : q200:新加坡的主要学府有哪些? 112 | sparql : select ?x where { <新加坡> <主要学府> ?x. } 113 | standard answer : "新加坡国立大学" "南洋理工大学" 114 | 115 | 116 | 117 | **************************************************************************************************** 118 | question : q201:大禹建立了哪个朝代? 119 | sparql : select ?x where { ?x <开国君主> <禹_(夏朝开国君主)>. } 120 | standard answer : <夏_(中国第一个朝代)> 121 | 122 | 123 | 124 | **************************************************************************************************** 125 | question : q202:菲尔·奈特创立了什么品牌? 126 | sparql : select ?x where { ?x <创始人> <菲尔·奈特>. } 127 | standard answer : 128 | 129 | 130 | 131 | **************************************************************************************************** 132 | question : q203:戊戌变法开始的标志是什么? 133 | sparql : select ?x where { <戊戌变法> <序幕> ?x. } 134 | standard answer : "公车上书" 135 | 136 | 137 | 138 | **************************************************************************************************** 139 | question : q204:肯德基的口号是什么? 140 | sparql : select ?x where { <肯德基> <公司口号> ?x . } 141 | standard answer : "生活如此多娇" 142 | 143 | 144 | 145 | **************************************************************************************************** 146 | question : q205:你知道彩虹又叫做什么吗? 147 | sparql : select ?x where { <彩虹_(雨后光学现象)> <又称> ?x. } 148 | standard answer : <天虹> 149 | 150 | 151 | 152 | **************************************************************************************************** 153 | question : q206:巴金的中文名是什么? 154 | sparql : select ?x where { <巴金> <中文名> ?x. } 155 | standard answer : "李尧棠" 156 | 157 | 158 | 159 | **************************************************************************************************** 160 | question : q207:海贼王里布鲁克的梦想是什么? 161 | sparql : select ?x where { <布鲁克_(日本动漫《海贼王》中主要人物)> <梦想> ?x. } 162 | standard answer : "回到双子跟鲸鱼拉布重逢" 163 | 164 | 165 | 166 | **************************************************************************************************** 167 | question : q208:你知道铁的熔点吗? 168 | sparql : select ?x where { <铁_(最常见的金属元素)> <熔点> ?x. } 169 | standard answer : "1535℃" 170 | 171 | 172 | 173 | **************************************************************************************************** 174 | question : q209:书籍《假如给我三天光明》是什么体裁? 175 | sparql : select ?x where { <假如给我三天光明> <文学体裁> ?x. } 176 | standard answer : <散文_(文学体裁)> 177 | 178 | 179 | 180 | **************************************************************************************************** 181 | question : q210:上海戏剧学院现任校长是谁? 182 | sparql : select ?y where { <上海戏剧学院> <现任校长> ?y. } 183 | standard answer : <韩生_(中国著名舞美设计家、上海戏剧学院院长)> 184 | ''' 185 | 186 | ''' 187 | 188 | 正确识别却查询失败 189 | 190 | - * q_text: linux的创始人毕业于哪个学校? 191 | 2020-06-14 09:59:32,694 - INFO qa.py run 40 - * get_candidate_entities 63 ... 192 | 2020-06-14 09:59:32,700 - INFO qa.py run 42 - * get candidate_out_paths: 1279,candidate_in_paths: 1279 ... 193 | 2020-06-14 09:59:32,707 - INFO qa.py run 46 - * get_most_overlap out path:13.9167, ['', '<创始人>', '<毕业院校>'] ... 194 | 2020-06-14 09:59:32,714 - INFO qa.py run 48 - * get_most_overlap in path:13.9167, ['', '<创始人>', '<毕业院校>'] ... 195 | 2020-06-14 09:59:32,714 - INFO qa.py run 56 - * get_most_overlap_path: ['', '<创始人>', '<毕业院校>'] ... 196 | 2020-06-14 09:59:32,714 - INFO neo4j_graph.py search_by_3path 123 - match (target)-[r1:Relation]-()-[r2:Relation]-(ent:Entity) where ent.id=3068935 and r1.name='<创始人>' and r2.name='<毕业院校>' return DISTINCT target.name; *ent_name: 197 | 2020-06-14 09:59:32,731 - INFO neo4j_graph.py search_by_2path 112 - match (ent:Entity)-[r1:Relation]-(target) where ent.id=3068935 and r1.name='<创始人>' return DISTINCT target.name; *ent_name: 198 | 2020-06-14 09:59:32,736 - INFO qa.py run 64 - * cypher answer: ['<林纳斯·托瓦兹>'] 199 | 200 | 201 | ------ 202 | - * q_text: 易水歌的作者的主要事迹是什么? 203 | 2020-06-14 10:07:50,878 - INFO qa.py run 40 - * get_candidate_entities 27 ... 204 | 2020-06-14 10:07:50,880 - INFO qa.py run 42 - * get candidate_out_paths: 414,candidate_in_paths: 414 ... 205 | 2020-06-14 10:07:50,883 - INFO qa.py run 46 - * get_most_overlap out path:12.9444, ['<易水歌>', '<作者>', '<主要事迹>'] ... 206 | 2020-06-14 10:07:50,885 - INFO qa.py run 48 - * get_most_overlap in path:12.9444, ['<易水歌>', '<作者>', '<主要事迹>'] ... 207 | 2020-06-14 10:07:50,885 - INFO qa.py run 56 - * get_most_overlap_path: ['<易水歌>', '<作者>', '<主要事迹>'] ... 208 | 2020-06-14 10:07:50,885 - INFO neo4j_graph.py search_by_3path 123 - match (target)-[r1:Relation]-()-[r2:Relation]-(ent:Entity) where ent.id=7957014 and r1.name='<作者>' and r2.name='<主要事迹>' return DISTINCT target.name; *ent_name:<易水歌> 209 | 2020-06-14 10:07:50,893 - INFO neo4j_graph.py search_by_2path 112 - match (ent:Entity)-[r1:Relation]-(target) where ent.id=7957014 and r1.name='<作者>' return DISTINCT target.name; *ent_name:<易水歌> 210 | 2020-06-14 10:07:50,896 - INFO qa.py run 64 - * cypher answer: ['<荆轲_(战国时期著名刺客)>'] 211 | ''' 212 | -------------------------------------------------------------------------------- /examples/del_method.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class DelMethod(object): 5 | 6 | def __del__(self): 7 | seconds = 10 8 | time.sleep(seconds) 9 | print(f'sleep seconds : {seconds}') 10 | 11 | 12 | def main(): 13 | DelMethod() 14 | 15 | 16 | if __name__ == '__main__': 17 | main() 18 | -------------------------------------------------------------------------------- /examples/kb_data.py: -------------------------------------------------------------------------------- 1 | def add_root_path(): 2 | import sys 3 | from pathlib import Path 4 | cur_dir = Path(__file__).absolute().parent 5 | times = 10 6 | while not str(cur_dir).endswith('ccks-2020') and times: 7 | cur_dir = cur_dir.parent 8 | times -= 1 9 | print(cur_dir) 10 | sys.path.append(str(cur_dir)) 11 | 12 | 13 | add_root_path() 14 | 15 | import re 16 | from collections import Counter 17 | 18 | import pandas as pd 19 | 20 | from ckbqa.dataset.data_prepare import load_data 21 | from ckbqa.dataset.kb_data_prepare import iter_triples 22 | from ckbqa.utils.tools import json_load, json_dump 23 | from config import mention2ent_txt 24 | 25 | 26 | def triples2csv(): 27 | """26G""" 28 | data_df = pd.DataFrame([(head_ent, rel, tail_ent) for head_ent, rel, tail_ent in iter_triples()], 29 | columns=['head', 'rel', 'tail']) 30 | data_df.to_csv('./triples.csv', encoding='utf_8_sig', index=False) 31 | 32 | 33 | def most_ents(): 34 | """26G""" 35 | 36 | entities = [] 37 | for head_ent, rel, tail_ent in iter_triples(): 38 | entities.extend([head_ent, rel, tail_ent]) 39 | print(f' len(entities): {len(entities)}') 40 | ent_counter = Counter(entities) 41 | del entities 42 | ent_count = ent_counter.most_common(1000000) # 百万 43 | print(ent_count[:10]) 44 | print(ent_count[-10:]) 45 | json_dump(dict(ent_counter), 'ent2count.json') 46 | # 47 | print('-----' * 10) 48 | mentions = [] 49 | with open(mention2ent_txt, 'r', encoding='utf-8') as f: 50 | for line in f: 51 | mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串 52 | mentions.append(mention) 53 | print(f"len(mentions): {len(mentions)}") 54 | mention_counter = Counter(mentions) 55 | del mentions 56 | mention_count = mention_counter.most_common(1000000) # 百万 57 | print(mention_count[:10]) 58 | print(mention_count[-10:]) 59 | json_dump(dict(mention_counter), 'mention2count.json') 60 | 61 | import ipdb 62 | ipdb.set_trace() 63 | 64 | 65 | def lac_test(): 66 | from LAC import LAC 67 | print('start ...') 68 | mention_count = Counter(json_load('mention2count.json')).most_common(10000) 69 | customization_dict = {mention: 'MENTION' for mention, count in mention_count 70 | if len(mention) >= 2} 71 | print('second ...') 72 | ent_count = Counter(json_load('ent2count.json')).most_common(300000) 73 | ent_pattrn = re.compile(r'["<](.*?)[>"]') 74 | customization_dict.update({' '.join(ent_pattrn.findall(ent)): 'ENT' for ent, count in ent_count}) 75 | with open('./customization_dict.txt', 'w') as f: 76 | f.write('\n'.join([f"{e}/{t}" for e, t in customization_dict.items() 77 | if len(e) >= 3])) 78 | import time 79 | before = time.time() 80 | print(f'before ...{before}') 81 | lac = LAC(mode='lac') 82 | lac.load_customization('./customization_dict.txt') # 20万21s;30万47s 83 | lac_raw = LAC(mode='lac') 84 | print(f'after ...{time.time() - before}') 85 | ## 86 | test_count = 10 87 | for q, sparql, a in load_data(): 88 | q_text = q.split(':')[1] 89 | print('---' * 10) 90 | print(q_text) 91 | print(sparql) 92 | print(a) 93 | words, tags = lac_raw.run(q_text) 94 | print(list(zip(words, tags))) 95 | words, tags = lac.run(q_text) 96 | print(list(zip(words, tags))) 97 | if not test_count: 98 | break 99 | test_count -= 1 100 | import ipdb 101 | ipdb.set_trace() 102 | 103 | 104 | class Node(object): 105 | def __init__(self, data): 106 | self.data = data 107 | self.ins = set() # 入度节点,关系和实体统一处理 108 | self.outs = set() # 出度节点,关系实体统一处理 109 | 110 | 111 | def create_graph(): 112 | nodes = {} 113 | for head_ent, rel, tail_ent in iter_triples(): 114 | if head_ent not in nodes: 115 | head_node = Node(head_ent) 116 | else: 117 | head_node = nodes[head_ent] 118 | head_node.outs.add(rel) 119 | # 120 | if rel not in nodes: 121 | rel_node = Node(rel) 122 | else: 123 | rel_node = nodes[rel] 124 | rel_node.ins.add(head_ent) 125 | rel_node.outs.add(tail_ent) 126 | # 127 | if tail_ent not in nodes: 128 | tail_node = Node(tail_ent) 129 | else: 130 | tail_node = nodes[tail_ent] 131 | tail_node.ins.add(rel) 132 | 133 | 134 | if __name__ == '__main__': 135 | # triples2csv() 136 | # most_ents() 137 | lac_test() 138 | -------------------------------------------------------------------------------- /examples/lac_test.py: -------------------------------------------------------------------------------- 1 | def add_root_path(): 2 | import sys 3 | from pathlib import Path 4 | cur_dir = Path(__file__).absolute().parent 5 | times = 10 6 | while not str(cur_dir).endswith('ccks-2020') and times: 7 | cur_dir = cur_dir.parent 8 | times -= 1 9 | print(cur_dir) 10 | sys.path.append(str(cur_dir)) 11 | 12 | 13 | add_root_path() 14 | 15 | import logging 16 | 17 | from ckbqa.qa.lac_tools import BaiduLac, JiebaLac 18 | from ckbqa.utils.logger import logging_config 19 | from ckbqa.utils.tools import pkl_dump, ProcessManager 20 | from config import Config 21 | 22 | logging_config('lac_test.log', stream_log=True) 23 | 24 | 25 | def lac_model(): 26 | # self.lac_seg = LAC(mode='seg') 27 | logging.info(f' load lac_custom_dict from {Config.lac_custom_dict_txt} start ...') 28 | baidu_lac = BaiduLac(mode='lac', _load_customization=True, reload=True) # 装载LAC模型 29 | save_path = baidu_lac._save_customization() # Config.lac_custom_dict_txt 30 | logging.info(f'load lac_custom_dict done, save to {save_path}...') 31 | 32 | 33 | def lac_test(): 34 | logging.info(f' load lac_custom_dict from {Config.lac_custom_dict_txt} start ...') 35 | save_path = './test.pkl' 36 | baidu_lac = BaiduLac(mode='lac', _load_customization=True, reload=True) # 装载LAC模型 37 | logging.info(f'load lac_custom_dict done, save to {save_path}...') 38 | 39 | 40 | def jieba_test(): 41 | # ProcessManager().run() 42 | jieba_model = JiebaLac() 43 | q_text1 = '被誉为万岛之国的是哪个国家?' 44 | q_text2 = '支付宝(中国)网络技术有限公司的工商注册号' 45 | for q_text in [q_text1, q_text2]: 46 | res = [x for x in jieba_model.cut_for_search(q_text)] 47 | print(res) 48 | print('-' * 10) 49 | # import ipdb 50 | # ipdb.set_trace() 51 | # pkl_dump(jieba_model, './jieba.pkl') 52 | 53 | 54 | if __name__ == '__main__': 55 | # lac_model() 56 | jieba_test() 57 | -------------------------------------------------------------------------------- /examples/single_example.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | from functools import wraps 4 | 5 | 6 | def synchronized(func): 7 | func.__lock__ = threading.Lock() 8 | 9 | def lock_func(*args, **kwargs): 10 | with func.__lock__: 11 | return func(*args, **kwargs) 12 | 13 | return lock_func 14 | 15 | 16 | def singleton(cls): 17 | instances = {} 18 | print(f"*********************************** instances: {instances}") 19 | 20 | @synchronized 21 | @wraps(cls) 22 | def get_instance(*args, **kw): 23 | if cls not in instances: 24 | logging.info(f"{cls.__name__} async init start ...") 25 | instances[cls] = cls(*args, **kw) 26 | logging.info(f"{cls.__name__} async init done ...") 27 | return instances[cls] 28 | 29 | return get_instance 30 | 31 | 32 | class Singleton(object): 33 | instance = None 34 | 35 | @synchronized 36 | def __new__(cls, *args, **kwargs): 37 | """ 38 | :type kwargs: object 39 | """ 40 | if cls.instance is None: 41 | cls.instance = super().__new__(cls) 42 | return cls.instance 43 | 44 | def _init__(self, *args, **kwargs): 45 | pass 46 | 47 | # def __init__(self, num): # 每次初始化后重新设置,重新初始化实例的属性值 48 | # self.a = num + 5 49 | # # 50 | # def printf(self): 51 | # print(self.a) 52 | 53 | 54 | @singleton 55 | class A(object): 56 | def __init__(self, data: int): 57 | super().__init__() 58 | self.data = data 59 | 60 | 61 | class B(Singleton): 62 | def __init__(self, data: int): 63 | super().__init__() 64 | self.data = data ** data 65 | 66 | 67 | def test(): 68 | a = Singleton(3) 69 | print(id(a), a, a.a) 70 | b = Singleton(4) 71 | print(id(b), b, b.a) 72 | 73 | 74 | def main(): 75 | a = A(2) 76 | print(f"a.data: {a.data}") 77 | a2 = A(4) 78 | print(f'a2.data: {a2.data}') 79 | b = B(3) 80 | print(f'b.data: {b.data}') 81 | b2 = B(6) 82 | print(f'b2.data: {b2.data}') 83 | print(singleton) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /examples/top_path.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def add_root_path(): 5 | import sys 6 | from pathlib import Path 7 | cur_dir = Path(__file__).absolute().parent 8 | times = 10 9 | while not str(cur_dir).endswith('ccks-2020') and times: 10 | cur_dir = cur_dir.parent 11 | times -= 1 12 | print(cur_dir) 13 | sys.path.append(str(cur_dir)) 14 | 15 | 16 | add_root_path() 17 | 18 | 19 | # --------------------- 20 | 21 | 22 | def get_most_overlap_path(q_text, paths): 23 | # 从排名前几的tuples里选择与问题overlap最多的 24 | max_score = 0 25 | top_path = paths[0] 26 | q_words = set(q_text) 27 | for path in paths: 28 | path_words = set() 29 | for ent_rel in path: 30 | if ent_rel.startswith('<'): 31 | ent_rel = ent_rel[1:-1].split('_')[0] 32 | path_words.update(set(ent_rel)) 33 | common_words = path_words & q_words # 交集 34 | score1 = len(common_words) 35 | score2 = len(common_words) / len(path_words) # 跳数 or len(path_words) 36 | score3 = len(common_words) / len(path) 37 | score = score1 + score2 # + score3 38 | if score > max_score: 39 | logging.info(f'score :{score}, score1 :{score1}, score2 :{score2}, score3 :{score3}, ' 40 | f'top_path: {top_path}') 41 | top_path = path 42 | max_score = score 43 | return top_path 44 | 45 | 46 | def main(): 47 | q='莫妮卡·贝鲁奇的代表作?' 48 | 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ckbqa.utils.logger import logging_config 4 | 5 | 6 | def set_envs(cpu_only, allow_gpus): 7 | import os 8 | import random 9 | rand_seed = 1234 10 | random.seed(rand_seed) 11 | if cpu_only: 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 13 | print("CPU only ...") 14 | elif allow_gpus: 15 | from ckbqa.utils.gpu_selector import get_available_gpu 16 | available_gpu = get_available_gpu(num_gpu=1, allow_gpus=allow_gpus) # default allow_gpus 0,1,2,3 17 | os.environ["CUDA_VISIBLE_DEVICES"] = available_gpu 18 | print("* using GPU: {} ".format(available_gpu)) # config前不可logging,否则config失效 19 | # 进程内存管理 20 | # from ckbqa.utils.tools import ProcessManager 21 | # ProcessManager().run() 22 | 23 | 24 | def run(model_name, mode): 25 | logging_config(f'{model_name}_{mode}.log', stream_log=True) 26 | if model_name in ['bert_match', 'bert_match2']: 27 | from ckbqa.models.relation_score.trainer import RelationScoreTrainer 28 | RelationScoreTrainer(model_name).train_match_model() 29 | elif model_name == 'entity_score': 30 | from ckbqa.models.entity_score.model import EntityScore 31 | EntityScore().train() 32 | 33 | 34 | def main(): 35 | ''' Parse command line arguments and execute the code 36 | --stream_log, --relative_path, --log_level 37 | --allow_gpus, --cpu_only 38 | ''' 39 | parser = argparse.ArgumentParser(description="基础,通用parser") 40 | # logging config 日志配置 41 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕 42 | parser.add_argument('--allow_gpus', default="0,1,2,3", type=str, 43 | help="指定GPU编号,0 or 0,1,2 or 0,1,2...7 | nvidia-smi 查看GPU占用情况") 44 | parser.add_argument('--cpu_only', action="store_true", help="CPU only, not to use GPU ") 45 | # 46 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个 47 | # 48 | all_models = ['bert_match', 'bert_match2', 'entity_score'] 49 | group.add_argument('--train', type=str, choices=all_models, help="训练") 50 | group.add_argument('--test', type=str, choices=all_models, help="测试") 51 | # parse args 52 | args = parser.parse_args() 53 | # 54 | set_envs(args.cpu_only, args.allow_gpus) # 设置环境变量等 55 | # 56 | if args.train: 57 | model_name = args.train 58 | mode = 'train' 59 | elif args.test: 60 | model_name = args.test 61 | mode = 'test' 62 | else: 63 | raise ValueError('must set model name ') 64 | run(model_name, mode) 65 | 66 | 67 | if __name__ == '__main__': 68 | """ 代码执行入口 69 | examples: 70 | nohup python manage.py --train bert_match &>bert_match.out& 71 | nohup python manage.py --train entity_score &>entity_score.out& 72 | """ 73 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况 74 | # ProcessManager().run() 75 | main() 76 | -------------------------------------------------------------------------------- /port_map.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ssh -f wangshengguang@remotehost -N -L 7474:localhost:7474 3 | ssh -f wangshengguang@remotehost -N -L 7687:localhost:7687 4 | 5 | -------------------------------------------------------------------------------- /qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import traceback 4 | 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from ckbqa.utils.logger import logging_config 9 | from config import valid_question_txt, ResultSaver 10 | 11 | 12 | def train_qa(): 13 | """训练数据进行回答;做指标测试""" 14 | logging_config('train_qa.log', stream_log=True) 15 | from ckbqa.qa.qa import QA 16 | from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern 17 | # from ckbqa.qa.evaluation_matrics import get_metrics 18 | # 19 | logging.info('* start run ...') 20 | qa = QA() 21 | data = [] 22 | for question, sparql, answer in load_data(tqdm_prefix='test qa'): 23 | print('\n' * 2) 24 | print('*****' * 20) 25 | print(f" question : {question}") 26 | print(f" sparql : {sparql}") 27 | print(f" standard answer : {answer}") 28 | q_text = question_patten.findall(question)[0] 29 | standard_subject_entities = entity_pattern.findall(sparql) + attr_pattern.findall(sparql) 30 | standard_answer_entities = entity_pattern.findall(answer) + attr_pattern.findall(answer) 31 | try: 32 | (result_entities, candidate_entities, 33 | candidate_out_paths, candidate_in_paths) = qa.run(q_text, return_candidates=True) 34 | except KeyboardInterrupt: 35 | exit('Ctrl C , exit') 36 | except: 37 | logging.info(f'ERROR: {traceback.format_exc()}') 38 | result_entities = [] 39 | candidate_entities = [] 40 | # print(f" result answer : {result_entities}") 41 | # precision, recall, f1 = get_metrics(subject_entities, candidate_entities) 42 | # if recall == 0 or len(set(standard_entities) & set(candidate_entities)) == 0: 43 | # print(f"question: {question}\n" 44 | # f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}" 45 | # f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n") 46 | # import ipdb 47 | # ipdb.set_trace() 48 | data.append( 49 | [question, standard_subject_entities, list(candidate_entities), standard_answer_entities, result_entities]) 50 | data_df = pd.DataFrame(data, columns=['question', 'standard_subject_entities', 'candidate_entities', 51 | 'standard_answer_entities', 'result_entities']) 52 | data_df.to_csv(ResultSaver().train_result_csv, index=False, encoding='utf_8_sig') 53 | 54 | 55 | def valid_qa(): 56 | """验证数据答案;验证集做回答;得到提交数据 """ 57 | logging_config('valid_qa.log', stream_log=True) 58 | from ckbqa.qa.qa import QA 59 | from ckbqa.dataset.data_prepare import question_patten 60 | # 61 | data_df = pd.DataFrame([question.strip() for question in open(valid_question_txt, 'r', encoding='utf-8')], 62 | columns=['question']) 63 | logging.info(f"data_df.shape: {data_df.shape}") 64 | qa = QA() 65 | valid_datas = {'question': [], 'result': []} 66 | with open(ResultSaver().submit_result_txt, 'w', encoding='utf-8') as f: 67 | for index, row in tqdm(data_df.iterrows(), total=data_df.shape[0], desc='qa answer'): 68 | question = row['question'] 69 | q_text = question_patten.findall(question)[0] 70 | try: 71 | result_entities = qa.run(q_text) 72 | result_entities = [ent if ent.startswith('<') else f'"{ent}"' for ent in result_entities] 73 | f.write('\t'.join(result_entities) + '\n') 74 | except KeyboardInterrupt: 75 | exit('Ctrl C , exit') 76 | except: 77 | result_entities = [] 78 | f.write('""\n') 79 | logging.info(traceback.format_exc()) 80 | f.flush() 81 | valid_datas['question'].append(question) 82 | valid_datas['result'].append(result_entities) 83 | pd.DataFrame(valid_datas).to_csv(ResultSaver().valid_result_csv, index=False, encoding='utf_8_sig') 84 | 85 | 86 | def valid2submit(): 87 | ''' 88 | 验证数据输出转化为答案提交 89 | ''' 90 | logging_config('valid2submit.log', stream_log=True) 91 | data_df = pd.read_csv(ResultSaver(find_exist_path=True).valid_result_csv[0]) 92 | with open(ResultSaver().submit_result_txt, 'w', encoding='utf-8') as f: 93 | for index, row in data_df.iterrows(): 94 | ents = [] 95 | for ent in eval(row['result'])[:10]: # 只保留前10个 96 | if ent.startswith('<'): 97 | ents.append(ent) 98 | elif ent.startswith('"'): 99 | ent = ent.strip('"') 100 | ents.append(f'"{ent}"') 101 | if ents: 102 | f.write('\t'.join(ents) + '\n') 103 | else: 104 | f.write('""\n') 105 | f.flush() 106 | 107 | 108 | def test(): 109 | logging_config('test.log', stream_log=True) 110 | from ckbqa.qa.qa import QA 111 | from ckbqa.dataset.data_prepare import question_patten 112 | # from ckbqa.qa.evaluation_matrics import get_metrics 113 | # 114 | logging.info('* start run ...') 115 | qa = QA() 116 | q188 = 'q188:墨冰仙是哪个门派的?' 117 | q189 = 'q189:在金庸小说《天龙八部》中,斗转星移的修习者是谁?' 118 | q190 = 'q190:《基督山伯爵》的作者是谁?' # 这里跑没问题 119 | for question in [q188, q189, q190]: 120 | q_text = question_patten.findall(question)[0] 121 | (result_entities, candidate_entities, 122 | candidate_out_paths, candidate_in_paths) = qa.run(q_text, return_candidates=True) 123 | print(question) 124 | import ipdb 125 | ipdb.set_trace() 126 | 127 | 128 | def main(): 129 | parser = argparse.ArgumentParser(description="基础,通用parser") 130 | # logging config 日志配置 131 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕 132 | # 133 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个 134 | 135 | group.add_argument('--train_qa', action="store_true", help="训练集答案") 136 | group.add_argument('--valid_qa', action="store_true", help="验证数据(待提交数据)答案") 137 | group.add_argument('--valid2submit', action="store_true", help="验证数据处理后提交") 138 | group.add_argument('--task', action="store_true", help="临时组织的任务,调用多个函数") 139 | group.add_argument('--test', action="store_true", help="临时测试") 140 | # parse args 141 | args = parser.parse_args() 142 | # 143 | # from ckbqa.utils.tools import ProcessManager 144 | # ProcessManager().run() 145 | if args.train_qa: 146 | train_qa() 147 | elif args.valid_qa: 148 | valid_qa() 149 | elif args.valid2submit: 150 | valid2submit() 151 | elif args.test: 152 | test() 153 | elif args.task: 154 | task() 155 | 156 | 157 | def task(): 158 | logging_config('qa_task.log', stream_log=True) 159 | train_qa() 160 | valid_qa() 161 | 162 | 163 | if __name__ == '__main__': 164 | """ 165 | example: 166 | nohup python qa.py --train_qa &>train_qa.out & 167 | nohup python qa.py --valid_qa &>valid_qa.out & 168 | nohup python qa.py --valid2submit &>valid2submit.out & 169 | nohup python qa.py --task &>qa_task.out & 170 | """ 171 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况 172 | # ProcessManager().run() 173 | main() 174 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import hanlp 4 | import jieba 5 | 6 | 7 | def add_root_path(): 8 | import sys 9 | from pathlib import Path 10 | cur_dir = Path(__file__).absolute().parent 11 | times = 10 12 | while not str(cur_dir).endswith('ccks-2020') and times: 13 | cur_dir = cur_dir.parent 14 | times -= 1 15 | print(cur_dir) 16 | sys.path.append(str(cur_dir)) 17 | 18 | 19 | add_root_path() 20 | 21 | from ckbqa.dataset.data_prepare import load_data, question_patten 22 | from ckbqa.utils.logger import logging_config 23 | from ckbqa.utils.tools import json_load 24 | 25 | # 26 | logging_config('test.log', stream_log=True) 27 | 28 | from config import Config 29 | 30 | 31 | def ner_train_data(): 32 | ent2mention = json_load(Config.ent2mention_json) 33 | # recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH) 34 | tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG') 35 | recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH) 36 | from LAC import LAC 37 | # 装载LAC模型 38 | lac = LAC(mode='lac') 39 | jieba.enable_paddle() # 启动paddle模式。 0.40版之后开始支持,早期版本不支持 40 | _ent_patten = re.compile(r'["<](.*?)[>"]') 41 | for q, sparql, a in load_data(): 42 | q_text = question_patten.findall(q)[0] 43 | hanlp_entities = recognizer([list(q_text)]) 44 | hanlp_words = tokenizer(q_text) 45 | lac_results = lac.run(q_text) 46 | q_entities = _ent_patten.findall(sparql) 47 | jieba_results = list(jieba.cut_for_search(q_text)) 48 | mentions = [ent2mention.get(ent) for ent in q_entities] 49 | print(f"q_text: {q_text}\nq_entities: {q_entities}, " 50 | f"\nlac_results:{lac_results}" 51 | f"\nhanlp_words: {hanlp_words}, " 52 | f"\njieba_results: {jieba_results}, " 53 | f"\nhanlp_entities: {hanlp_entities}, " 54 | f"\nmentions: {mentions}") 55 | import ipdb 56 | ipdb.set_trace() 57 | 58 | 59 | if __name__ == '__main__': 60 | ner_train_data() 61 | -------------------------------------------------------------------------------- /tests/test_ceg.py: -------------------------------------------------------------------------------- 1 | def add_root_path(): 2 | import sys 3 | from pathlib import Path 4 | cur_dir = Path(__file__).absolute().parent 5 | times = 10 6 | while not str(cur_dir).endswith('ccks-2020') and times: 7 | cur_dir = cur_dir.parent 8 | times -= 1 9 | print(cur_dir) 10 | sys.path.append(str(cur_dir)) 11 | 12 | 13 | add_root_path() 14 | import logging 15 | 16 | from ckbqa.dataset.data_prepare import load_data 17 | from ckbqa.qa.el import EL 18 | from ckbqa.utils.logger import logging_config 19 | 20 | logging_config('test.log', stream_log=True) 21 | 22 | 23 | def test_recgnizer(): 24 | print('start') 25 | logging.info('test start ...') 26 | el = EL() 27 | for q, sparql, a in load_data(): 28 | q_text = q.split(':')[1] 29 | rec_entities = el.ceg.get_ent2mention(q_text) 30 | print(rec_entities) 31 | print(q) 32 | print(sparql) 33 | print(a) 34 | import ipdb 35 | ipdb.set_trace() 36 | 37 | 38 | if __name__ == '__main__': 39 | test_recgnizer() 40 | --------------------------------------------------------------------------------