├── .gitignore
├── EDA.ipynb
├── README.md
├── ckbqa
├── __init__.py
├── dao
│ ├── __init__.py
│ ├── db.py
│ ├── db_tools.py
│ ├── mongo_models.py
│ ├── mongo_utils.py
│ ├── sqlite_models.py
│ └── sqlite_utils.py
├── dataset
│ ├── __init__.py
│ ├── data_prepare.py
│ └── kb_data_prepare.py
├── layers
│ ├── __init__.py
│ ├── losses.py
│ └── modules.py
├── models
│ ├── __init__.py
│ ├── base_trainer.py
│ ├── data_helper.py
│ ├── entity_score
│ │ ├── __init__.py
│ │ └── model.py
│ ├── evaluation_matrics.py
│ ├── ner
│ │ ├── __init__.py
│ │ ├── crf.py
│ │ └── model.py
│ └── relation_score
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── predictor.py
│ │ └── trainer.py
├── qa
│ ├── __init__.py
│ ├── algorithms.py
│ ├── cache.py
│ ├── el.py
│ ├── lac_tools.py
│ ├── neo4j_graph.py
│ ├── qa.py
│ └── relation_extractor.py
└── utils
│ ├── __init__.py
│ ├── async_tools.py
│ ├── decorators.py
│ ├── gpu_selector.py
│ ├── logger.py
│ ├── saver.py
│ ├── sequence.py
│ └── tools.py
├── config.py
├── data.py
├── docs
├── CCKS 2019 知识图谱评测技术报告:实体、关系、事件及问答 .pdf
└── bad_case.md
├── evaluate.py
├── examples
├── __init__.py
├── answer_format.py
├── bad_case.py
├── del_method.py
├── kb_data.py
├── lac_test.py
├── single_example.py
└── top_path.py
├── manage.py
├── port_map.sh
├── qa.py
└── tests
├── __init__.py
├── test_.py
└── test_ceg.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | #pycharm
132 | .idea
133 |
134 | #
135 | local
136 | data/raw_data/*
137 | data/
138 | output
139 | supports
140 |
141 | #
142 | *.txt
143 | *.csv
144 | *.tgz
145 | *.json
146 | *.old
147 |
--------------------------------------------------------------------------------
/EDA.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 数据分析"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "from ckbqa.dataset.data_prepare import get_raw_data_df"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "data_df = get_raw_data_df()"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 10,
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "data": {
36 | "text/html": [
37 | "
\n",
38 | "\n",
51 | "
\n",
52 | " \n",
53 | " \n",
54 | " | \n",
55 | " question | \n",
56 | " answer | \n",
57 | " sparql | \n",
58 | "
\n",
59 | " \n",
60 | " \n",
61 | " \n",
62 | " count | \n",
63 | " 4000 | \n",
64 | " 4000 | \n",
65 | " 4000 | \n",
66 | "
\n",
67 | " \n",
68 | " unique | \n",
69 | " 3988 | \n",
70 | " 2930 | \n",
71 | " 3965 | \n",
72 | "
\n",
73 | " \n",
74 | " top | \n",
75 | " [科学家牛顿的英文名是?] | \n",
76 | " <疾病> | \n",
77 | " select ?x where { <大华股份> <市盈率> ?x. } | \n",
78 | "
\n",
79 | " \n",
80 | " freq | \n",
81 | " 3 | \n",
82 | " 253 | \n",
83 | " 3 | \n",
84 | "
\n",
85 | " \n",
86 | "
\n",
87 | "
"
88 | ],
89 | "text/plain": [
90 | " question answer sparql\n",
91 | "count 4000 4000 4000\n",
92 | "unique 3988 2930 3965\n",
93 | "top [科学家牛顿的英文名是?] <疾病> select ?x where { <大华股份> <市盈率> ?x. }\n",
94 | "freq 3 253 3"
95 | ]
96 | },
97 | "execution_count": 10,
98 | "metadata": {},
99 | "output_type": "execute_result"
100 | }
101 | ],
102 | "source": [
103 | "data_df.describe()"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 11,
109 | "metadata": {},
110 | "outputs": [
111 | {
112 | "data": {
113 | "text/html": [
114 | "\n",
115 | "\n",
128 | "
\n",
129 | " \n",
130 | " \n",
131 | " | \n",
132 | " question | \n",
133 | " answer | \n",
134 | " sparql | \n",
135 | "
\n",
136 | " \n",
137 | " \n",
138 | " \n",
139 | " 0 | \n",
140 | " [莫妮卡·贝鲁奇的代表作?] | \n",
141 | " <西西里的美丽传说> | \n",
142 | " select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. } | \n",
143 | "
\n",
144 | " \n",
145 | " 1 | \n",
146 | " [《湖上草》是谁的诗?] | \n",
147 | " <柳如是_(明末“秦淮八艳”之一)> | \n",
148 | " select ?x where { ?x <主要作品> <湖上草>. } | \n",
149 | "
\n",
150 | " \n",
151 | " 2 | \n",
152 | " [龙卷风的英文名是什么?] | \n",
153 | " \"Tornado\" | \n",
154 | " select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. } | \n",
155 | "
\n",
156 | " \n",
157 | " 3 | \n",
158 | " [新加坡的水域率是多少?] | \n",
159 | " \"1.444%\" | \n",
160 | " select ?x where { <新加坡> <水域率> ?x. } | \n",
161 | "
\n",
162 | " \n",
163 | " 4 | \n",
164 | " [商朝在哪场战役中走向覆灭?] | \n",
165 | " <牧野之战> | \n",
166 | " select ?x where { <商朝> <灭亡> ?x. } | \n",
167 | "
\n",
168 | " \n",
169 | "
\n",
170 | "
"
171 | ],
172 | "text/plain": [
173 | " question answer \\\n",
174 | "0 [莫妮卡·贝鲁奇的代表作?] <西西里的美丽传说> \n",
175 | "1 [《湖上草》是谁的诗?] <柳如是_(明末“秦淮八艳”之一)> \n",
176 | "2 [龙卷风的英文名是什么?] \"Tornado\" \n",
177 | "3 [新加坡的水域率是多少?] \"1.444%\" \n",
178 | "4 [商朝在哪场战役中走向覆灭?] <牧野之战> \n",
179 | "\n",
180 | " sparql \n",
181 | "0 select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. } \n",
182 | "1 select ?x where { ?x <主要作品> <湖上草>. } \n",
183 | "2 select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. } \n",
184 | "3 select ?x where { <新加坡> <水域率> ?x. } \n",
185 | "4 select ?x where { <商朝> <灭亡> ?x. } "
186 | ]
187 | },
188 | "execution_count": 11,
189 | "metadata": {},
190 | "output_type": "execute_result"
191 | }
192 | ],
193 | "source": [
194 | "data_df.head()"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": 5,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "data_df['q_len'] = data_df['question'].apply(len)\n",
204 | "data_df['a_len'] = data_df['answer'].apply(len)"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 7,
210 | "metadata": {},
211 | "outputs": [
212 | {
213 | "data": {
214 | "text/html": [
215 | "\n",
216 | "\n",
229 | "
\n",
230 | " \n",
231 | " \n",
232 | " | \n",
233 | " question | \n",
234 | " answer | \n",
235 | " sparql | \n",
236 | " q_len | \n",
237 | " a_len | \n",
238 | "
\n",
239 | " \n",
240 | " \n",
241 | " \n",
242 | " 0 | \n",
243 | " [莫妮卡·贝鲁奇的代表作?] | \n",
244 | " <西西里的美丽传说> | \n",
245 | " select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. } | \n",
246 | " 1 | \n",
247 | " 10 | \n",
248 | "
\n",
249 | " \n",
250 | " 1 | \n",
251 | " [《湖上草》是谁的诗?] | \n",
252 | " <柳如是_(明末“秦淮八艳”之一)> | \n",
253 | " select ?x where { ?x <主要作品> <湖上草>. } | \n",
254 | " 1 | \n",
255 | " 18 | \n",
256 | "
\n",
257 | " \n",
258 | " 2 | \n",
259 | " [龙卷风的英文名是什么?] | \n",
260 | " \"Tornado\" | \n",
261 | " select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. } | \n",
262 | " 1 | \n",
263 | " 9 | \n",
264 | "
\n",
265 | " \n",
266 | " 3 | \n",
267 | " [新加坡的水域率是多少?] | \n",
268 | " \"1.444%\" | \n",
269 | " select ?x where { <新加坡> <水域率> ?x. } | \n",
270 | " 1 | \n",
271 | " 8 | \n",
272 | "
\n",
273 | " \n",
274 | " 4 | \n",
275 | " [商朝在哪场战役中走向覆灭?] | \n",
276 | " <牧野之战> | \n",
277 | " select ?x where { <商朝> <灭亡> ?x. } | \n",
278 | " 1 | \n",
279 | " 6 | \n",
280 | "
\n",
281 | " \n",
282 | "
\n",
283 | "
"
284 | ],
285 | "text/plain": [
286 | " question answer \\\n",
287 | "0 [莫妮卡·贝鲁奇的代表作?] <西西里的美丽传说> \n",
288 | "1 [《湖上草》是谁的诗?] <柳如是_(明末“秦淮八艳”之一)> \n",
289 | "2 [龙卷风的英文名是什么?] \"Tornado\" \n",
290 | "3 [新加坡的水域率是多少?] \"1.444%\" \n",
291 | "4 [商朝在哪场战役中走向覆灭?] <牧野之战> \n",
292 | "\n",
293 | " sparql q_len a_len \n",
294 | "0 select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. } 1 10 \n",
295 | "1 select ?x where { ?x <主要作品> <湖上草>. } 1 18 \n",
296 | "2 select ?x where { <龙卷风_(一种自然天气现象)> <外文名> ?x. } 1 9 \n",
297 | "3 select ?x where { <新加坡> <水域率> ?x. } 1 8 \n",
298 | "4 select ?x where { <商朝> <灭亡> ?x. } 1 6 "
299 | ]
300 | },
301 | "execution_count": 7,
302 | "metadata": {},
303 | "output_type": "execute_result"
304 | }
305 | ],
306 | "source": [
307 | "data_df.head()"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": null,
313 | "metadata": {},
314 | "outputs": [],
315 | "source": []
316 | }
317 | ],
318 | "metadata": {
319 | "kernelspec": {
320 | "display_name": "Python 3",
321 | "language": "python",
322 | "name": "python3"
323 | },
324 | "language_info": {
325 | "codemirror_mode": {
326 | "name": "ipython",
327 | "version": 3
328 | },
329 | "file_extension": ".py",
330 | "mimetype": "text/x-python",
331 | "name": "python",
332 | "nbconvert_exporter": "python",
333 | "pygments_lexer": "ipython3",
334 | "version": "3.7.1"
335 | }
336 | },
337 | "nbformat": 4,
338 | "nbformat_minor": 2
339 | }
340 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CCKS-2020
2 |
3 | - [CCKS 2020:新冠知识图谱构建与问答评测(四)新冠知识图谱问答评测
4 | ](https://biendata.com/competition/ccks_2020_7_4/)
5 |
6 | ### 0. 代码结构
7 |
8 | - 主要有三大模块:
9 | 1. dataset模块:数据预处理等与数据相关的
10 | 1. config.py 所有文件路径写这里,方便查找
11 | 2. models模块:各种深度学习模型
12 | 3. qa模块:将训练好的模型和策略组合应用,完成问答
13 |
14 | **代码运行必须从几个接口模块开始;小脚本测试,可模仿接口模块,新写一个接口模块,在其中导入函数运行**
15 |
16 | ├── dao //数据库接口,准备做辅助缓存
17 | ├── dataset //数据相关
18 | │ ├── data_prepare.py //train数据预处理,将给定训练数据做各种处理,构造字典等
19 | │ └── kb_data_prepare.py //图谱数据预处理,将给定图谱做各种转换,构造字典等
20 | ├── layers //损失函数等
21 | ├── models //模型
22 | │ ├── entity_score // 对识别出的主实体打分模型
23 | │ │ └── model.py // 模型定义
24 | │ ├── ner //主实体识别模型
25 | │ │ └── model.py // 模型定义
26 | │ ├── relation_score //对实体关联的关系打分的模型
27 | │ │ ├── model.py // 模型定义
28 | │ │ ├── predictor.py // 模型封装用作后续预测
29 | │ │ └── trainer.py // 模型训练
30 | │ ├── base_trainer.py //训练模块;模型初始化到训练
31 | │ ├── data_helper.py //train数据处理成合适的格式,feed给模型
32 | │ └── evaluation_matrics.py // 指标计算
33 | ├── qa //问答模块
34 | │ ├── algorithms.py // 后处理算法
35 | │ ├── cache.py // 大文件,在单例模式缓存;避免多次载入内存;ent2id等放在这里,提供给其他模块公共使用
36 | │ ├── el.py //entity link,实体链接(主实体识别模块)
37 | │ ├── entity_score.py // 对识别出的主实体打分模型
38 | │ ├── lac_tools.py // 分词模块自定义优化等
39 | │ ├── neo4j_graph.py //图数据库查询缓存等
40 | │ ├── qa.py //问答接口,将其他模块组装到这里完成问答
41 | │ └── relation_extractor.py //实体关联关系识别
42 | ├── utils //通用工具
43 | ├── docs //文档
44 | ├── examples //临时任务,模块试验等,单个脚本
45 | ├── tests //测试
46 | ├── config.py //所有数据路径和少量全局配置
47 | ├── data.py //所有数据处理的入口文件
48 | ├── evaluate.py //模块评测入口文件
49 | ├── manage.py //所有模型训练的入口文件
50 | ├── qa.py //问答入口文件
51 | ├── README.md //说明文档
52 | └── requirements.txt //依赖包
53 |
54 |
55 | ### 1. 图数据库
56 | - 知识库管理系统
57 | - [gStore](http://gstore-pku.com/pcsite/index.html)
58 | - [gStore - github](https://github.com/pkumod/gStore/blob/master/docs/DOCKER_DEPLOY_CN.md)
59 | - SPARQL
60 | - 国际化资源标识符(Internationalized Resource Identifiers,简称IRI),与其相提并论的是URI(Uniform Resource Identifier,统一资源标志符)。
61 | 使用来表示一个IRI
62 | - Literal用于表示三元组中客体(Object),表示非IRI的数据,例如字符串(String),数字(xsd:integer),日期(xsd:date)等。
63 | 普通字符串等 "chat"
64 | - [RDF查询语言SPARQL - SimmerChan的文章 - 知乎
65 | ](https://zhuanlan.zhihu.com/p/32703794)
66 |
67 | neo4j数据库
68 | - 安装
69 | https://segmentfault.com/a/1190000015389941
70 |
71 | - 运行
72 | cd /home/wangshengguang/neo4j-community-3.4.5/bin
73 | ./neo4j start
74 | ./neo4j stop
75 |
76 | - 数据导入
77 | cd /home/wangshengguang/neo4j-community-3.4.5/bin
78 | ./neo4j-admin import --database=graph.db --nodes /home/wangshengguang/ccks-2020/data/graph_entity.csv --relationships /home/wangshengguang/ccks-2020/data/graph_relation.csv --ignore-duplicate-nodes=true --id-type INTEGER --ignore-missing-nodes=true
79 |
80 | - 创建索引
81 | CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.id IS UNIQUE;
82 | CREATE INDEX ON :Entity(name)
83 | CREATE INDEX ON :Relation(name)
84 |
85 |
86 | - [neo4j学习笔记(三)——python接口-创建删除结点和关系](https://blog.csdn.net/qq_36591505/article/details/100987105)
87 | - [neo4j︱与python结合的py2neo使用教程(四)](https://blog.csdn.net/sinat_26917383/article/details/79901207)
88 | - [neo4j中文文档](http://neo4j.com.cn/public/docs/index.html)
89 |
90 |
91 |
92 | ## 2. 历年方案
93 |
94 | - [CCKS 2019 | 开放域中文KBQA系统 - 最AI的小PAI的文章 - 知乎
95 | ](https://zhuanlan.zhihu.com/p/92317079)
96 | 1. 首先在句子中找到主题实体。在这里,我们使用了比赛组织者提供的Entity-mention文件和一些外部工具,例如paddle-paddle。
97 | **2. 然后在关系识别模块中,通过提取知识图中的主题实体的子图来找到问题(也称为谓词)的关系。通过相似性评分模型获得所有关系的排名。**
98 | 3. 最后,在答案选择模块中,根据简单复杂问题分类器和一些规则,得出最终答案。
99 |
100 | - [百度智珠夺冠:在知识图谱领域百度持续领先 ](https://www.sohu.com/a/339187520_630344)
101 | 1. 实体链接组件把问题中提及的实体链接到了知识库,并识别问题的核心实体。为了提高链接的精度,链接组件综合考虑了实体的子图与问题的匹配度、实体的流行度、指称正确度等多种特征,最后利用 LambdaRank 算法对实体进行排序,得到得分最高的实体。
102 | 2. 子图排序组件目标是从多种角度计算问题与各个子图的匹配度,最后综合多个匹配度的得分,得到出得分最高的答案子图。
103 | 3. 针对千万级的图谱,百度智珠团队采用了自主研发的策略来进行子图生成时的剪枝,综合考虑了召回率、精确率和时间代价等因素,从而提高子图排序的效率和效果。
104 | 针对开放领域的子图匹配,采用字面匹配函数计算符号化的语义相似,应用 word2vec 框架计算浅层的语义匹配,最后应用 BERT 算法做深度语义对齐。
105 | 除此之外,方案还针对具体的特征类型的问题进行一系列的意图判断,进一步提升模型在真实的问答场景中的效果和精度,更好地控制返回的答案类型,更符合真实的问答产品的需要。
106 |
107 | - http://nlpprogress.com/
108 |
109 | ## 3. 数据分析及预处理
110 | 以下提到所有路径都是/home/wangshengguang/下的相对路径
111 | ### 3.1 原始数据
112 | 原始数据存放在data/roaw_data 目录下
113 | 1. 问答数据 data/raw_data/ccks_2020_7_4_Data下
114 | 1. 有标注训练集15999:data/raw_data/ccks_2020_7_4_Data/task1-4_train_2020.txt
115 | 2. 无标注验证集(做提交)1529:data/raw_data/ccks_2020_7_4_Data/task1-4_valid_2020.questions
116 | 2. 图谱数据 data/raw_data/PKUBASE下
117 | 1. 所有三元组66499745:data/raw_data/PKUBASE/pkubase-complete.txt
118 |
119 | ### 3.2 数据分析
120 | 三元组中实体分两类:<实体>和"属性";
121 |
122 |
123 | ### 3.3 预处理tips:
124 | 三元组中数据分两类,<实体>和"属性";
125 | 预处理时将属性的双引号去掉(包括构建的字典和导入neo4j的数据全部双引号都被去掉了,主要考虑双引号作为json的key不方便保存),方便使用;
126 | kb_data_prepare.py->iter_triples
127 | 在最后提交时需要恢复
128 |
129 |
130 | ## 4. 目前方案
131 | 1. 先做NER识别 主实体
132 | 2. 查找实体的关系,做分类,挑选出top 路径
133 | 3. 生成sparql查询结果
134 |
135 |
136 | ### 4.1 主实体识别模块
137 | EL
138 |
139 |
140 | ### 4.2 关系打分模块
141 | RelationExtractor
142 |
143 |
144 | ### 4.3 后处理模块
145 | Algorithms
146 |
147 |
148 |
149 | ## Others
150 |
151 | - 本地端口映射后在本机访问远程图数据库
152 | ssh -f wangshengguang@remotehost -N -L 7474:localhost:7474
153 | ssh -f wangshengguang@remotehost -N -L 7687:localhost:7687
154 |
155 | 访问:[http://localhost:7474/browser/](http://localhost:7474/browser/)
156 |
157 |
158 |
159 | - [Jupyter Notebook 有哪些奇技淫巧? - z.defying的回答 - 知乎
160 | ](https://www.zhihu.com/question/266988943/answer/1154607853)
161 | - [linux下查看文件编码及修改编码](https://blog.csdn.net/jnbbwyth/article/details/6991425)
162 |
163 |
164 | ## 执行顺序
165 | 1. 数据准备 data.py + neo4j 数据库安装及数据导入
166 | 2. 模型训练 manage.py
167 | 3. 组合以上模型得到问答结果 qa.py
168 |
--------------------------------------------------------------------------------
/ckbqa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/__init__.py
--------------------------------------------------------------------------------
/ckbqa/dao/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/dao/__init__.py
--------------------------------------------------------------------------------
/ckbqa/dao/db.py:
--------------------------------------------------------------------------------
1 | # -*-coding:utf-8 -*-
2 | """
3 | 在此管理数据库连接
4 | 在连接池启用的情况下,session 只是对连接池中的连接进行操作:
5 | session = Session() 将 Session 实例化的过程只是从连接池中取一个连接,在用完之后
6 | 使用 session.close(),session.commit(),session.rooback()还回到连接池中,而不是真正的关闭连接。
7 | 连接池禁用之后,sqlalchemy 会在查询之前创建连接,在查询结束,调用 session.close() 的时候关闭数据库连接。
8 | http://qinfei.glrsmart.com/2017/11/17/python-sqlalchemy-shu-ju-ku-lian-jie-chi/
9 | """
10 | import logging
11 | from mongoengine import connect, register_connection
12 |
13 | from sqlalchemy import create_engine
14 | from sqlalchemy.orm import sessionmaker
15 |
16 | from config import data_dir
17 |
18 | # db_pool_config = {
19 | # "pool_pre_ping": True, # 检查连接,无效则重置
20 | # "pool_size": 5, # 连接数, 默认5
21 | # "max_overflow": 10, # 超出pool_size后最多达到几个连接,超过部分使用后直接关闭,不放回连接池,默认为10,-1不限制
22 | # "pool_recycle": 7200, # 连接重置周期,默认-1,推荐7200,表示连接在给定时间之后会被回收,勿超过mysql 默认8小时
23 | # "pool_timeout": 30 # 等待 pool_timeout 秒,没有获取到连接之后,放弃从池中获取连接
24 | # "echo": False 显示执行的SQL语句 #想查看语句直接print str(session.query(UserProfile).filter_by(user_id='xxxx'))
25 | # }
26 |
27 |
28 | # sqlalchemy 一次只允许操作一个数据库,必须指定database
29 | sqlite_db_engine = create_engine(f'sqlite:///{data_dir}/data.sqlite?charset=utf8mb4')
30 |
31 |
32 | class DB(object):
33 | def __init__(self):
34 | session_factory = sessionmaker(bind=sqlite_db_engine) # autocommit=True若无开启Transaction,会自动commit
35 | self.session = session_factory()
36 | # self.session = scoped_session(session_factory) # 多线程安全,每个线程使用单独数据库连接
37 |
38 | def select(self, sql, params=None):
39 | """
40 | :param sql: type str
41 | :param params: Optional dictionary, or list of dictionaries
42 | :return: result = session.execute(
43 | "SELECT * FROM user WHERE id=:user_id",
44 | {"user_id":5}
45 | )
46 | """
47 | cursor = self.session.execute(sql, params)
48 | res = cursor.fetchall()
49 | self.session.commit()
50 | return res
51 |
52 | def execute(self, sql, params=None):
53 | """
54 | :param sql: type str
55 | :param params: Optional dictionary or list of dictionaries
56 | :return: result = session.execute(
57 | "SELECT * FROM user WHERE id=:user_id",
58 | {"user_id":5}
59 | )
60 | """
61 | self.session.execute(sql, params)
62 | self.session.commit()
63 |
64 | def __enter__(self):
65 | return self
66 |
67 | def __exit__(self, exc_type, exc_val, exc_tb):
68 | try:
69 | self.session.commit()
70 | except Exception:
71 | self.session.rollback()
72 | logging.info('save stage final,save error, session rollback ')
73 | finally:
74 | self.session.close()
75 | if exc_type is not None or exc_val is not None or exc_tb is not None:
76 | logging.info('exc_type: {}, exc_val: {}, exc_tb: {}'.format(exc_type, exc_val, exc_tb))
77 |
78 | def __del__(self):
79 | self.session.close()
80 |
81 |
82 | # mongodb
83 | class Mongo(object):
84 | mongodb = {
85 | "host": '127.0.0.1',
86 | "port": 27017,
87 | "username": None,
88 | "password": None,
89 | "db": 'kb'
90 | }
91 | connect(**mongodb)
92 | # # 连接数据库,不存在则会自动创建
93 | register_connection("whiski_db", "whiski", host=mongodb['host'])
94 | register_connection("recommend_db", "recommend", host=mongodb['host'])
95 |
--------------------------------------------------------------------------------
/ckbqa/dao/db_tools.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import traceback
3 | from collections import Iterable
4 | from functools import wraps
5 |
6 | from sqlalchemy import exc
7 |
8 |
9 | def try_commit_rollback(expunge=None):
10 | """ 装饰继承自DB的类的对象方法
11 | :param expunge: bool,将SQLAlchemy的数据对象实例转为一个简单的对象(切断与数据库会话的联系)
12 | 只在且在查询时尽量使用,做查询且返回sqlalchemy对象时使用,不知道就是用不到
13 | """
14 |
15 | def out_wrapper(func):
16 | @wraps(func)
17 | def wrapper(self, *args, **kwargs):
18 | res = None
19 | db_session = self.session
20 | try:
21 | res = func(self, *args, **kwargs)
22 | if expunge and res is not None:
23 | if isinstance(res, Iterable):
24 | [db_session.expunge(_data) for _data in res if isinstance(_data, BaseModel)]
25 | elif isinstance(res, BaseModel):
26 | db_session.expunge(res)
27 | else:
28 | raise TypeError("Please ensure object(s) is instance(s) of BaseModel")
29 | db_session.commit()
30 | except exc.IntegrityError as e: # duplicate key
31 | db_session.rollback()
32 | logging.error("sqlalchemy.exc.IntegrityError: {}".format(e))
33 | except exc.DataError as e: # invalid input syntax for uuid: "7ae16-ake"
34 | db_session.rollback()
35 | logging.error("sqlalchemy.exc.DataError: {}".format(e))
36 | except Exception:
37 | db_session.rollback()
38 | logging.error(traceback.format_exc())
39 | return res
40 |
41 | return wrapper
42 |
43 | return out_wrapper if isinstance(expunge, bool) else out_wrapper(expunge)
44 |
45 |
46 | try_commit_rollback_expunge = try_commit_rollback(expunge=True)
47 |
48 |
49 | def try_commit_rollback_dbsession(func):
50 | """装饰传入db_session对象的方法"""
51 |
52 | @wraps(func)
53 | def wrapper(self, db_session, *args, **kwargs):
54 | res = None
55 | try:
56 | res = func(self, db_session, *args, **kwargs)
57 | db_session.commit()
58 | except exc.IntegrityError as e: # duplicate key
59 | db_session.rollback()
60 | logging.error("sqlalchemy.exc.IntegrityError: {}".format(e))
61 | except exc.DataError as e: # invalid input syntax for uuid: "7ae16-ake"
62 | db_session.rollback()
63 | logging.error("sqlalchemy.exc.DataError: {}".format(e))
64 | except Exception:
65 | db_session.rollback()
66 | logging.error(traceback.format_exc())
67 | return res
68 |
69 | return wrapper
70 |
--------------------------------------------------------------------------------
/ckbqa/dao/mongo_models.py:
--------------------------------------------------------------------------------
1 | from mongoengine import Document, StringField, IntField, ListField
2 |
3 |
4 | class Entity2id(Document):
5 | meta = {
6 | 'db_alias': 'entity2id',
7 | 'collection': 'graph',
8 | 'strict': False,
9 | 'indexes': [
10 | 'name',
11 | 'id'
12 | ]
13 | }
14 |
15 | name = StringField(required=True)
16 | id = IntField(required=True)
17 |
18 |
19 | class Relation2id(Document):
20 | meta = {
21 | 'db_alias': 'relation2id',
22 | 'collection': 'graph',
23 | 'strict': False,
24 | 'indexes': [
25 | 'name',
26 | 'id'
27 | ]
28 | }
29 |
30 | name = StringField(required=True)
31 | id = IntField(required=True)
32 |
33 |
34 | class Graph(Document):
35 | meta = {
36 | 'db_alias': 'sub_graph',
37 | 'collection': 'graph',
38 | 'strict': False,
39 | 'indexes': [
40 | 'entity_name',
41 | 'entity_id'
42 | ]
43 | }
44 |
45 | entity_name = StringField(required=True)
46 | entity_id = IntField(required=True)
47 | ins = ListField(required=True)
48 | outs = ListField(required=True)
49 |
50 |
51 | class SubGraph(Document):
52 | meta = {
53 | 'db_alias': 'sub_graph',
54 | 'collection': 'graph',
55 | 'strict': False,
56 | 'indexes': [
57 | 'entity_name',
58 | 'entity_id'
59 | ]
60 | }
61 |
62 | entity_name = StringField(required=True)
63 | entity_id = IntField(required=True)
64 | sub_graph = ListField(required=True)
65 | hop = IntField(required=True) # 几跳
66 |
--------------------------------------------------------------------------------
/ckbqa/dao/mongo_utils.py:
--------------------------------------------------------------------------------
1 | from .db import Mongo
2 | from .mongo_models import Graph, Entity2id
3 |
4 |
5 | class MongoDB(Mongo):
6 | def entity2id(self, entity_name):
7 | return Entity2id.objects(name_eq=entity_name).all()
8 |
9 | def save_graph(self, graph_node: Graph):
10 | return Graph.objects(graph_node).all()
11 |
--------------------------------------------------------------------------------
/ckbqa/dao/sqlite_models.py:
--------------------------------------------------------------------------------
1 | # -*-coding:utf-8 -*-
2 |
3 | from sqlalchemy import Column
4 | from sqlalchemy.dialects.sqlite import CHAR, INTEGER, TEXT, SMALLINT
5 | from sqlalchemy.ext.declarative import declarative_base
6 |
7 | BaseModel = declarative_base() # 创建一个类,这个类的子类可以自动与一个表关联
8 |
9 |
10 | class Graph(BaseModel):
11 | __tablename__ = 'graph'
12 |
13 | id = Column(INTEGER, primary_key=True, comment="entity id")
14 | name = Column(CHAR(256), default=None, index=True, unique=True, comment="entity name")
15 | pure_name = Column(CHAR(256), default=None, index=True, comment="entity name")
16 | in_ids = Column(TEXT, default=[], comment="入度")
17 | out_ids = Column(TEXT, default=[], comment="出度")
18 | type = Column(SMALLINT, default=0, comment="entity:1, relation:2")
19 |
20 |
21 | class SubGraph(BaseModel):
22 | __tablename__ = 'sub_graph'
23 |
24 | id = Column(INTEGER, primary_key=True, comment="entity id")
25 | name = Column(CHAR(256), default=None, index=True, unique=True, comment="entity name")
26 | pure_name = Column(CHAR(256), default=None, index=True, comment="entity name")
27 | sub_graph_ids = Column(TEXT, default={}, comment="子图")
28 | sub_graph_names = Column(TEXT, default={}, comment="子图")
29 | type = Column(SMALLINT, index=True, nullable=False, comment="entity:1, relation:2")
30 |
31 |
32 | class Entity2id(BaseModel):
33 | __tablename__ = 'entity2id'
34 |
35 | entity_id = Column(INTEGER, default=None, primary_key=True, comment="entity id")
36 | entity_name = Column(CHAR(256), nullable=False, index=True, unique=True, comment="entity name")
37 | pure_name = Column(CHAR(256), nullable=False, index=True, comment="entity name")
38 |
--------------------------------------------------------------------------------
/ckbqa/dao/sqlite_utils.py:
--------------------------------------------------------------------------------
1 | from .db import DB
2 | from .db_tools import try_commit_rollback, try_commit_rollback_expunge
3 | from .sqlite_models import Entity2id, SubGraph
4 | from ..utils.decorators import singleton
5 |
6 |
7 | @singleton
8 | class SqliteDB(DB):
9 |
10 | @try_commit_rollback
11 | def get_id_by_entity_name(self, entity_name, pure_name):
12 | sql = 'select entity_name from '
13 | ws_video = self.select(sql)
14 | return ws_video
15 |
16 | @try_commit_rollback_expunge
17 | def get_subGraph_by_entity_ids(self, entity_ids):
18 | # sub_graph = self.session.query(SubGraph).filter_by(id=entity_id).first()
19 | sub_graph = self.session.query(SubGraph).filter(SubGraph.id.in_(entity_ids))
20 | return sub_graph
21 |
--------------------------------------------------------------------------------
/ckbqa/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/dataset/__init__.py
--------------------------------------------------------------------------------
/ckbqa/dataset/data_prepare.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 | from ckbqa.utils.tools import json_dump
7 | from config import DataConfig, raw_train_txt
8 |
9 | entity_pattern = re.compile(r'(<.*?>)') # 保留<>
10 | attr_pattern = re.compile(r'"(.*?)"') # 不保留"", "在json key中不便保存
11 | question_patten = re.compile(r'q\d{1,4}:(.*)')
12 |
13 | PAD = 0
14 | UNK = 1
15 |
16 |
17 | def load_data(tqdm_prefix=''):
18 | with open(raw_train_txt, 'r', encoding='utf-8') as f:
19 | lines = [line.strip() for line in f.readlines() if line.strip()]
20 | for q, sparql, a in tqdm([lines[i:i + 3] for i in range(0, len(lines), 3)],
21 | desc=tqdm_prefix):
22 | yield q, sparql, a
23 |
24 |
25 | def fit_on_texts():
26 | q_entities_set = set()
27 | a_entities_set = set()
28 | words_set = set()
29 | for q, sparql, a in load_data():
30 | a_entities = entity_pattern.findall(a)
31 | q_entities = entity_pattern.findall(sparql)
32 | q_text = question_patten.findall(q)
33 | entities = a_entities + q_entities
34 | words = list(q_text[0]) + [e for ent in entities for e in ent]
35 |
36 | q_entities_set.update(q_entities)
37 | a_entities_set.update(a_entities)
38 | words_set.update(words)
39 | word2id = {"PAD": PAD, "UNK": UNK}
40 | for w in sorted(words_set):
41 | word2id[w] = len(word2id)
42 | json_dump(word2id, DataConfig.word2id_json)
43 | # question entity
44 | q_entity2id = {"PAD": PAD, "UNK": UNK}
45 | for e in sorted(q_entities_set):
46 | q_entity2id[e] = len(q_entity2id)
47 | json_dump(q_entity2id, DataConfig.q_entity2id_json)
48 | # answer entity
49 | a_entity2id = {"PAD": PAD, "UNK": UNK}
50 | for e in sorted(a_entities_set):
51 | a_entity2id[e] = len(a_entity2id)
52 | json_dump(a_entity2id, DataConfig.a_entity2id_json)
53 |
54 |
55 | def data_convert():
56 | """转化原始数据格式, 方便采样entity"""
57 | data = {'question': [], 'q_entities': [], 'q_strs': [],
58 | 'a_entities': [], 'a_strs': []}
59 | for q, sparql, a in load_data():
60 | q_text = question_patten.findall(q)[0]
61 | q_entities = entity_pattern.findall(sparql)
62 | q_strs = attr_pattern.findall(sparql)
63 | a_entities = entity_pattern.findall(a)
64 | a_strs = attr_pattern.findall(a)
65 | data['question'].append(q_text)
66 | data['q_entities'].append(q_entities)
67 | data['q_strs'].append(q_strs)
68 | data['a_entities'].append(a_entities)
69 | data['a_strs'].append(a_strs)
70 | pd.DataFrame(data).to_csv(DataConfig.data_csv, encoding='utf_8_sig', index=False)
71 |
--------------------------------------------------------------------------------
/ckbqa/dataset/kb_data_prepare.py:
--------------------------------------------------------------------------------
1 | """
2 | 与知识图谱相关的数据与处理
3 | """
4 | import gc
5 | import logging
6 | import os
7 | import re
8 | from collections import defaultdict, Counter
9 | from pathlib import Path
10 |
11 | import pandas as pd
12 | from tqdm import tqdm
13 |
14 | from ckbqa.utils.tools import json_dump, get_file_linenums, pkl_dump, json_load, tqdm_iter_file
15 | from config import Config, mention2ent_txt, kb_triples_txt
16 |
17 |
18 | def iter_triples(tqdm_prefix=''):
19 | """迭代图谱的三元组;"""
20 | rdf_patten = re.compile(r'(["<].*?[>"])')
21 | ent_partten = re.compile('<(.*?)>')
22 | attr_partten = re.compile('"(.*?)"')
23 |
24 | def parse_entities(string):
25 | ents = ent_partten.findall(string)
26 | if ents: # 加上括号,使用上面会去掉括号
27 | ents = [f'<{_ent.strip()}>' for _ent in ents if _ent]
28 | else: # 去掉引号
29 | ents = [_ent.strip() for _ent in attr_partten.findall(string) if _ent]
30 | return ents
31 |
32 | line_num = get_file_linenums(kb_triples_txt)
33 | # f_record = open('fail.txt', 'w', encoding='utf-8')
34 | with open(kb_triples_txt, 'r', encoding='utf-8') as f:
35 | desc = Path(kb_triples_txt).name + '-' + 'iter_triples'
36 | for line in tqdm(f, total=line_num, desc=desc):
37 | # for idx, line in enumerate(f):
38 | # if idx < 6618182:
39 | # continue
40 | # 有些数据不以\t分割;有些数据连续两个<><>作为头实体;
41 | triples = line.rstrip('.\n').split('\t')
42 | if len(triples) != 3:
43 | triples = line.rstrip('.\n').split()
44 | if len(triples) != 3:
45 | triples = rdf_patten.findall(line)
46 | if len(triples) == 6:
47 | head_ent, rel, tail_ent = triples[:3]
48 | yield parse_entities(head_ent)[0], parse_entities(rel)[0], parse_entities(tail_ent)[0]
49 | head_ent, rel, tail_ent = triples[3:]
50 | yield parse_entities(head_ent)[0], parse_entities(rel)[0], parse_entities(tail_ent)[0]
51 | continue
52 | head_ent, rel, tail_ent = triples
53 | #
54 | head_ents = parse_entities(head_ent)
55 | tail_ents = parse_entities(tail_ent)
56 | rels = parse_entities(rel)
57 | for head_ent in head_ents:
58 | for rel in rels:
59 | for tail_ent in tail_ents:
60 | yield head_ent, rel, tail_ent
61 |
62 |
63 | def fit_triples():
64 | """entity,relation map to id;
65 | 实体到id的映射词典
66 | """
67 | logging.info('fit_triples start ...')
68 | entities = set()
69 | relations = set()
70 | for head_ent, rel, tail_ent in iter_triples(tqdm_prefix='fit_triples '):
71 | entities.add(head_ent)
72 | entities.add(tail_ent)
73 | relations.add(rel)
74 | logging.info('entity2id start ...')
75 | entity2id = {ent: idx + 1 for idx, ent in enumerate(sorted(entities))}
76 | json_dump(entity2id, Config.entity2id)
77 | id2entity = {id: ent for ent, id in entity2id.items()}
78 | pkl_dump(id2entity, Config.id2entity_pkl)
79 | logging.info('relation2id start ...')
80 | relation2id = {rel: idx + 1 for idx, rel in enumerate(sorted(relations))}
81 | json_dump(relation2id, Config.relation2id)
82 | id2relation = {id: rel for rel, id in relation2id.items()}
83 | pkl_dump(id2relation, Config.id2relation_pkl)
84 |
85 |
86 | def map_mention_entity():
87 | """mention2ent; 4 min; mention到实体的映射"""
88 | logging.info('map_mention_entity start ...')
89 | ent2mention = defaultdict(set)
90 | mention2ent = defaultdict(set)
91 | # count = 0
92 | for line in tqdm_iter_file(mention2ent_txt, prefix='iter mention2ent.txt '):
93 | mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串
94 | mention2ent[mention].add(ent)
95 | ent2mention[ent].add(mention)
96 | # if count == 13930117:
97 | # print(line)
98 | # import ipdb
99 | # ipdb.set_trace()
100 | # count += 1
101 | #
102 | mention2ent = {k: sorted(v) for k, v in mention2ent.items()}
103 | json_dump(mention2ent, Config.mention2ent_json)
104 | ent2mention = {k: sorted(v) for k, v in ent2mention.items()}
105 | json_dump(ent2mention, Config.ent2mention_json)
106 |
107 |
108 | def candidate_words():
109 | """实体的属性和类型映射字典;作为尾实体候选; 7min"""
110 | logging.info('candidate_words gen start ...')
111 | ent2types_dict = defaultdict(set)
112 | ent2attrs_dict = defaultdict(set)
113 | all_attrs = set()
114 | for head_ent, rel, tail_ent in iter_triples(tqdm_prefix='candidate_words '):
115 | if rel == '<类型>':
116 | ent2types_dict[head_ent].add(tail_ent)
117 | elif not head_ent.startswith('<'):
118 | ent2attrs_dict[tail_ent].add(head_ent)
119 | all_attrs.add(head_ent)
120 | elif not tail_ent.startswith('<'):
121 | ent2attrs_dict[head_ent].add(tail_ent)
122 | all_attrs.add(tail_ent)
123 | json_dump({ent: sorted(_types) for ent, _types in ent2types_dict.items()},
124 | Config.entity2types_json)
125 | json_dump({ent: sorted(attrs) for ent, attrs in ent2attrs_dict.items()},
126 | Config.entity2attrs_json)
127 | json_dump(list(all_attrs), Config.all_attrs_json)
128 |
129 |
130 | def _get_top_counter():
131 | """26G,高频实体和mention,作为后期筛选和lac字典;
132 | 统计实体和mention出现的次数; 方便取top作为最终自定义分词的词典;
133 | """
134 | logging.info('kb_count_top_dict start ...')
135 | if not (os.path.isfile(Config.entity2count_json) and
136 | os.path.isfile(Config.relation2count_json)):
137 | entities = []
138 | relations = []
139 | for head_ent, rel, tail_ent in iter_triples(tqdm_prefix='kb_top_count_dict '):
140 | entities.extend([head_ent, tail_ent])
141 | relations.append(rel)
142 | ent_counter = Counter(entities)
143 | json_dump(dict(ent_counter), Config.entity2count_json)
144 | del entities
145 | rel_counter = Counter(relations)
146 | del relations
147 | json_dump(dict(rel_counter), Config.relation2count_json)
148 | else:
149 | ent_counter = Counter(json_load(Config.entity2count_json))
150 | rel_counter = Counter(json_load(Config.relation2count_json))
151 | #
152 | if not os.path.isfile(Config.mention2count_json):
153 | mentions = []
154 | for line in tqdm_iter_file(mention2ent_txt, prefix='count_top_dict iter mention2ent.txt '):
155 | mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串
156 | mentions.append(mention)
157 | mention_counter = Counter(mentions)
158 | del mentions
159 | json_dump(dict(mention_counter), Config.mention2count_json)
160 | else:
161 | mention_counter = Counter(json_load(Config.mention2count_json))
162 | return ent_counter, rel_counter, mention_counter
163 |
164 |
165 | def create_lac_custom_dict():
166 | """生成自定义分词词典"""
167 | logging.info('create_lac_custom_dict start...')
168 | ent_counter, rel_counter, mention_counter = _get_top_counter()
169 | mention_count = mention_counter.most_common(50 * 10000) #
170 | customization_dict = {mention: 'MENTION' for mention, count in mention_count
171 | if 2 <= len(mention) <= 8}
172 | del mention_count, mention_counter
173 | logging.info('create ent&rel customization_dict ...')
174 | _ent_pattrn = re.compile(r'["<](.*?)[>"]')
175 | # customization_dict.update({' '.join(_ent_pattrn.findall(rel)): 'REL'
176 | # for rel, count in rel_counter.most_common(10000) # 10万
177 | # if 2 <= len(rel) <= 8})
178 | # del rel_counter
179 | customization_dict.update({' '.join(_ent_pattrn.findall(ent)): 'ENT'
180 | for ent, count in ent_counter.most_common(50 * 10000) # 100万
181 | if 2 <= len(ent) <= 8})
182 | q_entity2id = json_load(Config.q_entity2id_json)
183 | q_entity2id.update(json_load(Config.a_entity2id_json))
184 | customization_dict.update({' '.join(_ent_pattrn.findall(ent)): 'ENT'
185 | for ent, _id in q_entity2id.items()})
186 | del ent_counter
187 | with open(Config.lac_custom_dict_txt, 'w') as f:
188 | for e, t in customization_dict.items():
189 | if len(e) >= 3:
190 | f.write(f"{e}/{t}\n")
191 | logging.info('attr_custom_dict gen start ...')
192 | entity2attrs = json_load(Config.entity2attrs_json)
193 | all_attrs = set()
194 | for attrs in entity2attrs.values():
195 | all_attrs.update(attrs)
196 | name_patten = re.compile('"(.*?)"')
197 | with open(Config.lac_attr_custom_dict_txt, 'w') as f:
198 | for _attr in all_attrs:
199 | attr = ' '.join(name_patten.findall(_attr))
200 | if len(attr) >= 2:
201 | f.write(f"{attr}/ATTR\n")
202 |
203 |
204 | def create_graph_csv():
205 | """ 生成数据库导入文件
206 | cd /home/wangshengguang/neo4j-community-3.4.5/bin/
207 | ./neo4j-admin import --database=graph.db --nodes /home/wangshengguang/ccks-2020/data/graph_entity.csv --relationships /home/wangshengguang/ccks-2020/data/graph_relation2.csv --ignore-duplicate-nodes=true --id-type INTEGER --ignore-missing-nodes=true
208 | CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.id IS UNIQUE;
209 | CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.name IS UNIQUE;
210 | CREATE CONSTRAINT ON (r:Relation) ASSERT r.name IS UNIQUE;
211 | """
212 | logging.info('start load Config.id2entity ..')
213 | entity2id = json_load(Config.entity2id)
214 | pd.DataFrame.from_records(
215 | [(id, ent, 'Entity') for ent, id in entity2id.items()],
216 | columns=["id:ID(Entity)", "name:String", ":LABEL"]).to_csv(
217 | Config.graph_entity_csv, index=False, encoding='utf_8_sig')
218 | #
219 | records = [[entity2id[head_name], entity2id[tail_name], 'Relation', rel_name]
220 | for head_name, rel_name, tail_name in iter_triples(tqdm_prefix='gen relation csv')]
221 | del entity2id
222 | gc.collect()
223 | pd.DataFrame.from_records(
224 | records, columns=[":START_ID(Entity)", ":END_ID(Entity)", ":TYPE", "name:String"]).to_csv(
225 | Config.graph_relation_csv, index=False, encoding='utf_8_sig')
226 |
227 | # def kge_data():
228 | # """17.3G,生成KGE训练数据"""
229 | # logging.info('kge_data start ...')
230 | # kge_data_dir = Path(Config.entity2id).parent.joinpath('ccks-2020')
231 | # kge_data_dir.mkdir(exist_ok=True)
232 | #
233 | # def data_df2file(data_df, save_path):
234 | # with open(save_path, 'w') as f:
235 | # f.write(f'{data_df.shape[0]}\n')
236 | # data_df.to_csv(save_path, mode='a', header=False, index=False, sep='\t')
237 | #
238 | # #
239 | # ent_most = set([ent for ent, count in Counter(json_load(Config.entity2count_json)).most_common(1000000)])
240 | # entity2id = {ent: id + 1 for id, ent in enumerate(sorted(ent_most))}
241 | # data_df2file(data_df=pd.DataFrame([(ent, idx) for ent, idx in entity2id.items()]),
242 | # save_path=kge_data_dir.joinpath('entity2id.txt'))
243 | # #
244 | # rel_most = set([rel for rel, _count in Counter(json_load(Config.relation2count_json)).most_common(1 * 200000)])
245 | # relation2id = {rel: id + 1 for id, rel in enumerate(sorted(rel_most))}
246 | # data_df2file(data_df=pd.DataFrame([(rel, idx) for rel, idx in relation2id.items()]),
247 | # save_path=kge_data_dir.joinpath('relation2id.txt'))
248 | # # 训练集、验证集、测试集
249 | # logging.info('kge train data start ...')
250 | # # head,rel,tail -> head,tail,rel #HACK 交换了顺序,为了配合KGE模型的训练数据需要
251 | # data_df = pd.DataFrame([(entity2id.get(head_ent, 0), entity2id.get(tail_ent, 0),
252 | # relation2id.get(rel, 0))
253 | # for head_ent, rel, tail_ent in iter_triples()])
254 | # del entity2id, relation2id
255 | # train_df, test_df = train_test_split(data_df, test_size=0.01)
256 | # data_df2file(test_df, save_path=kge_data_dir.joinpath('test2id.txt'))
257 | # del test_df
258 | # train_df, valid_df = train_test_split(train_df, test_size=0.01)
259 | # data_df2file(valid_df, save_path=kge_data_dir.joinpath('valid2id.txt'))
260 | # del valid_df
261 | # data_df2file(train_df, save_path=kge_data_dir.joinpath('train2id.txt'))
262 | # del train_df
263 | # #
264 | # src_dir = str(kge_data_dir)
265 | # dst_dir = str(
266 | # Path(Config.entity2id).parent.parent.parent.joinpath('KE').joinpath('benchmarks').joinpath('ccks-2020'))
267 | # logging.info(f'move file: {src_dir} to {dst_dir}')
268 | # shutil.copytree(src=src_dir, dst=dst_dir) # directory
269 |
--------------------------------------------------------------------------------
/ckbqa/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/layers/__init__.py
--------------------------------------------------------------------------------
/ckbqa/layers/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ContrastiveLoss(torch.nn.Module):
5 | """
6 | Contrastive loss function.
7 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
8 | """
9 |
10 | def __init__(self, margin=1.0):
11 | super(ContrastiveLoss, self).__init__()
12 | self.margin = margin
13 |
14 | def forward(self, euclidean_distance, label):
15 | loss_contrastive = torch.mean(label * torch.pow(euclidean_distance, 2) +
16 | (1 - label) * torch.pow(torch.relu(self.margin - euclidean_distance), 2)) / 2
17 |
18 | return loss_contrastive
19 |
20 |
--------------------------------------------------------------------------------
/ckbqa/layers/modules.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | import torch
4 |
5 |
6 | class TextCNN(nn.Module):
7 | def __init__(self, embed_dim, out_dim):
8 | super(TextCNN).__init__()
9 | input_chanel = 1
10 | kernel_num = 1
11 | filter_sizes = [2, 3, 4]
12 | self.convs = [nn.Conv2d(input_chanel, kernel_num, (kernel_size, embed_dim))
13 | for kernel_size in filter_sizes]
14 | # self.drop_out = nn.Dropout(0.8)
15 | self.fc = nn.Linear(len(filter_sizes) * kernel_num, out_dim)
16 |
17 | def forward(self, x):
18 | xs = []
19 | for i in range(len(self.convs)):
20 | x = self.convs[i](x)
21 | xs.append(x)
22 | x = torch.cat(xs, dim=1)
23 | x = self.fc(x)
24 | return x
25 |
--------------------------------------------------------------------------------
/ckbqa/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/models/__init__.py
--------------------------------------------------------------------------------
/ckbqa/models/base_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.optim import Adam
4 | from torch.optim.lr_scheduler import LambdaLR
5 |
6 | from config import Config
7 |
8 |
9 | class BaseTrainer(object):
10 | def __init__(self, model_name):
11 | self.global_step = 0
12 | self.device = Config.device
13 | self.model_name = model_name
14 |
15 | def backfoward(self, loss, model):
16 | if Config.gpu_nums > 1 and Config.multi_gpu:
17 | loss = loss.mean() # mean() to average on multi-gpu
18 | if Config.gradient_accumulation_steps > 1:
19 | loss = loss / Config.gradient_accumulation_steps
20 | # https://zhuanlan.zhihu.com/p/79887894
21 | loss.backward()
22 | self.global_step += 1
23 | if self.global_step % Config.gradient_accumulation_steps == 0:
24 | nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=Config.clip_grad)
25 | self.optimizer.step()
26 | self.optimizer.zero_grad()
27 | return loss
28 |
29 | def init_model(self, model):
30 | model.to(self.device) # without this there is no error, but it runs in CPU (instead of GPU).
31 | if Config.gpu_nums > 1 and Config.multi_gpu:
32 | model = torch.nn.DataParallel(model)
33 | self.optimizer = Adam(model.parameters(), lr=Config.learning_rate)
34 | self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))
35 | return model
36 |
--------------------------------------------------------------------------------
/ckbqa/models/data_helper.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from pytorch_transformers import BertTokenizer
5 |
6 | from ckbqa.utils.decorators import singleton
7 | from ckbqa.utils.sequence import pad_sequences
8 | from config import Config
9 |
10 |
11 | @singleton
12 | class DataHelper(object):
13 | def __init__(self, load_tokenizer=True):
14 | self.device = Config.device
15 | if load_tokenizer:
16 | self.load_tokenizer()
17 |
18 | def load_tokenizer(self):
19 | self.tokenizer = BertTokenizer.from_pretrained(
20 | Config.pretrained_model_name_or_path, do_lower_case=True)
21 |
22 | def batch_sent2tensor(self, sent_texts: List[str], pad=False):
23 | batch_token_ids = [self.sent2ids(sent_text) for sent_text in sent_texts]
24 | if pad:
25 | batch_token_ids = pad_sequences(batch_token_ids, maxlen=Config.max_len, padding='post')
26 | batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long).to(self.device)
27 | return batch_token_ids
28 |
29 | def sent2ids(self, sent_text):
30 | sent_tokens = ['[CLS]'] + self.tokenizer.tokenize(sent_text) + ["[SEP]"]
31 | token_ids = self.tokenizer.convert_tokens_to_ids(sent_tokens)
32 | return token_ids
33 |
34 | def data2tensor(self, batch_token_ids, pad=True, maxlen=Config.max_len):
35 | if pad:
36 | batch_token_ids = pad_sequences(batch_token_ids, maxlen=maxlen, padding='post')
37 | batch_token_ids = torch.tensor(batch_token_ids, dtype=torch.long).to(self.device)
38 | return batch_token_ids
39 |
--------------------------------------------------------------------------------
/ckbqa/models/entity_score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/models/entity_score/__init__.py
--------------------------------------------------------------------------------
/ckbqa/models/entity_score/model.py:
--------------------------------------------------------------------------------
1 | '''
2 | 从所有候选实体中筛选出topn个候选实体
3 | '''
4 |
5 | import os
6 |
7 | import numpy as np
8 | from sklearn import preprocessing
9 | from sklearn.linear_model import LogisticRegression
10 |
11 | from ckbqa.dataset.data_prepare import question_patten, entity_pattern, attr_pattern, load_data
12 | from ckbqa.utils.tools import pkl_dump, pkl_load
13 | from config import Config
14 |
15 |
16 | class EntityScore(object):
17 | """
18 | 主实体打分排序;
19 | """
20 |
21 | def __init__(self, load_pretrain_model=False):
22 | if load_pretrain_model:
23 | self.model: LogisticRegression = pkl_load(Config.entity_score_model_pkl)
24 |
25 | def gen_train_data(self):
26 | X_train = []
27 | Y_label = []
28 | from ckbqa.qa.el import EL # 避免循环导入
29 | el = EL()
30 | for q, sparql, a in load_data(tqdm_prefix='EntityScore train data '):
31 | # a_entities = entity_pattern.findall(a)
32 | q_entities = set(entity_pattern.findall(sparql) + attr_pattern.findall(sparql)) # attr
33 | q_text = question_patten.findall(q)[0]
34 | candidate_entities = el.el(q_text)
35 | for ent_name, feature_dict in candidate_entities.items():
36 | feature = feature_dict['feature']
37 | label = 1 if ent_name in q_entities else 0 # 候选实体有的不在答案中
38 | X_train.append(feature)
39 | Y_label.append(label)
40 | pkl_dump({'x_data': X_train, 'y_label': Y_label}, Config.entity_score_data_pkl)
41 |
42 | def train(self):
43 | if not os.path.isfile(Config.entity_score_data_pkl):
44 | self.gen_train_data()
45 | data = pkl_load(Config.entity_score_data_pkl)
46 | X_train = preprocessing.scale(np.array(data['x_data']))
47 | # Y_label = np.eye(2)[data['y_label']] # class_num==2
48 | Y_label = np.array(data['y_label']) # class_num==2
49 | print(f"X_train : {X_train.shape}, Y_label: {Y_label.shape}")
50 | print(f"sum(Y_label): {sum(Y_label)}")
51 | model = LogisticRegression(C=1e5, verbose=1)
52 | model.fit(X_train, Y_label)
53 | # model.predict()
54 | accuracy_score = model.score(X_train, Y_label)
55 | print(f"accuracy_score: {accuracy_score:.4f}")
56 | pkl_dump(model, Config.entity_score_model_pkl)
57 |
58 | def predict(self, features):
59 | preds = self.model.predict(np.asarray(features))
60 | return preds
61 |
--------------------------------------------------------------------------------
/ckbqa/models/evaluation_matrics.py:
--------------------------------------------------------------------------------
1 | def get_metrics(real_entities, pred_entities):
2 | if len(pred_entities) == 0:
3 | return 0, 0, 0
4 | real_entities_set = set(real_entities)
5 | pred_entities_set = set(pred_entities)
6 | TP = len(real_entities_set & pred_entities_set) # 负例 TP
7 | # FN = len(pred_entities - real_entities) # 预测出的负例, FN
8 | # FP = len(real_entities - pred_entities) # 没有预测出的正例
9 | # TP+FP = real_entities_set
10 | # TP+FN = pred_entities_set
11 | precision = TP / len(pred_entities)
12 | if TP:
13 | recall = 1
14 | else:
15 | recall = TP / len(real_entities_set)
16 | if (precision + recall):
17 | f1 = 2 * precision * recall / (precision + recall)
18 | else:
19 | f1 = 0
20 | return precision, recall, f1
21 |
--------------------------------------------------------------------------------
/ckbqa/models/ner/__init__.py:
--------------------------------------------------------------------------------
1 | """主实体识别"""
2 |
--------------------------------------------------------------------------------
/ckbqa/models/ner/crf.py:
--------------------------------------------------------------------------------
1 | """
2 | CRF Module
3 | """
4 |
5 | from typing import List, Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class CRF(nn.Module):
12 | """Conditional random field.
13 |
14 | This module implements a conditional random field [LMP01]_. The forward computation
15 | of this class computes the log likelihood of the given sequence of tags and
16 | emission score tensor. This class also has `~CRF.decode` method which finds
17 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
18 |
19 | Args:
20 | num_tags: Number of tags.
21 | batch_first: Whether the first dimension corresponds to the size of a minibatch.
22 |
23 | Attributes:
24 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
25 | ``(num_tags,)``.
26 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
27 | ``(num_tags,)``.
28 | transitions (`~torch.nn.Parameter`): Transition score tensor of size
29 | ``(num_tags, num_tags)``.
30 |
31 |
32 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
33 | "Conditional random fields: Probabilistic models for segmenting and
34 | labeling sequence data". *Proc. 18th International Conf. on Machine
35 | Learning*. Morgan Kaufmann. pp. 282–289.
36 |
37 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
38 | """
39 |
40 | def __init__(self, num_tags: int, batch_first: bool = False) -> None:
41 | if num_tags <= 0:
42 | raise ValueError(f'invalid number of tags: {num_tags}')
43 | super().__init__()
44 | self.num_tags = num_tags
45 | self.batch_first = batch_first
46 | self.start_transitions = nn.Parameter(torch.empty(num_tags))
47 | self.end_transitions = nn.Parameter(torch.empty(num_tags))
48 | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
49 |
50 | self.reset_parameters()
51 |
52 | def reset_parameters(self) -> None:
53 | """Initialize the transition parameters.
54 |
55 | The parameters will be initialized randomly from a uniform distribution
56 | between -0.1 and 0.1.
57 | """
58 | nn.init.uniform_(self.start_transitions, -0.1, 0.1)
59 | nn.init.uniform_(self.end_transitions, -0.1, 0.1)
60 | nn.init.uniform_(self.transitions, -0.1, 0.1)
61 |
62 | def __repr__(self) -> str:
63 | return f'{self.__class__.__name__}(num_tags={self.num_tags})'
64 |
65 | def forward(
66 | self,
67 | emissions: torch.Tensor,
68 | tags: torch.LongTensor,
69 | mask: Optional[torch.ByteTensor] = None,
70 | reduction: str = 'sum',
71 | ) -> torch.Tensor:
72 | """Compute the conditional log likelihood of a sequence of tags given emission scores.
73 |
74 | Args:
75 | emissions (`~torch.Tensor`): Emission score tensor of size
76 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
77 | ``(batch_size, seq_length, num_tags)`` otherwise.
78 | tags (`~torch.LongTensor`): Sequence of tags tensor of size
79 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
80 | ``(batch_size, seq_length)`` otherwise.
81 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
82 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
83 | reduction: Specifies the reduction to apply to the output:
84 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
85 | ``sum``: the output will be summed over batches. ``mean``: the output will be
86 | averaged over batches. ``token_mean``: the output will be averaged over tokens.
87 |
88 | Returns:
89 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
90 | reduction is ``none``, ``()`` otherwise.
91 | """
92 | self._validate(emissions, tags=tags, mask=mask)
93 | if reduction not in ('none', 'sum', 'mean', 'token_mean'):
94 | raise ValueError(f'invalid reduction: {reduction}')
95 | if mask is None:
96 | mask = torch.ones_like(tags, dtype=torch.uint8)
97 |
98 | if self.batch_first:
99 | emissions = emissions.transpose(0, 1)
100 | tags = tags.transpose(0, 1)
101 | mask = mask.transpose(0, 1)
102 |
103 | # shape: (batch_size,)
104 | numerator = self._compute_score(emissions, tags, mask)
105 | # shape: (batch_size,)
106 | denominator = self._compute_normalizer(emissions, mask)
107 | # shape: (batch_size,)
108 | llh = numerator - denominator
109 |
110 | if reduction == 'none':
111 | return llh
112 | if reduction == 'sum':
113 | return llh.sum()
114 | if reduction == 'mean':
115 | return llh.mean()
116 | assert reduction == 'token_mean'
117 | # return llh.sum() / mask.half().sum()
118 | return llh.sum() / mask.float().sum()
119 |
120 | def decode(self, emissions: torch.Tensor,
121 | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
122 | """Find the most likely tag sequence using Viterbi algorithm.
123 |
124 | Args:
125 | emissions (`~torch.Tensor`): Emission score tensor of size
126 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
127 | ``(batch_size, seq_length, num_tags)`` otherwise.
128 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
129 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
130 |
131 | Returns:
132 | List of list containing the best tag sequence for each batch.
133 | """
134 | self._validate(emissions, mask=mask)
135 | if mask is None:
136 | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
137 |
138 | if self.batch_first:
139 | emissions = emissions.transpose(0, 1)
140 | mask = mask.transpose(0, 1)
141 |
142 | return self._viterbi_decode(emissions, mask)
143 |
144 | def _validate(
145 | self,
146 | emissions: torch.Tensor,
147 | tags: Optional[torch.LongTensor] = None,
148 | mask: Optional[torch.ByteTensor] = None) -> None:
149 | if emissions.dim() != 3:
150 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
151 | if emissions.size(2) != self.num_tags:
152 | raise ValueError(
153 | f'expected last dimension of emissions is {self.num_tags}, '
154 | f'got {emissions.size(2)}')
155 |
156 | if tags is not None:
157 | if emissions.shape[:2] != tags.shape:
158 | raise ValueError(
159 | 'the first two dimensions of emissions and tags must match, '
160 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
161 |
162 | if mask is not None:
163 | if emissions.shape[:2] != mask.shape:
164 | raise ValueError(
165 | 'the first two dimensions of emissions and mask must match, '
166 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
167 | no_empty_seq = not self.batch_first and mask[0].all()
168 | no_empty_seq_bf = self.batch_first and mask[:, 0].all()
169 | if not no_empty_seq and not no_empty_seq_bf:
170 | raise ValueError('mask of the first timestep must all be on')
171 |
172 | def _compute_score(
173 | self, emissions: torch.Tensor, tags: torch.LongTensor,
174 | mask: torch.ByteTensor) -> torch.Tensor:
175 | # emissions: (seq_length, batch_size, num_tags)
176 | # tags: (seq_length, batch_size)
177 | # mask: (seq_length, batch_size)
178 | assert emissions.dim() == 3 and tags.dim() == 2
179 | assert emissions.shape[:2] == tags.shape
180 | assert emissions.size(2) == self.num_tags
181 | assert mask.shape == tags.shape
182 | assert mask[0].all()
183 |
184 | seq_length, batch_size = tags.shape
185 | # mask = mask.half()
186 | mask = mask.float()
187 |
188 | # Start transition score and first emission
189 | # shape: (batch_size,)
190 | score = self.start_transitions[tags[0]]
191 | score += emissions[0, torch.arange(batch_size), tags[0]]
192 |
193 | for i in range(1, seq_length):
194 | # Transition score to next tag, only added if next timestep is valid (mask == 1)
195 | # shape: (batch_size,)
196 | score += self.transitions[tags[i - 1], tags[i]] * mask[i]
197 | # Emission score for next tag, only added if next timestep is valid (mask == 1)
198 | # shape: (batch_size,)
199 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
200 |
201 | # End transition score
202 | # shape: (batch_size,)
203 | seq_ends = mask.long().sum(dim=0) - 1
204 | # shape: (batch_size,)
205 | last_tags = tags[seq_ends, torch.arange(batch_size)]
206 | # shape: (batch_size,)
207 | score += self.end_transitions[last_tags]
208 |
209 | return score
210 |
211 | def _compute_normalizer(
212 | self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
213 | # emissions: (seq_length, batch_size, num_tags)
214 | # mask: (seq_length, batch_size)
215 | assert emissions.dim() == 3 and mask.dim() == 2
216 | assert emissions.shape[:2] == mask.shape
217 | assert emissions.size(2) == self.num_tags
218 | assert mask[0].all()
219 |
220 | seq_length = emissions.size(0)
221 |
222 | # Start transition score and first emission; score has size of
223 | # (batch_size, num_tags) where for each batch, the j-th column stores
224 | # the score that the first timestep has tag j
225 | # shape: (batch_size, num_tags)
226 | score = self.start_transitions + emissions[0]
227 |
228 | for i in range(1, seq_length):
229 | # Broadcast score for every possible next tag
230 | # shape: (batch_size, num_tags, 1)
231 | broadcast_score = score.unsqueeze(2)
232 |
233 | # Broadcast emission score for every possible current tag
234 | # shape: (batch_size, 1, num_tags)
235 | broadcast_emissions = emissions[i].unsqueeze(1)
236 |
237 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where
238 | # for each sample, entry at row i and column j stores the sum of scores of all
239 | # possible tag sequences so far that end with transitioning from tag i to tag j
240 | # and emitting
241 | # shape: (batch_size, num_tags, num_tags)
242 | next_score = broadcast_score + self.transitions + broadcast_emissions
243 |
244 | # Sum over all possible current tags, but we're in score space, so a sum
245 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
246 | # all possible tag sequences so far, that end in tag i
247 | # shape: (batch_size, num_tags)
248 | next_score = torch.logsumexp(next_score, dim=1)
249 |
250 | # Set score to the next score if this timestep is valid (mask == 1)
251 | # shape: (batch_size, num_tags)
252 | score = torch.where(mask[i].unsqueeze(1), next_score, score)
253 |
254 | # End transition score
255 | # shape: (batch_size, num_tags)
256 | score += self.end_transitions
257 |
258 | # Sum (log-sum-exp) over all possible tags
259 | # shape: (batch_size,)
260 | return torch.logsumexp(score, dim=1)
261 |
262 | def _viterbi_decode(self, emissions: torch.Tensor,
263 | mask: torch.ByteTensor) -> List[List[int]]:
264 | # emissions: (seq_length, batch_size, num_tags)
265 | # mask: (seq_length, batch_size)
266 | assert emissions.dim() == 3 and mask.dim() == 2
267 | assert emissions.shape[:2] == mask.shape
268 | assert emissions.size(2) == self.num_tags
269 | assert mask[0].all()
270 |
271 | seq_length, batch_size = mask.shape
272 |
273 | # Start transition and first emission
274 | # shape: (batch_size, num_tags)
275 | score = self.start_transitions + emissions[0]
276 |
277 | # shape: (seq_length-1, batch_size, num_tags)
278 | history = torch.empty([seq_length - 1, batch_size, self.num_tags], dtype=torch.long,
279 | device=torch.device('cuda'))
280 |
281 | # score is a tensor of size (batch_size, num_tags) where for every batch,
282 | # value at column j stores the score of the best tag sequence so far that ends
283 | # with tag j
284 | # history saves where the best tags candidate transitioned from; this is used
285 | # when we trace back the best tag sequence
286 |
287 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence
288 | # for every possible next tag
289 | for i in range(1, seq_length):
290 | # Broadcast viterbi score for every possible next tag
291 | # shape: (batch_size, num_tags, 1)
292 | broadcast_score = score.unsqueeze(2)
293 |
294 | # Broadcast emission score for every possible current tag
295 | # shape: (batch_size, 1, num_tags)
296 | broadcast_emission = emissions[i].unsqueeze(1)
297 |
298 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where
299 | # for each sample, entry at row i and column j stores the score of the best
300 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting
301 | # shape: (batch_size, num_tags, num_tags)
302 | next_score = broadcast_score + self.transitions + broadcast_emission
303 |
304 | # Find the maximum score over all possible current tag
305 | # shape: (batch_size, num_tags), (batch_size, num_tags)
306 | next_score, indices = next_score.max(dim=1)
307 |
308 | # Set score to the next score if this timestep is valid (mask == 1)
309 | # and save the index that produces the next score
310 | # shape: (batch_size, num_tags)
311 | score = torch.where(mask[i].unsqueeze(1), next_score, score)
312 | # history.append(indices)
313 | history[i - 1] = indices
314 | # End transition score
315 | # shape: (batch_size, num_tags)
316 | score += self.end_transitions
317 |
318 | # Now, compute the best path for each sample
319 |
320 | # shape: (batch_size,)
321 | seq_ends = mask.long().sum(dim=0) - 1
322 | # best_tags_list = []
323 | best_tags_all = torch.empty(batch_size, seq_length, dtype=torch.long, device=torch.device('cuda'))
324 |
325 | best_tags = torch.empty(seq_length, dtype=torch.long, device=torch.device('cuda'))
326 | for idx in range(batch_size):
327 | # Find the tag which maximizes the score at the last timestep; this is our best tag
328 | # for the last timestep
329 | _, best_last_tag = score[idx].max(dim=0)
330 | best_tags[0] = best_last_tag
331 |
332 | # We trace back where the best last tag comes from, append that to our best tag
333 | # sequence, and trace it back again, and so on
334 | for i, hist in enumerate(history[:seq_ends[idx]].flip([0])):
335 | best_last_tag = hist[idx][best_tags[i]]
336 | best_tags[i + 1] = best_last_tag
337 |
338 | # Reverse the order because we start from the last timestep
339 | best_tags = best_tags.flip([0])
340 | best_tags_all[idx] = best_tags
341 |
342 | return best_tags_all
343 |
--------------------------------------------------------------------------------
/ckbqa/models/ner/model.py:
--------------------------------------------------------------------------------
1 | """BERT + CRF for named entity recognition"""
2 |
3 | import torch.nn as nn
4 | from pytorch_transformers import BertPreTrainedModel, BertModel
5 |
6 | from .crf import CRF
7 |
8 |
9 | class BERTCRF(BertPreTrainedModel):
10 | """BERT model for token-level classification.
11 | This module is composed of the BERT model with a linear layer on top of
12 | the full hidden state of the last layer.
13 |
14 | Args:
15 | config: a BertConfig class instance with the configuration to build a new model.
16 | num_labels: the number of classes for the classifier. Default = 2.
17 |
18 | Inputs:
19 | input_ids: a torch.LongTensor of shape (batch_size, seq_length)
20 | with the word token indices in the vocabulary.
21 | token_type_ids: an optional torch.LongTensor of shape (batch_size, seq_length) with the token
22 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
23 | a `sentence B` token.
24 | attention_mask: an optional torch.LongTensor of shape (batch_size, seq_length) with indices
25 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
26 | input sequence length in the current batch. It's the mask that we typically use for attention when
27 | a batch has varying length sentences.
28 | labels: labels for the classification output: torch.LongTensor of shape (batch_size, seq_length)
29 | with indices selected in [0, ..., num_labels].
30 |
31 | Returns:
32 | if labels is not None:
33 | Outputs the CrossEntropy classification loss of the output with the labels.
34 | if labels is None:
35 | Outputs the classification tag indices of shape (batch_size, seq_length).
36 | """
37 |
38 | def __init__(self, config, num_labels):
39 | super(BERTCRF, self).__init__(config)
40 | self.num_labels = num_labels # ent label num
41 | self.bert = BertModel(config)
42 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
43 | self.fc = nn.Linear(config.hidden_size, num_labels)
44 | self.crf = CRF(num_labels, batch_first=True)
45 | self.apply(self.init_bert_weights)
46 | # self.fc.weight.data.normal_(mean=0.0, std=0.02)
47 | # self.fc.bias.data.zero_()
48 |
49 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
50 | # shape: (batch_size, seq_length, hidden_size)
51 | outputs = self.bert(input_ids, token_type_ids, attention_mask)
52 | seq_output = outputs[0]
53 | # shape: (batch_size, seq_length, hidden_size)
54 | seq_output = self.dropout(seq_output)
55 | # shape: (batch_size, seq_length, num_labels)
56 | fc_output = self.fc(seq_output)
57 |
58 | if attention_mask is None:
59 | pred = self.crf.decode(fc_output)
60 | else:
61 | pred = self.crf.decode(fc_output, attention_mask)
62 |
63 | if labels is None:
64 | return pred
65 | else:
66 | if attention_mask is not None:
67 | # shape: (1,)
68 | loss = -self.crf(fc_output, labels, attention_mask, reduction='token_mean')
69 | else:
70 | loss = -self.crf(fc_output, labels, reduction='token_mean')
71 | return pred, loss
72 |
--------------------------------------------------------------------------------
/ckbqa/models/relation_score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/models/relation_score/__init__.py
--------------------------------------------------------------------------------
/ckbqa/models/relation_score/model.py:
--------------------------------------------------------------------------------
1 | """
2 | 主实体关联的关系打分,得到topn
3 | """
4 | import torch
5 | from pytorch_transformers import BertModel
6 | from torch import nn
7 |
8 | from config import Config
9 |
10 |
11 | class BertMatch(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 | self.bert = BertModel.from_pretrained(Config.pretrained_model_name_or_path)
15 | self.soft_max = nn.Softmax(dim=-1)
16 | self.fc = nn.Linear(768 * 3, 2)
17 | self.dropout = nn.Dropout(0.8)
18 | self.loss = nn.CrossEntropyLoss()
19 |
20 | def encode(self, x):
21 | outputs = self.bert(x)
22 | pooled_out = outputs[1]
23 | return pooled_out
24 |
25 | def forward(self, input1_ids, input2_ids, labels=None):
26 | # shape: (batch_size, seq_length, hidden_size)
27 | seq1_output = self.encode(input1_ids)
28 | seq2_output = self.encode(input2_ids)
29 | feature = torch.cat([seq1_output, seq2_output, seq1_output - seq2_output], dim=-1)
30 | logistic = self.fc(feature)
31 | if labels is None:
32 | output = self.soft_max(logistic)
33 | # pred = torch.argmax(output, dim=-1)
34 | return output
35 | else:
36 | loss = self.loss(logistic, labels)
37 | return logistic, loss
38 |
39 |
40 | class BertMatch2(nn.Module):
41 | def __init__(self):
42 | super().__init__()
43 | self.bert = BertModel.from_pretrained(Config.pretrained_model_name_or_path)
44 | self.soft_max = nn.Softmax(dim=-1)
45 | self.fc1 = nn.Linear(768 * 3, 128)
46 | self.fc2 = nn.Linear(128, 2)
47 | self.dropout1 = nn.Dropout(0.8)
48 | self.dropout2 = nn.Dropout(0.8)
49 | self.loss = nn.CrossEntropyLoss()
50 |
51 | def encode(self, x):
52 | outputs = self.bert(x)
53 | pooled_out = outputs[1]
54 | return pooled_out
55 |
56 | def forward(self, input1_ids, input2_ids, labels=None):
57 | # shape: (batch_size, seq_length, hidden_size)
58 | seq1_output = self.encode(input1_ids)
59 | seq2_output = self.encode(input2_ids)
60 | feature = torch.cat([seq1_output, seq2_output, seq1_output - seq2_output], dim=-1)
61 | feature1 = self.fc1(feature)
62 | logistic = self.fc2(feature1)
63 | if labels is None:
64 | output = self.soft_max(logistic)
65 | # pred = torch.argmax(output, dim=-1)
66 | return output
67 | else:
68 | loss = self.loss(logistic, labels)
69 | return logistic, loss
70 |
--------------------------------------------------------------------------------
/ckbqa/models/relation_score/predictor.py:
--------------------------------------------------------------------------------
1 | """现有模型封装;提供给预测使用"""
2 | import logging
3 | from typing import List
4 |
5 | from ckbqa.models.data_helper import DataHelper
6 | from ckbqa.models.relation_score.model import BertMatch, BertMatch2
7 | from ckbqa.utils.saver import Saver
8 |
9 |
10 | class RelationScorePredictor(object):
11 | def __init__(self, model_name):
12 | logging.info(f'BertMatchPredictor loaded sim_model init ...')
13 | self.data_helper = DataHelper(load_tokenizer=True)
14 | self.model = self.load_sim_model(model_name)
15 | logging.info(f'loaded sim_model: {model_name}')
16 |
17 | def load_sim_model(self, model_name: str):
18 | """文本相似度模型"""
19 | assert model_name in ['bert_match', 'bert_match2']
20 | saver = Saver(model_name=model_name)
21 | if model_name == 'bert_match':
22 | model = BertMatch()
23 | elif model_name == 'bert_match2':
24 | model = BertMatch2()
25 | else:
26 | raise ValueError()
27 | saver.load_model(model)
28 | return model
29 |
30 | def iter_sample(self, q_text, sim_texts):
31 | batch_size = 32
32 | q_ids = self.data_helper.sent2ids(q_text)
33 | batch_q_sent_token_ids = [q_ids for _ in range(min(batch_size, len(sim_texts)))]
34 | q_tensors = self.data_helper.data2tensor(batch_q_sent_token_ids)
35 | batch_sim_texts = [sim_texts[i:i + batch_size] for i in range(0, len(sim_texts), batch_size)]
36 | for batch_sim_text in batch_sim_texts:
37 | sim_sent_token_ids = [self.data_helper.sent2ids(sim_text) for sim_text in batch_sim_text]
38 | sim_tensors = self.data_helper.data2tensor(sim_sent_token_ids)
39 | yield q_tensors[:len(sim_tensors)], sim_tensors
40 |
41 | def predict(self, q_text: str, sim_texts: List[str]):
42 | all_preds = []
43 | for q_tensors, sim_tensors in self.iter_sample(q_text, sim_texts):
44 | preds = self.model(q_tensors, sim_tensors)
45 | preds = preds[:, 1]
46 | all_preds.extend(preds.tolist())
47 | return all_preds
48 |
--------------------------------------------------------------------------------
/ckbqa/models/relation_score/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import re
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from sklearn.model_selection import train_test_split
9 |
10 | from ckbqa.dataset.data_prepare import load_data, question_patten
11 | from ckbqa.models.base_trainer import BaseTrainer
12 | from ckbqa.models.data_helper import DataHelper
13 | from ckbqa.models.relation_score.model import BertMatch, BertMatch2
14 | from ckbqa.utils.saver import Saver
15 | from ckbqa.utils.tools import json_load
16 | from config import Config
17 |
18 |
19 | class RelationScoreTrainer(BaseTrainer):
20 | """关系打分模块训练"""
21 |
22 | def __init__(self, model_name):
23 | super().__init__(model_name)
24 | self.model_name = model_name
25 | self.data_helper = DataHelper()
26 | self.saver = Saver(model_name=model_name)
27 |
28 | def data2samples(self, neg_rate=3, test_size=0.1):
29 | if os.path.isfile(Config.get_relation_score_sample_csv('train', neg_rate)):
30 | return
31 | questions = []
32 | sim_questions = []
33 | labels = []
34 | all_relations = list(json_load(Config.relation2id))
35 | _entity_pattern = re.compile(r'["<](.*?)[>"]')
36 | for q, sparql, a in load_data(tqdm_prefix='data2samples '):
37 | q_text = question_patten.findall(q)[0]
38 | q_entities = _entity_pattern.findall(sparql)
39 | questions.append(q_text)
40 | sim_questions.append('的'.join(q_entities))
41 | labels.append(1)
42 | #
43 | for neg_relation in random.sample(all_relations, neg_rate):
44 | questions.append(q_text)
45 | neg_question = '的'.join(q_entities[:-1] + [neg_relation]) # 随机替换 <关系>
46 | sim_questions.append(neg_question)
47 | labels.append(0)
48 | data_df = pd.DataFrame({'question': questions, 'sim_question': sim_questions, 'label': labels})
49 | data_df.to_csv(Config.relation_score_sample_csv, encoding='utf_8_sig', index=False)
50 | train_df, test_df = train_test_split(data_df, test_size=test_size)
51 | test_df.to_csv(Config.get_relation_score_sample_csv('test', neg_rate), encoding='utf_8_sig', index=False)
52 | train_df.to_csv(Config.get_relation_score_sample_csv('train', neg_rate), encoding='utf_8_sig', index=False)
53 |
54 | def batch_iter(self, data_type, batch_size, _shuffle=True, fixed_seq_len=None):
55 | file_path = Config.get_relation_score_sample_csv(data_type=data_type, neg_rate=3)
56 | logging.info(f'* load data from {file_path}')
57 | data_df = pd.read_csv(file_path)
58 | x_data = data_df['question'].apply(self.data_helper.sent2ids)
59 | y_data = data_df['sim_question'].apply(self.data_helper.sent2ids)
60 | labels = data_df['label']
61 | data_size = len(x_data)
62 | order = list(range(data_size))
63 | if _shuffle:
64 | np.random.shuffle(order)
65 | for batch_step in range(data_size // batch_size + 1):
66 | batch_idxs = order[batch_step * batch_size:(batch_step + 1) * batch_size]
67 | if len(batch_idxs) != batch_size: # batch size 不可过大; 不足batch_size的数据丢弃(最后一batch)
68 | continue
69 | q_sents = self.data_helper.data2tensor([x_data[idx] for idx in batch_idxs])
70 | a_sents = self.data_helper.data2tensor([y_data[idx] for idx in batch_idxs])
71 | batch_labels = self.data_helper.data2tensor([labels[idx] for idx in batch_idxs], pad=False)
72 | yield q_sents, a_sents, batch_labels
73 |
74 | def train_match_model(self, mode='train'):
75 | """
76 | :mode: train,test
77 | """
78 | self.data2samples(neg_rate=3, test_size=0.1)
79 | if self.model_name == 'bert_match': # 单层
80 | model = BertMatch()
81 | elif self.model_name == 'bert_match2': # 双层
82 | model = BertMatch2()
83 | else:
84 | raise ValueError()
85 | for para in model.bert.parameters():
86 | para.requires_grad = False # bert参数不训练
87 | model = self.init_model(model)
88 | model_path, epoch, step = self.saver.load_model(model, fail_ok=True)
89 | self.global_step = step
90 | for q_sents, a_sents, batch_labels in self.batch_iter(
91 | data_type='train', batch_size=32):
92 | pred, loss = model(q_sents, a_sents, batch_labels)
93 | # print(f' {str(datetime.datetime.now())[:19]} global_step: {self.global_step}, loss:{loss.item():.04f}')
94 | logging.info(f' global_step: {self.global_step}, loss:{loss.item():.04f}')
95 | self.backfoward(loss, model)
96 | if self.global_step % 100 == 0:
97 | model_path = self.saver.save(model, epoch=1, step=self.global_step)
98 | logging.info(f'save to {model_path}')
99 | # if mode == 'test':
100 | # output = torch.softmax(pred, dim=-1)
101 | # pred = torch.argmax(output, dim=-1)
102 | # print('labels: ', batch_labels)
103 | # print('pred: ', pred.tolist())
104 | # print('--' * 10 + '\n\n')
105 |
--------------------------------------------------------------------------------
/ckbqa/qa/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 不做模型训练
3 |
4 | 将其他模型训练完后组合应用到这里,提取答案
5 | """
6 |
--------------------------------------------------------------------------------
/ckbqa/qa/algorithms.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from typing import Set
4 |
5 | legal_char_partten = re.compile(u"([^\u4e00-\u9fa5\u0030-\u0039\u0041-\u005a\u0061-\u007a])") # 中英文和数字
6 |
7 |
8 | def sequences_set_similar(s1: Set, s2: Set):
9 | overlap = len(s1 & s2)
10 | jaccard = overlap / len(s1 | s2) # 集合距离
11 | # 词向量相似度
12 | # wordvecsim = model.similarity(''.join(s1),''.join(s2))
13 | return overlap, jaccard
14 |
15 |
16 | class Algorithms(object):
17 | def __init__(self):
18 | self.invalid_names = {'是什么', ''} # 汽车的动力是什么
19 |
20 | def get_most_overlap_path(self, q_text, paths):
21 | # 从排名前几的tuples里选择与问题overlap最多的
22 | max_score = 0
23 | top_path = []
24 | q_words = set(q_text)
25 | pure_qtext = legal_char_partten.sub('', q_text)
26 | for path in paths:
27 | path_words = set()
28 | score5 = 0
29 | for ent_rel in path:
30 | if ent_rel.startswith('<'):
31 | ent_rel = ent_rel[1:-1].split('_')[0]
32 | path_words.update(set(ent_rel))
33 | pure_ent = legal_char_partten.sub('', ent_rel)
34 | if pure_ent not in self.invalid_names and pure_ent in pure_qtext:
35 | score5 += 1
36 | # print(pure_ent)
37 | common_words = path_words & q_words # 交集
38 | score1 = len(common_words) # 主score,交集长度 ;权重排名:1>2>3>4
39 | score2 = 1 / len(path_words) # 候选词长度越短越好,权重4
40 | score3 = 1 / len(path) # 跳数越少越好,权重2
41 | score4 = 0.5 if path[-1].startswith('<') else 0 # 实体大于属性,实体加分,权重3;会造成越长分数越高
42 | score = score1 + score2 + score3 + score4 + score5
43 | # log_str = (f'score : {score:.3f}, score1 : {score1}, score2 : {score2:.3f}, '
44 | # f'score3 : {score3:.3f}, score4 : {score4}, score5: {score5}, '
45 | # f'path: {path}')
46 | # logging.info(log_str)
47 | # print(log_str)
48 | if score > max_score:
49 | top_path = path
50 | max_score = score
51 | # logging.info(f'score: {max_score}, top_path: {top_path}')
52 | return top_path, max_score
53 |
54 |
55 | if __name__ == '__main__':
56 | # q_text = '怎样的检查项目能对小儿多源性房性心动过速、急性肾>功能不全以及动静脉血管瘤做出检测?'
57 | # out_paths = [['<肾功能不全>', '<>来源>']]
58 | # in_paths = [['肾功能', '<检查项目>', '<标签>']]
59 | # q_text = '演员梅艳芳有多高?'
60 | # out_paths = [['<梅艳芳>', '<身高>']] # 6.2
61 | # in_paths = [['<梅艳芳>', '<演员>']] # 8.2
62 | q_text = '叶文洁毕业于哪个大学?'
63 | out_paths = [['<叶文洁>', '<毕业院校>', '<学校代码>']] # 7.9333
64 | in_paths = [['<大学>', '<毕业于>', '<类型>']] # 7.9762
65 | algo = Algorithms()
66 | algo.get_most_overlap_path(q_text, out_paths)
67 | algo.get_most_overlap_path(q_text, in_paths)
68 |
--------------------------------------------------------------------------------
/ckbqa/qa/cache.py:
--------------------------------------------------------------------------------
1 | from ckbqa.utils.decorators import singleton
2 | from ckbqa.utils.tools import json_load
3 | from config import Config
4 |
5 |
6 | @singleton
7 | class Memory(object):
8 | def __init__(self):
9 | self.entity2id = json_load(Config.entity2id)
10 | self.mention2entity = json_load(Config.mention2ent_json)
11 | # self.all_attrs = set(json_load(Config.all_attrs_json))
12 |
13 | def get_entity_id(self, ent_name):
14 | if ent_name in self.entity2id:
15 | ent_id = self.entity2id[ent_name]
16 | else:
17 | # ent_name = ''.join(ent_patten.findall(ent_name))
18 | ent_id = self.entity2id.get(ent_name[1:-1], 0)
19 | return ent_id
20 |
--------------------------------------------------------------------------------
/ckbqa/qa/el.py:
--------------------------------------------------------------------------------
1 | """实体链接模块"""
2 | from collections import defaultdict
3 | from typing import List, Dict
4 |
5 | import numpy as np
6 |
7 | from ckbqa.models.entity_score.model import EntityScore
8 | from ckbqa.qa.algorithms import sequences_set_similar
9 | from ckbqa.qa.cache import Memory
10 | from ckbqa.qa.lac_tools import Ngram, BaiduLac
11 | from ckbqa.qa.neo4j_graph import GraphDB
12 | from ckbqa.utils.decorators import singleton
13 |
14 |
15 | @singleton
16 | class CEG(object):
17 | """ Candidate Entity Generation: NER
18 | 从mention出发,找到KB中所有可能的实体,组成候选实体集 (candidate entities);
19 | 最主流也最有效的方法就是Name Dictionary,说白了就是配别名
20 | """
21 |
22 | def __init__(self):
23 | # self.jieba = JiebaLac(load_custom_dict=True)
24 | self.lac = BaiduLac(mode='lac', _load_customization=True)
25 | self.memory = Memory()
26 | self.ngram = Ngram()
27 | self.mention_stop_words = {'是什么', '在哪里', '哪里', '什么', '提出的', '有什么', '国家', '哪个', '所在',
28 | '培养出', '为什么', '什么时候', '人', '你知道', '都包括', '是谁', '告诉我',
29 | '又叫做'}
30 | self.entity_stop_words = {'有', '是', '《', '》', '率', '·'}
31 |
32 | self.main_tags = {'n', 'an', 'nw', 'nz', 'PER', 'LOC', 'ORG', 'TIME',
33 | 'vn', 'v'}
34 |
35 | def get_ent2mention(self, q_text: str) -> dict:
36 | '''
37 | 标签 含义 标签 含义 标签 含义 标签 含义
38 | n 普通名词 f 方位名词 s 处所名词 nw 作品名
39 | nz 其他专名 v 普通动词 vd 动副词 vn 名动词
40 | a 形容词 ad 副形词 an 名形词 d 副词
41 | m 数量词 q 量词 r 代词 p 介词
42 | c 连词 u 助词 xc 其他虚词 w 标点符号
43 | PER 人名 LOC 地名 ORG 机构名 TIME 时间
44 | '''
45 | entity2mention = {}
46 | words, tags = self.lac.run(q_text)
47 | main_text = ''.join([word for word, tag in zip(words, tags) if tag in self.main_tags])
48 | main_text_ngrams = list(self.ngram.get_all_grams(main_text))
49 | q_text_grams = list(self.ngram.get_all_grams(q_text))
50 | for gram in set(q_text_grams + main_text_ngrams):
51 | if f"<{gram}>" in self.memory.entity2id:
52 | entity2mention[f'<{gram}>'] = gram
53 | if gram in self.memory.entity2id:
54 | entity2mention[gram] = gram
55 | if gram in self.memory.mention2entity and gram not in self.mention_stop_words:
56 | for ent in self.memory.mention2entity[gram]:
57 | if f"<{ent}>" in self.memory.entity2id:
58 | entity2mention[f'<{ent}>'] = gram
59 | if ent in self.memory.entity2id:
60 | entity2mention[ent] = gram
61 | # print(gram)
62 | return entity2mention
63 |
64 | def seg_text(self, text: str) -> List[str]:
65 | return self.lac.run(text)[0]
66 |
67 | # def _get_ent2mention(self, text: str) -> Dict:
68 | # entity2mention = {}
69 | # lac_words, tags = self.lac.run(text)
70 | # jieba_words = [w for w in self.jieba.cut_for_search(text)]
71 | # for word in set(lac_words + jieba_words):
72 | # # if word in self.entity_stop_words:
73 | # # continue
74 | # if word in self.memory.mention2entity and word not in self.mention_stop_words:
75 | # for ent in self.memory.mention2entity[word]:
76 | # if f"<{ent}>" in self.memory.entity2id:
77 | # entity2mention[f'<{ent}>'] = word
78 | # if ent in self.memory.entity2id:
79 | # entity2mention[ent] = word
80 | # if f"<{word}>" in self.memory.entity2id:
81 | # entity2mention[f'<{word}>'] = word
82 | # if word in self.memory.entity2id:
83 | # entity2mention[word] = word
84 | # return entity2mention
85 |
86 |
87 | @singleton
88 | class ED(object):
89 | """ Entity Disambiguation:
90 | 从candidate entities中,选择最可能的实体作为预测实体。
91 | """
92 |
93 | def __init__(self):
94 | self.memory = Memory()
95 | self.graph_db = GraphDB()
96 | self.ceg = CEG()
97 | self.mention_stop_words = {'是什么', '在哪里', '哪里', '什么', '提出的', '有什么', '国家', '哪个', '所在',
98 | '培养出', '为什么', '什么时候', '人', '你知道', '都包括', '是谁', '告诉我',
99 | '又叫做'}
100 | self.entity_stop_words = {'有', '是', '《', '》', '率', '·'}
101 | # self.entity_score = EntityScore()
102 | self.entity_score = EntityScore(load_pretrain_model=True)
103 |
104 | def ent_rel_similar(self, question: str, entity: str, relations: List) -> List:
105 | '''
106 | 抽取每个实体或属性值2hop内的所有关系,来跟问题计算各种相似度特征
107 | input:
108 | question: python-str
109 | entity: python-str
110 | relations: python-dic key:
111 | output:
112 | [word_overlap,char_overlap,word_embedding_similarity,char_overlap_ratio]
113 | '''
114 | # 得到主语-谓词的tokens及chars
115 | rel_tokens = set()
116 | for rel_name in set(relations):
117 | rel_tokens.update(self.ceg.seg_text(rel_name[1:-1]))
118 | rel_chars = set([char for rel_name in relations for char in rel_name])
119 | #
120 | question_tokens = set(self.ceg.seg_text(question))
121 | question_chars = set(question)
122 | #
123 | entity_tokens = set(self.ceg.seg_text(entity[1:-1]))
124 | entity_chars = set(entity[1:-1])
125 | #
126 | qestion_ent_sim = (sequences_set_similar(question_tokens, entity_tokens) +
127 | sequences_set_similar(question_chars, entity_chars))
128 | qestion_rel_sim = (sequences_set_similar(question_tokens, rel_tokens) +
129 | sequences_set_similar(question_chars, rel_chars))
130 | # 实体名和问题的overlap除以实体名长度的比例
131 | return qestion_ent_sim + qestion_rel_sim
132 |
133 | def get_entity_feature(self, q_text: str, ent_name: str):
134 | # 得到实体两跳内的所有关系
135 | one_hop_out_rel_names = self.graph_db.get_onehop_relations_by_entName(ent_name, direction='out')
136 | two_hop_out_rel_names = self.graph_db.get_twohop_relations_by_entName(ent_name, direction='out')
137 | out_rel_names = [rel for rels in two_hop_out_rel_names for rel in rels] + one_hop_out_rel_names
138 | #
139 | one_hop_in_rel_names = self.graph_db.get_onehop_relations_by_entName(ent_name, direction='in')
140 | two_hop_in_rel_names = self.graph_db.get_twohop_relations_by_entName(ent_name, direction='in')
141 | in_rel_names = [rel for rels in two_hop_in_rel_names for rel in rels] + one_hop_in_rel_names
142 | # 计算问题和主语实体及其两跳内关系间的相似度
143 | similar_feature = self.ent_rel_similar(q_text, ent_name, out_rel_names + in_rel_names)
144 | # 实体的流行度特征
145 | popular_feature = self.graph_db.get_onehop_relCount_by_entName(ent_name)
146 | return similar_feature, popular_feature
147 |
148 | def get_candidate_entities(self, q_text: str, entity2mention: dict) -> Dict:
149 | """
150 | :param q_text:
151 | :return: candidate_subject: { ent_name: {'mention':mention_txt,
152 | 'feature': [feature1, feature2, ...]
153 | },
154 | ...
155 | }
156 | """
157 | candidate_subject = defaultdict(dict) # {ent:[mention, feature1, feature2, ...]}
158 | #
159 | for entity, mention in entity2mention.items():
160 | similar_feature, popular_feature = self.get_entity_feature(q_text, entity)
161 | candidate_subject[entity]['mention'] = mention
162 | candidate_subject[entity]['feature'] = [sim for sim in similar_feature] + [popular_feature ** 0.5]
163 |
164 | # subject_score_topn打分做筛选
165 | top_subjects = self.subject_score_topn(candidate_subject)
166 | return top_subjects
167 |
168 | def subject_score_topn(self, candidate_entities: dict, top_k=10):
169 | '''
170 | :candidate_entities {ent: {'mention':mention},
171 | 'feature':[...],
172 | ent2:...,
173 | ...
174 | }
175 | 输入候选主语和对应的特征,使用训练好的模型进行打分,排序后返回前topn个候选主语
176 | '''
177 | top_k = len(candidate_entities) # TODO 后需删除;不做筛选,保留所有实体;目前筛选效果不好;
178 | if top_k >= len(candidate_entities): # 太少则不作筛选
179 | return candidate_entities
180 | entities = []
181 | features = []
182 | for ent, feature_dic in candidate_entities.items():
183 | entities.append(ent)
184 | features.append(feature_dic['feature'])
185 | scores = self.entity_score.predict(features)
186 | # print(f'scores: {scores}')
187 | top_k_indexs = np.asarray(scores).argsort()[-top_k:][::-1]
188 | top_k_entities = {entities[i]: candidate_entities[entities[i]]
189 | for i in top_k_indexs}
190 | return top_k_entities
191 |
192 |
193 | @singleton
194 | class EL(object):
195 | """实体链接"""
196 |
197 | def __init__(self):
198 | self.ceg = CEG() # Candidate Entity Generation
199 | self.ed = ED() # Entity Disambiguation
200 |
201 | def el(self, q_text: str) -> Dict:
202 | ent2mention = self.ceg.get_ent2mention(q_text)
203 | candidate_entities = self.ed.get_candidate_entities(q_text, ent2mention)
204 | return candidate_entities
205 |
206 | # try
207 | # except:
208 | # import traceback,ipdb
209 | # traceback.print_exc()
210 | # ipdb.set_trace()
211 |
--------------------------------------------------------------------------------
/ckbqa/qa/lac_tools.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | import os
4 |
5 | import jieba
6 | from LAC import LAC
7 | from LAC.ahocorasick import Ahocorasick
8 | from tqdm import tqdm
9 |
10 | from ckbqa.utils.decorators import singleton
11 | from ckbqa.utils.tools import tqdm_iter_file, pkl_dump, json_load, json_dump
12 | from config import Config
13 |
14 |
15 | class Ngram(object):
16 | def __init__(self):
17 | pass
18 |
19 | def ngram(self, text, n):
20 | ngrams = [text[i:i + n] for i in range(len(text) - n + 1)]
21 | return ngrams
22 |
23 | def get_all_grams(self, text):
24 | for n in range(2, len(text)):
25 | for ngram in self.ngram(text, n):
26 | yield ngram
27 | # for gram in ngrams:
28 | # yield gram
29 |
30 |
31 | @singleton
32 | class JiebaLac(object):
33 | def __init__(self, load_custom_dict=True):
34 | # jieba.enable_paddle() # 启动paddle模式。 0.40版之后开始支持,早期版本不支持
35 | if load_custom_dict:
36 | self.load_custom_dict()
37 |
38 | def load_custom_dict(self):
39 | if not os.path.isfile(Config.jieba_custom_dict):
40 | all_ent_mention_words = set(json_load(Config.mention2count_json).keys())
41 | entity_set = set(json_load(Config.entity2id).keys())
42 | for ent in tqdm(entity_set, total=len(entity_set), desc='create jieba words '):
43 | if ent.startswith('<'):
44 | ent = ent[1:-1]
45 | all_ent_mention_words.add(ent)
46 | # FREQ,total,
47 | # 模仿jieba.add_word,重写逻辑,加速
48 | # jieba.dt.FREQ = {}
49 | # jieba.dt.total = 0
50 | for word in tqdm(all_ent_mention_words, desc='jieba custom create '):
51 | freq = len(word) * 3
52 | jieba.dt.FREQ[word] = freq
53 | jieba.dt.total += freq
54 | for ch in range(len(word)):
55 | wfrag = word[:ch + 1]
56 | if wfrag not in jieba.dt.FREQ:
57 | jieba.dt.FREQ[wfrag] = 0
58 | del all_ent_mention_words
59 | gc.collect()
60 | json_dump({'dt.FREQ': jieba.dt.FREQ, 'dt.total': jieba.dt.total},
61 | Config.jieba_custom_dict)
62 | logging.info('create jieba custom dict done ...')
63 | # load
64 | jieba_custom = json_load(Config.jieba_custom_dict)
65 | jieba.dt.check_initialized()
66 | jieba.dt.FREQ = jieba_custom['dt.FREQ']
67 | jieba.dt.total = jieba_custom['dt.total']
68 | logging.info('load jieba custom dict done ...')
69 |
70 | def cut_for_search(self, text):
71 | return jieba.cut_for_search(text)
72 |
73 | def cut(self, text):
74 | return jieba.cut(text)
75 |
76 |
77 | @singleton
78 | class BaiduLac(LAC):
79 | def __init__(self, model_path=None, mode='lac', use_cuda=False,
80 | _load_customization=True, reload=False,
81 | customization_path=Config.lac_custom_dict_txt):
82 | super().__init__(model_path=model_path, mode=mode, use_cuda=use_cuda)
83 | self.mode = mode # lac,seg
84 | self.reload = reload
85 | self.customization_path = customization_path
86 | if _load_customization:
87 | self._load_customization()
88 |
89 | def _save_customization(self):
90 | # lac = LAC(mode=self.mode) # 装载LAC模型,LAC(mode='seg')
91 | # lac.load_customization(self.customization_path) # 35万2min;100万>20min;
92 | parms_dict = {'dictitem': self.custom.dictitem,
93 | # 'ac': self.custom.ac,
94 | 'mode': self.mode}
95 | pkl_dump(parms_dict, Config.lac_model_pkl)
96 | return Config.lac_model_pkl
97 |
98 | def _load_customization(self):
99 | logging.info('暂停载入BaiduLac自定词典 ...')
100 | return
101 | # self.custom = Customization()
102 | # if self.reload or not os.path.isfile(Config.lac_model_pkl):
103 | # self.custom.load_customization(self.customization_path) # 35万2min;100万,20min;
104 | # self._save_customization()
105 | # else:
106 | # logging.info('load lac customization start ...')
107 | # parms_dict = pkl_load(Config.lac_model_pkl)
108 | # self.custom.dictitem = parms_dict['dictitem'] # dict
109 | # self.custom.ac = Ahocorasick()
110 | # for phrase in tqdm(parms_dict['dictitem'], desc='load baidu lac '):
111 | # self.custom.ac.add_word(phrase)
112 | # self.custom.ac.make()
113 | # logging.info('loaded lac customization Done ...')
114 |
115 |
116 | class Customization(object):
117 | """
118 | 基于AC自动机实现用户干预的功能
119 | """
120 |
121 | def __init__(self):
122 | self.dictitem = {}
123 | self.ac = None
124 | pass
125 |
126 | def load_customization(self, filename):
127 | """装载人工干预词典"""
128 | self.ac = Ahocorasick()
129 |
130 | for line in tqdm_iter_file(filename, prefix='load_customization '):
131 | words = line.strip().split()
132 | if len(words) == 0:
133 | continue
134 |
135 | phrase = ""
136 | tags = []
137 | offset = []
138 | for word in words:
139 | if word.rfind('/') < 1:
140 | phrase += word
141 | tags.append('')
142 | else:
143 | phrase += word[:word.rfind('/')]
144 | tags.append(word[word.rfind('/') + 1:])
145 | offset.append(len(phrase))
146 |
147 | if len(phrase) < 2 and tags[0] == '':
148 | continue
149 |
150 | self.dictitem[phrase] = (tags, offset)
151 | self.ac.add_word(phrase)
152 | self.ac.make()
153 |
154 | def parse_customization(self, query, lac_tags):
155 | """使用人工干预词典修正lac模型的输出"""
156 |
157 | def ac_postpress(ac_res):
158 | ac_res.sort()
159 | i = 1
160 | while i < len(ac_res):
161 | if ac_res[i - 1][0] < ac_res[i][0] and ac_res[i][0] <= ac_res[i - 1][1]:
162 | ac_res.pop(i)
163 | continue
164 | i += 1
165 | return ac_res
166 |
167 | if not self.ac:
168 | logging.warning("customization dict is not load")
169 | return
170 |
171 | ac_res = self.ac.search(query)
172 |
173 | ac_res = ac_postpress(ac_res)
174 |
175 | for begin, end in ac_res:
176 | phrase = query[begin:end + 1]
177 | index = begin
178 |
179 | tags, offsets = self.dictitem[phrase]
180 | for tag, offset in zip(tags, offsets):
181 | while index < begin + offset:
182 | if len(tag) == 0:
183 | lac_tags[index] = lac_tags[index][:-1] + 'I'
184 | else:
185 | lac_tags[index] = tag + "-I"
186 | index += 1
187 |
188 | lac_tags[begin] = lac_tags[begin][:-1] + 'B'
189 | for offset in offsets:
190 | index = begin + offset
191 | if index < len(lac_tags):
192 | lac_tags[index] = lac_tags[index][:-1] + 'B'
193 |
--------------------------------------------------------------------------------
/ckbqa/qa/neo4j_graph.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import logging
3 | import os
4 | import time
5 |
6 | from py2neo import Graph
7 |
8 | from ckbqa.qa.cache import Memory
9 | from ckbqa.utils.async_tools import apply_async
10 | from ckbqa.utils.decorators import singleton, synchronized
11 | from ckbqa.utils.tools import json_load, json_dump
12 | from config import Config
13 |
14 | graph = Graph('http://localhost:7474', username='neo4j', password='password') # bolt://localhost:7687
15 |
16 |
17 | @singleton
18 | class GraphDB(object):
19 | def __init__(self):
20 | self._one_hop_relNames_map = {'in': {}, 'out': {}}
21 | self._two_hop_relNames_map = {'in': {}, 'out': {}}
22 | self.all_directions = {'in', 'out'}
23 | self.load_cache()
24 | self.memory = Memory()
25 | self.async_cache()
26 |
27 | def __del__(self):
28 | # time.sleep(30) # 等待另一个线程写入完成,避免交叉影响;使用线程锁替代
29 | self.cache() # 保证程序结束前会正确保存文件
30 |
31 | def get_total_entity_count(self):
32 | count = (len(self._one_hop_relNames_map['out']) +
33 | len(self._one_hop_relNames_map['in']) +
34 | len(self._two_hop_relNames_map['out']) +
35 | len(self._two_hop_relNames_map['in']))
36 | return count
37 |
38 | def load_cache(self):
39 | # self.queue = Queue(maxsize=1)
40 | if os.path.isfile(Config.neo4j_query_cache):
41 | data_dict = json_load(Config.neo4j_query_cache)
42 | self._one_hop_relNames_map = data_dict['_one_hop_relNames_map']
43 | self._two_hop_relNames_map = data_dict['_two_hop_relNames_map']
44 | self.total_count = self.get_total_entity_count()
45 | logging.info(f'load neo4j_query_cache, total: {self.total_count}')
46 | else:
47 | logging.info(f'not found neo4j_query_cache: {Config.neo4j_query_cache}')
48 | self.total_count = 0
49 |
50 | # def put_cache_sign(self):
51 | # self.queue.put(1) # 每次结束后检查
52 |
53 | def async_cache(self):
54 | def loop_check_cache():
55 | while True: # self.queue
56 | self.cache()
57 | time.sleep(10 * 60)
58 |
59 | thr = apply_async(loop_check_cache, daemon=True)
60 | return thr
61 |
62 | @synchronized # 避免多线程交叉影响
63 | def cache(self):
64 | total = self.get_total_entity_count()
65 | if total > self.total_count:
66 | parms = {'_one_hop_relNames_map': self._one_hop_relNames_map,
67 | '_two_hop_relNames_map': self._two_hop_relNames_map}
68 | json_dump(parms, Config.neo4j_query_cache)
69 | logging.info(f'neo4j_query_cache , increase: {total - self.total_count}, total: {total}')
70 | self.total_count = total
71 |
72 | def get_onehop_relations_by_entName(self, ent_name: str, direction='out'):
73 | assert direction in self.all_directions
74 | if ent_name not in self._one_hop_relNames_map[direction]:
75 | ent_id = self.memory.get_entity_id(ent_name)
76 | start, end = ('ent:Entity', '') if direction == 'out' else ('', 'ent:Entity')
77 | cpl = f"MATCH p=({start})-[r1:Relation]->({end}) WHERE ent.id={ent_id} RETURN DISTINCT r1.name"
78 | _one_hop_relNames = [rel.data()['r1.name'] for rel in graph.run(cpl)]
79 | self._one_hop_relNames_map[direction][ent_name] = _one_hop_relNames
80 | rel_names = self._one_hop_relNames_map[direction][ent_name]
81 | return rel_names
82 |
83 | def get_twohop_relations_by_entName(self, ent_name: str, direction='out'):
84 | assert direction in self.all_directions
85 | if ent_name not in self._two_hop_relNames_map[direction]:
86 | ent_id = self.memory.get_entity_id(ent_name)
87 | start, end = ('ent:Entity', '') if direction == 'out' else ('', 'ent:Entity')
88 | cql = f"MATCH p=({start})-[r1:Relation]->()-[r2:Relation]->({end}) WHERE ent.id={ent_id} RETURN DISTINCT r1.name,r2.name"
89 | rels = [rel.data() for rel in graph.run(cql)]
90 | rel_names = [[d['r1.name'], d['r2.name']] for d in rels]
91 | self._two_hop_relNames_map[direction][ent_name] = rel_names
92 | rel_names = self._two_hop_relNames_map[direction][ent_name]
93 | return rel_names
94 |
95 | def get_onehop_relCount_by_entName(self, ent_name: str):
96 | '''根据实体名,得到与之相连的关系数量,代表实体在知识库中的流行度'''
97 | if ent_name not in self._one_hop_relNames_map['out']:
98 | rel_names = self.get_onehop_relations_by_entName(ent_name, direction='out')
99 | self._one_hop_relNames_map['out'][ent_name] = rel_names
100 | if ent_name not in self._one_hop_relNames_map['in']:
101 | rel_names = self.get_onehop_relations_by_entName(ent_name, direction='in')
102 | self._one_hop_relNames_map['in'][ent_name] = rel_names
103 | count = len(self._one_hop_relNames_map['out'][ent_name]) + len(self._one_hop_relNames_map['in'][ent_name])
104 | return count
105 |
106 | def search_by_2path(self, ent_name, rel_name, direction='out'):
107 | ent_id = self.memory.get_entity_id(ent_name)
108 | assert direction in self.all_directions
109 | start, end = ('ent:Entity', 'target') if direction == 'out' else ('target', 'ent:Entity')
110 | cql = (f"match ({start})-[r1:Relation]-({end}) where ent.id={ent_id} "
111 | f" and r1.name='{rel_name}' return DISTINCT target.name")
112 | logging.info(f"{cql}; *ent_name:{ent_name}")
113 | res = graph.run(cql)
114 | target_name = [ent['target.name'] for ent in res]
115 | return target_name
116 |
117 | def search_by_3path(self, ent_name, rel1_name, rel2_name, direction='out'):
118 | ent_id = self.memory.get_entity_id(ent_name)
119 | assert direction in self.all_directions
120 | start, end = ('ent:Entity', 'target') if direction == 'out' else ('target', 'ent:Entity')
121 | cql = (f"match ({start})-[r1:Relation]-()-[r2:Relation]-({end}) where ent.id={ent_id} "
122 | f" and r1.name='{rel1_name}' and r2.name='{rel2_name}' return DISTINCT target.name")
123 | logging.info(f"{cql}; *ent_name:{ent_name}")
124 | res = graph.run(cql)
125 | target_name = [ent['target.name'] for ent in res]
126 | return target_name
127 |
128 | def search_by_4path(self, ent1_name, rel1_name, rel2_name, ent2_name, direction='out'):
129 | ent1_id = self.get_entity_id(ent1_name)
130 | ent2_id = self.get_entity_id(ent2_name)
131 | assert direction in self.all_directions
132 | start, end = ('ent:Entity', 'target') if direction == 'out' else ('target', 'ent:Entity')
133 | cql = (f"match ({start})-[r1:Relation]-({end})-[r2:Relation]-(ent2:Entity) "
134 | f" where ent.id={ent1_id} and r1.name='{rel1_name}' and r2.name='{rel2_name}' and ent2.id={ent2_id}"
135 | f"return DISTINCT target.name")
136 | logging.info(f"{cql}; *ent_name:{ent1_name},{ent2_name}")
137 | res = graph.run(cql)
138 | target_name = [ent['target.name'] for ent in res]
139 | return target_name
140 |
--------------------------------------------------------------------------------
/ckbqa/qa/qa.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from ckbqa.qa.algorithms import Algorithms
4 | from ckbqa.qa.cache import Memory
5 | from ckbqa.qa.el import EL
6 | from ckbqa.qa.neo4j_graph import GraphDB
7 | from ckbqa.qa.relation_extractor import RelationExtractor
8 | from ckbqa.utils.async_tools import async_init_singleton_class
9 |
10 |
11 | class QA(object):
12 | def __init__(self):
13 | async_init_singleton_class([Memory, GraphDB, RelationExtractor, EL])
14 | logging.info('QA init start ...')
15 | self.graph_db = GraphDB()
16 | self.memory = Memory()
17 | self.el = EL()
18 | self.relation_extractor = RelationExtractor()
19 | self.algo = Algorithms()
20 | logging.info('QA init Done ...')
21 |
22 | def query_path(self, path, direction='out'):
23 | if len(path) == 2:
24 | ent_name, rel_name = path
25 | answer = self.graph_db.search_by_2path(ent_name, rel_name, direction=direction)
26 | elif len(path) == 3:
27 | ent_name, rel1_name, rel2_name = path
28 | answer = self.graph_db.search_by_3path(ent_name, rel1_name, rel2_name, direction=direction)
29 | elif len(path) == 4:
30 | ent1_name, rel1_name, rel2_name, ent2_name = path
31 | answer = self.graph_db.search_by_4path(ent1_name, rel1_name, rel2_name, ent2_name, direction=direction)
32 | else:
33 | logging.info(f'这个查询路径不规范: {path}')
34 | answer = []
35 | return answer
36 |
37 | def run(self, q_text, return_candidates=False):
38 | logging.info(f"* q_text: {q_text}")
39 | candidate_entities = self.el.el(q_text)
40 | logging.info(f'* get_candidate_entities {len(candidate_entities)} ...')
41 | candidate_out_paths, candidate_in_paths = self.relation_extractor.get_ent_relations(q_text, candidate_entities)
42 | logging.info(f'* get candidate_out_paths: {len(candidate_out_paths)},'
43 | f'candidate_in_paths: {len(candidate_in_paths)} ...')
44 | # 生成cypher语句并查询
45 | top_out_path, max_out_score = self.algo.get_most_overlap_path(q_text, candidate_out_paths)
46 | logging.info(f'* get_most_overlap out path:{max_out_score:.4f}, {top_out_path} ...')
47 | top_in_path, max_in_score = self.algo.get_most_overlap_path(q_text, candidate_in_paths)
48 | logging.info(f'* get_most_overlap in path:{max_in_score:.4f}, {top_in_path} ...')
49 | # self.graph_db.cache() # 缓存neo4j查询结果
50 | if max_out_score >= max_in_score: # 分数相同时取out
51 | direction = 'out'
52 | top_path = top_out_path
53 | else:
54 | direction = 'in'
55 | top_path = top_in_path
56 | logging.info(f'* get_most_overlap_path: {top_path} ...')
57 | result_ents = self.query_path(top_path, direction=direction)
58 | if not result_ents and len(top_path) > 2:
59 | if direction == 'out':
60 | top_path = top_path[:2]
61 | else:
62 | top_path = top_path[-2:]
63 | result_ents = self.query_path(top_path)
64 | if not result_ents:
65 | top_path = top_path[0] + top_path[-1]
66 | result_ents = self.query_path(top_path, direction=direction)
67 | logging.info(f"* cypher result_ents: {result_ents}" + '\n' + '--' * 10 + '\n\n')
68 | if return_candidates:
69 | return result_ents, candidate_entities, candidate_out_paths, candidate_in_paths
70 | return result_ents
71 |
72 | # def evaluate(self, question, subject_entities, result_ents_entities):
73 | # q_text = question_patten.findall(question)[0]
74 | # candidate_entities = self.recognizer.get_candidate_entities(q_text)
75 | # precision, recall, f1 = get_metrics(subject_entities, candidate_entities)
76 | # logging.info(f'get_candidate_entities, precision: {precision}, recall: {recall}, f1: {f1}')
77 |
--------------------------------------------------------------------------------
/ckbqa/qa/relation_extractor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 | from typing import Tuple, Dict, List
4 |
5 | import numpy as np
6 |
7 | from ckbqa.models.relation_score.predictor import RelationScorePredictor
8 | from ckbqa.qa.neo4j_graph import GraphDB
9 | from ckbqa.utils.decorators import singleton
10 |
11 |
12 | @singleton
13 | class RelationExtractor(object):
14 |
15 | def __init__(self):
16 | self.graph_db = GraphDB()
17 | self.sim_predictor = RelationScorePredictor(model_name='bert_match') # bert_match,bert_match2
18 |
19 | def get_relations(self, candidate_entities, ent_name, direction='out') -> Tuple:
20 | onehop_relations = self.graph_db.get_onehop_relations_by_entName(ent_name, direction=direction)
21 | twohop_relations = self.graph_db.get_twohop_relations_by_entName(ent_name, direction=direction)
22 | mention = candidate_entities[ent_name]['mention']
23 | candidate_paths, candidate_sents = [], []
24 | for rel_name in onehop_relations:
25 | _candidate_path = '的'.join([mention, rel_name[1:-1]]) # 查询的关系有<>,需要去掉
26 | candidate_sents.append(_candidate_path)
27 | candidate_paths.append([ent_name, rel_name])
28 | for rels in twohop_relations:
29 | pure_rel_names = [rel_name[1:-1] for rel_name in rels] # python-list 关系名列表
30 | _candidate_path = '的'.join([mention] + pure_rel_names)
31 | candidate_sents.append(_candidate_path)
32 | candidate_paths.append([ent_name] + rels)
33 | return candidate_paths, candidate_sents
34 |
35 | def get_ent_relations(self, q_text: str, candidate_entities: Dict) -> Tuple:
36 | """
37 | :param q_text:
38 | :top_k 10 ; out 10, in 10
39 | :param candidate_entities: {ent:[mention, feature1, feature2, ...]}
40 | :return:
41 | """
42 | candidate_out_sents, candidate_out_paths = [], []
43 | candidate_in_sents, candidate_in_paths = [], []
44 | for entity in candidate_entities:
45 | candidate_paths, candidate_sents = self.get_relations(candidate_entities, entity, direction='out')
46 | candidate_out_sents.extend(candidate_sents)
47 | candidate_out_paths.extend(candidate_paths)
48 | candidate_paths, candidate_sents = self.get_relations(candidate_entities, entity, direction='in')
49 | candidate_in_sents.extend(candidate_sents)
50 | candidate_in_paths.extend(candidate_paths)
51 | if not candidate_out_sents and not candidate_in_sents:
52 | logging.info('* candidate_out_paths Empty ...')
53 | return candidate_out_paths, candidate_in_paths
54 | # 模型打分排序
55 | _out_paths = self.relation_score_topn(q_text, candidate_out_paths, candidate_out_sents)
56 | _in_paths = self.relation_score_topn(q_text, candidate_out_paths, candidate_out_sents)
57 | return _out_paths, _in_paths
58 |
59 | def relation_score_topn(self, q_text: str, candidate_paths, candidate_sents, top_k=10) -> List:
60 | top_k = len(candidate_sents) # TODO 后需删除;不做筛选,保留所有路径;目前筛选效果不好
61 | if top_k >= len(candidate_sents): # 太少则不作筛选
62 | return candidate_paths
63 | if candidate_sents:
64 | sim_scores = self.sim_predictor.predict(q_text, candidate_sents) # 目前算法不好,score==1
65 | sim_indexs = np.array(sim_scores).argsort()[-top_k:][::-1]
66 | _paths = [candidate_paths[i] for i in sim_indexs]
67 | else:
68 | _paths = []
69 | return _paths
70 | # for i in in_sim_indexs:
71 | # try:
72 | # candidate_in_paths[i]
73 | # except:
74 | # import traceback
75 | # logging.info(traceback.format_exc())
76 | # print(f'i: {i}, candidate_in_paths: {len(candidate_in_paths)}')
77 | # import ipdb
78 | # ipdb.set_trace()
79 |
--------------------------------------------------------------------------------
/ckbqa/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/ckbqa/utils/__init__.py
--------------------------------------------------------------------------------
/ckbqa/utils/async_tools.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import threading
3 |
4 | from ckbqa.utils.decorators import synchronized
5 |
6 |
7 | def apply_async(func, *args, daemon=False, **kwargs):
8 | thr = threading.Thread(target=func, args=args, kwargs=kwargs)
9 | if daemon:
10 | thr.setDaemon(True) # 随主线程一起结束
11 | thr.start()
12 | if not daemon:
13 | thr.join()
14 | return thr
15 |
16 |
17 | # for sc in A.__subclasses__():#所有子类
18 | # print(sc.__name__)
19 | @synchronized
20 | def async_init_singleton_class(classes=()):
21 | logging.info('async_init_singleton_class start ...')
22 | thrs = [threading.Thread(target=_singleton_class)
23 | for _singleton_class in classes]
24 | for t in thrs:
25 | t.start()
26 | for t in thrs:
27 | t.join() # 等待所有线程结束再往下继续
28 | logging.info('async_init_singleton_class done ...')
29 |
--------------------------------------------------------------------------------
/ckbqa/utils/decorators.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | import threading
4 | import traceback
5 | from functools import wraps
6 |
7 |
8 | def synchronized(func):
9 | func.__lock__ = threading.Lock()
10 |
11 | @wraps(func)
12 | def lock_func(*args, **kwargs):
13 | with func.__lock__:
14 | res = func(*args, **kwargs)
15 | return res
16 |
17 | return lock_func
18 |
19 |
20 | # singleton_classes = set()
21 |
22 |
23 | def singleton(cls):
24 | """
25 | 只初始化一次;new和init都只调用一次
26 | """
27 | instances = {} # 当类创建完成之后才有内容
28 | # singleton_classes.add(cls)
29 |
30 | @synchronized
31 | @wraps(cls)
32 | def get_instance(*args, **kw):
33 | if cls not in instances: # 保证只初始化一次
34 | logging.info(f"{cls.__name__} init start ...")
35 | instances[cls] = cls(*args, **kw)
36 | logging.info(f"{cls.__name__} init done ...")
37 | gc.collect()
38 | return instances[cls]
39 |
40 | return get_instance
41 |
42 |
43 | class Singleton(object):
44 | """
45 | 保证只new一次;但会初始化多次,每个子类的init方法都会被调用
46 | 会造成对象虽然是同一个,但因为会不断地调用init方法,对象属性被不断的修改
47 | """
48 | instance = None
49 |
50 | @synchronized
51 | def __new__(cls, *args, **kwargs):
52 | """
53 | :type kwargs: object
54 | """
55 | if cls.instance is None: # 保证只new一次;但会初始化多次,每个子类的init方法都会被调用,造成对象虽然是同一个但会不断地被修改
56 | cls.instance = super().__new__(cls)
57 | return cls.instance
58 |
59 |
60 | def try_catch_with_logging(default_response=None):
61 | def out_wrapper(func):
62 | @wraps(func)
63 | def wrapper(*args, **kwargs):
64 | try:
65 | res = func(*args, **kwargs)
66 | except Exception:
67 | res = default_response
68 | logging.error(traceback.format_exc())
69 | return res
70 |
71 | return wrapper
72 |
73 | return out_wrapper
74 |
--------------------------------------------------------------------------------
/ckbqa/utils/gpu_selector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import traceback
4 |
5 |
6 | def get_available_gpu(num_gpu=1, min_memory=1000, try_times=3, allow_gpus="0,1,2,3", verbose=False):
7 | '''
8 | :param num_gpu: number of GPU you want to use
9 | :param min_memory: minimum memory MiB
10 | :param sample: try_times
11 | :param all_gpu_nums: accessible gpu nums 允许挑选的GPU编号列表
12 | :param verbose: verbose mode
13 | :return: str of best choices, e.x. '1, 2'
14 | '''
15 | selected = ""
16 | while try_times:
17 | try_times -= 1
18 | info_text = os.popen('nvidia-smi --query-gpu=utilization.gpu,memory.free --format=csv').read()
19 | try:
20 | gpu_info = [(str(gpu_id), int(memory.replace('%', '').replace('MiB', '').split(',')[1].strip()))
21 | for gpu_id, memory in enumerate(info_text.split('\n')[1:-1])]
22 | except:
23 | if verbose:
24 | print(traceback.format_exc())
25 | return "Not found gpu info ..."
26 | gpu_info.sort(key=lambda info: info[1], reverse=True) # 内存从高到低排序
27 | avilable_gpu = []
28 | for gpu_id, memory in gpu_info:
29 | if gpu_id in allow_gpus:
30 | if memory >= min_memory:
31 | avilable_gpu.append(gpu_id)
32 | if avilable_gpu:
33 | selected = ",".join(avilable_gpu[:num_gpu])
34 | else:
35 | print('No GPU available, will retry after 2.0 seconds ...')
36 | time.sleep(2)
37 | continue
38 | if verbose:
39 | print('Available GPU List')
40 | first_line = [['id', 'utilization.gpu(%)', 'memory.free(MiB)']]
41 | matrix = first_line + avilable_gpu
42 | s = [[str(e) for e in row] for row in matrix]
43 | lens = [max(map(len, col)) for col in zip(*s)]
44 | fmt = '\t'.join('{{:{}}}'.format(x) for x in lens)
45 | table = [fmt.format(*row) for row in s]
46 | print('\n'.join(table))
47 | print('Select id #' + selected + ' for you.')
48 | return selected
49 |
--------------------------------------------------------------------------------
/ckbqa/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging.handlers
2 |
3 |
4 | def logging_config(logging_name='./run.log', stream_log=True, log_level="info"):
5 | """
6 | :param logging_name: log名
7 | :param stream_log: 是否把log信息输出到屏幕,标准输出
8 | :param level: fatal,error,warn,info,debug
9 | :return: None
10 | """
11 | log_handles = [logging.handlers.RotatingFileHandler(
12 | logging_name, maxBytes=20 * 1024 * 1024, backupCount=5, encoding='utf-8')]
13 | if stream_log:
14 | log_handles.append(logging.StreamHandler())
15 | logging_level = {"fatal": logging.FATAL, "error": logging.ERROR, "warn": logging.WARN,
16 | "info": logging.INFO, "debug": logging.DEBUG}[log_level]
17 | logging.basicConfig(
18 | handlers=log_handles,
19 | level=logging_level,
20 | format="%(asctime)s - %(levelname)s %(filename)s %(funcName)s %(lineno)s - %(message)s"
21 | )
22 |
23 | # if __name__ == "__main__":
24 | # logging_config("./test.log", stream_log=True) # ../../log/test.log
25 | # logging.info("标准输出 log ...")
26 | # logging.debug("hello")
27 |
--------------------------------------------------------------------------------
/ckbqa/utils/saver.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import torch
5 |
6 | from config import Config, ckpt_dir
7 |
8 |
9 | class Saver(object):
10 | def __init__(self, model_name):
11 | # self.dataset = dataset
12 | self.model_name = model_name
13 | self.model_dir = os.path.join(ckpt_dir, model_name)
14 |
15 | def load_model(self, model: torch.nn.Module, mode="max_step", fail_ok=False, map_location=Config.device):
16 | model_path = os.path.join(self.model_dir, mode, f"{self.model_name}.bin")
17 | if os.path.isfile(model_path):
18 | ckpt = torch.load(model_path, map_location=map_location)
19 | model.load_state_dict(ckpt["net"])
20 | step = ckpt["step"]
21 | epoch = ckpt["epoch"]
22 | logging.info("* load model from file: {},epoch:{}, step:{}".format(model_path, epoch, step))
23 | else:
24 | if fail_ok:
25 | epoch = 0
26 | step = 0
27 | else:
28 | raise ValueError(f'model path : {model_path} is not exist')
29 | logging.info("* Fail load model from file: {}".format(model_path))
30 | return model_path, epoch, step
31 |
32 | def save(self, model, epoch, step=-1, mode="max_step", parms_dic: dict = None):
33 | model_path = os.path.join(self.model_dir, mode, f"{self.model_name}.bin")
34 | os.makedirs(os.path.dirname(model_path), exist_ok=True)
35 | state = {"net": model.state_dict(), 'epoch': epoch, "step": step}
36 | if isinstance(parms_dic, dict):
37 | state.update(parms_dic)
38 | torch.save(state, model_path)
39 | return model_path
40 |
--------------------------------------------------------------------------------
/ckbqa/utils/sequence.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import six
3 |
4 |
5 | def pad_sequences(sequences, maxlen=None, dtype='int32',
6 | padding='pre', truncating='pre', value=0.):
7 | """Pads sequences to the same length.
8 |
9 | This function transforms a list of
10 | `num_samples` sequences (lists of integers)
11 | into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
12 | `num_timesteps` is either the `maxlen` argument if provided,
13 | or the length of the longest sequence otherwise.
14 |
15 | Sequences that are shorter than `num_timesteps`
16 | are padded with `value` at the end.
17 |
18 | Sequences longer than `num_timesteps` are truncated
19 | so that they fit the desired length.
20 | The position where padding or truncation happens is determined by
21 | the arguments `padding` and `truncating`, respectively.
22 |
23 | Pre-padding is the default.
24 |
25 | # Arguments
26 | sequences: List of lists, where each element is a sequence.
27 | maxlen: Int, maximum length of all sequences.
28 | dtype: Type of the output sequences.
29 | To pad sequences with variable length strings, you can use `object`.
30 | padding: String, 'pre' or 'post':
31 | pad either before or after each sequence.
32 | truncating: String, 'pre' or 'post':
33 | remove values from sequences larger than
34 | `maxlen`, either at the beginning or at the end of the sequences.
35 | value: Float or String, padding value.
36 |
37 | # Returns
38 | x: Numpy array with shape `(len(sequences), maxlen)`
39 |
40 | # Raises
41 | ValueError: In case of invalid values for `truncating` or `padding`,
42 | or in case of invalid shape for a `sequences` entry.
43 | """
44 | if not hasattr(sequences, '__len__'):
45 | raise ValueError('`sequences` must be iterable.')
46 | num_samples = len(sequences)
47 | lengths = []
48 | for x in sequences:
49 | try:
50 | lengths.append(len(x))
51 | except TypeError:
52 | raise ValueError('`sequences` must be a list of iterables. '
53 | 'Found non-iterable: ' + str(x))
54 |
55 | if maxlen is None:
56 | maxlen = np.max(lengths)
57 |
58 | # take the sample shape from the first non empty sequence
59 | # checking for consistency in the main loop below.
60 | sample_shape = tuple()
61 | for s in sequences:
62 | if len(s) > 0:
63 | sample_shape = np.asarray(s).shape[1:]
64 | break
65 |
66 | is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
67 | if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
68 | raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
69 | "You should set `dtype=object` for variable length strings."
70 | .format(dtype, type(value)))
71 |
72 | x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
73 | for idx, s in enumerate(sequences):
74 | if not len(s):
75 | continue # empty list/array was found
76 | if truncating == 'pre':
77 | trunc = s[-maxlen:]
78 | elif truncating == 'post':
79 | trunc = s[:maxlen]
80 | else:
81 | raise ValueError('Truncating type "%s" '
82 | 'not understood' % truncating)
83 |
84 | # check `trunc` has expected shape
85 | trunc = np.asarray(trunc, dtype=dtype)
86 | if trunc.shape[1:] != sample_shape:
87 | raise ValueError('Shape of sample %s of sequence at position %s '
88 | 'is different from expected shape %s' %
89 | (trunc.shape[1:], idx, sample_shape))
90 |
91 | if padding == 'post':
92 | x[idx, :len(trunc)] = trunc
93 | elif padding == 'pre':
94 | x[idx, -len(trunc):] = trunc
95 | else:
96 | raise ValueError('Padding type "%s" not understood' % padding)
97 | return x
98 |
--------------------------------------------------------------------------------
/ckbqa/utils/tools.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | import os
4 | import pickle
5 | import platform
6 | import re
7 | import sys
8 | import threading
9 | import time
10 | import traceback
11 |
12 | import orjson # faster than json,ujson
13 | import psutil
14 | from tqdm import tqdm
15 |
16 |
17 | def pkl_load(file_path: str):
18 | with open(file_path, 'rb') as f:
19 | gc.disable()
20 | obj = pickle.load(f)
21 | gc.enable()
22 | return obj
23 |
24 |
25 | def pkl_dump(obj: object, file_path: str):
26 | # limit = {'default': sys.getrecursionlimit(),
27 | # 'common': 10 * 10000,
28 | # 'max': resource.getrlimit(resource.RLIMIT_STACK)[0]
29 | # }
30 | # sys.setrecursionlimit(limit[recursionlimit])
31 | with open(file_path, 'wb') as f:
32 | gc.disable()
33 | pickle.dump(obj, f)
34 | gc.enable()
35 |
36 |
37 | def json_load(path):
38 | with open(path, 'rb') as f:
39 | obj = orjson.loads(f.read())
40 | gc.collect()
41 | return obj
42 |
43 |
44 | def json_dump(dict_data, save_path, override_exist=True):
45 | if override_exist or not os.path.isfile(save_path):
46 | strs = orjson.dumps(dict_data, option=orjson.OPT_INDENT_2)
47 | with open(save_path, "wb") as f:
48 | f.write(strs)
49 | # if save_memory:
50 | # json.dump(dict_data, f, ensure_ascii=False)
51 | # else:
52 | # json.dump(dict_data, f, ensure_ascii=False, indent=indent, sort_keys=sort_keys)
53 |
54 |
55 | def get_file_linenums(file_name):
56 | if platform.system() in ['Linux', 'Darwin']: # Linux,Mac
57 | num_str = os.popen(f'wc -l {file_name}').read()
58 | line_num = int(re.findall('\d+', num_str)[0])
59 | else: # Windows
60 | line_num = sum([1 for _ in open(file_name, encoding='utf-8')])
61 | return line_num
62 |
63 |
64 | def tqdm_iter_file(file_path, prefix=''):
65 | line_num = get_file_linenums(file_path)
66 | with open(file_path, 'r', encoding='utf-8') as f:
67 | for line in tqdm(f, total=line_num, desc=prefix):
68 | yield line
69 |
70 |
71 | def byte2human(bytes, unit='B', precision=2):
72 | unit_mount_map = {'B': 1, 'KB': 1024, 'MB': 1024 * 1024, 'GB': 1024 * 1024 * 1024}
73 | memo = bytes / unit_mount_map[unit]
74 | memo = round(memo, precision)
75 | return memo
76 |
77 |
78 | def get_var_size(var, unit='B'):
79 | size = sys.getsizeof(var)
80 | readable_size = f"{byte2human(size):.2f} {unit}"
81 | return readable_size
82 |
83 |
84 | class ShowTime(object):
85 | '''
86 | 用上下文管理器计时
87 | '''
88 |
89 | def __init__(self, prefix=""):
90 | self.prefix = prefix
91 |
92 | def __enter__(self):
93 | self.start_timestamp = time.time()
94 | return self
95 |
96 | def __exit__(self, exc_type, exc_val, exc_tb):
97 | self.runtime = time.time() - self.start_timestamp
98 | print("{} take time: {:.2f} s".format(self.prefix, self.runtime))
99 | if exc_type is not None:
100 | print(exc_type, exc_val, exc_tb)
101 | print(traceback.format_exc())
102 | return self
103 |
104 |
105 | class ProcessManager(object):
106 | def __init__(self, check_secends=20, memo_unit='GB', precision=2):
107 | self.pid = os.getpid()
108 | self.p = psutil.Process(self.pid)
109 | self.check_secends = check_secends
110 | self.memo_unit = memo_unit
111 | self.precision = precision
112 | self.start_time = time.time()
113 |
114 | def kill(self):
115 | child_poll = self.p.children(recursive=True)
116 | for p in child_poll:
117 | if not 'SYSTEM' in p.username:
118 | print(f'kill sub process: PID: {p.pid} user: {p.username()} name: {p.name()}')
119 | p.kill()
120 | self.p.kill()
121 | print(f'kill {self.pid}')
122 |
123 | def get_memory_info(self):
124 | memo = byte2human(self.p.memory_info().rss, self.memo_unit)
125 | info = psutil.virtual_memory()
126 | total_memo = byte2human(info.total, self.memo_unit)
127 | used = byte2human(info.used, self.memo_unit)
128 | free = byte2human(info.free, self.memo_unit)
129 | available = byte2human(info.available, self.memo_unit)
130 | cur_pid_percent = info.percent
131 | return memo, used, free, available, total_memo, cur_pid_percent
132 |
133 | def task(self):
134 | while True:
135 | memo, used, free, available, total_memo, cur_pid_percent = self.get_memory_info()
136 | print('--' * 20)
137 | print(f'PID: {self.pid} name: {self.p.name()}')
138 | print(f'当前进程内存占用 :\t {memo:.2f} {self.memo_unit}')
139 | print(f'used :\t {used:.2f} {self.memo_unit}')
140 | print(f'free :\t {free:.2f} {self.memo_unit}')
141 | print(f'total :\t {total_memo} {self.memo_unit}')
142 | print(f'内存占比 :\t {cur_pid_percent} %')
143 | print(f'运行时间 :\t {(time.time() - self.start_time) / 60:.2f} min')
144 | # print('cpu个数:', psutil.cpu_count())
145 | if cur_pid_percent > 95:
146 | logging.info(f'内存占比过高: {cur_pid_percent}%, kill {self.pid}')
147 | self.kill() # 停止进程
148 | time.sleep(self.check_secends)
149 |
150 | def run(self):
151 | thr = threading.Thread(target=self.task)
152 | thr.setDaemon(True) # 跟随主线程结束
153 | thr.start()
154 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 | import arrow
6 | import torch
7 |
8 | cur_dir = os.path.dirname(os.path.abspath(__file__))
9 | root_dir = cur_dir
10 | data_dir = os.path.join(root_dir, "data")
11 | raw_data_dir = os.path.join(data_dir, 'raw_data')
12 |
13 | output_dir = os.path.join(root_dir, "output")
14 | ckpt_dir = os.path.join(output_dir, "ckpt")
15 | result_dir = os.path.join(output_dir, 'result')
16 |
17 | # 原始数据
18 | mention2ent_txt = os.path.join(raw_data_dir, 'PKUBASE', 'pkubase-mention2ent.txt')
19 | kb_triples_txt = os.path.join(raw_data_dir, 'PKUBASE', 'pkubase-complete2.txt')
20 | # 问答原始数据
21 | raw_train_txt = os.path.join(raw_data_dir, 'ccks_2020_7_4_Data', 'task1-4_train_2020.txt')
22 | valid_question_txt = os.path.join(raw_data_dir, 'ccks_2020_7_4_Data', 'task1-4_valid_2020.questions')
23 |
24 |
25 | class DataConfig(object):
26 | """
27 | 原始数据经过处理后生成的数据
28 | """
29 | word2id_json = os.path.join(data_dir, 'word2id.json')
30 | q_entity2id_json = os.path.join(data_dir, 'q_entity2id.json')
31 | a_entity2id_json = os.path.join(data_dir, 'a_entity2id.json')
32 | #
33 | data_csv = os.path.join(data_dir, 'data.csv') # 训练数据做了一点格式转换
34 | #
35 | mention2ent_json = os.path.join(data_dir, 'mention2ent.json')
36 | ent2mention_json = os.path.join(data_dir, 'ent2mention.json')
37 | entity2id = os.path.join(data_dir, 'entity2id.json')
38 | id2entity_pkl = os.path.join(data_dir, 'id2entity.pkl')
39 | relation2id = os.path.join(data_dir, 'relation2id.json')
40 | id2relation_pkl = os.path.join(data_dir, 'id2relation.pkl')
41 | # count
42 | entity2count_json = os.path.join(data_dir, 'entity2count.json')
43 | relation2count_json = os.path.join(data_dir, 'relation2count.json')
44 | mention2count_json = os.path.join(data_dir, 'mention2count.json')
45 | #
46 | lac_custom_dict_txt = os.path.join(data_dir, 'lac_custom_dict.txt')
47 | lac_attr_custom_dict_txt = os.path.join(data_dir, 'lac_attr_custom_dict.txt')
48 | jieba_custom_dict = os.path.join(data_dir, 'jieba_custom_dict.json')
49 | # graph_pkl = os.path.join(data_dir, 'graph.pkl')
50 | graph_entity_csv = os.path.join(data_dir, 'graph_entity.csv') # 图谱导入
51 | graph_relation_csv = os.path.join(data_dir, 'graph_relation.csv') # 图谱导入
52 | entity2types_json = os.path.join(data_dir, 'entity2type.json')
53 | entity2attrs_json = os.path.join(data_dir, 'entity2attr.json')
54 | all_attrs_json = os.path.join(data_dir, 'all_attrs.json') # 所有属性
55 | #
56 | lac_model_pkl = os.path.join(data_dir, 'lac_model.pkl')
57 | # EntityScore model
58 | entity_score_model_pkl = os.path.join(data_dir, 'entity_score_model.pkl')
59 | entity_score_data_pkl = os.path.join(data_dir, 'entity_score_data.pkl')
60 | #
61 | neo4j_query_cache = os.path.join(data_dir, 'neo4j_query_cache.json')
62 |
63 | #
64 | relation_score_sample_csv = os.path.join(data_dir, 'sample.csv')
65 |
66 | @staticmethod
67 | def get_relation_score_sample_csv(data_type, neg_rate):
68 | if data_type == 'train':
69 | file_path = Path(DataConfig.relation_score_sample_csv).with_name(f'train.1_{neg_rate}.csv')
70 | else:
71 | file_path = Path(DataConfig.relation_score_sample_csv).with_name('test.csv')
72 | return str(file_path)
73 |
74 |
75 | class TorchConfig(object):
76 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
77 | device = torch.device('cpu')
78 | # device = "cpu"
79 | gpu_nums = torch.cuda.device_count()
80 | multi_gpu = True
81 | gradient_accumulation_steps = 1
82 | clip_grad = 2
83 |
84 |
85 | class Parms(object):
86 | #
87 | learning_rate = 0.001
88 | # #
89 | # min_epoch_nums = 1
90 | # max_epoch_nums = 10
91 | # #
92 | # embedding_dim = 128 # entity enbedding dim, relation enbedding dim , word enbedding dim
93 | max_len = 50 # max sentence length
94 | # batch_size = 32
95 | # # subtask = 'general'
96 | # test_batch_size = 128
97 |
98 |
99 | class Config(TorchConfig, DataConfig, Parms):
100 | pretrained_model_name_or_path = os.path.join(data_dir, 'bert-base-chinese-pytorch') # 'bert-base-chinese'
101 | # load_pretrain = True
102 | # rand_seed = 1234
103 | # load_model_mode = "min_loss"
104 | # load_model_mode = "max_step"
105 | # load_model_mode = "max_acc" # mrr
106 | #
107 | # train_count = 1000 # TODO for debug
108 | # test_count = 10 # 10*2*13589
109 |
110 |
111 | class ResultSaver(object):
112 | """输出文件管理;自动生成新文件名;避免覆盖
113 | 自动查找已存在的文件
114 | """
115 |
116 | def __init__(self, find_exist_path=False):
117 | os.makedirs(result_dir, exist_ok=True)
118 | self.find_exist_path = find_exist_path
119 |
120 | def _get_new_path(self, file_name):
121 | date_str = arrow.now().format("YYYYMMDD")
122 | # date_str = '20200609' #临时修改
123 | num = 1
124 | path = os.path.join(result_dir, f"{date_str}-{num}-{file_name}")
125 | while os.path.isfile(path):
126 | path = os.path.join(result_dir, f"{date_str}-{num}-{file_name}")
127 | num += 1
128 | return path
129 |
130 | def _find_paths(self, file_name):
131 | paths = [str(_path) for _path in
132 | Path(result_dir).rglob(f'*{file_name}')]
133 | _paths = sorted(paths, reverse=True)
134 | return _paths
135 |
136 | def get_path(self, file_name):
137 | if self.find_exist_path:
138 | path = self._find_paths(file_name)
139 | else:
140 | path = self._get_new_path(file_name)
141 | logging.info(f'* get path: {path}')
142 | return path
143 |
144 | @property
145 | def train_result_csv(self):
146 | file_name = 'train_answer_result.csv'
147 | path = self.get_path(file_name)
148 | return path
149 |
150 | @property
151 | def valid_result_csv(self):
152 | file_name = 'valid_result.csv'
153 | path = self.get_path(file_name)
154 | return path
155 |
156 | @property
157 | def submit_result_txt(self):
158 | file_name = 'submit_result.txt'
159 | path = self.get_path(file_name)
160 | return path
161 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from ckbqa.utils.logger import logging_config
4 |
5 |
6 | def create_db_tabels():
7 | from sqlalchemy_utils import database_exists, create_database
8 | from ckbqa.dao.db import sqlite_db_engine
9 | if not database_exists(sqlite_db_engine.url):
10 | create_database(sqlite_db_engine.url)
11 | from ckbqa.dao.sqlite_models import BaseModel
12 | BaseModel.metadata.create_all(sqlite_db_engine) # 创建表
13 |
14 |
15 | def data_prepare():
16 | logging_config('data_prepare.log', stream_log=True)
17 | from ckbqa.dataset.data_prepare import fit_on_texts, data_convert
18 | # map_mention_entity()
19 | # data2samples(neg_rate=3)
20 | data_convert()
21 | fit_on_texts()
22 | create_db_tabels()
23 |
24 |
25 | def kb_data_prepare():
26 | logging_config('kb_data_prepare.log', stream_log=True)
27 | from ckbqa.dataset.kb_data_prepare import (candidate_words, fit_triples)
28 | from ckbqa.dataset.kb_data_prepare import create_graph_csv
29 | fit_triples() # 生成字典
30 | candidate_words() # 属性
31 | # create_lac_custom_dict() # 自定义分词词典
32 |
33 | create_graph_csv() # 生成数据库导入文件
34 | # from examples.lac_test import lac_model
35 | # lac_model()
36 |
37 |
38 | def main():
39 | parser = argparse.ArgumentParser(description="基础,通用parser")
40 | # logging config 日志配置
41 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕
42 | #
43 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个
44 |
45 | group.add_argument('--data_prepare', action="store_true", help="训练集数据预处理")
46 | group.add_argument('--kb_data_prepare', action="store_true", help="知识库数据预处理")
47 | group.add_argument('--task', action="store_true", help="临时组织的任务,调用多个函数")
48 | # parse args
49 | args = parser.parse_args()
50 | #
51 | # from ckbqa.utils.tools import ProcessManager
52 | # ProcessManager().run()
53 | if args.data_prepare:
54 | data_prepare()
55 | elif args.kb_data_prepare:
56 | kb_data_prepare()
57 | elif args.task:
58 | task()
59 |
60 |
61 | def task():
62 | pass
63 |
64 |
65 | if __name__ == '__main__':
66 | """ 代码执行入口
67 | examples:
68 | python manage.py --data_prepare
69 | nohup python manage.py --kb_data_prepare &>kb_data_prepare.out &
70 | """
71 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况
72 | # ProcessManager().run()
73 | main()
74 |
--------------------------------------------------------------------------------
/docs/CCKS 2019 知识图谱评测技术报告:实体、关系、事件及问答 .pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/docs/CCKS 2019 知识图谱评测技术报告:实体、关系、事件及问答 .pdf
--------------------------------------------------------------------------------
/docs/bad_case.md:
--------------------------------------------------------------------------------
1 | # bad case
2 |
3 | ## 1. 主实体识别错误
4 |
5 | question | 错误描述 | 解决办法|是否解决|
6 | ---|---|---|---|
7 | 被誉为万岛之国的是哪个国家? | 切分错误 | 增加词典or NER |
8 | LOF的全称是什么?|切分对了,mention:但LOF无对应实体||
9 | 恒铭达的全称是什么?|库里面就不存在这个实体|||
10 | 列举唐朝的代表诗人?|库里不存在这个实体|||
11 | 迪梵是哪个公司旗下的品牌 ?|已近正确,但是训练集实体和给定图谱不一致(训练集:,知识库:)|||
12 | K2属于哪个山脉?|mention2ent里面无提及|||
13 | 哪位作家的原名叫张煐?|需要内容理解,已经识别出张煐等,但不是标准实体|||
14 | 谁下令修建了万里长城?|需要内容理解,已经识别出万里长城等,但不是标准实体|||
15 | 谁废除了分封制?|提供的标准划分,根本不存在|||
16 | 被称为冰上沙皇的是哪位运动员?|需要ner标注,但mention里没有,需要增加mention|||
17 | Hermès的总部在哪里?||||
18 | 《水浒传》所属文学体裁的三要素是什么?|已经识别出水浒传,但标注为<水浒传(竖版)>|||
19 |
20 | - 改进点:
21 | - 增加mention2entity的内容(例子看k2没有指向;IBM没有指向),
22 | 原名->别名
23 | 别号->别名
24 |
25 |
26 | ## 2. 关系分类错误
27 | question | 错误描述 | 解决办法|是否解决|
28 | ---|---|---|---|
29 | 超级碗是什么体育项目联赛的年度冠军赛? | 分类后主实体不见了 | 分类算法 |
30 |
31 |
32 |
33 |
34 | ## 3. 后处理错误
35 |
36 | question | 错误描述 | 解决办法|是否解决|
37 | ---|---|---|---|
38 | 叶文洁毕业于哪个大学? | 主实体不见了,使用了最长匹配 | 增加词典or NER |
39 |
40 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 |
4 | import pandas as pd
5 | from tqdm import tqdm
6 |
7 | from ckbqa.utils.logger import logging_config
8 | from config import ResultSaver
9 |
10 |
11 | def train_data():
12 | logging_config('train_evaluate.log', stream_log=True)
13 | from ckbqa.models.evaluation_matrics import get_metrics
14 | #
15 | partten = re.compile(r'["<](.*?)[>"]')
16 | #
17 | _paths = ResultSaver(find_exist_path=True).train_result_csv
18 | print(_paths)
19 | train_df = pd.read_csv(_paths[0])
20 | ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
21 | answer_precisions, answer_recalls, answer_f1_scores = [], [], []
22 | for index, row in tqdm(train_df.iterrows(), total=train_df.shape[0], desc='evaluate '):
23 | subject_entities = partten.findall(row['standard_subject_entities']) # 匹配文字
24 | if not subject_entities:
25 | subject_entities = eval(row['standard_subject_entities'])
26 | # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
27 | # 所以需要匹配文字,不匹配 <>, ""
28 | # CEG Candidate Entity Generation
29 | candidate_entities = eval(row['candidate_entities']) + partten.findall(row['candidate_entities'])
30 | precision, recall, f1 = get_metrics(subject_entities, candidate_entities)
31 | ceg_precisions.append(precision)
32 | ceg_recalls.append(recall)
33 | ceg_f1_scores.append(f1)
34 | # Answer
35 | standard_entities = eval(row['standard_answer_entities'])
36 | result_entities = eval(row['result_entities'])
37 | precision, recall, f1 = get_metrics(standard_entities, result_entities)
38 | answer_precisions.append(precision)
39 | answer_recalls.append(recall)
40 | answer_f1_scores.append(f1)
41 | #
42 | # print(f"question: {row['question']}\n"
43 | # f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}"
44 | # f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
45 | # import time
46 | # time.sleep(2)
47 | ave_ceg_precision = sum(ceg_precisions) / len(ceg_precisions)
48 | ave_ceg_recall = sum(ceg_recalls) / len(ceg_recalls)
49 | ave_ceg_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
50 | print(f"ave_ceg_precision: {ave_ceg_precision:.3f}, "
51 | f"ave_ceg_recall: {ave_ceg_recall:.3f}, "
52 | f"ave_ceg_f1_score:{ave_ceg_f1_score:.3f}")
53 | #
54 | ave_answer_precision = sum(answer_precisions) / len(answer_precisions)
55 | ave_answer_recall = sum(answer_recalls) / len(answer_recalls)
56 | ave_answer_f1_score = sum(answer_f1_scores) / len(answer_f1_scores)
57 | print(f"ave_result_precision: {ave_answer_precision:.3f}, "
58 | f"ave_result_recall: {ave_answer_recall:.3f}, "
59 | f"ave_result_f1_score:{ave_answer_f1_score:.3f}")
60 |
61 |
62 | def ceg():
63 | logging_config('train_evaluate.log', stream_log=True)
64 | from ckbqa.models.evaluation_matrics import get_metrics
65 | from ckbqa.qa.el import CEG
66 | from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern #
67 | ceg = CEG() # Candidate Entity Generation
68 | ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
69 | ceg_csv = "./ceg.csv"
70 | data = []
71 | for q, sparql, a in load_data(tqdm_prefix='ceg evaluate '):
72 | q_entities = entity_pattern.findall(sparql) + attr_pattern.findall(sparql)
73 |
74 | q_text = ''.join(question_patten.findall(q))
75 | # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
76 | # 所以需要匹配文字,不匹配 <>, ""
77 | ent2mention = ceg.get_ent2mention(q_text)
78 | # CEG Candidate Entity Generation
79 | precision, recall, f1 = get_metrics(q_entities, ent2mention)
80 | ceg_precisions.append(precision)
81 | ceg_recalls.append(recall)
82 | ceg_f1_scores.append(f1)
83 | #
84 | data.append([q, q_entities, list(ent2mention.keys())])
85 | if recall == 0:
86 | # ceg.memory.entity2id
87 | # ceg.memory.mention2entity
88 | print(f"question: {q}\n"
89 | f"subject_entities: {q_entities}, candidate_entities: {ent2mention}"
90 | f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
91 | # import ipdb
92 | # ipdb.set_trace()
93 | print('\n\n')
94 | # import time
95 | # time.sleep(2)
96 | pd.DataFrame(data, columns=['question', 'q_entities', 'ceg']).to_csv(
97 | ceg_csv, index=False, encoding='utf_8_sig')
98 | ave_precision = sum(ceg_precisions) / len(ceg_precisions)
99 | ave_recall = sum(ceg_recalls) / len(ceg_recalls)
100 | ave_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
101 | print(f"ave_precision: {ave_precision:.3f}, "
102 | f"ave_recall: {ave_recall:.3f}, "
103 | f"ave_f1_score:{ave_f1_score:.3f}")
104 |
105 |
106 | def main():
107 | parser = argparse.ArgumentParser(description="基础,通用parser")
108 | # logging config 日志配置
109 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕
110 | #
111 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个
112 |
113 | group.add_argument('--ceg', action="store_true", help="ceg Candidate Entity Generation评价")
114 | group.add_argument('--train_data', action="store_true", help="train_answer_data评价")
115 | # parse args
116 | args = parser.parse_args()
117 | #
118 | # from ckbqa.utils.tools import ProcessManager
119 | # ProcessManager().run()
120 | if args.ceg:
121 | ceg()
122 | elif args.train_data:
123 | train_data()
124 | elif args.task:
125 | task()
126 |
127 |
128 | def task():
129 | logging_config('ceg.log', stream_log=True)
130 | ceg()
131 |
132 |
133 | if __name__ == '__main__':
134 | """
135 | example:
136 | nohup python qa.py --train_data &>train_data.out &
137 | nohup python qa.py --ceg &>ceg.out &
138 | """
139 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况
140 | # ProcessManager().run()
141 | main()
142 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | def add_root_path():
3 | import sys
4 | from pathlib import Path
5 | cur_dir = Path(__file__).absolute().parent
6 | times = 10
7 | while not str(cur_dir).endswith('ccks-2020') and times:
8 | cur_dir = cur_dir.parent
9 | times -= 1
10 | print(cur_dir)
11 | sys.path.append(str(cur_dir))
12 |
13 |
14 | add_root_path()
15 |
--------------------------------------------------------------------------------
/examples/answer_format.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 |
4 | def format():
5 | data_df = pd.read_csv('./answer.txt')
6 |
7 | def func(text: str):
8 | if isinstance(text, str):
9 | ents = []
10 | for ent in set(text.split('\t')):
11 | if ent.startswith('<') or ent.startswith('"'):
12 | ents.append(ent)
13 | else:
14 | ents.append('"' + ent + '"')
15 | print('"' + ent + '"')
16 | return '\t'.join(ents)
17 | else:
18 | print(text)
19 | return ""
20 |
21 | with open('./result.txt', 'w', encoding='utf-8') as f:
22 | for answer in data_df['answer']:
23 | line = func(answer)
24 | f.write(line + '\n')
25 | # data_df[['commit']].to_csv('./result.txt', encoding='utf_8_sig', index=False)
26 |
27 |
28 | if __name__ == '__main__':
29 | format()
30 |
--------------------------------------------------------------------------------
/examples/bad_case.py:
--------------------------------------------------------------------------------
1 | '''
2 |
3 | 2020-06-09 00:32:53,712 - INFO qa.py run 59 - * get_most_overlap_path: ['<吉林大学>', '<校歌>'] ...
4 | 2020-06-09 00:32:53,712 - INFO neo4j_graph.py search_by_2path 83 - match (ent:Entity)-[r1:Relation]-(target) where ent.id=5199304 and r1.name='<校歌>' return DISTINCT target.name
5 | 2020-06-09 00:32:53,892 - INFO qa.py run 72 - * cypher answer: ['<吉林大学校歌>']
6 | '''
7 |
8 | '''
9 | 没有执行到 qa.run内部,搞不懂
10 |
11 |
12 | question : q185:林清玄有哪些别名?
13 | sparql : select ?x where { <林清玄> <别名> ?x. }
14 | standard answer : "秦情、林漓、林大悲等"
15 | ****************************************************************************************************
16 |
17 | question : q186:北京大学的发展定位是什么
18 | sparql : select ?x where { <北京大学> <发展定位> ?x . }
19 | standard answer : <世界一流大学>
20 | ****************************************************************************************************
21 |
22 | question : q187:中华人民共和国的国家领导人是谁?
23 | sparql : select ?x where { <中华人民共和国> <国家领袖> ?x . }
24 | standard answer : <习近平> <中华人民共和国主席> <中华人民共和国国务院总理> <李克强_(中共中 央政治局常委,国务院总理)> <张德江_(中央政治局常委)> <中华人民共和国全国人民代表大会常务委员会委员长>
25 |
26 | ****************************************************************************************************
27 | question : q188:墨冰仙是哪个门派的?
28 | sparql : select ?x where { <墨冰仙> <门派> ?x. }
29 | standard answer : <蜀山派_(游戏《仙剑奇侠传》中的武林门派)>
30 |
31 |
32 |
33 | ****************************************************************************************************
34 | question : q189:在金庸小说《天龙八部》中,斗转星移的修习者是谁?
35 | sparql : select ?x where { <斗转星移_(金庸小说武功心法)> <主要修习者> ?x. }
36 | standard answer : "慕容博" "慕容龙城" "慕容复"
37 |
38 |
39 |
40 | ****************************************************************************************************
41 | question : q190:《基督山伯爵》的作者是谁?
42 | sparql : select ?x where { <基督山伯爵_(大仲马代表作之一)> <作者> ?x. }
43 | standard answer : <亚历山大·仲马_(法国作家、大仲马)>
44 |
45 |
46 |
47 | ****************************************************************************************************
48 | question : q191:奥运会多久举办一次?
49 | sparql : select ?x where { <奥林匹克运动会> <举办时间> ?x. }
50 | standard answer : "每四年一届"
51 |
52 |
53 |
54 | ****************************************************************************************************
55 | question : q192:音乐剧歌剧魅影的主要演员?
56 | sparql : select ?x where { <歌剧魅影_(安德鲁·劳埃德·韦伯创作的音乐剧)> <主演> ?x. }
57 | standard answer : <莎拉·布莱曼>
58 |
59 |
60 |
61 | ****************************************************************************************************
62 | question : q193:索尼公司的经营范围有哪些?
63 | sparql : select ?x where { <索尼_(世界著名企业)> <经营范围> ?x. }
64 | standard answer : <信息技术_(用于管理和处理信息所采用各种技术总称)> <金融_(专业词语汉语词语 )> <电子_(基本粒子之一)> <娱乐_(汉语词语)>
65 |
66 |
67 |
68 | ****************************************************************************************************
69 | question : q194:2014年世界杯的冠军是哪只队?
70 | sparql : select ?x where { <2014年巴西世界杯> <冠军> ?x. }
71 | standard answer : <德国国家男子足球队>
72 |
73 |
74 |
75 | ****************************************************************************************************
76 | question : q195:泰铢的通胀率是多少?
77 | sparql : select ?x where { <泰铢> <通胀率> ?x. }
78 | standard answer : "5.1%"
79 |
80 |
81 |
82 | ****************************************************************************************************
83 | question : q196:小龙虾的中文学名?
84 | sparql : select ?x where { <小龙虾_(甲壳纲螯虾科水生动物)> <中文学名> ?x. }
85 | standard answer : "克氏原螯虾"
86 |
87 |
88 |
89 | ****************************************************************************************************
90 | question : q197:纽约的气候条件是怎样的?
91 | sparql : select ?x where { <纽约_(美国第一大城市)> <气候条件> ?x. }
92 | standard answer : <温带大陆性气候>
93 |
94 |
95 |
96 | ****************************************************************************************************
97 | question : q198:海贼王动画里黑胡子的声优是谁?
98 | sparql : select ?x where { <黑胡子_(动漫《海贼王》中人物)> <配音> ?x. }
99 | standard answer : <大冢明夫>
100 |
101 |
102 |
103 | ****************************************************************************************************
104 | question : q199:大连在哪个省份?
105 | sparql : select ?x where { <大连_(辽宁省辖市)> <所属地区> ?x. }
106 | standard answer : <辽宁>
107 |
108 |
109 |
110 | ****************************************************************************************************
111 | question : q200:新加坡的主要学府有哪些?
112 | sparql : select ?x where { <新加坡> <主要学府> ?x. }
113 | standard answer : "新加坡国立大学" "南洋理工大学"
114 |
115 |
116 |
117 | ****************************************************************************************************
118 | question : q201:大禹建立了哪个朝代?
119 | sparql : select ?x where { ?x <开国君主> <禹_(夏朝开国君主)>. }
120 | standard answer : <夏_(中国第一个朝代)>
121 |
122 |
123 |
124 | ****************************************************************************************************
125 | question : q202:菲尔·奈特创立了什么品牌?
126 | sparql : select ?x where { ?x <创始人> <菲尔·奈特>. }
127 | standard answer :
128 |
129 |
130 |
131 | ****************************************************************************************************
132 | question : q203:戊戌变法开始的标志是什么?
133 | sparql : select ?x where { <戊戌变法> <序幕> ?x. }
134 | standard answer : "公车上书"
135 |
136 |
137 |
138 | ****************************************************************************************************
139 | question : q204:肯德基的口号是什么?
140 | sparql : select ?x where { <肯德基> <公司口号> ?x . }
141 | standard answer : "生活如此多娇"
142 |
143 |
144 |
145 | ****************************************************************************************************
146 | question : q205:你知道彩虹又叫做什么吗?
147 | sparql : select ?x where { <彩虹_(雨后光学现象)> <又称> ?x. }
148 | standard answer : <天虹>
149 |
150 |
151 |
152 | ****************************************************************************************************
153 | question : q206:巴金的中文名是什么?
154 | sparql : select ?x where { <巴金> <中文名> ?x. }
155 | standard answer : "李尧棠"
156 |
157 |
158 |
159 | ****************************************************************************************************
160 | question : q207:海贼王里布鲁克的梦想是什么?
161 | sparql : select ?x where { <布鲁克_(日本动漫《海贼王》中主要人物)> <梦想> ?x. }
162 | standard answer : "回到双子跟鲸鱼拉布重逢"
163 |
164 |
165 |
166 | ****************************************************************************************************
167 | question : q208:你知道铁的熔点吗?
168 | sparql : select ?x where { <铁_(最常见的金属元素)> <熔点> ?x. }
169 | standard answer : "1535℃"
170 |
171 |
172 |
173 | ****************************************************************************************************
174 | question : q209:书籍《假如给我三天光明》是什么体裁?
175 | sparql : select ?x where { <假如给我三天光明> <文学体裁> ?x. }
176 | standard answer : <散文_(文学体裁)>
177 |
178 |
179 |
180 | ****************************************************************************************************
181 | question : q210:上海戏剧学院现任校长是谁?
182 | sparql : select ?y where { <上海戏剧学院> <现任校长> ?y. }
183 | standard answer : <韩生_(中国著名舞美设计家、上海戏剧学院院长)>
184 | '''
185 |
186 | '''
187 |
188 | 正确识别却查询失败
189 |
190 | - * q_text: linux的创始人毕业于哪个学校?
191 | 2020-06-14 09:59:32,694 - INFO qa.py run 40 - * get_candidate_entities 63 ...
192 | 2020-06-14 09:59:32,700 - INFO qa.py run 42 - * get candidate_out_paths: 1279,candidate_in_paths: 1279 ...
193 | 2020-06-14 09:59:32,707 - INFO qa.py run 46 - * get_most_overlap out path:13.9167, ['', '<创始人>', '<毕业院校>'] ...
194 | 2020-06-14 09:59:32,714 - INFO qa.py run 48 - * get_most_overlap in path:13.9167, ['', '<创始人>', '<毕业院校>'] ...
195 | 2020-06-14 09:59:32,714 - INFO qa.py run 56 - * get_most_overlap_path: ['', '<创始人>', '<毕业院校>'] ...
196 | 2020-06-14 09:59:32,714 - INFO neo4j_graph.py search_by_3path 123 - match (target)-[r1:Relation]-()-[r2:Relation]-(ent:Entity) where ent.id=3068935 and r1.name='<创始人>' and r2.name='<毕业院校>' return DISTINCT target.name; *ent_name:
197 | 2020-06-14 09:59:32,731 - INFO neo4j_graph.py search_by_2path 112 - match (ent:Entity)-[r1:Relation]-(target) where ent.id=3068935 and r1.name='<创始人>' return DISTINCT target.name; *ent_name:
198 | 2020-06-14 09:59:32,736 - INFO qa.py run 64 - * cypher answer: ['<林纳斯·托瓦兹>']
199 |
200 |
201 | ------
202 | - * q_text: 易水歌的作者的主要事迹是什么?
203 | 2020-06-14 10:07:50,878 - INFO qa.py run 40 - * get_candidate_entities 27 ...
204 | 2020-06-14 10:07:50,880 - INFO qa.py run 42 - * get candidate_out_paths: 414,candidate_in_paths: 414 ...
205 | 2020-06-14 10:07:50,883 - INFO qa.py run 46 - * get_most_overlap out path:12.9444, ['<易水歌>', '<作者>', '<主要事迹>'] ...
206 | 2020-06-14 10:07:50,885 - INFO qa.py run 48 - * get_most_overlap in path:12.9444, ['<易水歌>', '<作者>', '<主要事迹>'] ...
207 | 2020-06-14 10:07:50,885 - INFO qa.py run 56 - * get_most_overlap_path: ['<易水歌>', '<作者>', '<主要事迹>'] ...
208 | 2020-06-14 10:07:50,885 - INFO neo4j_graph.py search_by_3path 123 - match (target)-[r1:Relation]-()-[r2:Relation]-(ent:Entity) where ent.id=7957014 and r1.name='<作者>' and r2.name='<主要事迹>' return DISTINCT target.name; *ent_name:<易水歌>
209 | 2020-06-14 10:07:50,893 - INFO neo4j_graph.py search_by_2path 112 - match (ent:Entity)-[r1:Relation]-(target) where ent.id=7957014 and r1.name='<作者>' return DISTINCT target.name; *ent_name:<易水歌>
210 | 2020-06-14 10:07:50,896 - INFO qa.py run 64 - * cypher answer: ['<荆轲_(战国时期著名刺客)>']
211 | '''
212 |
--------------------------------------------------------------------------------
/examples/del_method.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class DelMethod(object):
5 |
6 | def __del__(self):
7 | seconds = 10
8 | time.sleep(seconds)
9 | print(f'sleep seconds : {seconds}')
10 |
11 |
12 | def main():
13 | DelMethod()
14 |
15 |
16 | if __name__ == '__main__':
17 | main()
18 |
--------------------------------------------------------------------------------
/examples/kb_data.py:
--------------------------------------------------------------------------------
1 | def add_root_path():
2 | import sys
3 | from pathlib import Path
4 | cur_dir = Path(__file__).absolute().parent
5 | times = 10
6 | while not str(cur_dir).endswith('ccks-2020') and times:
7 | cur_dir = cur_dir.parent
8 | times -= 1
9 | print(cur_dir)
10 | sys.path.append(str(cur_dir))
11 |
12 |
13 | add_root_path()
14 |
15 | import re
16 | from collections import Counter
17 |
18 | import pandas as pd
19 |
20 | from ckbqa.dataset.data_prepare import load_data
21 | from ckbqa.dataset.kb_data_prepare import iter_triples
22 | from ckbqa.utils.tools import json_load, json_dump
23 | from config import mention2ent_txt
24 |
25 |
26 | def triples2csv():
27 | """26G"""
28 | data_df = pd.DataFrame([(head_ent, rel, tail_ent) for head_ent, rel, tail_ent in iter_triples()],
29 | columns=['head', 'rel', 'tail'])
30 | data_df.to_csv('./triples.csv', encoding='utf_8_sig', index=False)
31 |
32 |
33 | def most_ents():
34 | """26G"""
35 |
36 | entities = []
37 | for head_ent, rel, tail_ent in iter_triples():
38 | entities.extend([head_ent, rel, tail_ent])
39 | print(f' len(entities): {len(entities)}')
40 | ent_counter = Counter(entities)
41 | del entities
42 | ent_count = ent_counter.most_common(1000000) # 百万
43 | print(ent_count[:10])
44 | print(ent_count[-10:])
45 | json_dump(dict(ent_counter), 'ent2count.json')
46 | #
47 | print('-----' * 10)
48 | mentions = []
49 | with open(mention2ent_txt, 'r', encoding='utf-8') as f:
50 | for line in f:
51 | mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串
52 | mentions.append(mention)
53 | print(f"len(mentions): {len(mentions)}")
54 | mention_counter = Counter(mentions)
55 | del mentions
56 | mention_count = mention_counter.most_common(1000000) # 百万
57 | print(mention_count[:10])
58 | print(mention_count[-10:])
59 | json_dump(dict(mention_counter), 'mention2count.json')
60 |
61 | import ipdb
62 | ipdb.set_trace()
63 |
64 |
65 | def lac_test():
66 | from LAC import LAC
67 | print('start ...')
68 | mention_count = Counter(json_load('mention2count.json')).most_common(10000)
69 | customization_dict = {mention: 'MENTION' for mention, count in mention_count
70 | if len(mention) >= 2}
71 | print('second ...')
72 | ent_count = Counter(json_load('ent2count.json')).most_common(300000)
73 | ent_pattrn = re.compile(r'["<](.*?)[>"]')
74 | customization_dict.update({' '.join(ent_pattrn.findall(ent)): 'ENT' for ent, count in ent_count})
75 | with open('./customization_dict.txt', 'w') as f:
76 | f.write('\n'.join([f"{e}/{t}" for e, t in customization_dict.items()
77 | if len(e) >= 3]))
78 | import time
79 | before = time.time()
80 | print(f'before ...{before}')
81 | lac = LAC(mode='lac')
82 | lac.load_customization('./customization_dict.txt') # 20万21s;30万47s
83 | lac_raw = LAC(mode='lac')
84 | print(f'after ...{time.time() - before}')
85 | ##
86 | test_count = 10
87 | for q, sparql, a in load_data():
88 | q_text = q.split(':')[1]
89 | print('---' * 10)
90 | print(q_text)
91 | print(sparql)
92 | print(a)
93 | words, tags = lac_raw.run(q_text)
94 | print(list(zip(words, tags)))
95 | words, tags = lac.run(q_text)
96 | print(list(zip(words, tags)))
97 | if not test_count:
98 | break
99 | test_count -= 1
100 | import ipdb
101 | ipdb.set_trace()
102 |
103 |
104 | class Node(object):
105 | def __init__(self, data):
106 | self.data = data
107 | self.ins = set() # 入度节点,关系和实体统一处理
108 | self.outs = set() # 出度节点,关系实体统一处理
109 |
110 |
111 | def create_graph():
112 | nodes = {}
113 | for head_ent, rel, tail_ent in iter_triples():
114 | if head_ent not in nodes:
115 | head_node = Node(head_ent)
116 | else:
117 | head_node = nodes[head_ent]
118 | head_node.outs.add(rel)
119 | #
120 | if rel not in nodes:
121 | rel_node = Node(rel)
122 | else:
123 | rel_node = nodes[rel]
124 | rel_node.ins.add(head_ent)
125 | rel_node.outs.add(tail_ent)
126 | #
127 | if tail_ent not in nodes:
128 | tail_node = Node(tail_ent)
129 | else:
130 | tail_node = nodes[tail_ent]
131 | tail_node.ins.add(rel)
132 |
133 |
134 | if __name__ == '__main__':
135 | # triples2csv()
136 | # most_ents()
137 | lac_test()
138 |
--------------------------------------------------------------------------------
/examples/lac_test.py:
--------------------------------------------------------------------------------
1 | def add_root_path():
2 | import sys
3 | from pathlib import Path
4 | cur_dir = Path(__file__).absolute().parent
5 | times = 10
6 | while not str(cur_dir).endswith('ccks-2020') and times:
7 | cur_dir = cur_dir.parent
8 | times -= 1
9 | print(cur_dir)
10 | sys.path.append(str(cur_dir))
11 |
12 |
13 | add_root_path()
14 |
15 | import logging
16 |
17 | from ckbqa.qa.lac_tools import BaiduLac, JiebaLac
18 | from ckbqa.utils.logger import logging_config
19 | from ckbqa.utils.tools import pkl_dump, ProcessManager
20 | from config import Config
21 |
22 | logging_config('lac_test.log', stream_log=True)
23 |
24 |
25 | def lac_model():
26 | # self.lac_seg = LAC(mode='seg')
27 | logging.info(f' load lac_custom_dict from {Config.lac_custom_dict_txt} start ...')
28 | baidu_lac = BaiduLac(mode='lac', _load_customization=True, reload=True) # 装载LAC模型
29 | save_path = baidu_lac._save_customization() # Config.lac_custom_dict_txt
30 | logging.info(f'load lac_custom_dict done, save to {save_path}...')
31 |
32 |
33 | def lac_test():
34 | logging.info(f' load lac_custom_dict from {Config.lac_custom_dict_txt} start ...')
35 | save_path = './test.pkl'
36 | baidu_lac = BaiduLac(mode='lac', _load_customization=True, reload=True) # 装载LAC模型
37 | logging.info(f'load lac_custom_dict done, save to {save_path}...')
38 |
39 |
40 | def jieba_test():
41 | # ProcessManager().run()
42 | jieba_model = JiebaLac()
43 | q_text1 = '被誉为万岛之国的是哪个国家?'
44 | q_text2 = '支付宝(中国)网络技术有限公司的工商注册号'
45 | for q_text in [q_text1, q_text2]:
46 | res = [x for x in jieba_model.cut_for_search(q_text)]
47 | print(res)
48 | print('-' * 10)
49 | # import ipdb
50 | # ipdb.set_trace()
51 | # pkl_dump(jieba_model, './jieba.pkl')
52 |
53 |
54 | if __name__ == '__main__':
55 | # lac_model()
56 | jieba_test()
57 |
--------------------------------------------------------------------------------
/examples/single_example.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import threading
3 | from functools import wraps
4 |
5 |
6 | def synchronized(func):
7 | func.__lock__ = threading.Lock()
8 |
9 | def lock_func(*args, **kwargs):
10 | with func.__lock__:
11 | return func(*args, **kwargs)
12 |
13 | return lock_func
14 |
15 |
16 | def singleton(cls):
17 | instances = {}
18 | print(f"*********************************** instances: {instances}")
19 |
20 | @synchronized
21 | @wraps(cls)
22 | def get_instance(*args, **kw):
23 | if cls not in instances:
24 | logging.info(f"{cls.__name__} async init start ...")
25 | instances[cls] = cls(*args, **kw)
26 | logging.info(f"{cls.__name__} async init done ...")
27 | return instances[cls]
28 |
29 | return get_instance
30 |
31 |
32 | class Singleton(object):
33 | instance = None
34 |
35 | @synchronized
36 | def __new__(cls, *args, **kwargs):
37 | """
38 | :type kwargs: object
39 | """
40 | if cls.instance is None:
41 | cls.instance = super().__new__(cls)
42 | return cls.instance
43 |
44 | def _init__(self, *args, **kwargs):
45 | pass
46 |
47 | # def __init__(self, num): # 每次初始化后重新设置,重新初始化实例的属性值
48 | # self.a = num + 5
49 | # #
50 | # def printf(self):
51 | # print(self.a)
52 |
53 |
54 | @singleton
55 | class A(object):
56 | def __init__(self, data: int):
57 | super().__init__()
58 | self.data = data
59 |
60 |
61 | class B(Singleton):
62 | def __init__(self, data: int):
63 | super().__init__()
64 | self.data = data ** data
65 |
66 |
67 | def test():
68 | a = Singleton(3)
69 | print(id(a), a, a.a)
70 | b = Singleton(4)
71 | print(id(b), b, b.a)
72 |
73 |
74 | def main():
75 | a = A(2)
76 | print(f"a.data: {a.data}")
77 | a2 = A(4)
78 | print(f'a2.data: {a2.data}')
79 | b = B(3)
80 | print(f'b.data: {b.data}')
81 | b2 = B(6)
82 | print(f'b2.data: {b2.data}')
83 | print(singleton)
84 |
85 |
86 | if __name__ == '__main__':
87 | main()
88 |
--------------------------------------------------------------------------------
/examples/top_path.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def add_root_path():
5 | import sys
6 | from pathlib import Path
7 | cur_dir = Path(__file__).absolute().parent
8 | times = 10
9 | while not str(cur_dir).endswith('ccks-2020') and times:
10 | cur_dir = cur_dir.parent
11 | times -= 1
12 | print(cur_dir)
13 | sys.path.append(str(cur_dir))
14 |
15 |
16 | add_root_path()
17 |
18 |
19 | # ---------------------
20 |
21 |
22 | def get_most_overlap_path(q_text, paths):
23 | # 从排名前几的tuples里选择与问题overlap最多的
24 | max_score = 0
25 | top_path = paths[0]
26 | q_words = set(q_text)
27 | for path in paths:
28 | path_words = set()
29 | for ent_rel in path:
30 | if ent_rel.startswith('<'):
31 | ent_rel = ent_rel[1:-1].split('_')[0]
32 | path_words.update(set(ent_rel))
33 | common_words = path_words & q_words # 交集
34 | score1 = len(common_words)
35 | score2 = len(common_words) / len(path_words) # 跳数 or len(path_words)
36 | score3 = len(common_words) / len(path)
37 | score = score1 + score2 # + score3
38 | if score > max_score:
39 | logging.info(f'score :{score}, score1 :{score1}, score2 :{score2}, score3 :{score3}, '
40 | f'top_path: {top_path}')
41 | top_path = path
42 | max_score = score
43 | return top_path
44 |
45 |
46 | def main():
47 | q='莫妮卡·贝鲁奇的代表作?'
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
--------------------------------------------------------------------------------
/manage.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from ckbqa.utils.logger import logging_config
4 |
5 |
6 | def set_envs(cpu_only, allow_gpus):
7 | import os
8 | import random
9 | rand_seed = 1234
10 | random.seed(rand_seed)
11 | if cpu_only:
12 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
13 | print("CPU only ...")
14 | elif allow_gpus:
15 | from ckbqa.utils.gpu_selector import get_available_gpu
16 | available_gpu = get_available_gpu(num_gpu=1, allow_gpus=allow_gpus) # default allow_gpus 0,1,2,3
17 | os.environ["CUDA_VISIBLE_DEVICES"] = available_gpu
18 | print("* using GPU: {} ".format(available_gpu)) # config前不可logging,否则config失效
19 | # 进程内存管理
20 | # from ckbqa.utils.tools import ProcessManager
21 | # ProcessManager().run()
22 |
23 |
24 | def run(model_name, mode):
25 | logging_config(f'{model_name}_{mode}.log', stream_log=True)
26 | if model_name in ['bert_match', 'bert_match2']:
27 | from ckbqa.models.relation_score.trainer import RelationScoreTrainer
28 | RelationScoreTrainer(model_name).train_match_model()
29 | elif model_name == 'entity_score':
30 | from ckbqa.models.entity_score.model import EntityScore
31 | EntityScore().train()
32 |
33 |
34 | def main():
35 | ''' Parse command line arguments and execute the code
36 | --stream_log, --relative_path, --log_level
37 | --allow_gpus, --cpu_only
38 | '''
39 | parser = argparse.ArgumentParser(description="基础,通用parser")
40 | # logging config 日志配置
41 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕
42 | parser.add_argument('--allow_gpus', default="0,1,2,3", type=str,
43 | help="指定GPU编号,0 or 0,1,2 or 0,1,2...7 | nvidia-smi 查看GPU占用情况")
44 | parser.add_argument('--cpu_only', action="store_true", help="CPU only, not to use GPU ")
45 | #
46 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个
47 | #
48 | all_models = ['bert_match', 'bert_match2', 'entity_score']
49 | group.add_argument('--train', type=str, choices=all_models, help="训练")
50 | group.add_argument('--test', type=str, choices=all_models, help="测试")
51 | # parse args
52 | args = parser.parse_args()
53 | #
54 | set_envs(args.cpu_only, args.allow_gpus) # 设置环境变量等
55 | #
56 | if args.train:
57 | model_name = args.train
58 | mode = 'train'
59 | elif args.test:
60 | model_name = args.test
61 | mode = 'test'
62 | else:
63 | raise ValueError('must set model name ')
64 | run(model_name, mode)
65 |
66 |
67 | if __name__ == '__main__':
68 | """ 代码执行入口
69 | examples:
70 | nohup python manage.py --train bert_match &>bert_match.out&
71 | nohup python manage.py --train entity_score &>entity_score.out&
72 | """
73 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况
74 | # ProcessManager().run()
75 | main()
76 |
--------------------------------------------------------------------------------
/port_map.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | ssh -f wangshengguang@remotehost -N -L 7474:localhost:7474
3 | ssh -f wangshengguang@remotehost -N -L 7687:localhost:7687
4 |
5 |
--------------------------------------------------------------------------------
/qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import traceback
4 |
5 | import pandas as pd
6 | from tqdm import tqdm
7 |
8 | from ckbqa.utils.logger import logging_config
9 | from config import valid_question_txt, ResultSaver
10 |
11 |
12 | def train_qa():
13 | """训练数据进行回答;做指标测试"""
14 | logging_config('train_qa.log', stream_log=True)
15 | from ckbqa.qa.qa import QA
16 | from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern
17 | # from ckbqa.qa.evaluation_matrics import get_metrics
18 | #
19 | logging.info('* start run ...')
20 | qa = QA()
21 | data = []
22 | for question, sparql, answer in load_data(tqdm_prefix='test qa'):
23 | print('\n' * 2)
24 | print('*****' * 20)
25 | print(f" question : {question}")
26 | print(f" sparql : {sparql}")
27 | print(f" standard answer : {answer}")
28 | q_text = question_patten.findall(question)[0]
29 | standard_subject_entities = entity_pattern.findall(sparql) + attr_pattern.findall(sparql)
30 | standard_answer_entities = entity_pattern.findall(answer) + attr_pattern.findall(answer)
31 | try:
32 | (result_entities, candidate_entities,
33 | candidate_out_paths, candidate_in_paths) = qa.run(q_text, return_candidates=True)
34 | except KeyboardInterrupt:
35 | exit('Ctrl C , exit')
36 | except:
37 | logging.info(f'ERROR: {traceback.format_exc()}')
38 | result_entities = []
39 | candidate_entities = []
40 | # print(f" result answer : {result_entities}")
41 | # precision, recall, f1 = get_metrics(subject_entities, candidate_entities)
42 | # if recall == 0 or len(set(standard_entities) & set(candidate_entities)) == 0:
43 | # print(f"question: {question}\n"
44 | # f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}"
45 | # f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
46 | # import ipdb
47 | # ipdb.set_trace()
48 | data.append(
49 | [question, standard_subject_entities, list(candidate_entities), standard_answer_entities, result_entities])
50 | data_df = pd.DataFrame(data, columns=['question', 'standard_subject_entities', 'candidate_entities',
51 | 'standard_answer_entities', 'result_entities'])
52 | data_df.to_csv(ResultSaver().train_result_csv, index=False, encoding='utf_8_sig')
53 |
54 |
55 | def valid_qa():
56 | """验证数据答案;验证集做回答;得到提交数据 """
57 | logging_config('valid_qa.log', stream_log=True)
58 | from ckbqa.qa.qa import QA
59 | from ckbqa.dataset.data_prepare import question_patten
60 | #
61 | data_df = pd.DataFrame([question.strip() for question in open(valid_question_txt, 'r', encoding='utf-8')],
62 | columns=['question'])
63 | logging.info(f"data_df.shape: {data_df.shape}")
64 | qa = QA()
65 | valid_datas = {'question': [], 'result': []}
66 | with open(ResultSaver().submit_result_txt, 'w', encoding='utf-8') as f:
67 | for index, row in tqdm(data_df.iterrows(), total=data_df.shape[0], desc='qa answer'):
68 | question = row['question']
69 | q_text = question_patten.findall(question)[0]
70 | try:
71 | result_entities = qa.run(q_text)
72 | result_entities = [ent if ent.startswith('<') else f'"{ent}"' for ent in result_entities]
73 | f.write('\t'.join(result_entities) + '\n')
74 | except KeyboardInterrupt:
75 | exit('Ctrl C , exit')
76 | except:
77 | result_entities = []
78 | f.write('""\n')
79 | logging.info(traceback.format_exc())
80 | f.flush()
81 | valid_datas['question'].append(question)
82 | valid_datas['result'].append(result_entities)
83 | pd.DataFrame(valid_datas).to_csv(ResultSaver().valid_result_csv, index=False, encoding='utf_8_sig')
84 |
85 |
86 | def valid2submit():
87 | '''
88 | 验证数据输出转化为答案提交
89 | '''
90 | logging_config('valid2submit.log', stream_log=True)
91 | data_df = pd.read_csv(ResultSaver(find_exist_path=True).valid_result_csv[0])
92 | with open(ResultSaver().submit_result_txt, 'w', encoding='utf-8') as f:
93 | for index, row in data_df.iterrows():
94 | ents = []
95 | for ent in eval(row['result'])[:10]: # 只保留前10个
96 | if ent.startswith('<'):
97 | ents.append(ent)
98 | elif ent.startswith('"'):
99 | ent = ent.strip('"')
100 | ents.append(f'"{ent}"')
101 | if ents:
102 | f.write('\t'.join(ents) + '\n')
103 | else:
104 | f.write('""\n')
105 | f.flush()
106 |
107 |
108 | def test():
109 | logging_config('test.log', stream_log=True)
110 | from ckbqa.qa.qa import QA
111 | from ckbqa.dataset.data_prepare import question_patten
112 | # from ckbqa.qa.evaluation_matrics import get_metrics
113 | #
114 | logging.info('* start run ...')
115 | qa = QA()
116 | q188 = 'q188:墨冰仙是哪个门派的?'
117 | q189 = 'q189:在金庸小说《天龙八部》中,斗转星移的修习者是谁?'
118 | q190 = 'q190:《基督山伯爵》的作者是谁?' # 这里跑没问题
119 | for question in [q188, q189, q190]:
120 | q_text = question_patten.findall(question)[0]
121 | (result_entities, candidate_entities,
122 | candidate_out_paths, candidate_in_paths) = qa.run(q_text, return_candidates=True)
123 | print(question)
124 | import ipdb
125 | ipdb.set_trace()
126 |
127 |
128 | def main():
129 | parser = argparse.ArgumentParser(description="基础,通用parser")
130 | # logging config 日志配置
131 | parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕
132 | #
133 | group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个
134 |
135 | group.add_argument('--train_qa', action="store_true", help="训练集答案")
136 | group.add_argument('--valid_qa', action="store_true", help="验证数据(待提交数据)答案")
137 | group.add_argument('--valid2submit', action="store_true", help="验证数据处理后提交")
138 | group.add_argument('--task', action="store_true", help="临时组织的任务,调用多个函数")
139 | group.add_argument('--test', action="store_true", help="临时测试")
140 | # parse args
141 | args = parser.parse_args()
142 | #
143 | # from ckbqa.utils.tools import ProcessManager
144 | # ProcessManager().run()
145 | if args.train_qa:
146 | train_qa()
147 | elif args.valid_qa:
148 | valid_qa()
149 | elif args.valid2submit:
150 | valid2submit()
151 | elif args.test:
152 | test()
153 | elif args.task:
154 | task()
155 |
156 |
157 | def task():
158 | logging_config('qa_task.log', stream_log=True)
159 | train_qa()
160 | valid_qa()
161 |
162 |
163 | if __name__ == '__main__':
164 | """
165 | example:
166 | nohup python qa.py --train_qa &>train_qa.out &
167 | nohup python qa.py --valid_qa &>valid_qa.out &
168 | nohup python qa.py --valid2submit &>valid2submit.out &
169 | nohup python qa.py --task &>qa_task.out &
170 | """
171 | # from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况
172 | # ProcessManager().run()
173 | main()
174 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShengguang/ccks-2020/d77f2de91284efb11e314fea21b4a3982ab78554/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import hanlp
4 | import jieba
5 |
6 |
7 | def add_root_path():
8 | import sys
9 | from pathlib import Path
10 | cur_dir = Path(__file__).absolute().parent
11 | times = 10
12 | while not str(cur_dir).endswith('ccks-2020') and times:
13 | cur_dir = cur_dir.parent
14 | times -= 1
15 | print(cur_dir)
16 | sys.path.append(str(cur_dir))
17 |
18 |
19 | add_root_path()
20 |
21 | from ckbqa.dataset.data_prepare import load_data, question_patten
22 | from ckbqa.utils.logger import logging_config
23 | from ckbqa.utils.tools import json_load
24 |
25 | #
26 | logging_config('test.log', stream_log=True)
27 |
28 | from config import Config
29 |
30 |
31 | def ner_train_data():
32 | ent2mention = json_load(Config.ent2mention_json)
33 | # recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
34 | tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
35 | recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
36 | from LAC import LAC
37 | # 装载LAC模型
38 | lac = LAC(mode='lac')
39 | jieba.enable_paddle() # 启动paddle模式。 0.40版之后开始支持,早期版本不支持
40 | _ent_patten = re.compile(r'["<](.*?)[>"]')
41 | for q, sparql, a in load_data():
42 | q_text = question_patten.findall(q)[0]
43 | hanlp_entities = recognizer([list(q_text)])
44 | hanlp_words = tokenizer(q_text)
45 | lac_results = lac.run(q_text)
46 | q_entities = _ent_patten.findall(sparql)
47 | jieba_results = list(jieba.cut_for_search(q_text))
48 | mentions = [ent2mention.get(ent) for ent in q_entities]
49 | print(f"q_text: {q_text}\nq_entities: {q_entities}, "
50 | f"\nlac_results:{lac_results}"
51 | f"\nhanlp_words: {hanlp_words}, "
52 | f"\njieba_results: {jieba_results}, "
53 | f"\nhanlp_entities: {hanlp_entities}, "
54 | f"\nmentions: {mentions}")
55 | import ipdb
56 | ipdb.set_trace()
57 |
58 |
59 | if __name__ == '__main__':
60 | ner_train_data()
61 |
--------------------------------------------------------------------------------
/tests/test_ceg.py:
--------------------------------------------------------------------------------
1 | def add_root_path():
2 | import sys
3 | from pathlib import Path
4 | cur_dir = Path(__file__).absolute().parent
5 | times = 10
6 | while not str(cur_dir).endswith('ccks-2020') and times:
7 | cur_dir = cur_dir.parent
8 | times -= 1
9 | print(cur_dir)
10 | sys.path.append(str(cur_dir))
11 |
12 |
13 | add_root_path()
14 | import logging
15 |
16 | from ckbqa.dataset.data_prepare import load_data
17 | from ckbqa.qa.el import EL
18 | from ckbqa.utils.logger import logging_config
19 |
20 | logging_config('test.log', stream_log=True)
21 |
22 |
23 | def test_recgnizer():
24 | print('start')
25 | logging.info('test start ...')
26 | el = EL()
27 | for q, sparql, a in load_data():
28 | q_text = q.split(':')[1]
29 | rec_entities = el.ceg.get_ent2mention(q_text)
30 | print(rec_entities)
31 | print(q)
32 | print(sparql)
33 | print(a)
34 | import ipdb
35 | ipdb.set_trace()
36 |
37 |
38 | if __name__ == '__main__':
39 | test_recgnizer()
40 |
--------------------------------------------------------------------------------