├── .gitignore
├── CMeIE
├── CMeIE_dev.json
├── CMeIE_test.json
├── CMeIE_train.json
├── README.txt
└── schema.json
├── LICENSE
├── README.md
├── record
└── log.txt
├── requirements.txt
└── roberta_base.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/CMeIE/README.txt:
--------------------------------------------------------------------------------
1 | 1. Mainfest:
2 | - 53_schema.jsonl: SPO关系约束表
3 | - CMeIE_train.jsonl: 训练集
4 | - CMeIE_dev.jsonl: 验证集
5 | - CMeIE_test.jsonl: 测试集,选手提交的时候需要为每条记录填充"spo_list"字段,类型为列表。每个识别出来的关系必须包含"subject", "predicate", "object"3个字段,且"object"是一个字典(和训练数据保持一致): {"@value": "some string"}。
6 | - example_gold.jsonl: 标准答案示例
7 | - example_pred.jsonl: 提交结果示例
8 | - README.txt: 说明文件
9 |
10 | 2. 评估指标以严格Micro-F1值为准
11 |
12 | 3. 该任务提交的文件名为:CMeIE_test.jsonl
13 |
--------------------------------------------------------------------------------
/CMeIE/schema.json:
--------------------------------------------------------------------------------
1 | {"subject_type": "疾病", "predicate": "预防", "object_type": "其他"}
2 | {"subject_type": "疾病", "predicate": "阶段", "object_type": "其他"}
3 | {"subject_type": "疾病", "predicate": "就诊科室", "object_type": "其他"}
4 | {"subject_type": "其他", "predicate": "同义词", "object_type": "其他"}
5 | {"subject_type": "疾病", "predicate": "辅助治疗", "object_type": "其他治疗"}
6 | {"subject_type": "疾病", "predicate": "化疗", "object_type": "其他治疗"}
7 | {"subject_type": "疾病", "predicate": "放射治疗", "object_type": "其他治疗"}
8 | {"subject_type": "其他治疗", "predicate": "同义词", "object_type": "其他治疗"}
9 | {"subject_type": "疾病", "predicate": "手术治疗", "object_type": "手术治疗"}
10 | {"subject_type": "手术治疗", "predicate": "同义词", "object_type": "手术治疗"}
11 | {"subject_type": "疾病", "predicate": "实验室检查", "object_type": "检查"}
12 | {"subject_type": "疾病", "predicate": "影像学检查", "object_type": "检查"}
13 | {"subject_type": "疾病", "predicate": "辅助检查", "object_type": "检查"}
14 | {"subject_type": "疾病", "predicate": "组织学检查", "object_type": "检查"}
15 | {"subject_type": "检查", "predicate": "同义词", "object_type": "检查"}
16 | {"subject_type": "疾病", "predicate": "内窥镜检查", "object_type": "检查"}
17 | {"subject_type": "疾病", "predicate": "筛查", "object_type": "检查"}
18 | {"subject_type": "疾病", "predicate": "多发群体", "object_type": "流行病学"}
19 | {"subject_type": "疾病", "predicate": "发病率", "object_type": "流行病学"}
20 | {"subject_type": "疾病", "predicate": "发病年龄", "object_type": "流行病学"}
21 | {"subject_type": "疾病", "predicate": "多发地区", "object_type": "流行病学"}
22 | {"subject_type": "疾病", "predicate": "发病性别倾向", "object_type": "流行病学"}
23 | {"subject_type": "疾病", "predicate": "死亡率", "object_type": "流行病学"}
24 | {"subject_type": "疾病", "predicate": "多发季节", "object_type": "流行病学"}
25 | {"subject_type": "疾病", "predicate": "传播途径", "object_type": "流行病学"}
26 | {"subject_type": "流行病学", "predicate": "同义词", "object_type": "流行病学"}
27 | {"subject_type": "疾病", "predicate": "同义词", "object_type": "疾病"}
28 | {"subject_type": "疾病", "predicate": "并发症", "object_type": "疾病"}
29 | {"subject_type": "疾病", "predicate": "病理分型", "object_type": "疾病"}
30 | {"subject_type": "疾病", "predicate": "相关(导致)", "object_type": "疾病"}
31 | {"subject_type": "疾病", "predicate": "鉴别诊断", "object_type": "疾病"}
32 | {"subject_type": "疾病", "predicate": "相关(转化)", "object_type": "疾病"}
33 | {"subject_type": "疾病", "predicate": "相关(症状)", "object_type": "疾病"}
34 | {"subject_type": "疾病", "predicate": "临床表现", "object_type": "症状"}
35 | {"subject_type": "疾病", "predicate": "治疗后症状", "object_type": "症状"}
36 | {"subject_type": "疾病", "predicate": "侵及周围组织转移的症状", "object_type": "症状"}
37 | {"subject_type": "症状", "predicate": "同义词", "object_type": "症状"}
38 | {"subject_type": "疾病", "predicate": "病因", "object_type": "社会学"}
39 | {"subject_type": "疾病", "predicate": "高危因素", "object_type": "社会学"}
40 | {"subject_type": "疾病", "predicate": "风险评估因素", "object_type": "社会学"}
41 | {"subject_type": "疾病", "predicate": "病史", "object_type": "社会学"}
42 | {"subject_type": "疾病", "predicate": "遗传因素", "object_type": "社会学"}
43 | {"subject_type": "社会学", "predicate": "同义词", "object_type": "社会学"}
44 | {"subject_type": "疾病", "predicate": "发病机制", "object_type": "社会学"}
45 | {"subject_type": "疾病", "predicate": "病理生理", "object_type": "社会学"}
46 | {"subject_type": "疾病", "predicate": "药物治疗", "object_type": "药物"}
47 | {"subject_type": "药物", "predicate": "同义词", "object_type": "药物"}
48 | {"subject_type": "疾病", "predicate": "发病部位", "object_type": "部位"}
49 | {"subject_type": "疾病", "predicate": "转移部位", "object_type": "部位"}
50 | {"subject_type": "疾病", "predicate": "外侵部位", "object_type": "部位"}
51 | {"subject_type": "部位", "predicate": "同义词", "object_type": "部位"}
52 | {"subject_type": "疾病", "predicate": "预后状况", "object_type": "预后"}
53 | {"subject_type": "疾病", "predicate": "预后生存率", "object_type": "预后"}
54 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Robin-WZQ
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CBLUE_CMeIE_Model
2 |
3 | ## Mission - 任务描述
4 | > 实体和关系抽取作为信息抽取的重要子任务,近些年众多学者利用多种技术在该领域开展深入研究。将这些技术应用于医学领域,抽取非结构化和半结构化的医学文本构建成医学知识图谱,可服务于下游子任务。非结构化的医学文本,如医学教材每一个自然段落,临床实践中每种疾病下的主题,电子病历数据中的主诉、现病史、鉴别诊断等,都是由中文自然语言句子或句子集合组成。实体关系抽取是从非结构化医学文本中找出医学实体,并确定实体对关系事实的过程。
5 |
6 | 关于任务更详尽的描述可参考比赛链接: [中文医疗信息处理挑战榜](https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414)
7 |
8 |
9 | ## Dataset - 数据集
10 |
11 | 中文医疗信息处理挑战榜CBLUE 中CMeIE数据集,同样是 CHIP2020/2021 的医学实体关系抽取数据集。
12 |
13 | ## Enviroment - 环境
14 | - python = 3.8.3
15 | - pytorch = 1.10.2+cu113
16 | - pytorch_pretrained_bert
17 | - tqdm
18 |
19 | ## Usage - 使用方法
20 |
21 | 0. 下载本仓库:
22 | ```Shell
23 | git clone https://github.com/Robin-WZQ/CBLUE_CMeIE_model.git
24 | cd CBLUE_CMeIE_model
25 | ```
26 |
27 | 1. 安装依赖:
28 | ```Shell
29 | pip install -r requirements.txt
30 | ```
31 |
32 | 2. 下载预训练模型
33 |
34 | https://drive.google.com/file/d/1eHM3l4fMo6DsQYGmey7UZGiTmQquHw25/view
35 |
36 | P.S. 我使用的是Roberta-base(Large太大电脑跑不动了-_-),也可以使用别的模型。
37 |
38 | 3.保证目录如下所示:
39 |
40 | ```
41 | | Chinese_roberta_wwm_ext_pytorch # 中文预训练模型
42 | ----| bert_config.json
43 | ----| pytorch_model.bin
44 | ----| vocab.txt
45 | | CMeIE
46 | ----| CMeIE_train.json # 训练集
47 | ----| CMeIE_dev.json # 开发集
48 | ----| CMeIE_test.json # 测试集
49 | ----| README.txt # 数据说明文件
50 | ----| schema.json # 关系约束
51 | | record
52 | draw.py # 绘图函数
53 | roberta_base.py # 主函数
54 | README.md # 说明文件
55 | requirements.txt # 配置文件
56 | ```
57 |
58 | 4. 运行
59 | ```Shell
60 | run roberta_base.py
61 | ```
62 | 5. 生成文件
63 | - roberta.pth # 训练完的模型
64 | - dev_pred.json # 开发集预测结果
65 | - RE_pred.json # 测试集预测结果(提交时需命名为CMeIE_test.jsonl,并压缩后提交)
66 |
67 | ## Result - 结果
68 |
69 |

