├── .gitignore ├── CMeIE ├── CMeIE_dev.json ├── CMeIE_test.json ├── CMeIE_train.json ├── README.txt └── schema.json ├── LICENSE ├── README.md ├── record └── log.txt ├── requirements.txt └── roberta_base.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CMeIE/README.txt: -------------------------------------------------------------------------------- 1 | 1. Mainfest: 2 | - 53_schema.jsonl: SPO关系约束表 3 | - CMeIE_train.jsonl: 训练集 4 | - CMeIE_dev.jsonl: 验证集 5 | - CMeIE_test.jsonl: 测试集,选手提交的时候需要为每条记录填充"spo_list"字段,类型为列表。每个识别出来的关系必须包含"subject", "predicate", "object"3个字段,且"object"是一个字典(和训练数据保持一致): {"@value": "some string"}。 6 | - example_gold.jsonl: 标准答案示例 7 | - example_pred.jsonl: 提交结果示例 8 | - README.txt: 说明文件 9 | 10 | 2. 评估指标以严格Micro-F1值为准 11 | 12 | 3. 该任务提交的文件名为:CMeIE_test.jsonl 13 | -------------------------------------------------------------------------------- /CMeIE/schema.json: -------------------------------------------------------------------------------- 1 | {"subject_type": "疾病", "predicate": "预防", "object_type": "其他"} 2 | {"subject_type": "疾病", "predicate": "阶段", "object_type": "其他"} 3 | {"subject_type": "疾病", "predicate": "就诊科室", "object_type": "其他"} 4 | {"subject_type": "其他", "predicate": "同义词", "object_type": "其他"} 5 | {"subject_type": "疾病", "predicate": "辅助治疗", "object_type": "其他治疗"} 6 | {"subject_type": "疾病", "predicate": "化疗", "object_type": "其他治疗"} 7 | {"subject_type": "疾病", "predicate": "放射治疗", "object_type": "其他治疗"} 8 | {"subject_type": "其他治疗", "predicate": "同义词", "object_type": "其他治疗"} 9 | {"subject_type": "疾病", "predicate": "手术治疗", "object_type": "手术治疗"} 10 | {"subject_type": "手术治疗", "predicate": "同义词", "object_type": "手术治疗"} 11 | {"subject_type": "疾病", "predicate": "实验室检查", "object_type": "检查"} 12 | {"subject_type": "疾病", "predicate": "影像学检查", "object_type": "检查"} 13 | {"subject_type": "疾病", "predicate": "辅助检查", "object_type": "检查"} 14 | {"subject_type": "疾病", "predicate": "组织学检查", "object_type": "检查"} 15 | {"subject_type": "检查", "predicate": "同义词", "object_type": "检查"} 16 | {"subject_type": "疾病", "predicate": "内窥镜检查", "object_type": "检查"} 17 | {"subject_type": "疾病", "predicate": "筛查", "object_type": "检查"} 18 | {"subject_type": "疾病", "predicate": "多发群体", "object_type": "流行病学"} 19 | {"subject_type": "疾病", "predicate": "发病率", "object_type": "流行病学"} 20 | {"subject_type": "疾病", "predicate": "发病年龄", "object_type": "流行病学"} 21 | {"subject_type": "疾病", "predicate": "多发地区", "object_type": "流行病学"} 22 | {"subject_type": "疾病", "predicate": "发病性别倾向", "object_type": "流行病学"} 23 | {"subject_type": "疾病", "predicate": "死亡率", "object_type": "流行病学"} 24 | {"subject_type": "疾病", "predicate": "多发季节", "object_type": "流行病学"} 25 | {"subject_type": "疾病", "predicate": "传播途径", "object_type": "流行病学"} 26 | {"subject_type": "流行病学", "predicate": "同义词", "object_type": "流行病学"} 27 | {"subject_type": "疾病", "predicate": "同义词", "object_type": "疾病"} 28 | {"subject_type": "疾病", "predicate": "并发症", "object_type": "疾病"} 29 | {"subject_type": "疾病", "predicate": "病理分型", "object_type": "疾病"} 30 | {"subject_type": "疾病", "predicate": "相关(导致)", "object_type": "疾病"} 31 | {"subject_type": "疾病", "predicate": "鉴别诊断", "object_type": "疾病"} 32 | {"subject_type": "疾病", "predicate": "相关(转化)", "object_type": "疾病"} 33 | {"subject_type": "疾病", "predicate": "相关(症状)", "object_type": "疾病"} 34 | {"subject_type": "疾病", "predicate": "临床表现", "object_type": "症状"} 35 | {"subject_type": "疾病", "predicate": "治疗后症状", "object_type": "症状"} 36 | {"subject_type": "疾病", "predicate": "侵及周围组织转移的症状", "object_type": "症状"} 37 | {"subject_type": "症状", "predicate": "同义词", "object_type": "症状"} 38 | {"subject_type": "疾病", "predicate": "病因", "object_type": "社会学"} 39 | {"subject_type": "疾病", "predicate": "高危因素", "object_type": "社会学"} 40 | {"subject_type": "疾病", "predicate": "风险评估因素", "object_type": "社会学"} 41 | {"subject_type": "疾病", "predicate": "病史", "object_type": "社会学"} 42 | {"subject_type": "疾病", "predicate": "遗传因素", "object_type": "社会学"} 43 | {"subject_type": "社会学", "predicate": "同义词", "object_type": "社会学"} 44 | {"subject_type": "疾病", "predicate": "发病机制", "object_type": "社会学"} 45 | {"subject_type": "疾病", "predicate": "病理生理", "object_type": "社会学"} 46 | {"subject_type": "疾病", "predicate": "药物治疗", "object_type": "药物"} 47 | {"subject_type": "药物", "predicate": "同义词", "object_type": "药物"} 48 | {"subject_type": "疾病", "predicate": "发病部位", "object_type": "部位"} 49 | {"subject_type": "疾病", "predicate": "转移部位", "object_type": "部位"} 50 | {"subject_type": "疾病", "predicate": "外侵部位", "object_type": "部位"} 51 | {"subject_type": "部位", "predicate": "同义词", "object_type": "部位"} 52 | {"subject_type": "疾病", "predicate": "预后状况", "object_type": "预后"} 53 | {"subject_type": "疾病", "predicate": "预后生存率", "object_type": "预后"} 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Robin-WZQ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CBLUE_CMeIE_Model 2 | 3 | ## Mission - 任务描述 4 | > 实体和关系抽取作为信息抽取的重要子任务,近些年众多学者利用多种技术在该领域开展深入研究。将这些技术应用于医学领域,抽取非结构化和半结构化的医学文本构建成医学知识图谱,可服务于下游子任务。非结构化的医学文本,如医学教材每一个自然段落,临床实践中每种疾病下的主题,电子病历数据中的主诉、现病史、鉴别诊断等,都是由中文自然语言句子或句子集合组成。实体关系抽取是从非结构化医学文本中找出医学实体,并确定实体对关系事实的过程。 5 | 6 | 关于任务更详尽的描述可参考比赛链接: [中文医疗信息处理挑战榜](https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414) 7 | 8 | 9 | ## Dataset - 数据集 10 | 11 | 中文医疗信息处理挑战榜CBLUE 中CMeIE数据集,同样是 CHIP2020/2021 的医学实体关系抽取数据集。 12 | 13 | ## Enviroment - 环境 14 | - python = 3.8.3 15 | - pytorch = 1.10.2+cu113 16 | - pytorch_pretrained_bert 17 | - tqdm 18 | 19 | ## Usage - 使用方法 20 | 21 | 0. 下载本仓库: 22 | ```Shell 23 | git clone https://github.com/Robin-WZQ/CBLUE_CMeIE_model.git 24 | cd CBLUE_CMeIE_model 25 | ``` 26 | 27 | 1. 安装依赖: 28 | ```Shell 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | 2. 下载预训练模型 33 | 34 | https://drive.google.com/file/d/1eHM3l4fMo6DsQYGmey7UZGiTmQquHw25/view 35 | 36 | P.S. 我使用的是Roberta-base(Large太大电脑跑不动了-_-),也可以使用别的模型。 37 | 38 | 3.保证目录如下所示: 39 | 40 | ``` 41 | | Chinese_roberta_wwm_ext_pytorch # 中文预训练模型 42 | ----| bert_config.json 43 | ----| pytorch_model.bin 44 | ----| vocab.txt 45 | | CMeIE 46 | ----| CMeIE_train.json # 训练集 47 | ----| CMeIE_dev.json # 开发集 48 | ----| CMeIE_test.json # 测试集 49 | ----| README.txt # 数据说明文件 50 | ----| schema.json # 关系约束 51 | | record 52 | draw.py # 绘图函数 53 | roberta_base.py # 主函数 54 | README.md # 说明文件 55 | requirements.txt # 配置文件 56 | ``` 57 | 58 | 4. 运行 59 | ```Shell 60 | run roberta_base.py 61 | ``` 62 | 5. 生成文件 63 | - roberta.pth # 训练完的模型 64 | - dev_pred.json # 开发集预测结果 65 | - RE_pred.json # 测试集预测结果(提交时需命名为CMeIE_test.jsonl,并压缩后提交) 66 | 67 | ## Result - 结果 68 |
69 | pic1 70 |
71 | 72 | 结果有些拉,不过就当作baseline也可以。 73 | 74 | 参考论文为”A Novel Cascade Binary Tagging Framework for Relational Triple Extraction” (CASREL) 75 | 76 | ## Reference - 参考 77 | 78 | https://github.com/CBLUEbenchmark/CBLUE 79 | 80 | https://github.com/ymcui/Chinese-BERT-wwm#%E6%A8%A1%E5%9E%8B%E5%AF%B9%E6%AF%94 81 | 82 | https://zhuanlan.zhihu.com/p/136277427 83 | -------------------------------------------------------------------------------- /record/log.txt: -------------------------------------------------------------------------------- 1 | epoch: 2 2 | loss: 0.43698995719353356 3 | f1, pre, rec: 1.8843037497644266e-14 9.999999999e-11 9.422406482615572e-15 4 | epoch: 3 5 | loss: 0.2870555300017198 6 | f1, pre, rec: 1.8844812965230965e-14 1.0 9.422406482615572e-15 7 | epoch: 4 8 | loss: 0.2508908245464166 9 | f1, pre, rec: 1.8841262364578074e-14 4.9999999997500004e-11 9.422406482615572e-15 10 | epoch: 5 11 | loss: 0.2221885260939598 12 | f1, pre, rec: 0.00223546944860279 0.09756097561048979 0.001130688777923291 13 | epoch: 6 14 | loss: 0.185391767496864 15 | f1, pre, rec: 0.06620736010756063 0.30566330488756355 0.03712428154151477 16 | epoch: 7 17 | loss: 0.15719785166283448 18 | f1, pre, rec: 0.20451020164659106 0.2787445113026626 0.16150004711204033 19 | epoch: 8 20 | loss: 0.13639551517864068 21 | f1, pre, rec: 0.23789380642718305 0.360505166475328 0.1775181381324868 22 | epoch: 9 23 | loss: 0.12217968091368675 24 | f1, pre, rec: 0.22514803204459458 0.43185462319616696 0.15226608875907705 25 | epoch: 10 26 | loss: 0.10952602570255597 27 | f1, pre, rec: 0.2854753941710549 0.38982221497309755 0.22519551493452158 28 | epoch: 11 29 | loss: 0.10297784574329853 30 | f1, pre, rec: 0.31555765298097305 0.35575786237882057 0.28352021106191194 31 | epoch: 12 32 | loss: 0.09667525255431732 33 | f1, pre, rec: 0.3251311448813009 0.38156892612338944 0.2832375388674335 34 | epoch: 13 35 | loss: 0.09070725123087565 36 | f1, pre, rec: 0.3344887348353625 0.3933256909947855 0.29096391218317824 37 | epoch: 14 38 | loss: 0.08317714810371399 39 | f1, pre, rec: 0.3523270509368718 0.4226206169259238 0.30208235183266463 40 | epoch: 15 41 | loss: 0.0782144309207797 42 | f1, pre, rec: 0.36543253826985256 0.34623492705962267 0.3868840101762048 43 | epoch: 16 44 | loss: 0.07398854872832696 45 | f1, pre, rec: 0.37905642171551085 0.411103767349643 0.35164420993122253 46 | epoch: 17 47 | loss: 0.06968923959881067 48 | 1, pre, rec: 0.37655677655678266 0.39089434191848116 0.3632337699048397 49 | epoch: 18 50 | loss: 0.0648442996914188 51 | f1, pre, rec: 0.3963028351853838 0.44141122035859526 0.3595590313766196 52 | epoch: 19 53 | loss: 0.06146668403098981 54 | f1, pre, rec: 0.3768615902398033 0.34042714904614013 0.42202958635636084 55 | epoch: 20 56 | loss: 0.059042305971185365 57 | f1, pre, rec: 0.4011946241911458 0.4252400548696906 0.37972298124941695 58 | epoch: 21 59 | loss: 0.054907060861587524 60 | f1, pre, rec: 0.4018789665683934 0.42785699084912293 0.37887496466598153 61 | epoch: 22 62 | loss: 0.05235395323485136 63 | f1, pre, rec: 0.40784526883182126 0.392707606420103 0.42419673984736245 64 | epoch: 23 65 | loss: 0.05071420204515258 66 | f1, pre, rec: 0.4162441359048532 0.41868446139180726 0.4138320927164853 67 | epoch: 24 68 | loss: 0.04806953402236104 69 | f1, pre, rec: 0.43165600110492675 0.4220381706877978 0.4417224159050274 70 | epoch: 25 71 | loss: 0.04542926991979281 72 | f1, pre, rec: 0.4318417333961429 0.43176038428935265 0.4319231131631072 73 | epoch: 26 74 | loss: 0.04372004958490531 75 | f1, pre, rec: 0.43017329255861897 0.4231926337861298 0.43738810892302427 76 | epoch: 27 77 | loss: 0.04213489151249329 78 | f1, pre, rec: 0.4274379555309583 0.4187076366922782 0.43654009233958885 79 | epoch: 28 80 | loss: 0.040034455358982084 81 | f1, pre, rec: 0.44401858860729704 0.45165692007797803 0.436634316404415 82 | epoch: 29 83 | loss: 0.03817422329137723 84 | f1, pre, rec: 0.44188586511332084 0.42733914268110706 0.4574578347309954 85 | epoch: 30 86 | loss: 0.03523984232435093 87 | f1, pre, rec: 0.45134172942890583 0.4492749571929380 0.45684738278928729 88 | epoch: 31 89 | loss: 0.03201829484930484 90 | f1, pre, rec: 0.45714828355729115 0.4607580398162379 0.45359464807312305 91 | epoch: 32 92 | loss: 0.027590999432529014 93 | f1, pre, rec: 0.4652450276809569 0.5101742551995558 0.42758880618110406 94 | epoch: 33 95 | loss: 0.02752583069416384 96 | f1, pre, rec: 0.4634157788520255 0.46166074434262705 0.46518420804674016 97 | epoch: 34 98 | loss: 0.025788033933689197 99 | f1, pre, rec: 0.4717325227963575 0.4682510211659908 0.47526618298313883 100 | epoch: 35 101 | loss: 0.024115538876503705 102 | f1, pre, rec: 0.4683687943262462 0.47005789124039604 0.4666917930839587 103 | epoch: 36 104 | loss: 0.022993005256478984 105 | f1, pre, rec: 0.4663262825188479 0.4818127009646354 0.45180439084142604 106 | epoch: 37 107 | loss: 0.022681940294181305 108 | f1, pre, rec: 0.4721851781291725 0.4774150052971253 0.4670686893432633 109 | epoch: 38 110 | loss: 0.021437988535811504 111 | f1, pre, rec: 0.4659728550189491 0.47450673959758277 0.45774050692547386 112 | epoch: 39 113 | loss: 0.02119144581258297 114 | f1, pre, rec: 0.4752872515046555 0.4604647053626693 0.491095825873933 115 | epoch: 40 116 | loss: 0.01966104614858826 117 | f1, pre, rec: 0.4748313678704219 0.46279069767442343 0.48751531141053905 118 | epoch: 41 119 | loss: 0.019432777492329478 120 | f1, pre, rec: 0.48374043196334177 0.5012123661345776 0.46744558560256794 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | json 2 | tqdm 3 | matplotlib 4 | pytorch_pretrained_bert 5 | torch -------------------------------------------------------------------------------- /roberta_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import unicodedata 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from pytorch_pretrained_bert import BertModel, BertTokenizer 11 | from torch.utils.data import DataLoader, Dataset 12 | from tqdm import tqdm 13 | 14 | train_l = [] 15 | val_f1_r = [] 16 | 17 | def record(): 18 | f1 = open("record/val_f1_1e4.txt","w") 19 | f2 = open("record/loss_1e4.txt",'w') 20 | 21 | for i in range(len(train_l)): 22 | f1.write(str(val_f1_r[i])+"\n") 23 | f2.write(str(train_l[i])+"\n") 24 | 25 | f1.close() 26 | f2.close() 27 | 28 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 29 | BERT_PATH = "chinese_roberta_wwm_ext_pytorch/" 30 | maxlen = 256 31 | 32 | # 读取json文件,即输入text及对应spo格式 33 | def load_data(filename): 34 | D = [] 35 | with open(filename, encoding='utf-8') as f: 36 | for l in tqdm(f): 37 | l = json.loads(l) 38 | d = {'text': l['text'], 'spo_list': []} 39 | for spo in l['spo_list']: 40 | for k, v in spo['object'].items(): 41 | d['spo_list'].append( 42 | (spo['subject'], spo['predicate'], v) 43 | ) 44 | D.append(d) 45 | return D 46 | 47 | # 加载数据集 48 | train_data = load_data('CMeIE/CMeIE_train.json') 49 | valid_data = load_data('CMeIE/CMeIE_dev.json') 50 | 51 | def search(pattern, sequence): 52 | """从sequence中寻找子串pattern 53 | 如果找到,返回第一个下标;否则返回-1。 54 | """ 55 | n = len(pattern) 56 | for i in range(len(sequence)): 57 | if sequence[i:i + n] == pattern: 58 | return i 59 | return -1 60 | 61 | train_data_new = [] # 创建新的训练集,把结束位置超过250的文本去除,可见并没有去除多少 62 | for data in tqdm(train_data): 63 | flag = 1 64 | for s, p, o in data['spo_list']: 65 | s_begin = search(s, data['text']) 66 | o_begin = search(o, data['text']) 67 | if s_begin == -1 or o_begin == -1 or s_begin + len(s) > 250 or o_begin + len(o) > 250: 68 | flag = 0 69 | break 70 | if flag == 1: 71 | train_data_new.append(data) 72 | print(len(train_data_new)) 73 | 74 | # 读取schema 75 | with open('CMeIE/schema.json', encoding='utf-8') as f: 76 | id2predicate, predicate2id, n = {}, {}, 0 77 | predicate2type = {} 78 | for l in f: 79 | l = json.loads(l) 80 | predicate2type[l['predicate']] = (l['subject_type'], l['object_type']) 81 | key = l['predicate'] 82 | if key not in predicate2id: 83 | id2predicate[n] = key 84 | predicate2id[key] = n 85 | n += 1 86 | print(len(predicate2id)) 87 | 88 | 89 | class OurTokenizer(BertTokenizer): 90 | def tokenize(self, text): 91 | R = [] 92 | for c in text: 93 | if c in self.vocab: 94 | R.append(c) 95 | elif self._is_whitespace(c): 96 | R.append('[unused1]') 97 | else: 98 | R.append('[UNK]') 99 | return R 100 | 101 | def _is_whitespace(self, char): 102 | if char == " " or char == "\t" or char == "\n" or char == "\r": 103 | return True 104 | cat = unicodedata.category(char) 105 | if cat == "Zs": 106 | return True 107 | return False 108 | 109 | # 初始化分词器 110 | tokenizer = OurTokenizer(vocab_file=BERT_PATH + "vocab.txt") 111 | 112 | class TorchDataset(Dataset): 113 | def __init__(self, data): 114 | self.data = data 115 | 116 | def __getitem__(self, i): 117 | t = self.data[i] 118 | x = tokenizer.tokenize(t['text']) 119 | x = ["[CLS]"] + x + ["[SEP]"] 120 | token_ids = tokenizer.convert_tokens_to_ids(x) 121 | seg_ids = [0] * len(token_ids) 122 | assert len(token_ids) == len(t['text'])+2 123 | spoes = {} 124 | for s, p, o in t['spo_list']: 125 | s = tokenizer.tokenize(s) 126 | s = tokenizer.convert_tokens_to_ids(s) 127 | p = predicate2id[p] # 关系id 128 | o = tokenizer.tokenize(o) 129 | o = tokenizer.convert_tokens_to_ids(o) 130 | s_idx = search(s, token_ids) # subject起始位置 131 | o_idx = search(o, token_ids) # object起始位置 132 | 133 | if s_idx != -1 and o_idx != -1: 134 | s = (s_idx, s_idx + len(s) - 1) 135 | o = (o_idx, o_idx + len(o) - 1, p) # 同时预测o和p 136 | if s not in spoes: 137 | spoes[s] = [] # 可以一个subject多个object 138 | spoes[s].append(o) 139 | 140 | if spoes: 141 | sub_labels = np.zeros((len(token_ids), 2)) 142 | for s in spoes: 143 | sub_labels[s[0], 0] = 1 144 | sub_labels[s[1], 1] = 1 145 | # 随机选一个subject 146 | start, end = np.array(list(spoes.keys())).T 147 | start = np.random.choice(start) 148 | end = sorted(end[end >= start])[0] 149 | sub_ids = (start, end) 150 | obj_labels = np.zeros((len(token_ids), len(predicate2id), 2)) 151 | for o in spoes.get(sub_ids, []): 152 | obj_labels[o[0], o[2], 0] = 1 153 | obj_labels[o[1], o[2], 1] = 1 154 | 155 | token_ids = self.sequence_padding(token_ids, maxlen=maxlen) 156 | seg_ids = self.sequence_padding(seg_ids, maxlen=maxlen) 157 | sub_labels = self.sequence_padding(sub_labels, maxlen=maxlen, padding=np.zeros(2)) 158 | sub_ids = np.array(sub_ids) 159 | obj_labels = self.sequence_padding(obj_labels, maxlen=maxlen, 160 | padding=np.zeros((len(predicate2id), 2))) 161 | 162 | return (torch.LongTensor(token_ids), torch.LongTensor(seg_ids), torch.LongTensor(sub_ids), 163 | torch.LongTensor(sub_labels), torch.LongTensor(obj_labels) ) 164 | 165 | def __len__(self): 166 | data_len = len(self.data) 167 | return data_len 168 | 169 | def sequence_padding(self, x, maxlen, padding=0): 170 | output = np.concatenate([x, [padding]*(maxlen-len(x))]) if len(x) 254: 247 | text = text[:254] 248 | tokens = tokenizer.tokenize(text) 249 | tokens = ["[CLS]"] + tokens + ["[SEP]"] 250 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 251 | assert len(token_ids) == len(text) + 2 252 | seg_ids = [0] * len(token_ids) 253 | 254 | sub_preds = model(torch.LongTensor([token_ids]).to(device), 255 | torch.LongTensor([seg_ids]).to(device)) 256 | sub_preds = sub_preds.detach().cpu().numpy() # [1, maxlen, 2] 257 | # print(sub_preds[0,]) 258 | start = np.where(sub_preds[0, :, 0] > 0.2)[0] 259 | end = np.where(sub_preds[0, :, 1] > 0.2)[0] 260 | # print(start, end) 261 | tmp_print = [] 262 | subjects = [] 263 | for i in start: 264 | j = end[end>=i] 265 | if len(j) > 0: 266 | j = j[0] 267 | subjects.append((i, j)) 268 | tmp_print.append(text[i-1: j]) 269 | 270 | if subjects: 271 | spoes = [] 272 | token_ids = np.repeat([token_ids], len(subjects), 0) # [len_subjects, seqlen] 273 | seg_ids = np.repeat([seg_ids], len(subjects), 0) 274 | subjects = np.array(subjects) # [len_subjects, 2] 275 | # 传入subject 抽取object和predicate 276 | _, object_preds = model(torch.LongTensor(token_ids).to(device), 277 | torch.LongTensor(seg_ids).to(device), 278 | torch.LongTensor(subjects).to(device)) 279 | object_preds = object_preds.detach().cpu().numpy() 280 | # print(object_preds.shape) 281 | for sub, obj_pred in zip(subjects, object_preds): 282 | # obj_pred [maxlen, 55, 2] 283 | start = np.where(obj_pred[:, :, 0] > 0.2) 284 | end = np.where(obj_pred[:, :, 1] > 0.2) 285 | for _start, predicate1 in zip(*start): 286 | for _end, predicate2 in zip(*end): 287 | if _start <= _end and predicate1 == predicate2: 288 | spoes.append( 289 | ((sub[0]-1, sub[1]-1), predicate1, (_start-1, _end-1)) 290 | ) 291 | break 292 | return [(text[s[0]:s[1]+1], id2predicate[p], text[o[0]:o[1]+1]) 293 | for s, p, o in spoes] 294 | else: 295 | return [] 296 | 297 | def evaluate(data, model, device): 298 | """评估函数,计算f1、precision、recall 299 | """ 300 | X, Y, Z = 1e-10, 1e-10, 1e-10 301 | f = open('CMeIE/dev_pred.json', 'w', encoding='utf-8') 302 | pbar = tqdm() 303 | for d in data: 304 | R = extract_spoes(d['text'], model, device) 305 | T = d['spo_list'] 306 | # print(R, T) 307 | R = set(R) 308 | T = set(T) 309 | X += len(R & T) 310 | Y += len(R) 311 | Z += len(T) 312 | f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 313 | pbar.update() 314 | pbar.set_description( 315 | 'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall) 316 | ) 317 | s = json.dumps({ 318 | 'text': d['text'], 319 | 'spo_list': list(T), 320 | 'spo_list_pred': list(R), 321 | 'new': list(R - T), 322 | 'lack': list(T - R), 323 | }, ensure_ascii=False, indent=4) 324 | f.write(s + '\n') 325 | pbar.close() 326 | f.close() 327 | 328 | return f1, precision, recall 329 | 330 | def train(model, train_loader, optimizer, epoches, device): 331 | f1_max = 0.5428 332 | for _ in range(epoches): 333 | print('epoch: ', _ + 1) 334 | start = time.time() 335 | train_loss_sum = 0.0 336 | for batch_idx, x in tqdm(enumerate(train_loader)): 337 | token_ids, seg_ids, sub_ids = x[0].to(device), x[1].to(device), x[2].to(device) 338 | mask = (token_ids > 0).float() 339 | mask = mask.to(device) # zero-mask 340 | sub_labels, obj_labels = x[3].float().to(device), x[4].float().to(device) 341 | sub_preds, obj_preds = model(token_ids, seg_ids, sub_ids) 342 | # (batch_size, maxlen, 2), (batch_size, maxlen, 44, 2) 343 | 344 | # 计算loss 345 | loss_sub = F.binary_cross_entropy(sub_preds, sub_labels, reduction='none') #[bs, ml, 2] 346 | loss_sub = torch.mean(loss_sub, 2) # (batch_size, maxlen) 347 | loss_sub = torch.sum(loss_sub * mask) / torch.sum(mask) 348 | loss_obj = F.binary_cross_entropy(obj_preds, obj_labels, reduction='none') # [bs, ml, 44, 2] 349 | loss_obj = torch.sum(torch.mean(loss_obj, 3), 2) # (bs, maxlen) 350 | loss_obj = torch.sum(loss_obj * mask) / torch.sum(mask) 351 | loss = loss_sub + loss_obj 352 | 353 | optimizer.zero_grad() 354 | loss.backward() 355 | optimizer.step() 356 | train_loss_sum += loss.cpu().item() 357 | if (batch_idx + 1) % 300 == 0: 358 | print('loss: ', train_loss_sum / (batch_idx+1), 'time: ', time.time() - start) 359 | train_l.append((_ + 1,train_loss_sum / (batch_idx+1))) 360 | 361 | 362 | with torch.no_grad(): 363 | val_f1, pre, rec = evaluate(valid_data, net, DEVICE) 364 | print("f1, pre, rec: ", val_f1, pre, rec) 365 | if val_f1>f1_max: 366 | torch.save(net.state_dict(), "CMeIE/roberta.pth") 367 | f1_max = val_f1 368 | 369 | val_f1_r.append((_ + 1,val_f1)) 370 | 371 | # 如果运行完一次可以将其还原 372 | # net.load_state_dict(torch.load("CMeIE/roberta.pth")) 373 | train(net, train_loader, optimizer, 5, DEVICE) 374 | 375 | def combine_spoes(spoes): 376 | """ 377 | """ 378 | new_spoes = {} 379 | for s, p, o in spoes: 380 | p1 = p 381 | p2 = '@value' 382 | if (s, p1) in new_spoes: 383 | new_spoes[(s, p1)][p2] = o 384 | else: 385 | new_spoes[(s, p1)] = {p2: o} 386 | 387 | return [(k[0], k[1], v) for k, v in new_spoes.items()] 388 | 389 | def predict_to_file(in_file, out_file): 390 | """预测结果到文件,方便提交 391 | """ 392 | fw = open(out_file, 'w', encoding='utf-8') 393 | with open(in_file, encoding='utf-8') as fr: 394 | for l in tqdm(fr): 395 | l = json.loads(l) 396 | spoes = combine_spoes(extract_spoes(l['text'], net, DEVICE)) 397 | spoes = [{ 398 | 'subject': spo[0], 399 | 'subject_type': predicate2type[spo[1]][0], 400 | 'predicate': spo[1], 401 | 'object': spo[2], 402 | 'object_type': { 403 | k: predicate2type[spo[1]][1] 404 | for k in spo[2] 405 | } 406 | } 407 | for spo in spoes] 408 | l['spo_list'] = spoes 409 | s = json.dumps(l, ensure_ascii=False) 410 | fw.write(s + '\n') 411 | fw.close() 412 | 413 | predict_to_file('CMeIE/CMeIE_test.json', 'CMeIE/RE_pred.json') 414 | record() 415 | --------------------------------------------------------------------------------