├── .ipynb_checkpoints
└── ccks2020-baseline-checkpoint.ipynb
├── 20220716-jd.png
├── Dockerfile
├── build-shell.sh
├── ccks2020-baseline.ipynb
├── ccks2020-datagrand-qq.png
├── readme.md
├── requirements.txt
├── the-land-of-future.png
├── 封底.jpeg
└── 封面.png
/.ipynb_checkpoints/ccks2020-baseline-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# CCKS 2020: 基于本体的金融知识图谱自动化构建技术评测\n",
8 | "\n",
9 | "竞赛背景\n",
10 | "金融研报是各类金融研究结构对宏观经济、金融、行业、产业链以及公司的研究报告。报告通常是有专业人员撰写,对宏观、行业和公司的数据信息搜集全面、研究深入,质量高,内容可靠。报告内容往往包含产业、经济、金融、政策、社会等多领域的数据与知识,是构建行业知识图谱非常关键的数据来源。另一方面,由于研报本身所容纳的数据与知识涉及面广泛,专业知识众多,不同的研究结构和专业认识对相同的内容的表达方式也会略有差异。这些特点导致了从研报自动化构建知识图谱困难重重,解决这些问题则能够极大促进自动化构建知识图谱方面的技术进步。\n",
11 | " \n",
12 | "本评测任务参考 TAC KBP 中的 Cold Start 评测任务的方案,围绕金融研报知识图谱的自动化图谱构建所展开。评测从预定义图谱模式(Schema)和少量的种子知识图谱开始,从非结构化的文本数据中构建知识图谱。其中图谱模式包括 10 种实体类型,如机构、产品、业务、风险等;19 个实体间的关系,如(机构,生产销售,产品)、(机构,投资,机构)等;以及若干实体类型带有属性,如(机构,英文名)、(研报,评级)等。在给定图谱模式和种子知识图谱的条件下,评测内容为自动地从研报文本中抽取出符合图谱模式的实体、关系和属性值,实现金融知识图谱的自动化构建。所构建的图谱在大金融行业、监管部门、政府、行业研究机构和行业公司等应用非常广泛,如风险监测、智能投研、智能监管、智能风控等,具有巨大的学术价值和产业价值。\n",
13 | " \n",
14 | "评测本身不限制各参赛队伍使用的模型、算法和技术。希望各参赛队伍发挥聪明才智,构建各类无监督、弱监督、远程监督、半监督等系统,迭代的实现知识图谱的自动化构建,共同促进知识图谱技术的进步。\n",
15 | "\n",
16 | "竞赛任务\n",
17 | "本评测任务参考 TAC KBP 中的 Cold Start 评测任务的方案,围绕金融研报知识图谱的自动化图谱构建所展开。评测从预定义图谱模式(Schema)和少量的种子知识图谱开始,从非结构化的文本数据中构建知识图谱。评测本身不限制各参赛队伍使用的模型、算法和技术。希望各参赛队伍发挥聪明才智,构建各类无监督、弱监督、远程监督、半监督等系统,迭代的实现知识图谱的自动化构建,共同促进知识图谱技术的进步。\n",
18 | "\n",
19 | "主办方邮箱 wangwenguang@datagrand.com kdd.wang@gmail.com\n",
20 | "\n",
21 | "\n",
22 | "参考:https://www.biendata.com/competition/ccks_2020_5/"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "import json\n",
32 | "import logging\n",
33 | "import os\n",
34 | "import random\n",
35 | "import re\n",
36 | "import base64\n",
37 | "from collections import defaultdict\n",
38 | "from pathlib import Path\n",
39 | "\n",
40 | "# import attr\n",
41 | "import tqdm\n",
42 | "import hanlp\n",
43 | "import numpy as np\n",
44 | "import torch\n",
45 | "import torch.optim\n",
46 | "import torch.utils.data\n",
47 | "from torch.nn import functional as F\n",
48 | "from torchcrf import CRF\n",
49 | "from pytorch_transformers import BertModel, BertTokenizer\n",
50 | "import jieba\n",
51 | "from jieba.analyse.tfidf import TFIDF\n",
52 | "from jieba.posseg import POSTokenizer\n",
53 | "import jieba.posseg as pseg\n",
54 | "from itertools import product\n",
55 | "from IPython.display import HTML\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# 预处理函数\n",
63 | "\n",
64 | "对文章进行预处理,切分句子和子句等"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "def split_to_sents(content, filter_length=(2, 1000)):\n",
74 | " content = re.sub(r\"\\s*\", \"\", content)\n",
75 | " content = re.sub(\"([。!…??!;;])\", \"\\\\1\\1\", content)\n",
76 | " sents = content.split(\"\\1\")\n",
77 | " sents = [_[: filter_length[1]] for _ in sents]\n",
78 | " return [_ for _ in sents\n",
79 | " if filter_length[0] <= len(_) <= filter_length[1]]\n",
80 | "\n",
81 | "def split_to_subsents(content, filter_length=(2, 1000)):\n",
82 | " content = re.sub(r\"\\s*\", \"\", content)\n",
83 | " content = re.sub(\"([。!…??!;;,,])\", \"\\\\1\\1\", content)\n",
84 | " sents = content.split(\"\\1\")\n",
85 | " sents = [_[: filter_length[1]] for _ in sents]\n",
86 | " return [_ for _ in sents\n",
87 | " if filter_length[0] <= len(_) <= filter_length[1]]"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "def read_json(file_path):\n",
97 | " with open(file_path, mode='r', encoding='utf8') as f:\n",
98 | " return json.load(f)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "# 预训练模型配置\n",
106 | "\n",
107 | "参考 https://github.com/huggingface/pytorch-transformers 下载预训练模型,并配置下面参数为相关路径\n",
108 | "\n",
109 | "```python\n",
110 | "PRETRAINED_BERT_MODEL_DIR = '/you/path/to/bert-base-chinese/' \n",
111 | "```"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "PRETRAINED_BERT_MODEL_DIR = '/home/wangwenguang/bigdata/wke-data/pretrained_models/bert-base-chinese/' "
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "# 一些参数"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "DATA_DIR = './data' # 输入数据文件夹\n",
137 | "OUT_DIR = './output' # 输出文件夹\n",
138 | "\n",
139 | "Path(OUT_DIR).mkdir(exist_ok=True)\n",
140 | "\n",
141 | "BATCH_SIZE = 32\n",
142 | "TOTAL_EPOCH_NUMS = 10\n",
143 | "if torch.cuda.is_available():\n",
144 | " DEVICE = 'cuda:0'\n",
145 | "else:\n",
146 | " DEVICE = 'cpu'\n",
147 | "YANBAO_DIR_PATH = str(Path(DATA_DIR, 'yanbao_txt'))\n",
148 | "SAVE_MODEL_DIR = str(OUT_DIR)"
149 | ]
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "metadata": {},
154 | "source": [
155 | "## 读入原始数据\n",
156 | "\n",
157 | "- 读入:所有研报内容\n",
158 | "- 读入:原始训练实体数据"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "yanbao_texts = []\n",
168 | "for yanbao_file_path in Path(YANBAO_DIR_PATH).glob('*.txt'):\n",
169 | " with open(yanbao_file_path) as f:\n",
170 | " yanbao_texts.append(f.read())\n",
171 | "# if len(yanbao_texts) == 10:\n",
172 | "# break\n",
173 | "\n",
174 | "# 来做官方的实体训练集,后续会混合来自第三方工具,规则,训练数据来扩充模型训练数据\n",
175 | "to_be_trained_entities = read_json(Path(DATA_DIR, 'entities.json'))"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "# 用hanlp进行实体识别\n",
183 | "\n",
184 | "hanlp支持对人物、机构的实体识别,可以使用它来对其中的两个实体类型进行识别:人物、机构。\n",
185 | "\n",
186 | "hanlp见[https://github.com/hankcs/HanLP](https://github.com/hankcs/HanLP)"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "## NER by third party tool\n",
196 | "class HanlpNER:\n",
197 | " def __init__(self):\n",
198 | " self.recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)\n",
199 | " self.max_sent_len = 126\n",
200 | " self.ent_type_map = {\n",
201 | " 'NR': '人物',\n",
202 | " 'NT': '机构'\n",
203 | " }\n",
204 | " self.black_list = {'公司'}\n",
205 | "\n",
206 | " def recognize(self, sent):\n",
207 | " entities_dict = {}\n",
208 | " for result in self.recognizer.predict([list(sent)]):\n",
209 | " for entity, hanlp_ent_type, _, _ in result:\n",
210 | " if not re.findall(r'^[\\.\\s\\da-zA-Z]{1,2}$', entity) and \\\n",
211 | " len(entity) > 1 and entity not in self.black_list \\\n",
212 | " and hanlp_ent_type in self.ent_type_map:\n",
213 | " entities_dict.setdefault(self.ent_type_map[hanlp_ent_type], []).append(entity)\n",
214 | " return entities_dict"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "def aug_entities_by_third_party_tool():\n",
224 | " hanlpner = HanlpNER()\n",
225 | " entities_by_third_party_tool = defaultdict(list)\n",
226 | " for file in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))[:]):\n",
227 | " with open(file, encoding='utf-8') as f:\n",
228 | " sents = [[]]\n",
229 | " cur_sent_len = 0\n",
230 | " for line in f:\n",
231 | " for sent in split_to_subsents(line):\n",
232 | " sent = sent[:hanlpner.max_sent_len]\n",
233 | " if cur_sent_len + len(sent) > hanlpner.max_sent_len:\n",
234 | " sents.append([sent])\n",
235 | " cur_sent_len = len(sent)\n",
236 | " else:\n",
237 | " sents[-1].append(sent)\n",
238 | " cur_sent_len += len(sent)\n",
239 | " sents = [''.join(_) for _ in sents]\n",
240 | " sents = [_ for _ in sents if _]\n",
241 | " for sent in sents:\n",
242 | " entities_dict = hanlpner.recognize(sent)\n",
243 | " for ent_type, ents in entities_dict.items():\n",
244 | " entities_by_third_party_tool[ent_type] += ents\n",
245 | "\n",
246 | " for ent_type, ents in entities_by_third_party_tool.items():\n",
247 | " entities_by_third_party_tool[ent_type] = list([ent for ent in set(ents) if len(ent) > 1])\n",
248 | " return entities_by_third_party_tool"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {
255 | "scrolled": true
256 | },
257 | "outputs": [],
258 | "source": [
259 | "# 此任务十分慢, 但是只需要运行一次\n",
260 | "entities_by_third_party_tool = aug_entities_by_third_party_tool()\n",
261 | "for ent_type, ents in entities_by_third_party_tool.items():\n",
262 | " to_be_trained_entities[ent_type] = list(set(to_be_trained_entities[ent_type] + ents))"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {
269 | "scrolled": true
270 | },
271 | "outputs": [],
272 | "source": [
273 | "for k, v in entities_by_third_party_tool.items():\n",
274 | " print(k)\n",
275 | " print(set(v))"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "metadata": {},
281 | "source": [
282 | "## 通过规则抽取实体\n",
283 | "\n",
284 | "- 机构\n",
285 | "- 研报\n",
286 | "- 文章\n",
287 | "- 风险"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "def aug_entities_by_rules(yanbao_dir):\n",
297 | " entities_by_rule = defaultdict(list)\n",
298 | " for file in list(yanbao_dir.glob('*.txt'))[:]:\n",
299 | " with open(file, encoding='utf-8') as f:\n",
300 | " found_yanbao = False\n",
301 | " found_fengxian = False\n",
302 | " for lidx, line in enumerate(f):\n",
303 | " # 公司的标题\n",
304 | " ret = re.findall('^[\\((]*[\\d一二三四五六七八九十①②③④⑤]*[\\))\\.\\s]*(.*有限公司)$', line)\n",
305 | " if ret:\n",
306 | " entities_by_rule['机构'].append(ret[0])\n",
307 | " \n",
308 | " # 研报\n",
309 | " if not found_yanbao and lidx <= 5 and len(line) > 10:\n",
310 | " may_be_yanbao = line.strip()\n",
311 | " if not re.findall(r'\\d{4}\\s*[年-]\\s*\\d{1,2}\\s*[月-]\\s*\\d{1,2}\\s*日?', may_be_yanbao) \\\n",
312 | " and not re.findall('^[\\d一二三四五六七八九十]+\\s*[\\.、]\\s*.*$', may_be_yanbao) \\\n",
313 | " and not re.findall('[\\((]\\d+\\.*[A-Z]*[\\))]', may_be_yanbao) \\\n",
314 | " and len(may_be_yanbao) > 5:\n",
315 | " entities_by_rule['研报'].append(may_be_yanbao)\n",
316 | " found_yanbao = True\n",
317 | "\n",
318 | " # 文章\n",
319 | " for sent in split_to_sents(line):\n",
320 | " results = re.findall('《(.*?)》', sent)\n",
321 | " for result in results:\n",
322 | " entities_by_rule['文章'].append(result) \n",
323 | "\n",
324 | " # 风险\n",
325 | " for sent in split_to_sents(line):\n",
326 | " if found_fengxian:\n",
327 | " sent = sent.split(':')[0]\n",
328 | " fengxian_entities = re.split('以及|、|,|;|。', sent)\n",
329 | " fengxian_entities = [re.sub('^[■]+[\\d一二三四五六七八九十①②③④⑤]+', '', ent) for ent in fengxian_entities]\n",
330 | " fengxian_entities = [re.sub('^[\\((]*[\\d一二三四五六七八九十①②③④⑤]+[\\))\\.\\s]+', '', ent) for ent in fengxian_entities]\n",
331 | " fengxian_entities = [_ for _ in fengxian_entities if len(_) >=4]\n",
332 | " entities_by_rule['风险'] += fengxian_entities\n",
333 | " found_fengxian = False\n",
334 | " if not found_fengxian and re.findall('^\\s*[\\d一二三四五六七八九十]*\\s*[\\.、]*\\s*风险提示[::]*$', sent):\n",
335 | " found_fengxian = True\n",
336 | " \n",
337 | " results = re.findall('^\\s*[\\d一二三四五六七八九十]*\\s*[\\.、]*\\s*风险提示[::]*(.{5,})$', sent)\n",
338 | " if results:\n",
339 | " fengxian_entities = re.split('以及|、|,|;|。', results[0])\n",
340 | " fengxian_entities = [re.sub('^[■]+[\\d一二三四五六七八九十①②③④⑤]+', '', ent) for ent in fengxian_entities]\n",
341 | " fengxian_entities = [re.sub('^[\\((]*[\\d一二三四五六七八九十①②③④⑤]+[\\))\\.\\s]+', '', ent) for ent in fengxian_entities]\n",
342 | " fengxian_entities = [_ for _ in fengxian_entities if len(_) >=4]\n",
343 | " entities_by_rule['风险'] += fengxian_entities\n",
344 | " \n",
345 | " for ent_type, ents in entities_by_rule.items():\n",
346 | " entities_by_rule[ent_type] = list(set(ents))\n",
347 | " return entities_by_rule"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "# 通过规则来寻找新的实体\n",
357 | "entities_by_rule = aug_entities_by_rules(Path(DATA_DIR, 'yanbao_txt'))\n",
358 | "for ent_type, ents in entities_by_rule.items():\n",
359 | " to_be_trained_entities[ent_type] = list(set(to_be_trained_entities[ent_type] + ents))"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "scrolled": true
367 | },
368 | "outputs": [],
369 | "source": [
370 | "for k, v in entities_by_rule.items():\n",
371 | " print(k)\n",
372 | " print(set(v))"
373 | ]
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "metadata": {},
378 | "source": [
379 | "# 定义NER模型\n"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": null,
385 | "metadata": {},
386 | "outputs": [],
387 | "source": [
388 | "class BertCRF(torch.nn.Module):\n",
389 | " def __init__(self, pretrained_bert_model_file_path, num_tags: int, batch_first: bool = False, hidden_size=768):\n",
390 | " super(BertCRF, self).__init__()\n",
391 | " self.bert_module = BertModel.from_pretrained(pretrained_bert_model_file_path)\n",
392 | " self.tag_linear = torch.nn.Linear(hidden_size, num_tags)\n",
393 | " self.crf_module = CRF(num_tags, batch_first)\n",
394 | "\n",
395 | " def forward(self,\n",
396 | " inputs_ids,\n",
397 | " tags,\n",
398 | " mask = None,\n",
399 | " token_type_ids=None,\n",
400 | " reduction = 'mean'\n",
401 | " ) -> torch.Tensor:\n",
402 | " bert_outputs = self.bert_module.forward(inputs_ids, attention_mask=mask, token_type_ids=token_type_ids)[0]\n",
403 | " bert_outputs = F.dropout(bert_outputs, p=0.2, training=self.training)\n",
404 | " bert_outputs = self.tag_linear(bert_outputs)\n",
405 | " score = -self.crf_module.forward(bert_outputs, tags=tags, mask=mask, reduction=reduction)\n",
406 | " return score\n",
407 | "\n",
408 | " def decode(self,\n",
409 | " input_ids,\n",
410 | " attention_mask=None,\n",
411 | " token_type_ids=None\n",
412 | " ):\n",
413 | " bert_outputs = self.bert_module.forward(\n",
414 | " input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids\n",
415 | " )[0]\n",
416 | " bert_outputs = self.tag_linear(bert_outputs)\n",
417 | " best_tags_list = self.crf_module.decode(bert_outputs, mask=attention_mask)\n",
418 | " return best_tags_list\n"
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {},
424 | "source": [
425 | "## 定义NER标签"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "chinese_entity_type_vs_english_entity_type = {\n",
435 | " '人物': 'People',\n",
436 | " '行业': 'Industry',\n",
437 | " '业务': 'Business',\n",
438 | " '研报': 'Report',\n",
439 | " '机构': 'Organization',\n",
440 | " '风险': 'Risk',\n",
441 | " '文章': 'Article',\n",
442 | " '指标': 'Indicator',\n",
443 | " '品牌': 'Brand',\n",
444 | " '产品': 'Product'\n",
445 | "}\n",
446 | "english_entity_type_vs_chinese_entity_type = {v: k for k, v in chinese_entity_type_vs_english_entity_type.items()}\n",
447 | "\n",
448 | "START_TAG = \"[CLS]\"\n",
449 | "END_TAG = \"[SEP]\"\n",
450 | "O = \"O\"\n",
451 | "\n",
452 | "BPeople = \"B-People\"\n",
453 | "IPeople = \"I-People\"\n",
454 | "BIndustry = \"B-Industry\"\n",
455 | "IIndustry = \"I-Industry\"\n",
456 | "BBusiness = 'B-Business'\n",
457 | "IBusiness = 'I-Business'\n",
458 | "BProduct = 'B-Product'\n",
459 | "IProduct = 'I-Product'\n",
460 | "BReport = 'B-Report'\n",
461 | "IReport = 'I-Report'\n",
462 | "BOrganization = 'B-Organization'\n",
463 | "IOrganization = 'I-Organization'\n",
464 | "BRisk = 'B-Risk'\n",
465 | "IRisk = 'I-Risk'\n",
466 | "BArticle = 'B-Article'\n",
467 | "IArticle = 'I-Article'\n",
468 | "BIndicator = 'B-Indicator'\n",
469 | "IIndicator = 'I-Indicator'\n",
470 | "BBrand = 'B-Brand'\n",
471 | "IBrand = 'I-Brand'\n",
472 | "\n",
473 | "PAD = \"[PAD]\"\n",
474 | "UNK = \"[UNK]\"\n",
475 | "tag2idx = {\n",
476 | " START_TAG: 0,\n",
477 | " END_TAG: 1,\n",
478 | " O: 2,\n",
479 | " BPeople: 3,\n",
480 | " IPeople: 4,\n",
481 | " BIndustry: 5,\n",
482 | " IIndustry: 6,\n",
483 | " BBusiness: 7,\n",
484 | " IBusiness: 8,\n",
485 | " BProduct: 9,\n",
486 | " IProduct: 10,\n",
487 | " BReport: 11,\n",
488 | " IReport: 12,\n",
489 | " BOrganization: 13,\n",
490 | " IOrganization: 14,\n",
491 | " BRisk: 15,\n",
492 | " IRisk: 16,\n",
493 | " BArticle: 17,\n",
494 | " IArticle: 18,\n",
495 | " BIndicator: 19,\n",
496 | " IIndicator: 20,\n",
497 | " BBrand: 21,\n",
498 | " IBrand: 22,\n",
499 | "}\n",
500 | "tag2id = tag2idx\n",
501 | "idx2tag = {v: k for k, v in tag2idx.items()}"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {},
507 | "source": [
508 | "## 预处理数据函数\n",
509 | "\n",
510 | "`preprocess_data` 函数中的 `for_train` 参数比较重要,指示是否是训练集\n",
511 | "\n",
512 | "由于给定的训练数据实体部分没有给定出现的位置,这里需要自行查找到实体出现的位置\n",
513 | "\n",
514 | "- 如果是训练集, 按照`entities_json`中的内容在文章中寻找位置并标注, 并将训练数据处理成bio形式\n",
515 | "- 测试数据仅仅做了分句并转化成token id"
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": null,
521 | "metadata": {},
522 | "outputs": [],
523 | "source": [
524 | "class Article:\n",
525 | " def __init__(self, text):\n",
526 | " self._text = text\n",
527 | " self.para_texts = self.split_into_paras(self._text)\n",
528 | " self.sent_texts = [self.split_into_sentence(para) for para in self.para_texts]\n",
529 | "\n",
530 | " def fix_text(self, text: str) -> str:\n",
531 | " paras = text.split('\\n')\n",
532 | " paras = list(filter(lambda para: len(para.strip()) != 0, paras))\n",
533 | " return '\\n'.join(paras)\n",
534 | "\n",
535 | " def split_into_paras(self, text: str):\n",
536 | " paras = list(filter(lambda para: len(para.strip()) != 0, text.split('\\n')))\n",
537 | " return paras\n",
538 | "\n",
539 | " def split_into_sentence(self, one_para_text: str, splited_puncs = None):\n",
540 | " if splited_puncs is None:\n",
541 | " splited_puncs = ['。', '?', '!']\n",
542 | " splited_re_pattern = '[' + ''.join(splited_puncs) + ']'\n",
543 | "\n",
544 | " para = one_para_text\n",
545 | " sentences = re.split(splited_re_pattern, para)\n",
546 | " sentences = list(filter(lambda sent: len(sent) != 0, sentences))\n",
547 | "\n",
548 | " return sentences\n",
549 | "\n",
550 | " def find_sents_by_entity_name(self, entity_text):\n",
551 | " ret_sents = []\n",
552 | " if entity_text not in self._text:\n",
553 | " return []\n",
554 | " else:\n",
555 | " for para in self.split_into_paras(self._text):\n",
556 | " if entity_text not in para:\n",
557 | " continue\n",
558 | " else:\n",
559 | " for sent in self.split_into_sentence(para):\n",
560 | " if entity_text in sent:\n",
561 | " ret_sents.append(sent)\n",
562 | " return ret_sents"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": null,
568 | "metadata": {},
569 | "outputs": [],
570 | "source": [
571 | "def _find_all_start_end(source, target):\n",
572 | " if not target:\n",
573 | " return []\n",
574 | " occurs = []\n",
575 | " offset = 0\n",
576 | " while offset < len(source):\n",
577 | " found = source[offset:].find(target)\n",
578 | " if found == -1:\n",
579 | " break\n",
580 | " else:\n",
581 | " occurs.append([offset + found, offset + found + len(target) - 1])\n",
582 | " offset += (found + len(target))\n",
583 | " return occurs\n",
584 | "\n",
585 | "def preprocess_data(entities_json,\n",
586 | " article_texts,\n",
587 | " tokenizer: BertTokenizer,\n",
588 | " for_train: bool = True):\n",
589 | " \"\"\"\n",
590 | " [{\n",
591 | " 'sent': xxx, 'entity_name': yyy, 'entity_type': zzz, 'start_token_id': 0, 'end_token_id': 5,\n",
592 | " 'start_index': 0, 'end_index': 2, \n",
593 | " 'sent_tokens': ['token1', 'token2'], 'entity_tokens': ['token3', 'token4']\n",
594 | " }]\n",
595 | " \"\"\"\n",
596 | "\n",
597 | " preprocessed_datas = []\n",
598 | "\n",
599 | " all_sents = []\n",
600 | " for article in tqdm.tqdm([Article(t) for t in article_texts]):\n",
601 | " for para_text in article.para_texts:\n",
602 | " for sent in article.split_into_sentence(para_text):\n",
603 | " sent_tokens = list(sent)\n",
604 | " entity_labels = []\n",
605 | " for entity_type, entities in entities_json.items():\n",
606 | " for entity_name in entities:\n",
607 | " if entity_name not in sent:\n",
608 | " continue\n",
609 | " all_sents.append(sent)\n",
610 | " start_end_indexes = _find_all_start_end(sent, entity_name)\n",
611 | " assert len(start_end_indexes) >= 1\n",
612 | " for str_start_index, str_end_index in start_end_indexes:\n",
613 | " entity_tokens = list(entity_name)\n",
614 | "\n",
615 | " one_entity_label = {\n",
616 | " 'entity_type': entity_type,\n",
617 | " 'start_token_id': str_start_index,\n",
618 | " 'end_token_id': str_end_index,\n",
619 | " 'start_index': str_start_index,\n",
620 | " 'end_index': str_end_index,\n",
621 | " 'entity_tokens': entity_tokens,\n",
622 | " 'entity_name': entity_name\n",
623 | " }\n",
624 | " entity_labels.append(one_entity_label)\n",
625 | "\n",
626 | " if not entity_labels:\n",
627 | " tags = [O for _ in range(len(sent_tokens))]\n",
628 | " tag_ids = [tag2idx[O] for _ in range(len(sent_tokens))]\n",
629 | " if for_train:\n",
630 | " continue\n",
631 | " else:\n",
632 | " tags = []\n",
633 | " tag_ids = []\n",
634 | " for sent_token_index in range(len(sent_tokens)):\n",
635 | " tag = O\n",
636 | " for entity_label in entity_labels:\n",
637 | " if sent_token_index == entity_label['start_token_id']:\n",
638 | " tag = f'B-{chinese_entity_type_vs_english_entity_type[entity_label[\"entity_type\"]]}'\n",
639 | " elif entity_label['start_token_id'] < sent_token_index < entity_label[\"end_token_id\"]:\n",
640 | " tag = f'I-{chinese_entity_type_vs_english_entity_type[entity_label[\"entity_type\"]]}'\n",
641 | " tag_id = tag2idx[tag]\n",
642 | " tags.append(tag)\n",
643 | " tag_ids.append(tag_id)\n",
644 | " assert len(sent_tokens) == len(tags) == len(tag_ids)\n",
645 | " not_o_indexes = [index for index, tag in enumerate(tags) if tag != O]\n",
646 | " all_entities = [sent_tokens[index] for index in not_o_indexes]\n",
647 | " all_entities2 = entity_labels\n",
648 | "\n",
649 | " preprocessed_datas.append({\n",
650 | " 'sent': sent,\n",
651 | " 'sent_tokens': sent_tokens,\n",
652 | " 'sent_token_ids': tokenizer.convert_tokens_to_ids(sent_tokens),\n",
653 | " 'entity_labels': entity_labels,\n",
654 | " 'tags': tags,\n",
655 | " 'tag_ids': tag_ids\n",
656 | " })\n",
657 | " return preprocessed_datas"
658 | ]
659 | },
660 | {
661 | "cell_type": "markdown",
662 | "metadata": {},
663 | "source": [
664 | "# 定义dataset 以及 dataloader"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": null,
670 | "metadata": {},
671 | "outputs": [],
672 | "source": [
673 | "class MyDataset(torch.utils.data.Dataset):\n",
674 | " def __init__(self, preprocessed_datas, tokenizer: BertTokenizer, max_length=512 ):\n",
675 | " self.preprocessed_datas = preprocessed_datas\n",
676 | " self.tokenizer = tokenizer\n",
677 | " self.max_length = max_length\n",
678 | "\n",
679 | " def pad_sent_ids(self, sent_ids, max_length, padded_token_id):\n",
680 | " mask = [1] * (min(len(sent_ids), max_length)) + [0] * (max_length - len(sent_ids))\n",
681 | " sent_ids = sent_ids[:max_length] + [padded_token_id] * (max_length - len(sent_ids))\n",
682 | " return sent_ids, mask\n",
683 | "\n",
684 | " def process_one_preprocessed_data(self, preprocessed_data):\n",
685 | " import copy\n",
686 | " preprocessed_data = copy.deepcopy(preprocessed_data)\n",
687 | " \n",
688 | " sent_token_ids = self.tokenizer.convert_tokens_to_ids(\n",
689 | " [START_TAG]) + preprocessed_data['sent_token_ids'] + self.tokenizer.convert_tokens_to_ids([END_TAG])\n",
690 | "\n",
691 | " sent_token_ids, mask = self.pad_sent_ids(\n",
692 | " sent_token_ids, max_length=self.max_length, padded_token_id=self.tokenizer.pad_token_id)\n",
693 | "\n",
694 | " sent_token_ids = np.array(sent_token_ids)\n",
695 | " mask = np.array(mask)\n",
696 | " \n",
697 | " preprocessed_data['sent'] = '^' + preprocessed_data['sent'] + '$'\n",
698 | " preprocessed_data['sent_tokens'] = [START_TAG] + preprocessed_data['sent_tokens'] + [END_TAG]\n",
699 | " preprocessed_data['sent_token_ids'] = sent_token_ids\n",
700 | " \n",
701 | " \n",
702 | " tags = [START_TAG] + preprocessed_data['tags'] + [END_TAG]\n",
703 | " tag_ids = [tag2idx[START_TAG]] + preprocessed_data['tag_ids'] + [tag2idx[END_TAG]]\n",
704 | " tag_ids, _ = self.pad_sent_ids(tag_ids, max_length=self.max_length, padded_token_id=tag2idx[O])\n",
705 | " tag_ids = np.array(tag_ids)\n",
706 | " \n",
707 | "\n",
708 | " for entity_label in preprocessed_data['entity_labels']:\n",
709 | " entity_label['start_token_id'] += 1\n",
710 | " entity_label['end_token_id'] += 1\n",
711 | " entity_label['start_index'] += 1\n",
712 | " entity_label['end_index'] += 1\n",
713 | " \n",
714 | " \n",
715 | " preprocessed_data['tags'] = tags\n",
716 | " preprocessed_data['tag_ids'] = tag_ids\n",
717 | "\n",
718 | " not_o_indexes = [index for index, tag in enumerate(preprocessed_data['tags']) if tag != O]\n",
719 | "\n",
720 | " not_o_indexes_str = not_o_indexes\n",
721 | " all_entities = [preprocessed_data['sent_tokens'][index] for index in not_o_indexes]\n",
722 | " all_entities2 = preprocessed_data['entity_labels']\n",
723 | " all_entities3 = [preprocessed_data['sent'][index] for index in not_o_indexes_str]\n",
724 | " \n",
725 | " preprocessed_data.update({'mask': mask})\n",
726 | "\n",
727 | " return preprocessed_data\n",
728 | "\n",
729 | " def __getitem__(self, item):\n",
730 | " return self.process_one_preprocessed_data(\n",
731 | " self.preprocessed_datas[item]\n",
732 | " )\n",
733 | "\n",
734 | " def __len__(self):\n",
735 | " return len(self.preprocessed_datas)\n",
736 | "\n",
737 | "\n",
738 | "def custom_collate_fn(data):\n",
739 | " # copy from torch official,无需深究\n",
740 | " from torch._six import container_abcs, string_classes\n",
741 | "\n",
742 | " r\"\"\"Converts each NumPy array data field into a tensor\"\"\"\n",
743 | " np_str_obj_array_pattern = re.compile(r'[SaUO]')\n",
744 | " elem_type = type(data)\n",
745 | " if isinstance(data, torch.Tensor):\n",
746 | " return data\n",
747 | " elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' and elem_type.__name__ != 'string_':\n",
748 | " # array of string classes and object\n",
749 | " if elem_type.__name__ == 'ndarray' and np_str_obj_array_pattern.search(data.dtype.str) is not None:\n",
750 | " return data\n",
751 | " return torch.as_tensor(data)\n",
752 | " elif isinstance(data, container_abcs.Mapping):\n",
753 | " tmp_dict = {}\n",
754 | " for key in data:\n",
755 | " if key in ['sent_token_ids', 'tag_ids', 'mask']:\n",
756 | " tmp_dict[key] = custom_collate_fn(data[key])\n",
757 | " if key == 'mask':\n",
758 | " tmp_dict[key] = tmp_dict[key].byte()\n",
759 | " else:\n",
760 | " tmp_dict[key] = data[key]\n",
761 | " return tmp_dict\n",
762 | " elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple\n",
763 | " return elem_type(*(custom_collate_fn(d) for d in data))\n",
764 | " elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):\n",
765 | " return [custom_collate_fn(d) for d in data]\n",
766 | " else:\n",
767 | " return data\n",
768 | "\n",
769 | "\n",
770 | "def build_dataloader(preprocessed_datas, tokenizer: BertTokenizer, batch_size=32, shuffle=True):\n",
771 | " dataset = MyDataset(preprocessed_datas, tokenizer)\n",
772 | " import torch.utils.data\n",
773 | " dataloader = torch.utils.data.DataLoader(\n",
774 | " dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle)\n",
775 | " return dataloader"
776 | ]
777 | },
778 | {
779 | "cell_type": "markdown",
780 | "metadata": {},
781 | "source": [
782 | "# 定义训练时评价指标\n",
783 | "\n",
784 | "仅供训练时参考, 包含实体的precision,recall以及f1。\n",
785 | "\n",
786 | "只有和标注的数据完全相同才算是1,否则为0"
787 | ]
788 | },
789 | {
790 | "cell_type": "code",
791 | "execution_count": null,
792 | "metadata": {},
793 | "outputs": [],
794 | "source": [
795 | "# 训练时指标\n",
796 | "class EvaluateScores:\n",
797 | " def __init__(self, entities_json, predict_entities_json):\n",
798 | " self.entities_json = entities_json\n",
799 | " self.predict_entities_json = predict_entities_json\n",
800 | "\n",
801 | " def compute_entities_score(self):\n",
802 | " return evaluate_entities(self.entities_json, self.predict_entities_json, list(set(self.entities_json.keys())))\n",
803 | " \n",
804 | "def _compute_metrics(ytrue, ypred):\n",
805 | " ytrue = set(ytrue)\n",
806 | " ypred = set(ypred)\n",
807 | " tr = len(ytrue)\n",
808 | " pr = len(ypred)\n",
809 | " hit = len(ypred.intersection(ytrue))\n",
810 | " p = hit / pr if pr!=0 else 0\n",
811 | " r = hit / tr if tr!=0 else 0\n",
812 | " f1 = 2 * p * r / (p + r) if (p+r)!=0 else 0\n",
813 | " return {\n",
814 | " 'p': p,\n",
815 | " 'r': r,\n",
816 | " 'f': f1,\n",
817 | " }\n",
818 | "\n",
819 | "\n",
820 | "def evaluate_entities(true_entities, pred_entities, entity_types):\n",
821 | " scores = []\n",
822 | "\n",
823 | " ps2 = []\n",
824 | " rs2 = []\n",
825 | " fs2 = []\n",
826 | " \n",
827 | " for ent_type in entity_types:\n",
828 | "\n",
829 | " true_entities_list = true_entities.get(ent_type, [])\n",
830 | " pred_entities_list = pred_entities.get(ent_type, [])\n",
831 | " s = _compute_metrics(true_entities_list, pred_entities_list)\n",
832 | " scores.append(s)\n",
833 | " ps = [i['p'] for i in scores]\n",
834 | " rs = [i['r'] for i in scores]\n",
835 | " fs = [i['f'] for i in scores]\n",
836 | " s = {\n",
837 | " 'p': sum(ps) / len(ps),\n",
838 | " 'r': sum(rs) / len(rs),\n",
839 | " 'f': sum(fs) / len(fs),\n",
840 | " }\n",
841 | " return s"
842 | ]
843 | },
844 | {
845 | "cell_type": "markdown",
846 | "metadata": {},
847 | "source": [
848 | "## 定义ner train loop, evaluate loop ,test loop"
849 | ]
850 | },
851 | {
852 | "cell_type": "code",
853 | "execution_count": null,
854 | "metadata": {},
855 | "outputs": [],
856 | "source": [
857 | "def train(model: BertCRF, optimizer, data_loader: torch.utils.data.DataLoader, logger: logging.Logger, epoch_id,\n",
858 | " device='cpu'):\n",
859 | " pbar = tqdm.tqdm(data_loader)\n",
860 | " for batch_id, one_data in enumerate(pbar):\n",
861 | " model.train()\n",
862 | "\n",
863 | " sent_token_ids = torch.stack([d['sent_token_ids'] for d in one_data]).to(device)\n",
864 | " tag_ids = torch.stack([d['tag_ids'] for d in one_data]).to(device)\n",
865 | " mask = torch.stack([d['mask'] for d in one_data]).to(device)\n",
866 | "\n",
867 | " loss = model.forward(sent_token_ids, tag_ids, mask)\n",
868 | " optimizer.zero_grad()\n",
869 | " loss.backward()\n",
870 | " optimizer.step()\n",
871 | " pbar.set_description('epoch: {}, loss: {:.3f}'.format(epoch_id, loss.item()))\n",
872 | "\n",
873 | "\n",
874 | "def evaluate(\n",
875 | " model, data_loader: torch.utils.data.DataLoader, logger: logging.Logger,\n",
876 | " tokenizer, device='cpu',\n",
877 | "):\n",
878 | " founded_entities_json = defaultdict(set)\n",
879 | " golden_entities_json = defaultdict(set)\n",
880 | " for batch_id, one_data in enumerate(data_loader):\n",
881 | " model.eval()\n",
882 | " sent_token_ids = torch.stack([d['sent_token_ids'] for d in one_data]).to(device)\n",
883 | " tag_ids = torch.stack([d['tag_ids'] for d in one_data]).to(device)\n",
884 | " mask = torch.stack([d['mask'] for d in one_data]).to(device)\n",
885 | "\n",
886 | " best_tag_ids_list = model.decode(sent_token_ids, attention_mask=mask)\n",
887 | " best_tags_list = [[idx2tag[idx] for idx in idxs] for idxs in best_tag_ids_list]\n",
888 | "\n",
889 | " for data, best_tags in zip(one_data, best_tag_ids_list):\n",
890 | "\n",
891 | " for entity_label in data['entity_labels']:\n",
892 | " golden_entities_json[entity_label['entity_type']].add(entity_label['entity_name'])\n",
893 | "\n",
894 | " record = False\n",
895 | " for token_index, tag_id in enumerate(best_tags):\n",
896 | " tag = idx2tag[tag_id]\n",
897 | " if tag.startswith('B'):\n",
898 | " start_token_index = token_index\n",
899 | " entity_type = tag[2:]\n",
900 | " record = True\n",
901 | " elif record and tag == O:\n",
902 | " end_token_index = token_index\n",
903 | "\n",
904 | " str_start_index = start_token_index\n",
905 | " str_end_index = end_token_index\n",
906 | "\n",
907 | " entity_name = data['sent'][str_start_index: str_end_index]\n",
908 | "\n",
909 | " entity_type = english_entity_type_vs_chinese_entity_type[entity_type]\n",
910 | " founded_entities_json[entity_type].add(entity_name)\n",
911 | " record = False\n",
912 | " evaluate_tool = EvaluateScores(golden_entities_json, founded_entities_json)\n",
913 | " scores = evaluate_tool.compute_entities_score()\n",
914 | " return scores['f']\n",
915 | "\n",
916 | "\n",
917 | "def test(model, data_loader: torch.utils.data.DataLoader, logger: logging.Logger, device):\n",
918 | " founded_entities = []\n",
919 | " for batch_id, one_data in enumerate(tqdm.tqdm(data_loader)):\n",
920 | " model.eval()\n",
921 | " sent_token_ids = torch.stack([d['sent_token_ids'] for d in one_data]).to(device)\n",
922 | " mask = torch.stack([d['mask'] for d in one_data]).to(device)\n",
923 | "\n",
924 | " with torch.no_grad():\n",
925 | " best_tag_ids_list = model.decode(sent_token_ids, attention_mask=mask, token_type_ids=None)\n",
926 | "\n",
927 | " for data, best_tags in zip(one_data, best_tag_ids_list):\n",
928 | " record = False\n",
929 | " for token_index, tag_id in enumerate(best_tags):\n",
930 | " tag = idx2tag[tag_id]\n",
931 | " if tag.startswith('B'):\n",
932 | " start_token_index = token_index\n",
933 | " entity_type = tag[2:]\n",
934 | " record = True\n",
935 | " elif record and tag == O:\n",
936 | " end_token_index = token_index\n",
937 | " entity_name = data['sent_tokens'][start_token_index: end_token_index + 1]\n",
938 | " founded_entities.append((entity_name, entity_type, data['sent']))\n",
939 | " record = False\n",
940 | " result = defaultdict(list)\n",
941 | " for entity_name, entity_type, sent in founded_entities:\n",
942 | " entity = ''.join(entity_name).replace('##', '')\n",
943 | " entity = entity.replace('[CLS]', '')\n",
944 | " entity = entity.replace('[UNK]', '')\n",
945 | " entity = entity.replace('[SEP]', '')\n",
946 | " if len(entity) > 1:\n",
947 | " result[english_entity_type_vs_chinese_entity_type[entity_type]].append((entity, sent))\n",
948 | "\n",
949 | " for ent_type, ents in result.items():\n",
950 | " result[ent_type] = list(set(ents))\n",
951 | " return result"
952 | ]
953 | },
954 | {
955 | "cell_type": "markdown",
956 | "metadata": {},
957 | "source": [
958 | "# ner主要训练流程\n",
959 | "\n",
960 | "- 分隔训练集验证集,并处理成dataset dataloader\n",
961 | "- 训练,验证,保存模型"
962 | ]
963 | },
964 | {
965 | "cell_type": "code",
966 | "execution_count": null,
967 | "metadata": {},
968 | "outputs": [],
969 | "source": [
970 | "def main_train(logger, tokenizer, model, to_be_trained_entities, yanbao_texts):\n",
971 | " entities_json = to_be_trained_entities\n",
972 | " train_entities_json = {k: [] for k in entities_json}\n",
973 | " dev_entities_json = {k: [] for k in entities_json}\n",
974 | "\n",
975 | " train_proportion = 0.9\n",
976 | " for entity_type, entities in entities_json.items():\n",
977 | " entities = entities.copy()\n",
978 | " random.shuffle(entities)\n",
979 | " \n",
980 | " train_entities_json[entity_type] = entities[: int(len(entities) * train_proportion)]\n",
981 | " dev_entities_json[entity_type] = entities[int(len(entities) * train_proportion):]\n",
982 | "\n",
983 | " \n",
984 | " train_preprocessed_datas = preprocess_data(train_entities_json, yanbao_texts, tokenizer)\n",
985 | " train_dataloader = build_dataloader(train_preprocessed_datas, tokenizer, batch_size=BATCH_SIZE)\n",
986 | " \n",
987 | " dev_preprocessed_datas = preprocess_data(dev_entities_json, yanbao_texts, tokenizer)\n",
988 | " dev_dataloader = build_dataloader(dev_preprocessed_datas, tokenizer, batch_size=BATCH_SIZE)\n",
989 | "\n",
990 | " model = model.to(DEVICE)\n",
991 | " for name, param in model.named_parameters():\n",
992 | " if \"bert_module\" in name:\n",
993 | " param.requires_grad = False\n",
994 | " else:\n",
995 | " param.requires_grad = True\n",
996 | " optimizer = torch.optim.Adam([para for para in model.parameters() if para.requires_grad],\n",
997 | " lr=0.001,\n",
998 | " weight_decay=0.0005)\n",
999 | " best_evaluate_score = 0\n",
1000 | " for epoch in range(TOTAL_EPOCH_NUMS):\n",
1001 | " train(model, optimizer, train_dataloader, logger=logger, epoch_id=epoch, device=DEVICE)\n",
1002 | " evaluate_score = evaluate(model, dev_dataloader, logger=logger, tokenizer=tokenizer, device=DEVICE)\n",
1003 | " print('评估分数:', evaluate_score)\n",
1004 | " if evaluate_score >= best_evaluate_score:\n",
1005 | " best_evaluate_score = evaluate_score\n",
1006 | " save_model_path = os.path.join(SAVE_MODEL_DIR, 'finnal_ccks_model.pth')\n",
1007 | " logger.info('saving model to {}'.format(save_model_path))\n",
1008 | " torch.save(model.cpu().state_dict(), save_model_path)\n",
1009 | " model.to(DEVICE)"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "metadata": {},
1015 | "source": [
1016 | "## 准备训练ner模型"
1017 | ]
1018 | },
1019 | {
1020 | "cell_type": "code",
1021 | "execution_count": null,
1022 | "metadata": {},
1023 | "outputs": [],
1024 | "source": [
1025 | "logger = logging.getLogger(__name__)\n",
1026 | "\n",
1027 | "tokenizer = BertTokenizer.from_pretrained(\n",
1028 | " os.path.join(PRETRAINED_BERT_MODEL_DIR, 'vocab.txt')\n",
1029 | ")\n",
1030 | "\n",
1031 | "model = BertCRF(\n",
1032 | " pretrained_bert_model_file_path=PRETRAINED_BERT_MODEL_DIR,\n",
1033 | " num_tags=len(tag2id), batch_first=True\n",
1034 | ")\n",
1035 | "\n",
1036 | "save_model_path = os.path.join(SAVE_MODEL_DIR, 'finnal_ccks_model.pth')\n",
1037 | "if Path(save_model_path).exists():\n",
1038 | " model_state_dict = torch.load(save_model_path, map_location='cpu')\n",
1039 | " model.load_state_dict(model_state_dict)"
1040 | ]
1041 | },
1042 | {
1043 | "cell_type": "code",
1044 | "execution_count": null,
1045 | "metadata": {},
1046 | "outputs": [],
1047 | "source": [
1048 | "# 训练数据在main_train函数中处理并生成dataset dataloader,此处无需生成\n",
1049 | "\n",
1050 | "# 测试数据在此处处理并生成dataset dataloader\n",
1051 | "test_preprocessed_datas = preprocess_data({}, yanbao_texts, tokenizer, for_train=False)\n",
1052 | "test_dataloader = build_dataloader(test_preprocessed_datas, tokenizer, batch_size=BATCH_SIZE)"
1053 | ]
1054 | },
1055 | {
1056 | "cell_type": "markdown",
1057 | "metadata": {},
1058 | "source": [
1059 | "## 整个训练流程是:\n",
1060 | "\n",
1061 | "- 使用数据集增强得到更多的实体\n",
1062 | "- 使用增强过后的实体来指导训练\n",
1063 | "\n",
1064 | "\n",
1065 | "- 训练后的模型重新对所有文档中进行预测,得到新的实体,加入到实体数据集中\n",
1066 | "- 使用扩增后的实体数据集来进行二次训练,再得到新的实体,再增强实体数据集\n",
1067 | "- (模型预测出来的数据需要`review_model_predict_entities`后处理形成提交格式)\n",
1068 | "\n",
1069 | "\n",
1070 | "- 如果提交结果,需要`extract_entities`函数删除提交数据中那些出现在训练数据中的实体"
1071 | ]
1072 | },
1073 | {
1074 | "cell_type": "markdown",
1075 | "metadata": {},
1076 | "source": [
1077 | "### 模型预测结果后处理函数\n",
1078 | "\n",
1079 | "- `review_model_predict_entities`函数将模型预测结果后处理,从而生成提交文件格式"
1080 | ]
1081 | },
1082 | {
1083 | "cell_type": "code",
1084 | "execution_count": null,
1085 | "metadata": {},
1086 | "outputs": [],
1087 | "source": [
1088 | "def review_model_predict_entities(model_predict_entities):\n",
1089 | " word_tag_map = POSTokenizer().word_tag_tab\n",
1090 | " idf_freq = TFIDF().idf_freq\n",
1091 | " reviewed_entities = defaultdict(list)\n",
1092 | " for ent_type, ent_and_sent_list in model_predict_entities.items():\n",
1093 | " for ent, sent in ent_and_sent_list:\n",
1094 | " start = sent.lower().find(ent)\n",
1095 | " if start == -1:\n",
1096 | " continue\n",
1097 | " start += 1\n",
1098 | " end = start + len(ent) - 1\n",
1099 | " tokens = jieba.lcut(sent)\n",
1100 | " offset = 0\n",
1101 | " selected_tokens = []\n",
1102 | " for token in tokens:\n",
1103 | " offset += len(token)\n",
1104 | " if offset >= start:\n",
1105 | " selected_tokens.append(token)\n",
1106 | " if offset >= end:\n",
1107 | " break\n",
1108 | "\n",
1109 | " fixed_entity = ''.join(selected_tokens)\n",
1110 | " fixed_entity = re.sub(r'\\d*\\.?\\d+%$', '', fixed_entity)\n",
1111 | " if ent_type == '人物':\n",
1112 | " if len(fixed_entity) >= 10:\n",
1113 | " continue\n",
1114 | " if len(fixed_entity) <= 1:\n",
1115 | " continue\n",
1116 | " if re.findall(r'^\\d+$', fixed_entity):\n",
1117 | " continue\n",
1118 | " if word_tag_map.get(fixed_entity, '') == 'v' and idf_freq[fixed_entity] < 7:\n",
1119 | " continue\n",
1120 | " reviewed_entities[ent_type].append(fixed_entity)\n",
1121 | " return reviewed_entities"
1122 | ]
1123 | },
1124 | {
1125 | "cell_type": "markdown",
1126 | "metadata": {},
1127 | "source": [
1128 | "- `extract_entities` 删除与训练集中重复的实体"
1129 | ]
1130 | },
1131 | {
1132 | "cell_type": "code",
1133 | "execution_count": null,
1134 | "metadata": {},
1135 | "outputs": [],
1136 | "source": [
1137 | "def extract_entities(to_be_trained_entities):\n",
1138 | " test_entities = to_be_trained_entities\n",
1139 | " train_entities = read_json(Path(DATA_DIR, 'entities.json'))\n",
1140 | "\n",
1141 | " for ent_type, ents in test_entities.items():\n",
1142 | " test_entities[ent_type] = list(set(ents) - set(train_entities[ent_type]))\n",
1143 | "\n",
1144 | " for ent_type in train_entities.keys():\n",
1145 | " if ent_type not in test_entities:\n",
1146 | " test_entities[ent_type] = []\n",
1147 | " return test_entities"
1148 | ]
1149 | },
1150 | {
1151 | "cell_type": "code",
1152 | "execution_count": null,
1153 | "metadata": {
1154 | "scrolled": false
1155 | },
1156 | "outputs": [],
1157 | "source": [
1158 | "# 循环轮次数目\n",
1159 | "nums_round = 1\n",
1160 | "for i in range(nums_round):\n",
1161 | " # train\n",
1162 | " main_train(logger, tokenizer, model, to_be_trained_entities, yanbao_texts) \n",
1163 | " \n",
1164 | " model = model.to(DEVICE)\n",
1165 | " model_predict_entities = test(model, test_dataloader, logger=logger, device=DEVICE)\n",
1166 | " \n",
1167 | " # 修复训练预测结果\n",
1168 | " reviewed_entities = review_model_predict_entities(model_predict_entities)\n",
1169 | " \n",
1170 | " # 将训练预测结果再次放入训练集中, 重新训练或者直接出结果\n",
1171 | " for ent_type, ents in reviewed_entities.items():\n",
1172 | " to_be_trained_entities[ent_type] = list(set(to_be_trained_entities[ent_type] + ents))\n",
1173 | "\n",
1174 | "# 创造出提交结果\n",
1175 | "submit_entities = extract_entities(to_be_trained_entities)"
1176 | ]
1177 | },
1178 | {
1179 | "cell_type": "markdown",
1180 | "metadata": {},
1181 | "source": [
1182 | "# 属性抽取\n",
1183 | "\n",
1184 | "通过规则抽取属性\n",
1185 | "\n",
1186 | "- 研报时间\n",
1187 | "- 研报评级\n",
1188 | "- 文章时间"
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "code",
1193 | "execution_count": null,
1194 | "metadata": {},
1195 | "outputs": [],
1196 | "source": [
1197 | "def find_article_time(yanbao_txt, entity):\n",
1198 | " str_start_index = yanbao_txt.index(entity)\n",
1199 | " str_end_index = str_start_index + len(entity)\n",
1200 | " para_start_index = yanbao_txt.rindex('\\n', 0, str_start_index)\n",
1201 | " para_end_index = yanbao_txt.index('\\n', str_end_index)\n",
1202 | "\n",
1203 | " para = yanbao_txt[para_start_index + 1: para_end_index].strip()\n",
1204 | " if len(entity) > 5:\n",
1205 | " ret = re.findall(r'(\\d{4})\\s*[年-]\\s*(\\d{1,2})\\s*[月-]\\s*(\\d{1,2})\\s*日?', para)\n",
1206 | " if ret:\n",
1207 | " year, month, day = ret[0]\n",
1208 | " time = '{}/{}/{}'.format(year, month.lstrip(), day.lstrip())\n",
1209 | " return time\n",
1210 | "\n",
1211 | " start_index = 0\n",
1212 | " time = None\n",
1213 | " min_gap = float('inf')\n",
1214 | " for word, poseg in pseg.cut(para):\n",
1215 | " if poseg in ['t', 'TIME'] and str_start_index <= start_index < str_end_index:\n",
1216 | " gap = abs(start_index - (str_start_index + str_end_index) // 2)\n",
1217 | " if gap < min_gap:\n",
1218 | " min_gap = gap\n",
1219 | " time = word\n",
1220 | " start_index += len(word)\n",
1221 | " return time\n",
1222 | "\n",
1223 | "\n",
1224 | "def find_yanbao_time(yanbao_txt, entity):\n",
1225 | " paras = [para.strip() for para in yanbao_txt.split('\\n') if para.strip()][:5]\n",
1226 | " for para in paras:\n",
1227 | " ret = re.findall(r'(\\d{4})\\s*[\\./年-]\\s*(\\d{1,2})\\s*[\\./月-]\\s*(\\d{1,2})\\s*日?', para)\n",
1228 | " if ret:\n",
1229 | " year, month, day = ret[0]\n",
1230 | " time = '{}/{}/{}'.format(year, month.lstrip(), day.lstrip())\n",
1231 | " return time\n",
1232 | " return None"
1233 | ]
1234 | },
1235 | {
1236 | "cell_type": "code",
1237 | "execution_count": null,
1238 | "metadata": {},
1239 | "outputs": [],
1240 | "source": [
1241 | "def extract_attrs(entities_json):\n",
1242 | " train_attrs = read_json(Path(DATA_DIR, 'attrs.json'))['attrs']\n",
1243 | "\n",
1244 | " seen_pingjis = []\n",
1245 | " for attr in train_attrs:\n",
1246 | " if attr[1] == '评级':\n",
1247 | " seen_pingjis.append(attr[2])\n",
1248 | " article_entities = entities_json.get('文章', [])\n",
1249 | " yanbao_entities = entities_json.get('研报', [])\n",
1250 | "\n",
1251 | " attrs_json = []\n",
1252 | " for file_path in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))):\n",
1253 | " yanbao_txt = '\\n' + Path(file_path).open().read() + '\\n'\n",
1254 | " for entity in article_entities:\n",
1255 | " if entity not in yanbao_txt:\n",
1256 | " continue\n",
1257 | " time = find_article_time(yanbao_txt, entity)\n",
1258 | " if time:\n",
1259 | " attrs_json.append([entity, '发布时间', time])\n",
1260 | "\n",
1261 | " yanbao_txt = '\\n'.join(\n",
1262 | " [para.strip() for para in yanbao_txt.split('\\n') if\n",
1263 | " len(para.strip()) != 0])\n",
1264 | " for entity in yanbao_entities:\n",
1265 | " if entity not in yanbao_txt:\n",
1266 | " continue\n",
1267 | "\n",
1268 | " paras = yanbao_txt.split('\\n')\n",
1269 | " for para_id, para in enumerate(paras):\n",
1270 | " if entity in para:\n",
1271 | " break\n",
1272 | "\n",
1273 | " paras = paras[: para_id + 5]\n",
1274 | " for para in paras:\n",
1275 | " for pingji in seen_pingjis:\n",
1276 | " if pingji in para:\n",
1277 | " if '上次' in para:\n",
1278 | " attrs_json.append([entity, '上次评级', pingji])\n",
1279 | " continue\n",
1280 | " elif '维持' in para:\n",
1281 | " attrs_json.append([entity, '上次评级', pingji])\n",
1282 | " attrs_json.append([entity, '评级', pingji])\n",
1283 | "\n",
1284 | " time = find_yanbao_time(yanbao_txt, entity)\n",
1285 | " if time:\n",
1286 | " attrs_json.append([entity, '发布时间', time])\n",
1287 | " attrs_json = list(set(tuple(_) for _ in attrs_json) - set(tuple(_) for _ in train_attrs))\n",
1288 | " \n",
1289 | " return attrs_json"
1290 | ]
1291 | },
1292 | {
1293 | "cell_type": "code",
1294 | "execution_count": null,
1295 | "metadata": {},
1296 | "outputs": [],
1297 | "source": [
1298 | "train_attrs = read_json(Path(DATA_DIR, 'attrs.json'))['attrs']\n",
1299 | "submit_attrs = extract_attrs(submit_entities)"
1300 | ]
1301 | },
1302 | {
1303 | "cell_type": "code",
1304 | "execution_count": null,
1305 | "metadata": {},
1306 | "outputs": [],
1307 | "source": [
1308 | "submit_attrs"
1309 | ]
1310 | },
1311 | {
1312 | "cell_type": "markdown",
1313 | "metadata": {},
1314 | "source": [
1315 | "# 关系抽取\n",
1316 | "\n",
1317 | "- 对于研报实体,整个文档抽取特定类型(行业,机构,指标)的关系实体\n",
1318 | "- 其他的实体仅考虑与其出现在同一句话中的其他实体组织成特定关系"
1319 | ]
1320 | },
1321 | {
1322 | "cell_type": "code",
1323 | "execution_count": null,
1324 | "metadata": {},
1325 | "outputs": [],
1326 | "source": [
1327 | "def extract_relations(schema, entities_json):\n",
1328 | " relation_by_rules = []\n",
1329 | " relation_schema = schema['relationships']\n",
1330 | " unique_s_o_types = []\n",
1331 | " so_type_cnt = defaultdict(int)\n",
1332 | " for s_type, p, o_type in schema['relationships']:\n",
1333 | " so_type_cnt[(s_type, o_type)] += 1\n",
1334 | " for (s_type, o_type), cnt in so_type_cnt.items():\n",
1335 | " if cnt == 1 and s_type != o_type:\n",
1336 | " unique_s_o_types.append((s_type, o_type))\n",
1337 | "\n",
1338 | " for path in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))):\n",
1339 | " with open(path) as f:\n",
1340 | " entity_dict_in_file = defaultdict(lambda: defaultdict(list))\n",
1341 | " main_org = None\n",
1342 | " for line_idx, line in enumerate(f.readlines()):\n",
1343 | " for sent_idx, sent in enumerate(split_to_sents(line)):\n",
1344 | " for ent_type, ents in entities_json.items():\n",
1345 | " for ent in ents:\n",
1346 | " if ent in sent:\n",
1347 | " if ent_type == '机构' and len(line) - len(ent) < 3 or \\\n",
1348 | " re.findall('[\\((]\\d+\\.*[A-Z]*[\\))]', line):\n",
1349 | " main_org = ent\n",
1350 | " else:\n",
1351 | " if main_org and '客户' in sent:\n",
1352 | " relation_by_rules.append([ent, '客户', main_org])\n",
1353 | " entity_dict_in_file[ent_type][\n",
1354 | " ('test', ent)].append(\n",
1355 | " [line_idx, sent_idx, sent,\n",
1356 | " sent.find(ent)]\n",
1357 | " )\n",
1358 | "\n",
1359 | " for s_type, p, o_type in relation_schema:\n",
1360 | " s_ents = entity_dict_in_file[s_type]\n",
1361 | " o_ents = entity_dict_in_file[o_type]\n",
1362 | " if o_type == '业务' and not '业务' in line:\n",
1363 | " continue\n",
1364 | " if o_type == '行业' and not '行业' in line:\n",
1365 | " continue\n",
1366 | " if o_type == '文章' and not ('《' in line or not '》' in line):\n",
1367 | " continue\n",
1368 | " if s_ents and o_ents:\n",
1369 | " for (s_ent_src, s_ent), (o_ent_src, o_ent) in product(s_ents, o_ents):\n",
1370 | " if s_ent != o_ent:\n",
1371 | " s_occs = [tuple(_[:2]) for _ in\n",
1372 | " s_ents[(s_ent_src, s_ent)]]\n",
1373 | " o_occs = [tuple(_[:2]) for _ in\n",
1374 | " o_ents[(o_ent_src, o_ent)]]\n",
1375 | " intersection = set(s_occs) & set(o_occs)\n",
1376 | " if s_type == '研报' and s_ent_src == 'test':\n",
1377 | " relation_by_rules.append([s_ent, p, o_ent])\n",
1378 | " continue\n",
1379 | " if not intersection:\n",
1380 | " continue\n",
1381 | " if (s_type, o_type) in unique_s_o_types and s_ent_src == 'test':\n",
1382 | " relation_by_rules.append([s_ent, p, o_ent])\n",
1383 | "\n",
1384 | " train_relations = read_json(Path(DATA_DIR, 'relationships.json'))['relationships']\n",
1385 | " result_relations_set = list(set(tuple(_) for _ in relation_by_rules) - set(tuple(_) for _ in train_relations))\n",
1386 | " return result_relations_set"
1387 | ]
1388 | },
1389 | {
1390 | "cell_type": "code",
1391 | "execution_count": null,
1392 | "metadata": {},
1393 | "outputs": [],
1394 | "source": [
1395 | "schema = read_json(Path(DATA_DIR, 'schema.json'))\n",
1396 | "submit_relations = extract_relations(schema, submit_entities)"
1397 | ]
1398 | },
1399 | {
1400 | "cell_type": "code",
1401 | "execution_count": null,
1402 | "metadata": {},
1403 | "outputs": [],
1404 | "source": [
1405 | "submit_relations"
1406 | ]
1407 | },
1408 | {
1409 | "cell_type": "markdown",
1410 | "metadata": {},
1411 | "source": [
1412 | "## 生成提交文件\n",
1413 | "\n",
1414 | "根据biendata的要求生成提交文件\n",
1415 | "\n",
1416 | "参考:https://www.biendata.com/competition/ccks_2020_5/make-submission/"
1417 | ]
1418 | },
1419 | {
1420 | "cell_type": "code",
1421 | "execution_count": null,
1422 | "metadata": {},
1423 | "outputs": [],
1424 | "source": [
1425 | "final_answer = {'attrs': submit_attrs,\n",
1426 | " 'entities': submit_entities,\n",
1427 | " 'relationships': submit_relations,\n",
1428 | " }\n",
1429 | "\n",
1430 | "\n",
1431 | "with open('output/answers.json', mode='w') as fw:\n",
1432 | " json.dump(final_answer, fw, ensure_ascii=False, indent=4)\n"
1433 | ]
1434 | },
1435 | {
1436 | "cell_type": "code",
1437 | "execution_count": null,
1438 | "metadata": {},
1439 | "outputs": [],
1440 | "source": [
1441 | "with open('output/answers.json', 'rb') as fb:\n",
1442 | " data = fb.read()\n",
1443 | "\n",
1444 | "b64 = base64.b64encode(data)\n",
1445 | "payload = b64.decode()\n",
1446 | "html = '{title}'\n",
1447 | "html = html.format(payload=payload,title='answers.json',filename='answers.json')\n",
1448 | "HTML(html)"
1449 | ]
1450 | }
1451 | ],
1452 | "metadata": {
1453 | "kernelspec": {
1454 | "display_name": "Python 3",
1455 | "language": "python",
1456 | "name": "python3"
1457 | },
1458 | "language_info": {
1459 | "codemirror_mode": {
1460 | "name": "ipython",
1461 | "version": 3
1462 | },
1463 | "file_extension": ".py",
1464 | "mimetype": "text/x-python",
1465 | "name": "python",
1466 | "nbconvert_exporter": "python",
1467 | "pygments_lexer": "ipython3",
1468 | "version": "3.7.4"
1469 | }
1470 | },
1471 | "nbformat": 4,
1472 | "nbformat_minor": 4
1473 | }
1474 |
--------------------------------------------------------------------------------
/20220716-jd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wgwang/ccks2020-baseline/96c8f990f87afc3a12f1895bab21c095d8cb410e/20220716-jd.png
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04
2 | MAINTAINER hemengjie@datagrand.com
3 |
4 | # Install basic dependencies
5 | RUN apt-get update && apt-get install -y --no-install-recommends \
6 | build-essential \
7 | cmake \
8 | git \
9 | wget \
10 | libopencv-dev \
11 | libsnappy-dev \
12 | python-dev \
13 | python-pip \
14 | tzdata \
15 | vim
16 |
17 |
18 | # Install anaconda for python 3.6
19 | RUN wget --quiet https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -O ~/anaconda.sh && \
20 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \
21 | rm ~/anaconda.sh && \
22 | echo "export PATH=/opt/conda/bin:$PATH" >> ~/.bashrc
23 |
24 |
25 | # Set timezone
26 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
27 |
28 | # Set the locale
29 | RUN apt-get install -y locales
30 | RUN locale-gen en_US.UTF-8
31 | ENV LANG en_US.UTF-8
32 | ENV LANGUAGE en_US:en
33 | ENV LC_ALL en_US.UTF-8
34 |
35 |
36 | COPY ./requirements.txt /root/doc_compare/
37 | COPY ./start_server.py /root/doc_compare/
38 | COPY ./data /root/doc_compare/data
39 | COPY ./src /root/doc_compare/src
40 |
41 |
42 |
43 | WORKDIR /root/doc_compare
44 | #pip 依赖
45 |
46 | RUN cd /root/doc_compare/
47 | RUN /opt/conda/bin/conda install tensorflow-gpu -y
48 |
49 | RUN pip install -r requirements.txt -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
50 | RUN rm /root/doc_compare/requirements.txt
51 |
52 | EXPOSE 11000
53 | CMD ["python","start_server.py"]
--------------------------------------------------------------------------------
/build-shell.sh:
--------------------------------------------------------------------------------
1 | TIMENOW=`date +%y.%m.%d`
2 |
3 | # -f 指定文件 , -t 指定生成镜像名称 , 冒号后为版本号 , 各位大佬命名请不要冲突 例子 : rec_action_pipe:17.08.01.1311
4 | docker build -f Dockerfile -t dockerhub.datagrand.com/gfyuqing/doc_compare:${TIMENOW} .
5 |
--------------------------------------------------------------------------------
/ccks2020-baseline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# CCKS 2020: 基于本体的金融知识图谱自动化构建技术评测\n",
8 | "\n",
9 | "竞赛背景\n",
10 | "金融研报是各类金融研究结构对宏观经济、金融、行业、产业链以及公司的研究报告。报告通常是有专业人员撰写,对宏观、行业和公司的数据信息搜集全面、研究深入,质量高,内容可靠。报告内容往往包含产业、经济、金融、政策、社会等多领域的数据与知识,是构建行业知识图谱非常关键的数据来源。另一方面,由于研报本身所容纳的数据与知识涉及面广泛,专业知识众多,不同的研究结构和专业认识对相同的内容的表达方式也会略有差异。这些特点导致了从研报自动化构建知识图谱困难重重,解决这些问题则能够极大促进自动化构建知识图谱方面的技术进步。\n",
11 | " \n",
12 | "本评测任务参考 TAC KBP 中的 Cold Start 评测任务的方案,围绕金融研报知识图谱的自动化图谱构建所展开。评测从预定义图谱模式(Schema)和少量的种子知识图谱开始,从非结构化的文本数据中构建知识图谱。其中图谱模式包括 10 种实体类型,如机构、产品、业务、风险等;19 个实体间的关系,如(机构,生产销售,产品)、(机构,投资,机构)等;以及若干实体类型带有属性,如(机构,英文名)、(研报,评级)等。在给定图谱模式和种子知识图谱的条件下,评测内容为自动地从研报文本中抽取出符合图谱模式的实体、关系和属性值,实现金融知识图谱的自动化构建。所构建的图谱在大金融行业、监管部门、政府、行业研究机构和行业公司等应用非常广泛,如风险监测、智能投研、智能监管、智能风控等,具有巨大的学术价值和产业价值。\n",
13 | " \n",
14 | "评测本身不限制各参赛队伍使用的模型、算法和技术。希望各参赛队伍发挥聪明才智,构建各类无监督、弱监督、远程监督、半监督等系统,迭代的实现知识图谱的自动化构建,共同促进知识图谱技术的进步。\n",
15 | "\n",
16 | "竞赛任务\n",
17 | "本评测任务参考 TAC KBP 中的 Cold Start 评测任务的方案,围绕金融研报知识图谱的自动化图谱构建所展开。评测从预定义图谱模式(Schema)和少量的种子知识图谱开始,从非结构化的文本数据中构建知识图谱。评测本身不限制各参赛队伍使用的模型、算法和技术。希望各参赛队伍发挥聪明才智,构建各类无监督、弱监督、远程监督、半监督等系统,迭代的实现知识图谱的自动化构建,共同促进知识图谱技术的进步。\n",
18 | "\n",
19 | "主办方邮箱 wangwenguang@datagrand.com kdd.wang@gmail.com\n",
20 | "\n",
21 | "\n",
22 | "参考:https://www.biendata.com/competition/ccks_2020_5/"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "import json\n",
32 | "import logging\n",
33 | "import os\n",
34 | "import random\n",
35 | "import re\n",
36 | "import base64\n",
37 | "from collections import defaultdict\n",
38 | "from pathlib import Path\n",
39 | "\n",
40 | "# import attr\n",
41 | "import tqdm\n",
42 | "import hanlp\n",
43 | "import numpy as np\n",
44 | "import torch\n",
45 | "import torch.optim\n",
46 | "import torch.utils.data\n",
47 | "from torch.nn import functional as F\n",
48 | "from torchcrf import CRF\n",
49 | "from pytorch_transformers import BertModel, BertTokenizer\n",
50 | "import jieba\n",
51 | "from jieba.analyse.tfidf import TFIDF\n",
52 | "from jieba.posseg import POSTokenizer\n",
53 | "import jieba.posseg as pseg\n",
54 | "from itertools import product\n",
55 | "from IPython.display import HTML\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# 预处理函数\n",
63 | "\n",
64 | "对文章进行预处理,切分句子和子句等"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "def split_to_sents(content, filter_length=(2, 1000)):\n",
74 | " content = re.sub(r\"\\s*\", \"\", content)\n",
75 | " content = re.sub(\"([。!…??!;;])\", \"\\\\1\\1\", content)\n",
76 | " sents = content.split(\"\\1\")\n",
77 | " sents = [_[: filter_length[1]] for _ in sents]\n",
78 | " return [_ for _ in sents\n",
79 | " if filter_length[0] <= len(_) <= filter_length[1]]\n",
80 | "\n",
81 | "def split_to_subsents(content, filter_length=(2, 1000)):\n",
82 | " content = re.sub(r\"\\s*\", \"\", content)\n",
83 | " content = re.sub(\"([。!…??!;;,,])\", \"\\\\1\\1\", content)\n",
84 | " sents = content.split(\"\\1\")\n",
85 | " sents = [_[: filter_length[1]] for _ in sents]\n",
86 | " return [_ for _ in sents\n",
87 | " if filter_length[0] <= len(_) <= filter_length[1]]"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "def read_json(file_path):\n",
97 | " with open(file_path, mode='r', encoding='utf8') as f:\n",
98 | " return json.load(f)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "# 预训练模型配置\n",
106 | "\n",
107 | "参考 https://github.com/huggingface/pytorch-transformers 下载预训练模型,并配置下面参数为相关路径\n",
108 | "\n",
109 | "```python\n",
110 | "PRETRAINED_BERT_MODEL_DIR = '/you/path/to/bert-base-chinese/' \n",
111 | "```"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "PRETRAINED_BERT_MODEL_DIR = '/home/wangwenguang/bigdata/wke-data/pretrained_models/bert-base-chinese/' "
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "# 一些参数"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "DATA_DIR = './data' # 输入数据文件夹\n",
137 | "OUT_DIR = './output' # 输出文件夹\n",
138 | "\n",
139 | "Path(OUT_DIR).mkdir(exist_ok=True)\n",
140 | "\n",
141 | "BATCH_SIZE = 32\n",
142 | "TOTAL_EPOCH_NUMS = 10\n",
143 | "if torch.cuda.is_available():\n",
144 | " DEVICE = 'cuda:0'\n",
145 | "else:\n",
146 | " DEVICE = 'cpu'\n",
147 | "YANBAO_DIR_PATH = str(Path(DATA_DIR, 'yanbao_txt'))\n",
148 | "SAVE_MODEL_DIR = str(OUT_DIR)"
149 | ]
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "metadata": {},
154 | "source": [
155 | "## 读入原始数据\n",
156 | "\n",
157 | "- 读入:所有研报内容\n",
158 | "- 读入:原始训练实体数据"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "yanbao_texts = []\n",
168 | "for yanbao_file_path in Path(YANBAO_DIR_PATH).glob('*.txt'):\n",
169 | " with open(yanbao_file_path) as f:\n",
170 | " yanbao_texts.append(f.read())\n",
171 | "# if len(yanbao_texts) == 10:\n",
172 | "# break\n",
173 | "\n",
174 | "# 来做官方的实体训练集,后续会混合来自第三方工具,规则,训练数据来扩充模型训练数据\n",
175 | "to_be_trained_entities = read_json(Path(DATA_DIR, 'entities.json'))"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "# 用hanlp进行实体识别\n",
183 | "\n",
184 | "hanlp支持对人物、机构的实体识别,可以使用它来对其中的两个实体类型进行识别:人物、机构。\n",
185 | "\n",
186 | "hanlp见[https://github.com/hankcs/HanLP](https://github.com/hankcs/HanLP)"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "## NER by third party tool\n",
196 | "class HanlpNER:\n",
197 | " def __init__(self):\n",
198 | " self.recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)\n",
199 | " self.max_sent_len = 126\n",
200 | " self.ent_type_map = {\n",
201 | " 'NR': '人物',\n",
202 | " 'NT': '机构'\n",
203 | " }\n",
204 | " self.black_list = {'公司'}\n",
205 | "\n",
206 | " def recognize(self, sent):\n",
207 | " entities_dict = {}\n",
208 | " for result in self.recognizer.predict([list(sent)]):\n",
209 | " for entity, hanlp_ent_type, _, _ in result:\n",
210 | " if not re.findall(r'^[\\.\\s\\da-zA-Z]{1,2}$', entity) and \\\n",
211 | " len(entity) > 1 and entity not in self.black_list \\\n",
212 | " and hanlp_ent_type in self.ent_type_map:\n",
213 | " entities_dict.setdefault(self.ent_type_map[hanlp_ent_type], []).append(entity)\n",
214 | " return entities_dict"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "def aug_entities_by_third_party_tool():\n",
224 | " hanlpner = HanlpNER()\n",
225 | " entities_by_third_party_tool = defaultdict(list)\n",
226 | " for file in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))[:]):\n",
227 | " with open(file, encoding='utf-8') as f:\n",
228 | " sents = [[]]\n",
229 | " cur_sent_len = 0\n",
230 | " for line in f:\n",
231 | " for sent in split_to_subsents(line):\n",
232 | " sent = sent[:hanlpner.max_sent_len]\n",
233 | " if cur_sent_len + len(sent) > hanlpner.max_sent_len:\n",
234 | " sents.append([sent])\n",
235 | " cur_sent_len = len(sent)\n",
236 | " else:\n",
237 | " sents[-1].append(sent)\n",
238 | " cur_sent_len += len(sent)\n",
239 | " sents = [''.join(_) for _ in sents]\n",
240 | " sents = [_ for _ in sents if _]\n",
241 | " for sent in sents:\n",
242 | " entities_dict = hanlpner.recognize(sent)\n",
243 | " for ent_type, ents in entities_dict.items():\n",
244 | " entities_by_third_party_tool[ent_type] += ents\n",
245 | "\n",
246 | " for ent_type, ents in entities_by_third_party_tool.items():\n",
247 | " entities_by_third_party_tool[ent_type] = list([ent for ent in set(ents) if len(ent) > 1])\n",
248 | " return entities_by_third_party_tool"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {
255 | "scrolled": true
256 | },
257 | "outputs": [],
258 | "source": [
259 | "# 此任务十分慢, 但是只需要运行一次\n",
260 | "entities_by_third_party_tool = aug_entities_by_third_party_tool()\n",
261 | "for ent_type, ents in entities_by_third_party_tool.items():\n",
262 | " to_be_trained_entities[ent_type] = list(set(to_be_trained_entities[ent_type] + ents))"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {
269 | "scrolled": true
270 | },
271 | "outputs": [],
272 | "source": [
273 | "for k, v in entities_by_third_party_tool.items():\n",
274 | " print(k)\n",
275 | " print(set(v))"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "metadata": {},
281 | "source": [
282 | "## 通过规则抽取实体\n",
283 | "\n",
284 | "- 机构\n",
285 | "- 研报\n",
286 | "- 文章\n",
287 | "- 风险"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "def aug_entities_by_rules(yanbao_dir):\n",
297 | " entities_by_rule = defaultdict(list)\n",
298 | " for file in list(yanbao_dir.glob('*.txt'))[:]:\n",
299 | " with open(file, encoding='utf-8') as f:\n",
300 | " found_yanbao = False\n",
301 | " found_fengxian = False\n",
302 | " for lidx, line in enumerate(f):\n",
303 | " # 公司的标题\n",
304 | " ret = re.findall('^[\\((]*[\\d一二三四五六七八九十①②③④⑤]*[\\))\\.\\s]*(.*有限公司)$', line)\n",
305 | " if ret:\n",
306 | " entities_by_rule['机构'].append(ret[0])\n",
307 | " \n",
308 | " # 研报\n",
309 | " if not found_yanbao and lidx <= 5 and len(line) > 10:\n",
310 | " may_be_yanbao = line.strip()\n",
311 | " if not re.findall(r'\\d{4}\\s*[年-]\\s*\\d{1,2}\\s*[月-]\\s*\\d{1,2}\\s*日?', may_be_yanbao) \\\n",
312 | " and not re.findall('^[\\d一二三四五六七八九十]+\\s*[\\.、]\\s*.*$', may_be_yanbao) \\\n",
313 | " and not re.findall('[\\((]\\d+\\.*[A-Z]*[\\))]', may_be_yanbao) \\\n",
314 | " and len(may_be_yanbao) > 5:\n",
315 | " entities_by_rule['研报'].append(may_be_yanbao)\n",
316 | " found_yanbao = True\n",
317 | "\n",
318 | " # 文章\n",
319 | " for sent in split_to_sents(line):\n",
320 | " results = re.findall('《(.*?)》', sent)\n",
321 | " for result in results:\n",
322 | " entities_by_rule['文章'].append(result) \n",
323 | "\n",
324 | " # 风险\n",
325 | " for sent in split_to_sents(line):\n",
326 | " if found_fengxian:\n",
327 | " sent = sent.split(':')[0]\n",
328 | " fengxian_entities = re.split('以及|、|,|;|。', sent)\n",
329 | " fengxian_entities = [re.sub('^[■]+[\\d一二三四五六七八九十①②③④⑤]+', '', ent) for ent in fengxian_entities]\n",
330 | " fengxian_entities = [re.sub('^[\\((]*[\\d一二三四五六七八九十①②③④⑤]+[\\))\\.\\s]+', '', ent) for ent in fengxian_entities]\n",
331 | " fengxian_entities = [_ for _ in fengxian_entities if len(_) >=4]\n",
332 | " entities_by_rule['风险'] += fengxian_entities\n",
333 | " found_fengxian = False\n",
334 | " if not found_fengxian and re.findall('^\\s*[\\d一二三四五六七八九十]*\\s*[\\.、]*\\s*风险提示[::]*$', sent):\n",
335 | " found_fengxian = True\n",
336 | " \n",
337 | " results = re.findall('^\\s*[\\d一二三四五六七八九十]*\\s*[\\.、]*\\s*风险提示[::]*(.{5,})$', sent)\n",
338 | " if results:\n",
339 | " fengxian_entities = re.split('以及|、|,|;|。', results[0])\n",
340 | " fengxian_entities = [re.sub('^[■]+[\\d一二三四五六七八九十①②③④⑤]+', '', ent) for ent in fengxian_entities]\n",
341 | " fengxian_entities = [re.sub('^[\\((]*[\\d一二三四五六七八九十①②③④⑤]+[\\))\\.\\s]+', '', ent) for ent in fengxian_entities]\n",
342 | " fengxian_entities = [_ for _ in fengxian_entities if len(_) >=4]\n",
343 | " entities_by_rule['风险'] += fengxian_entities\n",
344 | " \n",
345 | " for ent_type, ents in entities_by_rule.items():\n",
346 | " entities_by_rule[ent_type] = list(set(ents))\n",
347 | " return entities_by_rule"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "# 通过规则来寻找新的实体\n",
357 | "entities_by_rule = aug_entities_by_rules(Path(DATA_DIR, 'yanbao_txt'))\n",
358 | "for ent_type, ents in entities_by_rule.items():\n",
359 | " to_be_trained_entities[ent_type] = list(set(to_be_trained_entities[ent_type] + ents))"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "scrolled": true
367 | },
368 | "outputs": [],
369 | "source": [
370 | "for k, v in entities_by_rule.items():\n",
371 | " print(k)\n",
372 | " print(set(v))"
373 | ]
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "metadata": {},
378 | "source": [
379 | "# 定义NER模型\n"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": null,
385 | "metadata": {},
386 | "outputs": [],
387 | "source": [
388 | "class BertCRF(torch.nn.Module):\n",
389 | " def __init__(self, pretrained_bert_model_file_path, num_tags: int, batch_first: bool = False, hidden_size=768):\n",
390 | " super(BertCRF, self).__init__()\n",
391 | " self.bert_module = BertModel.from_pretrained(pretrained_bert_model_file_path)\n",
392 | " self.tag_linear = torch.nn.Linear(hidden_size, num_tags)\n",
393 | " self.crf_module = CRF(num_tags, batch_first)\n",
394 | "\n",
395 | " def forward(self,\n",
396 | " inputs_ids,\n",
397 | " tags,\n",
398 | " mask = None,\n",
399 | " token_type_ids=None,\n",
400 | " reduction = 'mean'\n",
401 | " ) -> torch.Tensor:\n",
402 | " bert_outputs = self.bert_module.forward(inputs_ids, attention_mask=mask, token_type_ids=token_type_ids)[0]\n",
403 | " bert_outputs = F.dropout(bert_outputs, p=0.2, training=self.training)\n",
404 | " bert_outputs = self.tag_linear(bert_outputs)\n",
405 | " score = -self.crf_module.forward(bert_outputs, tags=tags, mask=mask, reduction=reduction)\n",
406 | " return score\n",
407 | "\n",
408 | " def decode(self,\n",
409 | " input_ids,\n",
410 | " attention_mask=None,\n",
411 | " token_type_ids=None\n",
412 | " ):\n",
413 | " bert_outputs = self.bert_module.forward(\n",
414 | " input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids\n",
415 | " )[0]\n",
416 | " bert_outputs = self.tag_linear(bert_outputs)\n",
417 | " best_tags_list = self.crf_module.decode(bert_outputs, mask=attention_mask)\n",
418 | " return best_tags_list\n"
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {},
424 | "source": [
425 | "## 定义NER标签"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "chinese_entity_type_vs_english_entity_type = {\n",
435 | " '人物': 'People',\n",
436 | " '行业': 'Industry',\n",
437 | " '业务': 'Business',\n",
438 | " '研报': 'Report',\n",
439 | " '机构': 'Organization',\n",
440 | " '风险': 'Risk',\n",
441 | " '文章': 'Article',\n",
442 | " '指标': 'Indicator',\n",
443 | " '品牌': 'Brand',\n",
444 | " '产品': 'Product'\n",
445 | "}\n",
446 | "english_entity_type_vs_chinese_entity_type = {v: k for k, v in chinese_entity_type_vs_english_entity_type.items()}\n",
447 | "\n",
448 | "START_TAG = \"[CLS]\"\n",
449 | "END_TAG = \"[SEP]\"\n",
450 | "O = \"O\"\n",
451 | "\n",
452 | "BPeople = \"B-People\"\n",
453 | "IPeople = \"I-People\"\n",
454 | "BIndustry = \"B-Industry\"\n",
455 | "IIndustry = \"I-Industry\"\n",
456 | "BBusiness = 'B-Business'\n",
457 | "IBusiness = 'I-Business'\n",
458 | "BProduct = 'B-Product'\n",
459 | "IProduct = 'I-Product'\n",
460 | "BReport = 'B-Report'\n",
461 | "IReport = 'I-Report'\n",
462 | "BOrganization = 'B-Organization'\n",
463 | "IOrganization = 'I-Organization'\n",
464 | "BRisk = 'B-Risk'\n",
465 | "IRisk = 'I-Risk'\n",
466 | "BArticle = 'B-Article'\n",
467 | "IArticle = 'I-Article'\n",
468 | "BIndicator = 'B-Indicator'\n",
469 | "IIndicator = 'I-Indicator'\n",
470 | "BBrand = 'B-Brand'\n",
471 | "IBrand = 'I-Brand'\n",
472 | "\n",
473 | "PAD = \"[PAD]\"\n",
474 | "UNK = \"[UNK]\"\n",
475 | "tag2idx = {\n",
476 | " START_TAG: 0,\n",
477 | " END_TAG: 1,\n",
478 | " O: 2,\n",
479 | " BPeople: 3,\n",
480 | " IPeople: 4,\n",
481 | " BIndustry: 5,\n",
482 | " IIndustry: 6,\n",
483 | " BBusiness: 7,\n",
484 | " IBusiness: 8,\n",
485 | " BProduct: 9,\n",
486 | " IProduct: 10,\n",
487 | " BReport: 11,\n",
488 | " IReport: 12,\n",
489 | " BOrganization: 13,\n",
490 | " IOrganization: 14,\n",
491 | " BRisk: 15,\n",
492 | " IRisk: 16,\n",
493 | " BArticle: 17,\n",
494 | " IArticle: 18,\n",
495 | " BIndicator: 19,\n",
496 | " IIndicator: 20,\n",
497 | " BBrand: 21,\n",
498 | " IBrand: 22,\n",
499 | "}\n",
500 | "tag2id = tag2idx\n",
501 | "idx2tag = {v: k for k, v in tag2idx.items()}"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {},
507 | "source": [
508 | "## 预处理数据函数\n",
509 | "\n",
510 | "`preprocess_data` 函数中的 `for_train` 参数比较重要,指示是否是训练集\n",
511 | "\n",
512 | "由于给定的训练数据实体部分没有给定出现的位置,这里需要自行查找到实体出现的位置\n",
513 | "\n",
514 | "- 如果是训练集, 按照`entities_json`中的内容在文章中寻找位置并标注, 并将训练数据处理成bio形式\n",
515 | "- 测试数据仅仅做了分句并转化成token id"
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": null,
521 | "metadata": {},
522 | "outputs": [],
523 | "source": [
524 | "class Article:\n",
525 | " def __init__(self, text):\n",
526 | " self._text = text\n",
527 | " self.para_texts = self.split_into_paras(self._text)\n",
528 | " self.sent_texts = [self.split_into_sentence(para) for para in self.para_texts]\n",
529 | "\n",
530 | " def fix_text(self, text: str) -> str:\n",
531 | " paras = text.split('\\n')\n",
532 | " paras = list(filter(lambda para: len(para.strip()) != 0, paras))\n",
533 | " return '\\n'.join(paras)\n",
534 | "\n",
535 | " def split_into_paras(self, text: str):\n",
536 | " paras = list(filter(lambda para: len(para.strip()) != 0, text.split('\\n')))\n",
537 | " return paras\n",
538 | "\n",
539 | " def split_into_sentence(self, one_para_text: str, splited_puncs = None):\n",
540 | " if splited_puncs is None:\n",
541 | " splited_puncs = ['。', '?', '!']\n",
542 | " splited_re_pattern = '[' + ''.join(splited_puncs) + ']'\n",
543 | "\n",
544 | " para = one_para_text\n",
545 | " sentences = re.split(splited_re_pattern, para)\n",
546 | " sentences = list(filter(lambda sent: len(sent) != 0, sentences))\n",
547 | "\n",
548 | " return sentences\n",
549 | "\n",
550 | " def find_sents_by_entity_name(self, entity_text):\n",
551 | " ret_sents = []\n",
552 | " if entity_text not in self._text:\n",
553 | " return []\n",
554 | " else:\n",
555 | " for para in self.split_into_paras(self._text):\n",
556 | " if entity_text not in para:\n",
557 | " continue\n",
558 | " else:\n",
559 | " for sent in self.split_into_sentence(para):\n",
560 | " if entity_text in sent:\n",
561 | " ret_sents.append(sent)\n",
562 | " return ret_sents"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": null,
568 | "metadata": {},
569 | "outputs": [],
570 | "source": [
571 | "def _find_all_start_end(source, target):\n",
572 | " if not target:\n",
573 | " return []\n",
574 | " occurs = []\n",
575 | " offset = 0\n",
576 | " while offset < len(source):\n",
577 | " found = source[offset:].find(target)\n",
578 | " if found == -1:\n",
579 | " break\n",
580 | " else:\n",
581 | " occurs.append([offset + found, offset + found + len(target) - 1])\n",
582 | " offset += (found + len(target))\n",
583 | " return occurs\n",
584 | "\n",
585 | "def preprocess_data(entities_json,\n",
586 | " article_texts,\n",
587 | " tokenizer: BertTokenizer,\n",
588 | " for_train: bool = True):\n",
589 | " \"\"\"\n",
590 | " [{\n",
591 | " 'sent': xxx, 'entity_name': yyy, 'entity_type': zzz, 'start_token_id': 0, 'end_token_id': 5,\n",
592 | " 'start_index': 0, 'end_index': 2, \n",
593 | " 'sent_tokens': ['token1', 'token2'], 'entity_tokens': ['token3', 'token4']\n",
594 | " }]\n",
595 | " \"\"\"\n",
596 | "\n",
597 | " preprocessed_datas = []\n",
598 | "\n",
599 | " all_sents = []\n",
600 | " for article in tqdm.tqdm([Article(t) for t in article_texts]):\n",
601 | " for para_text in article.para_texts:\n",
602 | " for sent in article.split_into_sentence(para_text):\n",
603 | " sent_tokens = list(sent)\n",
604 | " entity_labels = []\n",
605 | " for entity_type, entities in entities_json.items():\n",
606 | " for entity_name in entities:\n",
607 | " if entity_name not in sent:\n",
608 | " continue\n",
609 | " all_sents.append(sent)\n",
610 | " start_end_indexes = _find_all_start_end(sent, entity_name)\n",
611 | " assert len(start_end_indexes) >= 1\n",
612 | " for str_start_index, str_end_index in start_end_indexes:\n",
613 | " entity_tokens = list(entity_name)\n",
614 | "\n",
615 | " one_entity_label = {\n",
616 | " 'entity_type': entity_type,\n",
617 | " 'start_token_id': str_start_index,\n",
618 | " 'end_token_id': str_end_index,\n",
619 | " 'start_index': str_start_index,\n",
620 | " 'end_index': str_end_index,\n",
621 | " 'entity_tokens': entity_tokens,\n",
622 | " 'entity_name': entity_name\n",
623 | " }\n",
624 | " entity_labels.append(one_entity_label)\n",
625 | "\n",
626 | " if not entity_labels:\n",
627 | " tags = [O for _ in range(len(sent_tokens))]\n",
628 | " tag_ids = [tag2idx[O] for _ in range(len(sent_tokens))]\n",
629 | " if for_train:\n",
630 | " continue\n",
631 | " else:\n",
632 | " tags = []\n",
633 | " tag_ids = []\n",
634 | " for sent_token_index in range(len(sent_tokens)):\n",
635 | " tag = O\n",
636 | " for entity_label in entity_labels:\n",
637 | " if sent_token_index == entity_label['start_token_id']:\n",
638 | " tag = f'B-{chinese_entity_type_vs_english_entity_type[entity_label[\"entity_type\"]]}'\n",
639 | " elif entity_label['start_token_id'] < sent_token_index < entity_label[\"end_token_id\"]:\n",
640 | " tag = f'I-{chinese_entity_type_vs_english_entity_type[entity_label[\"entity_type\"]]}'\n",
641 | " tag_id = tag2idx[tag]\n",
642 | " tags.append(tag)\n",
643 | " tag_ids.append(tag_id)\n",
644 | " assert len(sent_tokens) == len(tags) == len(tag_ids)\n",
645 | " not_o_indexes = [index for index, tag in enumerate(tags) if tag != O]\n",
646 | " all_entities = [sent_tokens[index] for index in not_o_indexes]\n",
647 | " all_entities2 = entity_labels\n",
648 | "\n",
649 | " preprocessed_datas.append({\n",
650 | " 'sent': sent,\n",
651 | " 'sent_tokens': sent_tokens,\n",
652 | " 'sent_token_ids': tokenizer.convert_tokens_to_ids(sent_tokens),\n",
653 | " 'entity_labels': entity_labels,\n",
654 | " 'tags': tags,\n",
655 | " 'tag_ids': tag_ids\n",
656 | " })\n",
657 | " return preprocessed_datas"
658 | ]
659 | },
660 | {
661 | "cell_type": "markdown",
662 | "metadata": {},
663 | "source": [
664 | "# 定义dataset 以及 dataloader"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": null,
670 | "metadata": {},
671 | "outputs": [],
672 | "source": [
673 | "class MyDataset(torch.utils.data.Dataset):\n",
674 | " def __init__(self, preprocessed_datas, tokenizer: BertTokenizer, max_length=512 ):\n",
675 | " self.preprocessed_datas = preprocessed_datas\n",
676 | " self.tokenizer = tokenizer\n",
677 | " self.max_length = max_length\n",
678 | "\n",
679 | " def pad_sent_ids(self, sent_ids, max_length, padded_token_id):\n",
680 | " mask = [1] * (min(len(sent_ids), max_length)) + [0] * (max_length - len(sent_ids))\n",
681 | " sent_ids = sent_ids[:max_length] + [padded_token_id] * (max_length - len(sent_ids))\n",
682 | " return sent_ids, mask\n",
683 | "\n",
684 | " def process_one_preprocessed_data(self, preprocessed_data):\n",
685 | " import copy\n",
686 | " preprocessed_data = copy.deepcopy(preprocessed_data)\n",
687 | " \n",
688 | " sent_token_ids = self.tokenizer.convert_tokens_to_ids(\n",
689 | " [START_TAG]) + preprocessed_data['sent_token_ids'] + self.tokenizer.convert_tokens_to_ids([END_TAG])\n",
690 | "\n",
691 | " sent_token_ids, mask = self.pad_sent_ids(\n",
692 | " sent_token_ids, max_length=self.max_length, padded_token_id=self.tokenizer.pad_token_id)\n",
693 | "\n",
694 | " sent_token_ids = np.array(sent_token_ids)\n",
695 | " mask = np.array(mask)\n",
696 | " \n",
697 | " preprocessed_data['sent'] = '^' + preprocessed_data['sent'] + '$'\n",
698 | " preprocessed_data['sent_tokens'] = [START_TAG] + preprocessed_data['sent_tokens'] + [END_TAG]\n",
699 | " preprocessed_data['sent_token_ids'] = sent_token_ids\n",
700 | " \n",
701 | " \n",
702 | " tags = [START_TAG] + preprocessed_data['tags'] + [END_TAG]\n",
703 | " tag_ids = [tag2idx[START_TAG]] + preprocessed_data['tag_ids'] + [tag2idx[END_TAG]]\n",
704 | " tag_ids, _ = self.pad_sent_ids(tag_ids, max_length=self.max_length, padded_token_id=tag2idx[O])\n",
705 | " tag_ids = np.array(tag_ids)\n",
706 | " \n",
707 | "\n",
708 | " for entity_label in preprocessed_data['entity_labels']:\n",
709 | " entity_label['start_token_id'] += 1\n",
710 | " entity_label['end_token_id'] += 1\n",
711 | " entity_label['start_index'] += 1\n",
712 | " entity_label['end_index'] += 1\n",
713 | " \n",
714 | " \n",
715 | " preprocessed_data['tags'] = tags\n",
716 | " preprocessed_data['tag_ids'] = tag_ids\n",
717 | "\n",
718 | " not_o_indexes = [index for index, tag in enumerate(preprocessed_data['tags']) if tag != O]\n",
719 | "\n",
720 | " not_o_indexes_str = not_o_indexes\n",
721 | " all_entities = [preprocessed_data['sent_tokens'][index] for index in not_o_indexes]\n",
722 | " all_entities2 = preprocessed_data['entity_labels']\n",
723 | " all_entities3 = [preprocessed_data['sent'][index] for index in not_o_indexes_str]\n",
724 | " \n",
725 | " preprocessed_data.update({'mask': mask})\n",
726 | "\n",
727 | " return preprocessed_data\n",
728 | "\n",
729 | " def __getitem__(self, item):\n",
730 | " return self.process_one_preprocessed_data(\n",
731 | " self.preprocessed_datas[item]\n",
732 | " )\n",
733 | "\n",
734 | " def __len__(self):\n",
735 | " return len(self.preprocessed_datas)\n",
736 | "\n",
737 | "\n",
738 | "def custom_collate_fn(data):\n",
739 | " # copy from torch official,无需深究\n",
740 | " from torch._six import container_abcs, string_classes\n",
741 | "\n",
742 | " r\"\"\"Converts each NumPy array data field into a tensor\"\"\"\n",
743 | " np_str_obj_array_pattern = re.compile(r'[SaUO]')\n",
744 | " elem_type = type(data)\n",
745 | " if isinstance(data, torch.Tensor):\n",
746 | " return data\n",
747 | " elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' and elem_type.__name__ != 'string_':\n",
748 | " # array of string classes and object\n",
749 | " if elem_type.__name__ == 'ndarray' and np_str_obj_array_pattern.search(data.dtype.str) is not None:\n",
750 | " return data\n",
751 | " return torch.as_tensor(data)\n",
752 | " elif isinstance(data, container_abcs.Mapping):\n",
753 | " tmp_dict = {}\n",
754 | " for key in data:\n",
755 | " if key in ['sent_token_ids', 'tag_ids', 'mask']:\n",
756 | " tmp_dict[key] = custom_collate_fn(data[key])\n",
757 | " if key == 'mask':\n",
758 | " tmp_dict[key] = tmp_dict[key].byte()\n",
759 | " else:\n",
760 | " tmp_dict[key] = data[key]\n",
761 | " return tmp_dict\n",
762 | " elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple\n",
763 | " return elem_type(*(custom_collate_fn(d) for d in data))\n",
764 | " elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):\n",
765 | " return [custom_collate_fn(d) for d in data]\n",
766 | " else:\n",
767 | " return data\n",
768 | "\n",
769 | "\n",
770 | "def build_dataloader(preprocessed_datas, tokenizer: BertTokenizer, batch_size=32, shuffle=True):\n",
771 | " dataset = MyDataset(preprocessed_datas, tokenizer)\n",
772 | " import torch.utils.data\n",
773 | " dataloader = torch.utils.data.DataLoader(\n",
774 | " dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle)\n",
775 | " return dataloader"
776 | ]
777 | },
778 | {
779 | "cell_type": "markdown",
780 | "metadata": {},
781 | "source": [
782 | "# 定义训练时评价指标\n",
783 | "\n",
784 | "仅供训练时参考, 包含实体的precision,recall以及f1。\n",
785 | "\n",
786 | "只有和标注的数据完全相同才算是1,否则为0"
787 | ]
788 | },
789 | {
790 | "cell_type": "code",
791 | "execution_count": null,
792 | "metadata": {},
793 | "outputs": [],
794 | "source": [
795 | "# 训练时指标\n",
796 | "class EvaluateScores:\n",
797 | " def __init__(self, entities_json, predict_entities_json):\n",
798 | " self.entities_json = entities_json\n",
799 | " self.predict_entities_json = predict_entities_json\n",
800 | "\n",
801 | " def compute_entities_score(self):\n",
802 | " return evaluate_entities(self.entities_json, self.predict_entities_json, list(set(self.entities_json.keys())))\n",
803 | " \n",
804 | "def _compute_metrics(ytrue, ypred):\n",
805 | " ytrue = set(ytrue)\n",
806 | " ypred = set(ypred)\n",
807 | " tr = len(ytrue)\n",
808 | " pr = len(ypred)\n",
809 | " hit = len(ypred.intersection(ytrue))\n",
810 | " p = hit / pr if pr!=0 else 0\n",
811 | " r = hit / tr if tr!=0 else 0\n",
812 | " f1 = 2 * p * r / (p + r) if (p+r)!=0 else 0\n",
813 | " return {\n",
814 | " 'p': p,\n",
815 | " 'r': r,\n",
816 | " 'f': f1,\n",
817 | " }\n",
818 | "\n",
819 | "\n",
820 | "def evaluate_entities(true_entities, pred_entities, entity_types):\n",
821 | " scores = []\n",
822 | "\n",
823 | " ps2 = []\n",
824 | " rs2 = []\n",
825 | " fs2 = []\n",
826 | " \n",
827 | " for ent_type in entity_types:\n",
828 | "\n",
829 | " true_entities_list = true_entities.get(ent_type, [])\n",
830 | " pred_entities_list = pred_entities.get(ent_type, [])\n",
831 | " s = _compute_metrics(true_entities_list, pred_entities_list)\n",
832 | " scores.append(s)\n",
833 | " ps = [i['p'] for i in scores]\n",
834 | " rs = [i['r'] for i in scores]\n",
835 | " fs = [i['f'] for i in scores]\n",
836 | " s = {\n",
837 | " 'p': sum(ps) / len(ps),\n",
838 | " 'r': sum(rs) / len(rs),\n",
839 | " 'f': sum(fs) / len(fs),\n",
840 | " }\n",
841 | " return s"
842 | ]
843 | },
844 | {
845 | "cell_type": "markdown",
846 | "metadata": {},
847 | "source": [
848 | "## 定义ner train loop, evaluate loop ,test loop"
849 | ]
850 | },
851 | {
852 | "cell_type": "code",
853 | "execution_count": null,
854 | "metadata": {},
855 | "outputs": [],
856 | "source": [
857 | "def train(model: BertCRF, optimizer, data_loader: torch.utils.data.DataLoader, logger: logging.Logger, epoch_id,\n",
858 | " device='cpu'):\n",
859 | " pbar = tqdm.tqdm(data_loader)\n",
860 | " for batch_id, one_data in enumerate(pbar):\n",
861 | " model.train()\n",
862 | "\n",
863 | " sent_token_ids = torch.stack([d['sent_token_ids'] for d in one_data]).to(device)\n",
864 | " tag_ids = torch.stack([d['tag_ids'] for d in one_data]).to(device)\n",
865 | " mask = torch.stack([d['mask'] for d in one_data]).to(device)\n",
866 | "\n",
867 | " loss = model.forward(sent_token_ids, tag_ids, mask)\n",
868 | " optimizer.zero_grad()\n",
869 | " loss.backward()\n",
870 | " optimizer.step()\n",
871 | " pbar.set_description('epoch: {}, loss: {:.3f}'.format(epoch_id, loss.item()))\n",
872 | "\n",
873 | "\n",
874 | "def evaluate(\n",
875 | " model, data_loader: torch.utils.data.DataLoader, logger: logging.Logger,\n",
876 | " tokenizer, device='cpu',\n",
877 | "):\n",
878 | " founded_entities_json = defaultdict(set)\n",
879 | " golden_entities_json = defaultdict(set)\n",
880 | " for batch_id, one_data in enumerate(data_loader):\n",
881 | " model.eval()\n",
882 | " sent_token_ids = torch.stack([d['sent_token_ids'] for d in one_data]).to(device)\n",
883 | " tag_ids = torch.stack([d['tag_ids'] for d in one_data]).to(device)\n",
884 | " mask = torch.stack([d['mask'] for d in one_data]).to(device)\n",
885 | "\n",
886 | " best_tag_ids_list = model.decode(sent_token_ids, attention_mask=mask)\n",
887 | " best_tags_list = [[idx2tag[idx] for idx in idxs] for idxs in best_tag_ids_list]\n",
888 | "\n",
889 | " for data, best_tags in zip(one_data, best_tag_ids_list):\n",
890 | "\n",
891 | " for entity_label in data['entity_labels']:\n",
892 | " golden_entities_json[entity_label['entity_type']].add(entity_label['entity_name'])\n",
893 | "\n",
894 | " record = False\n",
895 | " for token_index, tag_id in enumerate(best_tags):\n",
896 | " tag = idx2tag[tag_id]\n",
897 | " if tag.startswith('B'):\n",
898 | " start_token_index = token_index\n",
899 | " entity_type = tag[2:]\n",
900 | " record = True\n",
901 | " elif record and tag == O:\n",
902 | " end_token_index = token_index\n",
903 | "\n",
904 | " str_start_index = start_token_index\n",
905 | " str_end_index = end_token_index\n",
906 | "\n",
907 | " entity_name = data['sent'][str_start_index: str_end_index]\n",
908 | "\n",
909 | " entity_type = english_entity_type_vs_chinese_entity_type[entity_type]\n",
910 | " founded_entities_json[entity_type].add(entity_name)\n",
911 | " record = False\n",
912 | " evaluate_tool = EvaluateScores(golden_entities_json, founded_entities_json)\n",
913 | " scores = evaluate_tool.compute_entities_score()\n",
914 | " return scores['f']\n",
915 | "\n",
916 | "\n",
917 | "def test(model, data_loader: torch.utils.data.DataLoader, logger: logging.Logger, device):\n",
918 | " founded_entities = []\n",
919 | " for batch_id, one_data in enumerate(tqdm.tqdm(data_loader)):\n",
920 | " model.eval()\n",
921 | " sent_token_ids = torch.stack([d['sent_token_ids'] for d in one_data]).to(device)\n",
922 | " mask = torch.stack([d['mask'] for d in one_data]).to(device)\n",
923 | "\n",
924 | " with torch.no_grad():\n",
925 | " best_tag_ids_list = model.decode(sent_token_ids, attention_mask=mask, token_type_ids=None)\n",
926 | "\n",
927 | " for data, best_tags in zip(one_data, best_tag_ids_list):\n",
928 | " record = False\n",
929 | " for token_index, tag_id in enumerate(best_tags):\n",
930 | " tag = idx2tag[tag_id]\n",
931 | " if tag.startswith('B'):\n",
932 | " start_token_index = token_index\n",
933 | " entity_type = tag[2:]\n",
934 | " record = True\n",
935 | " elif record and tag == O:\n",
936 | " end_token_index = token_index\n",
937 | " entity_name = data['sent_tokens'][start_token_index: end_token_index + 1]\n",
938 | " founded_entities.append((entity_name, entity_type, data['sent']))\n",
939 | " record = False\n",
940 | " result = defaultdict(list)\n",
941 | " for entity_name, entity_type, sent in founded_entities:\n",
942 | " entity = ''.join(entity_name).replace('##', '')\n",
943 | " entity = entity.replace('[CLS]', '')\n",
944 | " entity = entity.replace('[UNK]', '')\n",
945 | " entity = entity.replace('[SEP]', '')\n",
946 | " if len(entity) > 1:\n",
947 | " result[english_entity_type_vs_chinese_entity_type[entity_type]].append((entity, sent))\n",
948 | "\n",
949 | " for ent_type, ents in result.items():\n",
950 | " result[ent_type] = list(set(ents))\n",
951 | " return result"
952 | ]
953 | },
954 | {
955 | "cell_type": "markdown",
956 | "metadata": {},
957 | "source": [
958 | "# ner主要训练流程\n",
959 | "\n",
960 | "- 分隔训练集验证集,并处理成dataset dataloader\n",
961 | "- 训练,验证,保存模型"
962 | ]
963 | },
964 | {
965 | "cell_type": "code",
966 | "execution_count": null,
967 | "metadata": {},
968 | "outputs": [],
969 | "source": [
970 | "def main_train(logger, tokenizer, model, to_be_trained_entities, yanbao_texts):\n",
971 | " entities_json = to_be_trained_entities\n",
972 | " train_entities_json = {k: [] for k in entities_json}\n",
973 | " dev_entities_json = {k: [] for k in entities_json}\n",
974 | "\n",
975 | " train_proportion = 0.9\n",
976 | " for entity_type, entities in entities_json.items():\n",
977 | " entities = entities.copy()\n",
978 | " random.shuffle(entities)\n",
979 | " \n",
980 | " train_entities_json[entity_type] = entities[: int(len(entities) * train_proportion)]\n",
981 | " dev_entities_json[entity_type] = entities[int(len(entities) * train_proportion):]\n",
982 | "\n",
983 | " \n",
984 | " train_preprocessed_datas = preprocess_data(train_entities_json, yanbao_texts, tokenizer)\n",
985 | " train_dataloader = build_dataloader(train_preprocessed_datas, tokenizer, batch_size=BATCH_SIZE)\n",
986 | " \n",
987 | " dev_preprocessed_datas = preprocess_data(dev_entities_json, yanbao_texts, tokenizer)\n",
988 | " dev_dataloader = build_dataloader(dev_preprocessed_datas, tokenizer, batch_size=BATCH_SIZE)\n",
989 | "\n",
990 | " model = model.to(DEVICE)\n",
991 | " for name, param in model.named_parameters():\n",
992 | " if \"bert_module\" in name:\n",
993 | " param.requires_grad = False\n",
994 | " else:\n",
995 | " param.requires_grad = True\n",
996 | " optimizer = torch.optim.Adam([para for para in model.parameters() if para.requires_grad],\n",
997 | " lr=0.001,\n",
998 | " weight_decay=0.0005)\n",
999 | " best_evaluate_score = 0\n",
1000 | " for epoch in range(TOTAL_EPOCH_NUMS):\n",
1001 | " train(model, optimizer, train_dataloader, logger=logger, epoch_id=epoch, device=DEVICE)\n",
1002 | " evaluate_score = evaluate(model, dev_dataloader, logger=logger, tokenizer=tokenizer, device=DEVICE)\n",
1003 | " print('评估分数:', evaluate_score)\n",
1004 | " if evaluate_score >= best_evaluate_score:\n",
1005 | " best_evaluate_score = evaluate_score\n",
1006 | " save_model_path = os.path.join(SAVE_MODEL_DIR, 'finnal_ccks_model.pth')\n",
1007 | " logger.info('saving model to {}'.format(save_model_path))\n",
1008 | " torch.save(model.cpu().state_dict(), save_model_path)\n",
1009 | " model.to(DEVICE)"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "metadata": {},
1015 | "source": [
1016 | "## 准备训练ner模型"
1017 | ]
1018 | },
1019 | {
1020 | "cell_type": "code",
1021 | "execution_count": null,
1022 | "metadata": {},
1023 | "outputs": [],
1024 | "source": [
1025 | "logger = logging.getLogger(__name__)\n",
1026 | "\n",
1027 | "tokenizer = BertTokenizer.from_pretrained(\n",
1028 | " os.path.join(PRETRAINED_BERT_MODEL_DIR, 'vocab.txt')\n",
1029 | ")\n",
1030 | "\n",
1031 | "model = BertCRF(\n",
1032 | " pretrained_bert_model_file_path=PRETRAINED_BERT_MODEL_DIR,\n",
1033 | " num_tags=len(tag2id), batch_first=True\n",
1034 | ")\n",
1035 | "\n",
1036 | "save_model_path = os.path.join(SAVE_MODEL_DIR, 'finnal_ccks_model.pth')\n",
1037 | "if Path(save_model_path).exists():\n",
1038 | " model_state_dict = torch.load(save_model_path, map_location='cpu')\n",
1039 | " model.load_state_dict(model_state_dict)"
1040 | ]
1041 | },
1042 | {
1043 | "cell_type": "code",
1044 | "execution_count": null,
1045 | "metadata": {},
1046 | "outputs": [],
1047 | "source": [
1048 | "# 训练数据在main_train函数中处理并生成dataset dataloader,此处无需生成\n",
1049 | "\n",
1050 | "# 测试数据在此处处理并生成dataset dataloader\n",
1051 | "test_preprocessed_datas = preprocess_data({}, yanbao_texts, tokenizer, for_train=False)\n",
1052 | "test_dataloader = build_dataloader(test_preprocessed_datas, tokenizer, batch_size=BATCH_SIZE)"
1053 | ]
1054 | },
1055 | {
1056 | "cell_type": "markdown",
1057 | "metadata": {},
1058 | "source": [
1059 | "## 整个训练流程是:\n",
1060 | "\n",
1061 | "- 使用数据集增强得到更多的实体\n",
1062 | "- 使用增强过后的实体来指导训练\n",
1063 | "\n",
1064 | "\n",
1065 | "- 训练后的模型重新对所有文档中进行预测,得到新的实体,加入到实体数据集中\n",
1066 | "- 使用扩增后的实体数据集来进行二次训练,再得到新的实体,再增强实体数据集\n",
1067 | "- (模型预测出来的数据需要`review_model_predict_entities`后处理形成提交格式)\n",
1068 | "\n",
1069 | "\n",
1070 | "- 如果提交结果,需要`extract_entities`函数删除提交数据中那些出现在训练数据中的实体"
1071 | ]
1072 | },
1073 | {
1074 | "cell_type": "markdown",
1075 | "metadata": {},
1076 | "source": [
1077 | "### 模型预测结果后处理函数\n",
1078 | "\n",
1079 | "- `review_model_predict_entities`函数将模型预测结果后处理,从而生成提交文件格式"
1080 | ]
1081 | },
1082 | {
1083 | "cell_type": "code",
1084 | "execution_count": null,
1085 | "metadata": {},
1086 | "outputs": [],
1087 | "source": [
1088 | "def review_model_predict_entities(model_predict_entities):\n",
1089 | " word_tag_map = POSTokenizer().word_tag_tab\n",
1090 | " idf_freq = TFIDF().idf_freq\n",
1091 | " reviewed_entities = defaultdict(list)\n",
1092 | " for ent_type, ent_and_sent_list in model_predict_entities.items():\n",
1093 | " for ent, sent in ent_and_sent_list:\n",
1094 | " start = sent.lower().find(ent)\n",
1095 | " if start == -1:\n",
1096 | " continue\n",
1097 | " start += 1\n",
1098 | " end = start + len(ent) - 1\n",
1099 | " tokens = jieba.lcut(sent)\n",
1100 | " offset = 0\n",
1101 | " selected_tokens = []\n",
1102 | " for token in tokens:\n",
1103 | " offset += len(token)\n",
1104 | " if offset >= start:\n",
1105 | " selected_tokens.append(token)\n",
1106 | " if offset >= end:\n",
1107 | " break\n",
1108 | "\n",
1109 | " fixed_entity = ''.join(selected_tokens)\n",
1110 | " fixed_entity = re.sub(r'\\d*\\.?\\d+%$', '', fixed_entity)\n",
1111 | " if ent_type == '人物':\n",
1112 | " if len(fixed_entity) >= 10:\n",
1113 | " continue\n",
1114 | " if len(fixed_entity) <= 1:\n",
1115 | " continue\n",
1116 | " if re.findall(r'^\\d+$', fixed_entity):\n",
1117 | " continue\n",
1118 | " if word_tag_map.get(fixed_entity, '') == 'v' and idf_freq[fixed_entity] < 7:\n",
1119 | " continue\n",
1120 | " reviewed_entities[ent_type].append(fixed_entity)\n",
1121 | " return reviewed_entities"
1122 | ]
1123 | },
1124 | {
1125 | "cell_type": "markdown",
1126 | "metadata": {},
1127 | "source": [
1128 | "- `extract_entities` 删除与训练集中重复的实体"
1129 | ]
1130 | },
1131 | {
1132 | "cell_type": "code",
1133 | "execution_count": null,
1134 | "metadata": {},
1135 | "outputs": [],
1136 | "source": [
1137 | "def extract_entities(to_be_trained_entities):\n",
1138 | " test_entities = to_be_trained_entities\n",
1139 | " train_entities = read_json(Path(DATA_DIR, 'entities.json'))\n",
1140 | "\n",
1141 | " for ent_type, ents in test_entities.items():\n",
1142 | " test_entities[ent_type] = list(set(ents) - set(train_entities[ent_type]))\n",
1143 | "\n",
1144 | " for ent_type in train_entities.keys():\n",
1145 | " if ent_type not in test_entities:\n",
1146 | " test_entities[ent_type] = []\n",
1147 | " return test_entities"
1148 | ]
1149 | },
1150 | {
1151 | "cell_type": "code",
1152 | "execution_count": null,
1153 | "metadata": {
1154 | "scrolled": false
1155 | },
1156 | "outputs": [],
1157 | "source": [
1158 | "# 循环轮次数目\n",
1159 | "nums_round = 1\n",
1160 | "for i in range(nums_round):\n",
1161 | " # train\n",
1162 | " main_train(logger, tokenizer, model, to_be_trained_entities, yanbao_texts) \n",
1163 | " \n",
1164 | " model = model.to(DEVICE)\n",
1165 | " model_predict_entities = test(model, test_dataloader, logger=logger, device=DEVICE)\n",
1166 | " \n",
1167 | " # 修复训练预测结果\n",
1168 | " reviewed_entities = review_model_predict_entities(model_predict_entities)\n",
1169 | " \n",
1170 | " # 将训练预测结果再次放入训练集中, 重新训练或者直接出结果\n",
1171 | " for ent_type, ents in reviewed_entities.items():\n",
1172 | " to_be_trained_entities[ent_type] = list(set(to_be_trained_entities[ent_type] + ents))\n",
1173 | "\n",
1174 | "# 创造出提交结果\n",
1175 | "submit_entities = extract_entities(to_be_trained_entities)"
1176 | ]
1177 | },
1178 | {
1179 | "cell_type": "markdown",
1180 | "metadata": {},
1181 | "source": [
1182 | "# 属性抽取\n",
1183 | "\n",
1184 | "通过规则抽取属性\n",
1185 | "\n",
1186 | "- 研报时间\n",
1187 | "- 研报评级\n",
1188 | "- 文章时间"
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "code",
1193 | "execution_count": null,
1194 | "metadata": {},
1195 | "outputs": [],
1196 | "source": [
1197 | "def find_article_time(yanbao_txt, entity):\n",
1198 | " str_start_index = yanbao_txt.index(entity)\n",
1199 | " str_end_index = str_start_index + len(entity)\n",
1200 | " para_start_index = yanbao_txt.rindex('\\n', 0, str_start_index)\n",
1201 | " para_end_index = yanbao_txt.index('\\n', str_end_index)\n",
1202 | "\n",
1203 | " para = yanbao_txt[para_start_index + 1: para_end_index].strip()\n",
1204 | " if len(entity) > 5:\n",
1205 | " ret = re.findall(r'(\\d{4})\\s*[年-]\\s*(\\d{1,2})\\s*[月-]\\s*(\\d{1,2})\\s*日?', para)\n",
1206 | " if ret:\n",
1207 | " year, month, day = ret[0]\n",
1208 | " time = '{}/{}/{}'.format(year, month.lstrip(), day.lstrip())\n",
1209 | " return time\n",
1210 | "\n",
1211 | " start_index = 0\n",
1212 | " time = None\n",
1213 | " min_gap = float('inf')\n",
1214 | " for word, poseg in pseg.cut(para):\n",
1215 | " if poseg in ['t', 'TIME'] and str_start_index <= start_index < str_end_index:\n",
1216 | " gap = abs(start_index - (str_start_index + str_end_index) // 2)\n",
1217 | " if gap < min_gap:\n",
1218 | " min_gap = gap\n",
1219 | " time = word\n",
1220 | " start_index += len(word)\n",
1221 | " return time\n",
1222 | "\n",
1223 | "\n",
1224 | "def find_yanbao_time(yanbao_txt, entity):\n",
1225 | " paras = [para.strip() for para in yanbao_txt.split('\\n') if para.strip()][:5]\n",
1226 | " for para in paras:\n",
1227 | " ret = re.findall(r'(\\d{4})\\s*[\\./年-]\\s*(\\d{1,2})\\s*[\\./月-]\\s*(\\d{1,2})\\s*日?', para)\n",
1228 | " if ret:\n",
1229 | " year, month, day = ret[0]\n",
1230 | " time = '{}/{}/{}'.format(year, month.lstrip(), day.lstrip())\n",
1231 | " return time\n",
1232 | " return None"
1233 | ]
1234 | },
1235 | {
1236 | "cell_type": "code",
1237 | "execution_count": null,
1238 | "metadata": {},
1239 | "outputs": [],
1240 | "source": [
1241 | "def extract_attrs(entities_json):\n",
1242 | " train_attrs = read_json(Path(DATA_DIR, 'attrs.json'))['attrs']\n",
1243 | "\n",
1244 | " seen_pingjis = []\n",
1245 | " for attr in train_attrs:\n",
1246 | " if attr[1] == '评级':\n",
1247 | " seen_pingjis.append(attr[2])\n",
1248 | " article_entities = entities_json.get('文章', [])\n",
1249 | " yanbao_entities = entities_json.get('研报', [])\n",
1250 | "\n",
1251 | " attrs_json = []\n",
1252 | " for file_path in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))):\n",
1253 | " yanbao_txt = '\\n' + Path(file_path).open().read() + '\\n'\n",
1254 | " for entity in article_entities:\n",
1255 | " if entity not in yanbao_txt:\n",
1256 | " continue\n",
1257 | " time = find_article_time(yanbao_txt, entity)\n",
1258 | " if time:\n",
1259 | " attrs_json.append([entity, '发布时间', time])\n",
1260 | "\n",
1261 | " yanbao_txt = '\\n'.join(\n",
1262 | " [para.strip() for para in yanbao_txt.split('\\n') if\n",
1263 | " len(para.strip()) != 0])\n",
1264 | " for entity in yanbao_entities:\n",
1265 | " if entity not in yanbao_txt:\n",
1266 | " continue\n",
1267 | "\n",
1268 | " paras = yanbao_txt.split('\\n')\n",
1269 | " for para_id, para in enumerate(paras):\n",
1270 | " if entity in para:\n",
1271 | " break\n",
1272 | "\n",
1273 | " paras = paras[: para_id + 5]\n",
1274 | " for para in paras:\n",
1275 | " for pingji in seen_pingjis:\n",
1276 | " if pingji in para:\n",
1277 | " if '上次' in para:\n",
1278 | " attrs_json.append([entity, '上次评级', pingji])\n",
1279 | " continue\n",
1280 | " elif '维持' in para:\n",
1281 | " attrs_json.append([entity, '上次评级', pingji])\n",
1282 | " attrs_json.append([entity, '评级', pingji])\n",
1283 | "\n",
1284 | " time = find_yanbao_time(yanbao_txt, entity)\n",
1285 | " if time:\n",
1286 | " attrs_json.append([entity, '发布时间', time])\n",
1287 | " attrs_json = list(set(tuple(_) for _ in attrs_json) - set(tuple(_) for _ in train_attrs))\n",
1288 | " \n",
1289 | " return attrs_json"
1290 | ]
1291 | },
1292 | {
1293 | "cell_type": "code",
1294 | "execution_count": null,
1295 | "metadata": {},
1296 | "outputs": [],
1297 | "source": [
1298 | "train_attrs = read_json(Path(DATA_DIR, 'attrs.json'))['attrs']\n",
1299 | "submit_attrs = extract_attrs(submit_entities)"
1300 | ]
1301 | },
1302 | {
1303 | "cell_type": "code",
1304 | "execution_count": null,
1305 | "metadata": {},
1306 | "outputs": [],
1307 | "source": [
1308 | "submit_attrs"
1309 | ]
1310 | },
1311 | {
1312 | "cell_type": "markdown",
1313 | "metadata": {},
1314 | "source": [
1315 | "# 关系抽取\n",
1316 | "\n",
1317 | "- 对于研报实体,整个文档抽取特定类型(行业,机构,指标)的关系实体\n",
1318 | "- 其他的实体仅考虑与其出现在同一句话中的其他实体组织成特定关系"
1319 | ]
1320 | },
1321 | {
1322 | "cell_type": "code",
1323 | "execution_count": null,
1324 | "metadata": {},
1325 | "outputs": [],
1326 | "source": [
1327 | "def extract_relations(schema, entities_json):\n",
1328 | " relation_by_rules = []\n",
1329 | " relation_schema = schema['relationships']\n",
1330 | " unique_s_o_types = []\n",
1331 | " so_type_cnt = defaultdict(int)\n",
1332 | " for s_type, p, o_type in schema['relationships']:\n",
1333 | " so_type_cnt[(s_type, o_type)] += 1\n",
1334 | " for (s_type, o_type), cnt in so_type_cnt.items():\n",
1335 | " if cnt == 1 and s_type != o_type:\n",
1336 | " unique_s_o_types.append((s_type, o_type))\n",
1337 | "\n",
1338 | " for path in tqdm.tqdm(list(Path(DATA_DIR, 'yanbao_txt').glob('*.txt'))):\n",
1339 | " with open(path) as f:\n",
1340 | " entity_dict_in_file = defaultdict(lambda: defaultdict(list))\n",
1341 | " main_org = None\n",
1342 | " for line_idx, line in enumerate(f.readlines()):\n",
1343 | " for sent_idx, sent in enumerate(split_to_sents(line)):\n",
1344 | " for ent_type, ents in entities_json.items():\n",
1345 | " for ent in ents:\n",
1346 | " if ent in sent:\n",
1347 | " if ent_type == '机构' and len(line) - len(ent) < 3 or \\\n",
1348 | " re.findall('[\\((]\\d+\\.*[A-Z]*[\\))]', line):\n",
1349 | " main_org = ent\n",
1350 | " else:\n",
1351 | " if main_org and '客户' in sent:\n",
1352 | " relation_by_rules.append([ent, '客户', main_org])\n",
1353 | " entity_dict_in_file[ent_type][\n",
1354 | " ('test', ent)].append(\n",
1355 | " [line_idx, sent_idx, sent,\n",
1356 | " sent.find(ent)]\n",
1357 | " )\n",
1358 | "\n",
1359 | " for s_type, p, o_type in relation_schema:\n",
1360 | " s_ents = entity_dict_in_file[s_type]\n",
1361 | " o_ents = entity_dict_in_file[o_type]\n",
1362 | " if o_type == '业务' and not '业务' in line:\n",
1363 | " continue\n",
1364 | " if o_type == '行业' and not '行业' in line:\n",
1365 | " continue\n",
1366 | " if o_type == '文章' and not ('《' in line or not '》' in line):\n",
1367 | " continue\n",
1368 | " if s_ents and o_ents:\n",
1369 | " for (s_ent_src, s_ent), (o_ent_src, o_ent) in product(s_ents, o_ents):\n",
1370 | " if s_ent != o_ent:\n",
1371 | " s_occs = [tuple(_[:2]) for _ in\n",
1372 | " s_ents[(s_ent_src, s_ent)]]\n",
1373 | " o_occs = [tuple(_[:2]) for _ in\n",
1374 | " o_ents[(o_ent_src, o_ent)]]\n",
1375 | " intersection = set(s_occs) & set(o_occs)\n",
1376 | " if s_type == '研报' and s_ent_src == 'test':\n",
1377 | " relation_by_rules.append([s_ent, p, o_ent])\n",
1378 | " continue\n",
1379 | " if not intersection:\n",
1380 | " continue\n",
1381 | " if (s_type, o_type) in unique_s_o_types and s_ent_src == 'test':\n",
1382 | " relation_by_rules.append([s_ent, p, o_ent])\n",
1383 | "\n",
1384 | " train_relations = read_json(Path(DATA_DIR, 'relationships.json'))['relationships']\n",
1385 | " result_relations_set = list(set(tuple(_) for _ in relation_by_rules) - set(tuple(_) for _ in train_relations))\n",
1386 | " return result_relations_set"
1387 | ]
1388 | },
1389 | {
1390 | "cell_type": "code",
1391 | "execution_count": null,
1392 | "metadata": {},
1393 | "outputs": [],
1394 | "source": [
1395 | "schema = read_json(Path(DATA_DIR, 'schema.json'))\n",
1396 | "submit_relations = extract_relations(schema, submit_entities)"
1397 | ]
1398 | },
1399 | {
1400 | "cell_type": "code",
1401 | "execution_count": null,
1402 | "metadata": {},
1403 | "outputs": [],
1404 | "source": [
1405 | "submit_relations"
1406 | ]
1407 | },
1408 | {
1409 | "cell_type": "markdown",
1410 | "metadata": {},
1411 | "source": [
1412 | "## 生成提交文件\n",
1413 | "\n",
1414 | "根据biendata的要求生成提交文件\n",
1415 | "\n",
1416 | "参考:https://www.biendata.com/competition/ccks_2020_5/make-submission/"
1417 | ]
1418 | },
1419 | {
1420 | "cell_type": "code",
1421 | "execution_count": null,
1422 | "metadata": {},
1423 | "outputs": [],
1424 | "source": [
1425 | "final_answer = {'attrs': submit_attrs,\n",
1426 | " 'entities': submit_entities,\n",
1427 | " 'relationships': submit_relations,\n",
1428 | " }\n",
1429 | "\n",
1430 | "\n",
1431 | "with open('output/answers.json', mode='w') as fw:\n",
1432 | " json.dump(final_answer, fw, ensure_ascii=False, indent=4)\n"
1433 | ]
1434 | },
1435 | {
1436 | "cell_type": "code",
1437 | "execution_count": null,
1438 | "metadata": {},
1439 | "outputs": [],
1440 | "source": [
1441 | "with open('output/answers.json', 'rb') as fb:\n",
1442 | " data = fb.read()\n",
1443 | "\n",
1444 | "b64 = base64.b64encode(data)\n",
1445 | "payload = b64.decode()\n",
1446 | "html = '{title}'\n",
1447 | "html = html.format(payload=payload,title='answers.json',filename='answers.json')\n",
1448 | "HTML(html)"
1449 | ]
1450 | }
1451 | ],
1452 | "metadata": {
1453 | "kernelspec": {
1454 | "display_name": "Python 3",
1455 | "language": "python",
1456 | "name": "python3"
1457 | },
1458 | "language_info": {
1459 | "codemirror_mode": {
1460 | "name": "ipython",
1461 | "version": 3
1462 | },
1463 | "file_extension": ".py",
1464 | "mimetype": "text/x-python",
1465 | "name": "python",
1466 | "nbconvert_exporter": "python",
1467 | "pygments_lexer": "ipython3",
1468 | "version": "3.7.4"
1469 | }
1470 | },
1471 | "nbformat": 4,
1472 | "nbformat_minor": 4
1473 | }
1474 |
--------------------------------------------------------------------------------
/ccks2020-datagrand-qq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wgwang/ccks2020-baseline/96c8f990f87afc3a12f1895bab21c095d8cb410e/ccks2020-datagrand-qq.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # CCKS 2020 Baseline: 基于本体的金融知识图谱自动化构建技术评测的指引
2 |
3 |
4 | ## 竞赛主页
5 |
6 | https://www.biendata.com/competition/ccks_2020_5/
7 |
8 | ## docker
9 |
10 | 提供一个打好的镜像可直接运行,见:
11 |
12 | https://hub.docker.com/repository/docker/wgwang/ccks2020-baseline/general
13 |
14 |
15 | ## 竞赛数据集
16 |
17 | 竞赛的数据集已经在 SciDB 和 OpenKG 上公开发布,详情见:
18 |
19 | [**中文领域最大规模的金融研报知识图谱数据集FR2KG**](https://mp.weixin.qq.com/s/cgQq27rmU_8LqM1t07A6Pg)
20 |
21 | ## 参考书籍
22 |
23 | 《[**知识图谱:认知智能理论与实战**](https://item.jd.com/13172503.html)》一书是绝佳的参考教材:
24 | - 一百多张精美配图详细解析数十个知识图谱前沿算法
25 | - 首创知识图谱建模方法论——六韬法及模式设计工程模型
26 | - 全面涵盖知识图谱模式设计、构建、存储和应用技术
27 | - 前沿理论研究成果和产业落地实践相结合,学术界和企业界顶尖专家学者联合力荐
28 | - 全彩印刷,书籍质量一流,看起来赏心悦目
29 | - 促进知识图谱落地应用,推动人工智能向认知智能蓬勃发展
30 | - 学术界:
31 | - 中国中文信息学会会士,中国计算机学会 NLPCC杰出贡献奖获得者冯志伟作序
32 | - 北京大学万小军教授、同济大学王昊奋研究员、清华大学李涓子教授、复旦大学肖仰华教授、浙江大学陈华钧教授、复旦大学黄萱菁教授共同倾力推荐
33 | - 产业界:
34 | - 微软-仪电人工智能创新院总经理朱琳女士作序
35 | - Google知识图谱团队于志伟先生、微创医疗集团乐承筠博士、盛大Alex Lu先生、达观数据陈运文博士、微软陈宏刚博士、中国平安郭敏先生、神策数据桑文锋先生共同力荐
36 |
37 |
38 |
39 | 
40 | 
41 | 
42 |
43 |
44 |
45 | ## 竞赛背景
46 |
47 | 金融研报是各类金融研究结构对宏观经济、金融、行业、产业链以及公司的研究报告。报告通常是有专业人员撰写,对宏观、行业和公司的数据信息搜集全面、研究深入,质量高,内容可靠。报告内容往往包含产业、经济、金融、政策、社会等多领域的数据与知识,是构建行业知识图谱非常关键的数据来源。另一方面,由于研报本身所容纳的数据与知识涉及面广泛,专业知识众多,不同的研究结构和专业认识对相同的内容的表达方式也会略有差异。这些特点导致了从研报自动化构建知识图谱困难重重,解决这些问题则能够极大促进自动化构建知识图谱方面的技术进步。
48 |
49 | 本评测任务参考 TAC KBP 中的 Cold Start 评测任务的方案,围绕金融研报知识图谱的自动化图谱构建所展开。评测从预定义图谱模式(Schema)和少量的种子知识图谱开始,从非结构化的文本数据中构建知识图谱。其中图谱模式包括 10 种实体类型,如机构、产品、业务、风险等;19 个实体间的关系,如(机构,生产销售,产品)、(机构,投资,机构)等;以及若干实体类型带有属性,如(机构,英文名)、(研报,评级)等。在给定图谱模式和种子知识图谱的条件下,评测内容为自动地从研报文本中抽取出符合图谱模式的实体、关系和属性值,实现金融知识图谱的自动化构建。所构建的图谱在大金融行业、监管部门、政府、行业研究机构和行业公司等应用非常广泛,如风险监测、智能投研、智能监管、智能风控等,具有巨大的学术价值和产业价值。
50 |
51 | 评测本身不限制各参赛队伍使用的模型、算法和技术。希望各参赛队伍发挥聪明才智,构建各类无监督、弱监督、远程监督、半监督等系统,迭代的实现知识图谱的自动化构建,共同促进知识图谱技术的进步。
52 |
53 | ## 竞赛任务
54 |
55 | 本评测任务参考 TAC KBP 中的 Cold Start 评测任务的方案,围绕金融研报知识图谱的自动化图谱构建所展开。评测从预定义图谱模式(Schema)和少量的种子知识图谱开始,从非结构化的文本数据中构建知识图谱。评测本身不限制各参赛队伍使用的模型、算法和技术。希望各参赛队伍发挥聪明才智,构建各类无监督、弱监督、远程监督、半监督等系统,迭代的实现知识图谱的自动化构建,共同促进知识图谱技术的进步。
56 |
57 |
58 | ## 联系
59 |
60 | 达观数据 wangwenguang@datagrand.com
61 |
62 | ## 参与
63 |
64 | - 关于baseline的任何问题可以使用issue进行交流,有任何改进的想法可以使用pr参与
65 | - 有竞赛、数据集和书籍有关的问题、想法,欢迎扫描下面二维码关注“走向未来”公众号留言
66 | 
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jieba
2 | catalogue
3 | Cython
4 | fasttext
5 | jsonlines
6 | jsonschema
7 | jupyter
8 | matplotlib
9 | networkx
10 | numpy==1.16.4
11 | paddlepaddle-tiny
12 | pandas
13 | protobuf
14 | pyhanlp
15 | pytorch-crf
16 | pytorch-transformers
17 | requests
18 | scikit-learn
19 | scipy
20 | sentencepiece
21 | scikit-learn
22 | torch
23 | TorchCRF
24 | torchvision
25 | tqdm
26 | transformers
27 | hanlp
28 |
--------------------------------------------------------------------------------
/the-land-of-future.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wgwang/ccks2020-baseline/96c8f990f87afc3a12f1895bab21c095d8cb410e/the-land-of-future.png
--------------------------------------------------------------------------------
/封底.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wgwang/ccks2020-baseline/96c8f990f87afc3a12f1895bab21c095d8cb410e/封底.jpeg
--------------------------------------------------------------------------------
/封面.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wgwang/ccks2020-baseline/96c8f990f87afc3a12f1895bab21c095d8cb410e/封面.png
--------------------------------------------------------------------------------