70 |
71 |
72 | 结果有些拉,不过就当作baseline也可以。
73 |
74 | 参考论文为”A Novel Cascade Binary Tagging Framework for Relational Triple Extraction” (CASREL)
75 |
76 | ## Reference - 参考
77 |
78 | https://github.com/CBLUEbenchmark/CBLUE
79 |
80 | https://github.com/ymcui/Chinese-BERT-wwm#%E6%A8%A1%E5%9E%8B%E5%AF%B9%E6%AF%94
81 |
82 | https://zhuanlan.zhihu.com/p/136277427
83 |
--------------------------------------------------------------------------------
/record/log.txt:
--------------------------------------------------------------------------------
1 | epoch: 2
2 | loss: 0.43698995719353356
3 | f1, pre, rec: 1.8843037497644266e-14 9.999999999e-11 9.422406482615572e-15
4 | epoch: 3
5 | loss: 0.2870555300017198
6 | f1, pre, rec: 1.8844812965230965e-14 1.0 9.422406482615572e-15
7 | epoch: 4
8 | loss: 0.2508908245464166
9 | f1, pre, rec: 1.8841262364578074e-14 4.9999999997500004e-11 9.422406482615572e-15
10 | epoch: 5
11 | loss: 0.2221885260939598
12 | f1, pre, rec: 0.00223546944860279 0.09756097561048979 0.001130688777923291
13 | epoch: 6
14 | loss: 0.185391767496864
15 | f1, pre, rec: 0.06620736010756063 0.30566330488756355 0.03712428154151477
16 | epoch: 7
17 | loss: 0.15719785166283448
18 | f1, pre, rec: 0.20451020164659106 0.2787445113026626 0.16150004711204033
19 | epoch: 8
20 | loss: 0.13639551517864068
21 | f1, pre, rec: 0.23789380642718305 0.360505166475328 0.1775181381324868
22 | epoch: 9
23 | loss: 0.12217968091368675
24 | f1, pre, rec: 0.22514803204459458 0.43185462319616696 0.15226608875907705
25 | epoch: 10
26 | loss: 0.10952602570255597
27 | f1, pre, rec: 0.2854753941710549 0.38982221497309755 0.22519551493452158
28 | epoch: 11
29 | loss: 0.10297784574329853
30 | f1, pre, rec: 0.31555765298097305 0.35575786237882057 0.28352021106191194
31 | epoch: 12
32 | loss: 0.09667525255431732
33 | f1, pre, rec: 0.3251311448813009 0.38156892612338944 0.2832375388674335
34 | epoch: 13
35 | loss: 0.09070725123087565
36 | f1, pre, rec: 0.3344887348353625 0.3933256909947855 0.29096391218317824
37 | epoch: 14
38 | loss: 0.08317714810371399
39 | f1, pre, rec: 0.3523270509368718 0.4226206169259238 0.30208235183266463
40 | epoch: 15
41 | loss: 0.0782144309207797
42 | f1, pre, rec: 0.36543253826985256 0.34623492705962267 0.3868840101762048
43 | epoch: 16
44 | loss: 0.07398854872832696
45 | f1, pre, rec: 0.37905642171551085 0.411103767349643 0.35164420993122253
46 | epoch: 17
47 | loss: 0.06968923959881067
48 | 1, pre, rec: 0.37655677655678266 0.39089434191848116 0.3632337699048397
49 | epoch: 18
50 | loss: 0.0648442996914188
51 | f1, pre, rec: 0.3963028351853838 0.44141122035859526 0.3595590313766196
52 | epoch: 19
53 | loss: 0.06146668403098981
54 | f1, pre, rec: 0.3768615902398033 0.34042714904614013 0.42202958635636084
55 | epoch: 20
56 | loss: 0.059042305971185365
57 | f1, pre, rec: 0.4011946241911458 0.4252400548696906 0.37972298124941695
58 | epoch: 21
59 | loss: 0.054907060861587524
60 | f1, pre, rec: 0.4018789665683934 0.42785699084912293 0.37887496466598153
61 | epoch: 22
62 | loss: 0.05235395323485136
63 | f1, pre, rec: 0.40784526883182126 0.392707606420103 0.42419673984736245
64 | epoch: 23
65 | loss: 0.05071420204515258
66 | f1, pre, rec: 0.4162441359048532 0.41868446139180726 0.4138320927164853
67 | epoch: 24
68 | loss: 0.04806953402236104
69 | f1, pre, rec: 0.43165600110492675 0.4220381706877978 0.4417224159050274
70 | epoch: 25
71 | loss: 0.04542926991979281
72 | f1, pre, rec: 0.4318417333961429 0.43176038428935265 0.4319231131631072
73 | epoch: 26
74 | loss: 0.04372004958490531
75 | f1, pre, rec: 0.43017329255861897 0.4231926337861298 0.43738810892302427
76 | epoch: 27
77 | loss: 0.04213489151249329
78 | f1, pre, rec: 0.4274379555309583 0.4187076366922782 0.43654009233958885
79 | epoch: 28
80 | loss: 0.040034455358982084
81 | f1, pre, rec: 0.44401858860729704 0.45165692007797803 0.436634316404415
82 | epoch: 29
83 | loss: 0.03817422329137723
84 | f1, pre, rec: 0.44188586511332084 0.42733914268110706 0.4574578347309954
85 | epoch: 30
86 | loss: 0.03523984232435093
87 | f1, pre, rec: 0.45134172942890583 0.4492749571929380 0.45684738278928729
88 | epoch: 31
89 | loss: 0.03201829484930484
90 | f1, pre, rec: 0.45714828355729115 0.4607580398162379 0.45359464807312305
91 | epoch: 32
92 | loss: 0.027590999432529014
93 | f1, pre, rec: 0.4652450276809569 0.5101742551995558 0.42758880618110406
94 | epoch: 33
95 | loss: 0.02752583069416384
96 | f1, pre, rec: 0.4634157788520255 0.46166074434262705 0.46518420804674016
97 | epoch: 34
98 | loss: 0.025788033933689197
99 | f1, pre, rec: 0.4717325227963575 0.4682510211659908 0.47526618298313883
100 | epoch: 35
101 | loss: 0.024115538876503705
102 | f1, pre, rec: 0.4683687943262462 0.47005789124039604 0.4666917930839587
103 | epoch: 36
104 | loss: 0.022993005256478984
105 | f1, pre, rec: 0.4663262825188479 0.4818127009646354 0.45180439084142604
106 | epoch: 37
107 | loss: 0.022681940294181305
108 | f1, pre, rec: 0.4721851781291725 0.4774150052971253 0.4670686893432633
109 | epoch: 38
110 | loss: 0.021437988535811504
111 | f1, pre, rec: 0.4659728550189491 0.47450673959758277 0.45774050692547386
112 | epoch: 39
113 | loss: 0.02119144581258297
114 | f1, pre, rec: 0.4752872515046555 0.4604647053626693 0.491095825873933
115 | epoch: 40
116 | loss: 0.01966104614858826
117 | f1, pre, rec: 0.4748313678704219 0.46279069767442343 0.48751531141053905
118 | epoch: 41
119 | loss: 0.019432777492329478
120 | f1, pre, rec: 0.48374043196334177 0.5012123661345776 0.46744558560256794
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | json
2 | tqdm
3 | matplotlib
4 | pytorch_pretrained_bert
5 | torch
--------------------------------------------------------------------------------
/roberta_base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import unicodedata
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from pytorch_pretrained_bert import BertModel, BertTokenizer
11 | from torch.utils.data import DataLoader, Dataset
12 | from tqdm import tqdm
13 |
14 | train_l = []
15 | val_f1_r = []
16 |
17 | def record():
18 | f1 = open("record/val_f1_1e4.txt","w")
19 | f2 = open("record/loss_1e4.txt",'w')
20 |
21 | for i in range(len(train_l)):
22 | f1.write(str(val_f1_r[i])+"\n")
23 | f2.write(str(train_l[i])+"\n")
24 |
25 | f1.close()
26 | f2.close()
27 |
28 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
29 | BERT_PATH = "chinese_roberta_wwm_ext_pytorch/"
30 | maxlen = 256
31 |
32 | # 读取json文件,即输入text及对应spo格式
33 | def load_data(filename):
34 | D = []
35 | with open(filename, encoding='utf-8') as f:
36 | for l in tqdm(f):
37 | l = json.loads(l)
38 | d = {'text': l['text'], 'spo_list': []}
39 | for spo in l['spo_list']:
40 | for k, v in spo['object'].items():
41 | d['spo_list'].append(
42 | (spo['subject'], spo['predicate'], v)
43 | )
44 | D.append(d)
45 | return D
46 |
47 | # 加载数据集
48 | train_data = load_data('CMeIE/CMeIE_train.json')
49 | valid_data = load_data('CMeIE/CMeIE_dev.json')
50 |
51 | def search(pattern, sequence):
52 | """从sequence中寻找子串pattern
53 | 如果找到,返回第一个下标;否则返回-1。
54 | """
55 | n = len(pattern)
56 | for i in range(len(sequence)):
57 | if sequence[i:i + n] == pattern:
58 | return i
59 | return -1
60 |
61 | train_data_new = [] # 创建新的训练集,把结束位置超过250的文本去除,可见并没有去除多少
62 | for data in tqdm(train_data):
63 | flag = 1
64 | for s, p, o in data['spo_list']:
65 | s_begin = search(s, data['text'])
66 | o_begin = search(o, data['text'])
67 | if s_begin == -1 or o_begin == -1 or s_begin + len(s) > 250 or o_begin + len(o) > 250:
68 | flag = 0
69 | break
70 | if flag == 1:
71 | train_data_new.append(data)
72 | print(len(train_data_new))
73 |
74 | # 读取schema
75 | with open('CMeIE/schema.json', encoding='utf-8') as f:
76 | id2predicate, predicate2id, n = {}, {}, 0
77 | predicate2type = {}
78 | for l in f:
79 | l = json.loads(l)
80 | predicate2type[l['predicate']] = (l['subject_type'], l['object_type'])
81 | key = l['predicate']
82 | if key not in predicate2id:
83 | id2predicate[n] = key
84 | predicate2id[key] = n
85 | n += 1
86 | print(len(predicate2id))
87 |
88 |
89 | class OurTokenizer(BertTokenizer):
90 | def tokenize(self, text):
91 | R = []
92 | for c in text:
93 | if c in self.vocab:
94 | R.append(c)
95 | elif self._is_whitespace(c):
96 | R.append('[unused1]')
97 | else:
98 | R.append('[UNK]')
99 | return R
100 |
101 | def _is_whitespace(self, char):
102 | if char == " " or char == "\t" or char == "\n" or char == "\r":
103 | return True
104 | cat = unicodedata.category(char)
105 | if cat == "Zs":
106 | return True
107 | return False
108 |
109 | # 初始化分词器
110 | tokenizer = OurTokenizer(vocab_file=BERT_PATH + "vocab.txt")
111 |
112 | class TorchDataset(Dataset):
113 | def __init__(self, data):
114 | self.data = data
115 |
116 | def __getitem__(self, i):
117 | t = self.data[i]
118 | x = tokenizer.tokenize(t['text'])
119 | x = ["[CLS]"] + x + ["[SEP]"]
120 | token_ids = tokenizer.convert_tokens_to_ids(x)
121 | seg_ids = [0] * len(token_ids)
122 | assert len(token_ids) == len(t['text'])+2
123 | spoes = {}
124 | for s, p, o in t['spo_list']:
125 | s = tokenizer.tokenize(s)
126 | s = tokenizer.convert_tokens_to_ids(s)
127 | p = predicate2id[p] # 关系id
128 | o = tokenizer.tokenize(o)
129 | o = tokenizer.convert_tokens_to_ids(o)
130 | s_idx = search(s, token_ids) # subject起始位置
131 | o_idx = search(o, token_ids) # object起始位置
132 |
133 | if s_idx != -1 and o_idx != -1:
134 | s = (s_idx, s_idx + len(s) - 1)
135 | o = (o_idx, o_idx + len(o) - 1, p) # 同时预测o和p
136 | if s not in spoes:
137 | spoes[s] = [] # 可以一个subject多个object
138 | spoes[s].append(o)
139 |
140 | if spoes:
141 | sub_labels = np.zeros((len(token_ids), 2))
142 | for s in spoes:
143 | sub_labels[s[0], 0] = 1
144 | sub_labels[s[1], 1] = 1
145 | # 随机选一个subject
146 | start, end = np.array(list(spoes.keys())).T
147 | start = np.random.choice(start)
148 | end = sorted(end[end >= start])[0]
149 | sub_ids = (start, end)
150 | obj_labels = np.zeros((len(token_ids), len(predicate2id), 2))
151 | for o in spoes.get(sub_ids, []):
152 | obj_labels[o[0], o[2], 0] = 1
153 | obj_labels[o[1], o[2], 1] = 1
154 |
155 | token_ids = self.sequence_padding(token_ids, maxlen=maxlen)
156 | seg_ids = self.sequence_padding(seg_ids, maxlen=maxlen)
157 | sub_labels = self.sequence_padding(sub_labels, maxlen=maxlen, padding=np.zeros(2))
158 | sub_ids = np.array(sub_ids)
159 | obj_labels = self.sequence_padding(obj_labels, maxlen=maxlen,
160 | padding=np.zeros((len(predicate2id), 2)))
161 |
162 | return (torch.LongTensor(token_ids), torch.LongTensor(seg_ids), torch.LongTensor(sub_ids),
163 | torch.LongTensor(sub_labels), torch.LongTensor(obj_labels) )
164 |
165 | def __len__(self):
166 | data_len = len(self.data)
167 | return data_len
168 |
169 | def sequence_padding(self, x, maxlen, padding=0):
170 | output = np.concatenate([x, [padding]*(maxlen-len(x))]) if len(x) 254:
247 | text = text[:254]
248 | tokens = tokenizer.tokenize(text)
249 | tokens = ["[CLS]"] + tokens + ["[SEP]"]
250 | token_ids = tokenizer.convert_tokens_to_ids(tokens)
251 | assert len(token_ids) == len(text) + 2
252 | seg_ids = [0] * len(token_ids)
253 |
254 | sub_preds = model(torch.LongTensor([token_ids]).to(device),
255 | torch.LongTensor([seg_ids]).to(device))
256 | sub_preds = sub_preds.detach().cpu().numpy() # [1, maxlen, 2]
257 | # print(sub_preds[0,])
258 | start = np.where(sub_preds[0, :, 0] > 0.2)[0]
259 | end = np.where(sub_preds[0, :, 1] > 0.2)[0]
260 | # print(start, end)
261 | tmp_print = []
262 | subjects = []
263 | for i in start:
264 | j = end[end>=i]
265 | if len(j) > 0:
266 | j = j[0]
267 | subjects.append((i, j))
268 | tmp_print.append(text[i-1: j])
269 |
270 | if subjects:
271 | spoes = []
272 | token_ids = np.repeat([token_ids], len(subjects), 0) # [len_subjects, seqlen]
273 | seg_ids = np.repeat([seg_ids], len(subjects), 0)
274 | subjects = np.array(subjects) # [len_subjects, 2]
275 | # 传入subject 抽取object和predicate
276 | _, object_preds = model(torch.LongTensor(token_ids).to(device),
277 | torch.LongTensor(seg_ids).to(device),
278 | torch.LongTensor(subjects).to(device))
279 | object_preds = object_preds.detach().cpu().numpy()
280 | # print(object_preds.shape)
281 | for sub, obj_pred in zip(subjects, object_preds):
282 | # obj_pred [maxlen, 55, 2]
283 | start = np.where(obj_pred[:, :, 0] > 0.2)
284 | end = np.where(obj_pred[:, :, 1] > 0.2)
285 | for _start, predicate1 in zip(*start):
286 | for _end, predicate2 in zip(*end):
287 | if _start <= _end and predicate1 == predicate2:
288 | spoes.append(
289 | ((sub[0]-1, sub[1]-1), predicate1, (_start-1, _end-1))
290 | )
291 | break
292 | return [(text[s[0]:s[1]+1], id2predicate[p], text[o[0]:o[1]+1])
293 | for s, p, o in spoes]
294 | else:
295 | return []
296 |
297 | def evaluate(data, model, device):
298 | """评估函数,计算f1、precision、recall
299 | """
300 | X, Y, Z = 1e-10, 1e-10, 1e-10
301 | f = open('CMeIE/dev_pred.json', 'w', encoding='utf-8')
302 | pbar = tqdm()
303 | for d in data:
304 | R = extract_spoes(d['text'], model, device)
305 | T = d['spo_list']
306 | # print(R, T)
307 | R = set(R)
308 | T = set(T)
309 | X += len(R & T)
310 | Y += len(R)
311 | Z += len(T)
312 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
313 | pbar.update()
314 | pbar.set_description(
315 | 'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
316 | )
317 | s = json.dumps({
318 | 'text': d['text'],
319 | 'spo_list': list(T),
320 | 'spo_list_pred': list(R),
321 | 'new': list(R - T),
322 | 'lack': list(T - R),
323 | }, ensure_ascii=False, indent=4)
324 | f.write(s + '\n')
325 | pbar.close()
326 | f.close()
327 |
328 | return f1, precision, recall
329 |
330 | def train(model, train_loader, optimizer, epoches, device):
331 | f1_max = 0.5428
332 | for _ in range(epoches):
333 | print('epoch: ', _ + 1)
334 | start = time.time()
335 | train_loss_sum = 0.0
336 | for batch_idx, x in tqdm(enumerate(train_loader)):
337 | token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device)
338 | mask = (token_ids > 0).float()
339 | mask = mask.to(device) # zero-mask
340 | sub_labels, obj_labels = x[3].float().to(device), x[4].float().to(device)
341 | sub_preds, obj_preds = model(token_ids, seg_ids, sub_ids)
342 | # (batch_size, maxlen, 2), (batch_size, maxlen, 44, 2)
343 |
344 | # 计算loss
345 | loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') #[bs, ml, 2]
346 | loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen)
347 | loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask)
348 | loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 44, 2]
349 | loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen)
350 | loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask)
351 | loss = loss_sub + loss_obj
352 |
353 | optimizer.zero_grad()
354 | loss.backward()
355 | optimizer.step()
356 | train_loss_sum += loss.cpu().item()
357 | if (batch_idx + 1) % 300 == 0:
358 | print('loss: ', train_loss_sum / (batch_idx+1), 'time: ', time.time() - start)
359 | train_l.append((_ + 1,train_loss_sum / (batch_idx+1)))
360 |
361 |
362 | with torch.no_grad():
363 | val_f1, pre, rec = evaluate(valid_data, net, DEVICE)
364 | print("f1, pre, rec: ", val_f1, pre, rec)
365 | if val_f1>f1_max:
366 | torch.save(net.state_dict(), "CMeIE/roberta.pth")
367 | f1_max = val_f1
368 |
369 | val_f1_r.append((_ + 1,val_f1))
370 |
371 | # 如果运行完一次可以将其还原
372 | # net.load_state_dict(torch.load("CMeIE/roberta.pth"))
373 | train(net, train_loader, optimizer, 5, DEVICE)
374 |
375 | def combine_spoes(spoes):
376 | """
377 | """
378 | new_spoes = {}
379 | for s, p, o in spoes:
380 | p1 = p
381 | p2 = '@value'
382 | if (s, p1) in new_spoes:
383 | new_spoes[(s, p1)][p2] = o
384 | else:
385 | new_spoes[(s, p1)] = {p2: o}
386 |
387 | return [(k[0], k[1], v) for k, v in new_spoes.items()]
388 |
389 | def predict_to_file(in_file, out_file):
390 | """预测结果到文件,方便提交
391 | """
392 | fw = open(out_file, 'w', encoding='utf-8')
393 | with open(in_file, encoding='utf-8') as fr:
394 | for l in tqdm(fr):
395 | l = json.loads(l)
396 | spoes = combine_spoes(extract_spoes(l['text'], net, DEVICE))
397 | spoes = [{
398 | 'subject': spo[0],
399 | 'subject_type': predicate2type[spo[1]][0],
400 | 'predicate': spo[1],
401 | 'object': spo[2],
402 | 'object_type': {
403 | k: predicate2type[spo[1]][1]
404 | for k in spo[2]
405 | }
406 | }
407 | for spo in spoes]
408 | l['spo_list'] = spoes
409 | s = json.dumps(l, ensure_ascii=False)
410 | fw.write(s + '\n')
411 | fw.close()
412 |
413 | predict_to_file('CMeIE/CMeIE_test.json', 'CMeIE/RE_pred.json')
414 | record()
415 |
--------------------------------------------------------------------------------