├── README.md ├── competition_predict.py ├── convert_test_data.py ├── data ├── mid_data │ ├── crf_ent2id.json │ ├── mrc_ent2id.json │ └── span_ent2id.json └── raw_data │ ├── dev.json │ ├── pseudo.json │ ├── stack.json │ ├── test.json │ └── train.json ├── main.py ├── md_files ├── 1.png ├── 10.png ├── 11.png ├── 12.png ├── 13.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png └── 9.png ├── run.sh └── src ├── preprocess ├── __pycache__ │ └── processor.cpython-36.pyc ├── convert_raw_data.py └── processor.py └── utils ├── __pycache__ ├── attack_train_utils.cpython-36.pyc ├── dataset_utils.cpython-36.pyc ├── evaluator.cpython-36.pyc ├── functions_utils.cpython-36.pyc ├── model_utils.cpython-36.pyc ├── options.cpython-36.pyc └── trainer.cpython-36.pyc ├── attack_train_utils.py ├── dataset_utils.py ├── evaluator.py ├── functions_utils.py ├── model_utils.py ├── options.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Chinese-DeepNER-Pytorch 2 | 3 | ## 天池中药说明书实体识别挑战冠军方案开源 4 | 5 | ### 贡献者: 6 | zxx飞翔的鱼: https://github.com/z814081807 7 | 8 | 我是蛋糕王:https://github.com/WuHuRestaurant 9 | 10 | 数青峰:https://github.com/zchaizju 11 | 12 | ### 后续官方开放数据集后DeepNER项目会进行优化升级,包含完整的数据处理、训练、验证、测试、部署流程,提供详细的代码注释、模型介绍、实验结果,提供更普适的基于预训练的中文命名实体识别方案,开箱即用,欢迎Star! 13 | 14 | (代码框架基于**pytorch and transformers**, 框架**复用性、解耦性、易读性**较高,很容易修改迁移至其他NLP任务中) 15 | 16 | ## 环境 17 | 18 | ```python 19 | python3.7 20 | pytorch==1.6.0 + 21 | transformers==2.10.0 22 | pytorch-crf==0.7.2 23 | ``` 24 | 25 | ## 项目目录说明 26 | 27 | ```shell 28 | DeepNER 29 | │ 30 | ├── data # 数据文件夹 31 | │ ├── mid_data # 存放一些中间数据 32 | │ │ ├── crf_ent2id.json # crf 模型的 schema 33 | │ │ └── span_ent2id.json # span 模型的 schema 34 | │ │ └── mrc_ent2id.json # mrc 模型的 schema 35 | │ 36 | │ ├── raw_data # 转换后的数据 37 | │ │ ├── dev.json # 转换后的验证集 38 | │ │ ├── test.json # 转换后的初赛测试集 39 | │ │ ├── pseudo.json # 转换后的半监督数据 40 | │ │ ├── stack.json # 转换后的全体数据 41 | │ └── └── train.json # 转换后的训练集 42 | │ 43 | ├── out # 存放训练好的模型 44 | │ ├── ... 45 | │ └── ... 46 | │ 47 | ├── src 48 | │ ├── preprocess 49 | │ │ ├── convert_raw_data.py # 处理转换原始数据 50 | │ │ └── processor.py # 转换数据为 Bert 模型的输入 51 | │ ├── utils 52 | │ │ ├── attack_train_utils.py # 对抗训练 FGM / PGD 53 | │ │ ├── dataset_utils.py # torch Dataset 54 | │ │ ├── evaluator.py # 模型评估 55 | │ │ ├── functions_utils.py # 跨文件调用的一些 functions 56 | │ │ ├── model_utils.py # Span & CRF & MRC model (pytorch) 57 | │ │ ├── options.py # 命令行参数 58 | │ | └── trainer.py # 训练器 59 | | 60 | ├── competition_predict.py # 复赛数据推理并提交 61 | ├── README.md # ... 62 | ├── convert_test_data.py # 将复赛 test 转化成 json 格式 63 | ├── run.sh # 运行脚本 64 | └── main.py # main 函数 (主要用于训练/评估) 65 | ``` 66 | 67 | ## 使用说明 68 | 69 | ### 预训练使用说明 70 | 71 | * 腾讯预训练模型 Uer-large(24层): https://github.com/dbiir/UER-py/wiki/Modelzoo 72 | 73 | * 哈工大预训练模型 :https://github.com/ymcui/Chinese-BERT-wwm 74 | 75 | 百度云下载链接: 76 | 77 | 链接:https://pan.baidu.com/s/1axdkovbzGaszl8bXIn4sPw 78 | 提取码:jjba 79 | 80 | (注意:需人工将 vocab.txt 中两个 [unused] 转换成 [INV] 和 [BLANK]) 81 | 82 | tips: 推荐使用 uer、roberta-wwm、robert-wwm-large 83 | 84 | ### 数据转换 85 | 86 | **注:已提供转换好的数据 无需运行** 87 | 88 | ```python 89 | python src/preprocessing/convert_raw_data.py 90 | ``` 91 | 92 | ### 训练阶段 93 | 94 | ```shell 95 | bash run.sh 96 | ``` 97 | 98 | **注:脚本中指定的 BERT_DIR 指BERT所在文件夹,需要把 BERT 下载到指定文件夹中** 99 | 100 | ##### BERT-CRF模型训练 101 | 102 | ```python 103 | task_type='crf' 104 | mode='train' or 'stack' train:单模训练与验证 ; stack:5折训练与验证 105 | 106 | swa_start: swa 模型权重平均开始的 epoch 107 | attack_train: 'pgd' / 'fgm' / '' 对抗训练 fgm 训练速度慢一倍, pgd 慢两倍,pgd 本次数据集效果明显 108 | ``` 109 | 110 | ##### BERT-SPAN模型训练 111 | 112 | ```python 113 | task_type='span' 114 | mode:同上 115 | attack_train: 同上 116 | loss_type: 'ce':交叉熵; 'ls_ce':label_smooth; 'focal': focal loss 117 | ``` 118 | 119 | ##### BERT-MRC模型训练 120 | 121 | ```python 122 | task_type='mrc' 123 | mode:同上 124 | attack_train: 同上 125 | loss_type: 同上 126 | ``` 127 | 128 | ### 预测复赛 test 文件 (上述模型训练完成后) 129 | 130 | **注:暂无数据运行,等待官方数据开源后可运行** 131 | 132 | ```shell 133 | # convert_test_data 134 | python convert_test_data.py 135 | # predict 136 | python competition_predict.py 137 | ``` 138 | 139 | # 赛题背景 140 | ## 任务描述 141 | 人工智能加速了中医药领域的传承创新发展,其中中医药文本的信息抽取部分是构建中医药知识图谱的核心部分,为上层应用如临床辅助诊疗系统的构建(CDSS)等奠定了基础。本次NER挑战需要抽取中药药品说明书中的关键信息,包括药品、药物成分、疾病、症状、证候等13类实体,构建中医药药品知识库。 142 | 143 | ## 数据探索分析 144 | 本次竞赛训练数据有三个特点: 145 | 146 | - 中药药品说明书以长文本居多 147 | 148 | - 医疗场景下的标注样本不足 149 | 150 | - 标签分布不平衡 151 | 152 | 153 | # 核心思路 154 | ## 数据预处理 155 | 首先对说明书文本进行预清洗与长文本切分。预清洗部分对无效字符进行过滤。针对长文本问题,采用两级文本切分的策略。切分后的句子可能过短,将短文本归并,使得归并后的文本长度不超过设置的最大长度。此外,利用全部标注数据构造实体知识库,作为领域先验词典。 156 | 157 | ## Baseline: BERT-CRF 158 | 159 | 160 | - Baseline 细节 161 | - 预训练模型:选用 UER-large-24 layer[1],UER在RoBerta-wwm 框架下采用大规模优质中文语料继续训练,CLUE 任务中单模第一 162 | - 差分学习率:BERT层学习率2e-5;其他层学习率2e-3 163 | - 参数初始化:模型其他模块与BERT采用相同的初始化方式 164 | - 滑动参数平均:加权平均最后几个epoch模型的权重,得到更加平滑和表现更优的模型 165 | - Baseline bad-case分析 166 | 167 | 168 | ## 优化1:对抗训练 169 | - 动机:采用对抗训练缓解模型鲁棒性差的问题,提升模型泛化能力 170 | - 对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力 171 | - Fast Gradient Method (FGM):对embedding层在梯度方向添加扰动 172 | - Projected Gradient Descent (PGD) [2]:迭代扰动,每次扰动被投影到规定范围内 173 | 174 | ## 优化2:混合精度训练(FP16) 175 | - 动机:对抗训练降低了计算效率,使用混合精度训练优化训练耗时 176 | - 混合精度训练 177 | - 在内存中用FP16做存储和乘法来加速 178 | - 用FP32做累加避免舍入误差 179 | - 损失放大 180 | - 反向传播前扩大2^k倍loss,防止loss下溢出 181 | - 反向传播后将权重梯度还原 182 | 183 | ## 优化3:多模型融合 184 | - 动机:baseline 错误集中于歧义性错误,采用多级医学命名实体识别系统以消除歧义性 185 | - 方法:差异化多级模型融合系统 186 | - 模型框架差异化:BERT-CRF & BERT-SPAN & BERT-MRC 187 | - 训练数据差异化:更换随机种子、更换句子切分长度(256、512) 188 | - 多级模型融合策略 189 | 190 | - 融合模型1——BERT-SPAN 191 | - 采用SPAN指针的形式替代CRF模块,加快训练速度 192 | - 以半指针-半标注的结构预测实体的起始位置,同时标注过程中给出实体类别 193 | - 采用严格解码形式,重叠实体选取logits最大的一个,保证准确率 194 | - 使用label smooth缓解过拟合问题 195 | 196 | 197 | - 融合模型2——BERT-MRC 198 | - 基于阅读理解的方式处理NER任务 199 | - query:实体类型的描述来作为query 200 | - doc:分句后的原始文本作为doc 201 | - 针对每一种类型构造一个样本,训练时有大量负样本,可以随机选取30%加入训练,其余丢弃,保证效率 202 | - 预测时对每一类都需构造一次样本,对解码输出不做限制,保证召回率 203 | - 使用label smooth缓解过拟合问题 204 | - MRC在本次数据集上精度表现不佳,且训练和推理效率较低,仅作为提升召回率的方案,提供代码仅供学习,不推荐日常使用 205 | 206 | 207 | - 多级融合策略 208 | - CRF/SPAN/MRC 5折交叉验证得到的模型进行第一级概率融合,将 logits 平均后解码实体 209 | - CRF/SPAN/MRC 概率融合后的模型进行第二级投票融合,获取最终结果 210 | 211 | 212 | ## 优化4:半监督学习 213 | - 动机:为了缓解医疗场景下的标注语料稀缺的问题, 我们使用半监督学习(伪标签)充分利用未标注的500条初赛测试集 214 | - 策略:动态伪标签 215 | - 首先使用原始标注数据训练一个基准模型M 216 | - 使用基准模型M对初赛测试集进行预测得到伪标签 217 | - 将伪标签加入训练集,赋予伪标签一个动态可学习权重(图中alpha),加入真实标签数据中共同训练得到模型M’ 218 | 219 | - tips:使用多模融合的基准模型减少伪标签的噪音;权重也可以固定,选取需多尝试哪个效果好,本质上是降低伪标签的loss权重,是缓解伪标签噪音的一种方法。 220 | 221 | ## 其他无明显提升的尝试方案 222 | - 取BERT后四层动态加权输出,无明显提升 223 | - BERT 输出后加上BiLSTM / IDCNN 模块,过拟合严重,训练速度大大降低 224 | - 数据增强,对同类实体词进行随机替换,以扩充训练数据 225 | - BERT-SPAN / MRC 模型采用focal loss / dice loss 等缓解标签不平衡 226 | - 利用构造的领域词典修正模型输出 227 | 228 | ## 最终线上成绩72.90%,复赛Rank 1,决赛Rank 1 229 | 230 | 231 | # Ref 232 | [1] Zhao et al., UER: An Open-Source Toolkit for Pre-training Models, EMNLP-IJCNLP, 2019. 233 | [2] Madry et al., Towards Deep Learning Models Resistant to Adversarial Attacks, ICLR, 2018. 234 | -------------------------------------------------------------------------------- /competition_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from collections import defaultdict 5 | from transformers import BertTokenizer 6 | from src.utils.model_utils import CRFModel, SpanModel, EnsembleCRFModel, EnsembleSpanModel 7 | from src.utils.evaluator import crf_decode, span_decode 8 | from src.utils.functions_utils import load_model_and_parallel, ensemble_vote 9 | from src.preprocess.processor import cut_sent, fine_grade_tokenize 10 | 11 | MID_DATA_DIR = "./data/mid_data" 12 | RAW_DATA_DIR = "./data/raw_data_random" 13 | SUBMIT_DIR = "./result" 14 | GPU_IDS = "0" 15 | 16 | LAMBDA = 0.3 17 | THRESHOLD = 0.9 18 | MAX_SEQ_LEN = 512 19 | 20 | TASK_TYPE = "crf" # choose crf or span 21 | VOTE = True # choose True or False 22 | VERSION = "mixed" # choose single or ensemble or mixed ; if mixed VOTE and TAST_TYPE is useless. 23 | 24 | # single_predict 25 | BERT_TYPE = "uer_large" # roberta_wwm / ernie_1 / uer_large 26 | 27 | BERT_DIR = f"./bert/torch_{BERT_TYPE}" 28 | with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f: 29 | CKPT_PATH = f.read().strip() 30 | 31 | # ensemble_predict 32 | BERT_DIR_LIST = ["./bert/torch_uer_large", "./bert/torch_roberta_wwm"] 33 | 34 | with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f: 35 | ENSEMBLE_DIR_LIST = f.readlines() 36 | print('ENSEMBLE_DIR_LIST:{}'.format(ENSEMBLE_DIR_LIST)) 37 | 38 | 39 | # mixed_predict 40 | MIX_BERT_DIR = "./bert/torch_uer_large" 41 | 42 | with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f: 43 | MIX_DIR_LIST = f.readlines() 44 | print('MIX_DIR_LIST:{}'.format(MIX_DIR_LIST)) 45 | 46 | 47 | def prepare_info(): 48 | info_dict = {} 49 | with open(os.path.join(MID_DATA_DIR, f'{TASK_TYPE}_ent2id.json'), encoding='utf-8') as f: 50 | ent2id = json.load(f) 51 | 52 | with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f: 53 | info_dict['examples'] = json.load(f) 54 | 55 | info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()} 56 | 57 | info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt')) 58 | 59 | return info_dict 60 | 61 | 62 | def mixed_prepare_info(mixed='crf'): 63 | info_dict = {} 64 | with open(os.path.join(MID_DATA_DIR, f'{mixed}_ent2id.json'), encoding='utf-8') as f: 65 | ent2id = json.load(f) 66 | 67 | with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f: 68 | info_dict['examples'] = json.load(f) 69 | 70 | info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()} 71 | 72 | info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt')) 73 | 74 | return info_dict 75 | 76 | 77 | def base_predict(model, device, info_dict, ensemble=False, mixed=''): 78 | labels = defaultdict(list) 79 | 80 | tokenizer = info_dict['tokenizer'] 81 | id2ent = info_dict['id2ent'] 82 | 83 | with torch.no_grad(): 84 | for _ex in info_dict['examples']: 85 | ex_idx = _ex['id'] 86 | raw_text = _ex['text'] 87 | 88 | if not len(raw_text): 89 | labels[ex_idx] = [] 90 | print('{}为空'.format(ex_idx)) 91 | continue 92 | 93 | sentences = cut_sent(raw_text, MAX_SEQ_LEN) 94 | 95 | start_index = 0 96 | 97 | for sent in sentences: 98 | 99 | sent_tokens = fine_grade_tokenize(sent, tokenizer) 100 | 101 | encode_dict = tokenizer.encode_plus(text=sent_tokens, 102 | max_length=MAX_SEQ_LEN, 103 | is_pretokenized=True, 104 | pad_to_max_length=False, 105 | return_tensors='pt', 106 | return_token_type_ids=True, 107 | return_attention_mask=True) 108 | 109 | model_inputs = {'token_ids': encode_dict['input_ids'], 110 | 'attention_masks': encode_dict['attention_mask'], 111 | 'token_type_ids': encode_dict['token_type_ids']} 112 | 113 | for key in model_inputs: 114 | model_inputs[key] = model_inputs[key].to(device) 115 | 116 | if ensemble: 117 | if TASK_TYPE == 'crf': 118 | if VOTE: 119 | decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD) 120 | else: 121 | pred_tokens = model.predict(model_inputs)[0] 122 | decode_entities = crf_decode(pred_tokens, sent, id2ent) 123 | else: 124 | if VOTE: 125 | decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD) 126 | else: 127 | start_logits, end_logits = model.predict(model_inputs) 128 | start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)] 129 | end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)] 130 | 131 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent) 132 | 133 | else: 134 | 135 | if mixed: 136 | if mixed == 'crf': 137 | pred_tokens = model(**model_inputs)[0][0] 138 | decode_entities = crf_decode(pred_tokens, sent, id2ent) 139 | else: 140 | start_logits, end_logits = model(**model_inputs) 141 | 142 | start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)] 143 | end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)] 144 | 145 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent) 146 | 147 | else: 148 | if TASK_TYPE == 'crf': 149 | pred_tokens = model(**model_inputs)[0][0] 150 | decode_entities = crf_decode(pred_tokens, sent, id2ent) 151 | else: 152 | start_logits, end_logits = model(**model_inputs) 153 | 154 | start_logits = start_logits[0].cpu().numpy()[1:1+len(sent)] 155 | end_logits = end_logits[0].cpu().numpy()[1:1+len(sent)] 156 | 157 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent) 158 | 159 | 160 | for _ent_type in decode_entities: 161 | for _ent in decode_entities[_ent_type]: 162 | tmp_start = _ent[1] + start_index 163 | tmp_end = tmp_start + len(_ent[0]) 164 | 165 | assert raw_text[tmp_start: tmp_end] == _ent[0] 166 | 167 | labels[ex_idx].append((_ent_type, tmp_start, tmp_end, _ent[0])) 168 | 169 | start_index += len(sent) 170 | 171 | if not len(labels[ex_idx]): 172 | labels[ex_idx] = [] 173 | 174 | return labels 175 | 176 | 177 | def single_predict(): 178 | save_dir = os.path.join(SUBMIT_DIR, VERSION) 179 | 180 | if not os.path.exists(save_dir): 181 | os.makedirs(save_dir, exist_ok=True) 182 | 183 | info_dict = prepare_info() 184 | 185 | if TASK_TYPE == 'crf': 186 | model = CRFModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent'])) 187 | else: 188 | model = SpanModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent'])+1) 189 | 190 | print(f'Load model from {CKPT_PATH}') 191 | model, device = load_model_and_parallel(model, GPU_IDS, CKPT_PATH) 192 | model.eval() 193 | 194 | labels = base_predict(model, device, info_dict) 195 | 196 | for key in labels.keys(): 197 | with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f: 198 | if not len(labels[key]): 199 | print(key) 200 | f.write("") 201 | else: 202 | for idx, _label in enumerate(labels[key]): 203 | f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n') 204 | 205 | 206 | def ensemble_predict(): 207 | save_dir = os.path.join(SUBMIT_DIR, VERSION) 208 | if not os.path.exists(save_dir): 209 | os.makedirs(save_dir, exist_ok=True) 210 | 211 | info_dict = prepare_info() 212 | 213 | model_path_list = [x.strip() for x in ENSEMBLE_DIR_LIST] 214 | print('model_path_list:{}'.format(model_path_list)) 215 | 216 | device = torch.device(f'cuda:{GPU_IDS[0]}') 217 | 218 | 219 | if TASK_TYPE == 'crf': 220 | model = EnsembleCRFModel(model_path_list=model_path_list, 221 | bert_dir_list=BERT_DIR_LIST, 222 | num_tags=len(info_dict['id2ent']), 223 | device=device, 224 | lamb=LAMBDA) 225 | else: 226 | model = EnsembleSpanModel(model_path_list=model_path_list, 227 | bert_dir_list=BERT_DIR_LIST, 228 | num_tags=len(info_dict['id2ent'])+1, 229 | device=device) 230 | 231 | 232 | labels = base_predict(model, device, info_dict, ensemble=True) 233 | 234 | 235 | for key in labels.keys(): 236 | with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f: 237 | if not len(labels[key]): 238 | print(key) 239 | f.write("") 240 | else: 241 | for idx, _label in enumerate(labels[key]): 242 | f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n') 243 | 244 | def mixed_predict(): 245 | save_dir = os.path.join(SUBMIT_DIR, VERSION) 246 | 247 | if not os.path.exists(save_dir): 248 | os.makedirs(save_dir, exist_ok=True) 249 | 250 | model_path_list = [x.strip() for x in MIX_DIR_LIST] 251 | print('model_path_list:{}'.format(model_path_list)) 252 | 253 | all_labels = [] 254 | 255 | for i, model_path in enumerate(model_path_list): 256 | if i <= 4: 257 | info_dict = mixed_prepare_info(mixed='span') 258 | 259 | model = SpanModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']) + 1) 260 | print(f'Load model from {model_path}') 261 | model, device = load_model_and_parallel(model, GPU_IDS, model_path) 262 | model.eval() 263 | labels = base_predict(model, device, info_dict, ensemble=False, mixed='span') 264 | 265 | else: 266 | info_dict = mixed_prepare_info(mixed='crf') 267 | 268 | model = CRFModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent'])) 269 | print(f'Load model from {model_path}') 270 | model, device = load_model_and_parallel(model, GPU_IDS, model_path) 271 | model.eval() 272 | labels = base_predict(model, device, info_dict, ensemble=False, mixed='crf') 273 | 274 | all_labels.append(labels) 275 | 276 | labels = ensemble_vote(all_labels, THRESHOLD) 277 | 278 | # for key in labels.keys(): 279 | 280 | for key in range(1500, 1997): 281 | with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f: 282 | if not len(labels[key]): 283 | print(key) 284 | f.write("") 285 | else: 286 | for idx, _label in enumerate(labels[key]): 287 | f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n') 288 | 289 | 290 | 291 | if __name__ == '__main__': 292 | assert VERSION in ['single', 'ensemble', 'mixed'], 'VERSION mismatch' 293 | 294 | if VERSION == 'single': 295 | single_predict() 296 | elif VERSION == 'ensemble': 297 | if VOTE: 298 | print("————————开始投票:————————") 299 | ensemble_predict() 300 | 301 | elif VERSION == 'mixed': 302 | print("————————开始混合投票:————————") 303 | mixed_predict() 304 | 305 | # 压缩result.zip 306 | import zipfile 307 | 308 | def zip_file(src_dir): 309 | zip_name = src_dir + '.zip' 310 | z = zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) 311 | for dirpath, dirnames, filenames in os.walk(src_dir): 312 | fpath = dirpath.replace(src_dir, '') 313 | fpath = fpath and fpath + os.sep or '' 314 | for filename in filenames: 315 | z.write(os.path.join(dirpath, filename), fpath + filename) 316 | print('==压缩成功==') 317 | z.close() 318 | 319 | zip_file('./result') 320 | 321 | -------------------------------------------------------------------------------- /convert_test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import trange 4 | 5 | 6 | def save_info(data_dir, data, desc): 7 | with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f: 8 | json.dump(data, f, ensure_ascii=False, indent=2) 9 | 10 | 11 | def convert_test_data_to_json(test_dir, save_dir): 12 | 13 | test_examples = [] 14 | 15 | 16 | # process test examples 17 | for i in trange(1500, 1997): 18 | with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f: 19 | text = f.read() 20 | 21 | test_examples.append({'id': i, 22 | 'text': text}) 23 | 24 | save_info(save_dir, test_examples, 'test') 25 | 26 | 27 | if __name__ == '__main__': 28 | test_dir = './tcdata/juesai' 29 | save_dir = './data/raw_data_random' 30 | convert_test_data_to_json(test_dir, save_dir) 31 | print('测试数据转换完成') 32 | 33 | -------------------------------------------------------------------------------- /data/mid_data/crf_ent2id.json: -------------------------------------------------------------------------------- 1 | { 2 | "O": 0, 3 | "B-DRUG_GROUP": 1, 4 | "B-DRUG_DOSAGE": 2, 5 | "B-FOOD": 3, 6 | "B-DRUG_EFFICACY": 4, 7 | "B-FOOD_GROUP": 5, 8 | "B-SYMPTOM": 6, 9 | "B-DISEASE_GROUP": 7, 10 | "B-SYNDROME": 8, 11 | "B-PERSON_GROUP": 9, 12 | "B-DRUG_TASTE": 10, 13 | "B-DRUG_INGREDIENT": 11, 14 | "B-DRUG": 12, 15 | "B-DISEASE": 13, 16 | "I-DRUG_GROUP": 14, 17 | "I-DRUG_DOSAGE": 15, 18 | "I-FOOD": 16, 19 | "I-DRUG_EFFICACY": 17, 20 | "I-FOOD_GROUP": 18, 21 | "I-SYMPTOM": 19, 22 | "I-DISEASE_GROUP": 20, 23 | "I-SYNDROME": 21, 24 | "I-PERSON_GROUP": 22, 25 | "I-DRUG_TASTE": 23, 26 | "I-DRUG_INGREDIENT": 24, 27 | "I-DRUG": 25, 28 | "I-DISEASE": 26, 29 | "E-DRUG_GROUP": 27, 30 | "E-DRUG_DOSAGE": 28, 31 | "E-FOOD": 29, 32 | "E-DRUG_EFFICACY": 30, 33 | "E-FOOD_GROUP": 31, 34 | "E-SYMPTOM": 32, 35 | "E-DISEASE_GROUP": 33, 36 | "E-SYNDROME": 34, 37 | "E-PERSON_GROUP": 35, 38 | "E-DRUG_TASTE": 36, 39 | "E-DRUG_INGREDIENT": 37, 40 | "E-DRUG": 38, 41 | "E-DISEASE": 39, 42 | "S-DRUG_GROUP": 40, 43 | "S-DRUG_DOSAGE": 41, 44 | "S-FOOD": 42, 45 | "S-DRUG_EFFICACY": 43, 46 | "S-FOOD_GROUP": 44, 47 | "S-SYMPTOM": 45, 48 | "S-DISEASE_GROUP": 46, 49 | "S-SYNDROME": 47, 50 | "S-PERSON_GROUP": 48, 51 | "S-DRUG_TASTE": 49, 52 | "S-DRUG_INGREDIENT": 50, 53 | "S-DRUG": 51, 54 | "S-DISEASE": 52 55 | } -------------------------------------------------------------------------------- /data/mid_data/mrc_ent2id.json: -------------------------------------------------------------------------------- 1 | { 2 | "DRUG": "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。", 3 | "DRUG_INGREDIENT": "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。", 4 | "DISEASE": "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。", 5 | "SYMPTOM": "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。", 6 | "SYNDROME": "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。", 7 | "DISEASE_GROUP": "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。", 8 | "FOOD": "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。", 9 | "FOOD_GROUP": "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。", 10 | "PERSON_GROUP": "找出人群:中医药的适用及禁忌范围内相关特定人群。", 11 | "DRUG_GROUP": "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药", 12 | "DRUG_DOSAGE": "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。", 13 | "DRUG_TASTE": "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。", 14 | "DRUG_EFFICACY": "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀" 15 | } -------------------------------------------------------------------------------- /data/mid_data/span_ent2id.json: -------------------------------------------------------------------------------- 1 | { 2 | "DRUG_GROUP": 1, 3 | "DRUG_DOSAGE": 2, 4 | "FOOD": 3, 5 | "DRUG_EFFICACY": 4, 6 | "FOOD_GROUP": 5, 7 | "SYMPTOM": 6, 8 | "DISEASE_GROUP": 7, 9 | "SYNDROME": 8, 10 | "PERSON_GROUP": 9, 11 | "DRUG_TASTE": 10, 12 | "DRUG_INGREDIENT": 11, 13 | "DRUG": 12, 14 | "DISEASE": 13 15 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import json 4 | import logging 5 | from torch.utils.data import DataLoader 6 | from sklearn.model_selection import KFold 7 | from src.utils.trainer import train 8 | from src.utils.options import Args 9 | from src.utils.model_utils import build_model 10 | from src.utils.dataset_utils import NERDataset 11 | from src.utils.evaluator import crf_evaluation, span_evaluation, mrc_evaluation 12 | from src.utils.functions_utils import set_seed, get_model_path_list, load_model_and_parallel, get_time_dif 13 | from src.preprocess.processor import NERProcessor, convert_examples_to_features 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig( 17 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 18 | datefmt="%m/%d/%Y %H:%M:%S", 19 | level=logging.INFO 20 | ) 21 | 22 | def train_base(opt, train_examples, dev_examples=None): 23 | with open(os.path.join(opt.mid_data_dir, f'{opt.task_type}_ent2id.json'), encoding='utf-8') as f: 24 | ent2id = json.load(f) 25 | 26 | train_features = convert_examples_to_features(opt.task_type, train_examples, 27 | opt.max_seq_len, opt.bert_dir, ent2id)[0] 28 | 29 | train_dataset = NERDataset(opt.task_type, train_features, 'train', use_type_embed=opt.use_type_embed) 30 | 31 | if opt.task_type == 'crf': 32 | model = build_model('crf', opt.bert_dir, num_tags=len(ent2id), 33 | dropout_prob=opt.dropout_prob) 34 | elif opt.task_type == 'mrc': 35 | model = build_model('mrc', opt.bert_dir, 36 | dropout_prob=opt.dropout_prob, 37 | use_type_embed=opt.use_type_embed, 38 | loss_type=opt.loss_type) 39 | else: 40 | model = build_model('span', opt.bert_dir, num_tags=len(ent2id)+1, 41 | dropout_prob=opt.dropout_prob, 42 | loss_type=opt.loss_type) 43 | 44 | train(opt, model, train_dataset) 45 | 46 | if dev_examples is not None: 47 | 48 | dev_features, dev_callback_info = convert_examples_to_features(opt.task_type, dev_examples, 49 | opt.max_seq_len, opt.bert_dir, ent2id) 50 | 51 | dev_dataset = NERDataset(opt.task_type, dev_features, 'dev', use_type_embed=opt.use_type_embed) 52 | 53 | dev_loader = DataLoader(dev_dataset, batch_size=opt.eval_batch_size, 54 | shuffle=False, num_workers=0) 55 | 56 | dev_info = (dev_loader, dev_callback_info) 57 | 58 | model_path_list = get_model_path_list(opt.output_dir) 59 | 60 | metric_str = '' 61 | 62 | max_f1 = 0. 63 | max_f1_step = 0 64 | 65 | max_f1_path = '' 66 | 67 | for idx, model_path in enumerate(model_path_list): 68 | 69 | tmp_step = model_path.split('/')[-2].split('-')[-1] 70 | 71 | 72 | model, device = load_model_and_parallel(model, opt.gpu_ids[0], 73 | ckpt_path=model_path) 74 | 75 | if opt.task_type == 'crf': 76 | tmp_metric_str, tmp_f1 = crf_evaluation(model, dev_info, device, ent2id) 77 | elif opt.task_type == 'mrc': 78 | tmp_metric_str, tmp_f1 = mrc_evaluation(model, dev_info, device) 79 | else: 80 | tmp_metric_str, tmp_f1 = span_evaluation(model, dev_info, device, ent2id) 81 | 82 | logger.info(f'In step {tmp_step}:\n {tmp_metric_str}') 83 | 84 | metric_str += f'In step {tmp_step}:\n {tmp_metric_str}' + '\n\n' 85 | 86 | if tmp_f1 > max_f1: 87 | max_f1 = tmp_f1 88 | max_f1_step = tmp_step 89 | max_f1_path = model_path 90 | 91 | max_metric_str = f'Max f1 is: {max_f1}, in step {max_f1_step}' 92 | 93 | logger.info(max_metric_str) 94 | 95 | metric_str += max_metric_str + '\n' 96 | 97 | eval_save_path = os.path.join(opt.output_dir, 'eval_metric.txt') 98 | 99 | with open(eval_save_path, 'a', encoding='utf-8') as f1: 100 | f1.write(metric_str) 101 | 102 | with open('./best_ckpt_path.txt', 'a', encoding='utf-8') as f2: 103 | f2.write(max_f1_path + '\n') 104 | 105 | del_dir_list = [os.path.join(opt.output_dir, path.split('/')[-2]) 106 | for path in model_path_list if path != max_f1_path] 107 | 108 | import shutil 109 | for x in del_dir_list: 110 | shutil.rmtree(x) 111 | logger.info('{}已删除'.format(x)) 112 | 113 | 114 | def training(opt): 115 | if args.task_type == 'mrc': 116 | # 62 for mrc query 117 | processor = NERProcessor(opt.max_seq_len-62) 118 | else: 119 | processor = NERProcessor(opt.max_seq_len) 120 | 121 | train_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'train.json')) 122 | 123 | # add pseudo data to train data 124 | pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json')) 125 | train_raw_examples = train_raw_examples + pseudo_raw_examples 126 | 127 | train_examples = processor.get_examples(train_raw_examples, 'train') 128 | 129 | dev_examples = None 130 | if opt.eval_model: 131 | dev_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'dev.json')) 132 | dev_examples = processor.get_examples(dev_raw_examples, 'dev') 133 | 134 | train_base(opt, train_examples, dev_examples) 135 | 136 | 137 | def stacking(opt): 138 | logger.info('Start to KFold stack attribution model') 139 | 140 | if args.task_type == 'mrc': 141 | # 62 for mrc query 142 | processor = NERProcessor(opt.max_seq_len-62) 143 | else: 144 | processor = NERProcessor(opt.max_seq_len) 145 | 146 | kf = KFold(5, shuffle=True, random_state=42) 147 | 148 | stack_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'stack.json')) 149 | 150 | pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json')) 151 | 152 | base_output_dir = opt.output_dir 153 | 154 | for i, (train_ids, dev_ids) in enumerate(kf.split(stack_raw_examples)): 155 | logger.info(f'Start to train the {i} fold') 156 | train_raw_examples = [stack_raw_examples[_idx] for _idx in train_ids] 157 | 158 | # add pseudo data to train data 159 | train_raw_examples = train_raw_examples + pseudo_raw_examples 160 | train_examples = processor.get_examples(train_raw_examples, 'train') 161 | 162 | dev_raw_examples = [stack_raw_examples[_idx] for _idx in dev_ids] 163 | dev_info = processor.get_examples(dev_raw_examples, 'dev') 164 | 165 | tmp_output_dir = os.path.join(base_output_dir, f'v{i}') 166 | 167 | opt.output_dir = tmp_output_dir 168 | 169 | train_base(opt, train_examples, dev_info) 170 | 171 | if __name__ == '__main__': 172 | start_time = time.time() 173 | logging.info('----------------开始计时----------------') 174 | logging.info('----------------------------------------') 175 | 176 | args = Args().get_parser() 177 | 178 | assert args.mode in ['train', 'stack'], 'mode mismatch' 179 | assert args.task_type in ['crf', 'span', 'mrc'] 180 | 181 | args.output_dir = os.path.join(args.output_dir, args.bert_type) 182 | 183 | set_seed(args.seed) 184 | 185 | if args.attack_train != '': 186 | args.output_dir += f'_{args.attack_train}' 187 | 188 | if args.weight_decay: 189 | args.output_dir += '_wd' 190 | 191 | if args.use_fp16: 192 | args.output_dir += '_fp16' 193 | 194 | if args.task_type == 'span': 195 | args.output_dir += f'_{args.loss_type}' 196 | 197 | if args.task_type == 'mrc': 198 | if args.use_type_embed: 199 | args.output_dir += f'_embed' 200 | args.output_dir += f'_{args.loss_type}' 201 | 202 | args.output_dir += f'_{args.task_type}' 203 | 204 | if args.mode == 'stack': 205 | args.output_dir += '_stack' 206 | 207 | if not os.path.exists(args.output_dir): 208 | os.makedirs(args.output_dir, exist_ok=True) 209 | 210 | logger.info(f'{args.mode} {args.task_type} in max_seq_len {args.max_seq_len}') 211 | 212 | if args.mode == 'train': 213 | training(args) 214 | else: 215 | stacking(args) 216 | 217 | time_dif = get_time_dif(start_time) 218 | logging.info("----------本次容器运行时长:{}-----------".format(time_dif)) 219 | -------------------------------------------------------------------------------- /md_files/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/1.png -------------------------------------------------------------------------------- /md_files/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/10.png -------------------------------------------------------------------------------- /md_files/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/11.png -------------------------------------------------------------------------------- /md_files/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/12.png -------------------------------------------------------------------------------- /md_files/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/13.png -------------------------------------------------------------------------------- /md_files/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/2.png -------------------------------------------------------------------------------- /md_files/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/3.png -------------------------------------------------------------------------------- /md_files/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/4.png -------------------------------------------------------------------------------- /md_files/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/5.png -------------------------------------------------------------------------------- /md_files/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/6.png -------------------------------------------------------------------------------- /md_files/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/7.png -------------------------------------------------------------------------------- /md_files/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/8.png -------------------------------------------------------------------------------- /md_files/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/9.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export MID_DATA_DIR="./data/mid_data" 4 | export RAW_DATA_DIR="./data/raw_data" 5 | export OUTPUT_DIR="./out" 6 | 7 | export GPU_IDS="0" 8 | export BERT_TYPE="roberta_wwm" # roberta_wwm / roberta_wwm_large / uer_large 9 | export BERT_DIR="../bert/torch_$BERT_TYPE" 10 | 11 | export MODE="train" 12 | export TASK_TYPE="crf" 13 | 14 | python main.py \ 15 | --gpu_ids=$GPU_IDS \ 16 | --output_dir=$OUTPUT_DIR \ 17 | --mid_data_dir=$MID_DATA_DIR \ 18 | --mode=$MODE \ 19 | --task_type=$TASK_TYPE \ 20 | --raw_data_dir=$RAW_DATA_DIR \ 21 | --bert_dir=$BERT_DIR \ 22 | --bert_type=$BERT_TYPE \ 23 | --train_epochs=10 \ 24 | --swa_start=5 \ 25 | --attack_train="" \ 26 | --train_batch_size=24 \ 27 | --dropout_prob=0.1 \ 28 | --max_seq_len=512 \ 29 | --lr=2e-5 \ 30 | --other_lr=2e-3 \ 31 | --seed=123 \ 32 | --weight_decay=0.01 \ 33 | --loss_type='ls_ce' \ 34 | --eval_model \ 35 | #--use_fp16 -------------------------------------------------------------------------------- /src/preprocess/__pycache__/processor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/preprocess/__pycache__/processor.cpython-36.pyc -------------------------------------------------------------------------------- /src/preprocess/convert_raw_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import trange 4 | from sklearn.model_selection import train_test_split, KFold 5 | 6 | 7 | def save_info(data_dir, data, desc): 8 | with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f: 9 | json.dump(data, f, ensure_ascii=False, indent=2) 10 | 11 | 12 | def convert_data_to_json(base_dir, save_data=False, save_dict=False): 13 | stack_examples = [] 14 | pseudo_examples = [] 15 | test_examples = [] 16 | 17 | stack_dir = os.path.join(base_dir, 'train') 18 | pseudo_dir = os.path.join(base_dir, 'pseudo') 19 | test_dir = os.path.join(base_dir, 'test') 20 | 21 | 22 | # process train examples 23 | for i in trange(1000): 24 | with open(os.path.join(stack_dir, f'{i}.txt'), encoding='utf-8') as f: 25 | text = f.read() 26 | 27 | labels = [] 28 | with open(os.path.join(stack_dir, f'{i}.ann'), encoding='utf-8') as f: 29 | for line in f.readlines(): 30 | tmp_label = line.strip().split('\t') 31 | assert len(tmp_label) == 3 32 | tmp_mid = tmp_label[1].split() 33 | tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]] 34 | 35 | labels.append(tmp_label) 36 | tmp_label[2] = int(tmp_label[2]) 37 | tmp_label[3] = int(tmp_label[3]) 38 | 39 | assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i) 40 | 41 | stack_examples.append({'id': i, 42 | 'text': text, 43 | 'labels': labels, 44 | 'pseudo': 0}) 45 | 46 | 47 | # 构建实体知识库 48 | kf = KFold(10) 49 | entities = set() 50 | ent_types = set() 51 | for _now_id, _candidate_id in kf.split(stack_examples): 52 | now = [stack_examples[_id] for _id in _now_id] 53 | candidate = [stack_examples[_id] for _id in _candidate_id] 54 | now_entities = set() 55 | 56 | for _ex in now: 57 | for _label in _ex['labels']: 58 | ent_types.add(_label[1]) 59 | 60 | if len(_label[-1]) > 1: 61 | now_entities.add(_label[-1]) 62 | entities.add(_label[-1]) 63 | # print(len(now_entities)) 64 | for _ex in candidate: 65 | text = _ex['text'] 66 | candidate_entities = [] 67 | 68 | for _ent in now_entities: 69 | if _ent in text: 70 | candidate_entities.append(_ent) 71 | 72 | _ex['candidate_entities'] = candidate_entities 73 | assert len(ent_types) == 13 74 | 75 | # process test examples predicted by the preliminary model 76 | for i in trange(1000, 1500): 77 | with open(os.path.join(pseudo_dir, f'{i}.txt'), encoding='utf-8') as f: 78 | text = f.read() 79 | 80 | candidate_entities = [] 81 | for _ent in entities: 82 | if _ent in text: 83 | candidate_entities.append(_ent) 84 | 85 | labels = [] 86 | with open(os.path.join(pseudo_dir, f'{i}.ann'), encoding='utf-8') as f: 87 | for line in f.readlines(): 88 | tmp_label = line.strip().split('\t') 89 | assert len(tmp_label) == 3 90 | tmp_mid = tmp_label[1].split() 91 | tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]] 92 | 93 | labels.append(tmp_label) 94 | tmp_label[2] = int(tmp_label[2]) 95 | tmp_label[3] = int(tmp_label[3]) 96 | 97 | assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i) 98 | 99 | pseudo_examples.append({'id': i, 100 | 'text': text, 101 | 'labels': labels, 102 | 'candidate_entities': candidate_entities, 103 | 'pseudo': 1}) 104 | 105 | # process test examples 106 | for i in trange(1000, 1500): 107 | with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f: 108 | text = f.read() 109 | 110 | candidate_entities = [] 111 | for _ent in entities: 112 | if _ent in text: 113 | candidate_entities.append(_ent) 114 | 115 | test_examples.append({'id': i, 116 | 'text': text, 117 | 'candidate_entities': candidate_entities}) 118 | 119 | train, dev = train_test_split(stack_examples, shuffle=True, random_state=123, test_size=0.15) 120 | 121 | if save_data: 122 | save_info(base_dir, stack_examples, 'stack') 123 | save_info(base_dir, train, 'train') 124 | save_info(base_dir, dev, 'dev') 125 | save_info(base_dir, test_examples, 'test') 126 | 127 | save_info(base_dir, pseudo_examples, 'pseudo') 128 | 129 | if save_dict: 130 | ent_types = list(ent_types) 131 | span_ent2id = {_type: i+1 for i, _type in enumerate(ent_types)} 132 | 133 | ent_types = ['O'] + [p + '-' + _type for p in ['B', 'I', 'E', 'S'] for _type in list(ent_types)] 134 | crf_ent2id = {ent: i for i, ent in enumerate(ent_types)} 135 | 136 | mid_data_dir = os.path.join(os.path.split(base_dir)[0], 'mid_data') 137 | if not os.path.exists(mid_data_dir): 138 | os.mkdir(mid_data_dir) 139 | 140 | save_info(mid_data_dir, span_ent2id, 'span_ent2id') 141 | save_info(mid_data_dir, crf_ent2id, 'crf_ent2id') 142 | 143 | 144 | def build_ent2query(data_dir): 145 | # 利用比赛实体类型简介来描述 query 146 | ent2query = { 147 | # 药物 148 | 'DRUG': "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。", 149 | # 药物成分 150 | 'DRUG_INGREDIENT': "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。", 151 | # 疾病 152 | 'DISEASE': "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。", 153 | # 症状 154 | 'SYMPTOM': "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。", 155 | # 症候 156 | 'SYNDROME': "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。", 157 | # 疾病分组 158 | 'DISEASE_GROUP': "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。", 159 | # 食物 160 | 'FOOD': "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。", 161 | # 食物分组 162 | 'FOOD_GROUP': "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性," 163 | "同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。", 164 | # 人群 165 | 'PERSON_GROUP': "找出人群:中医药的适用及禁忌范围内相关特定人群。", 166 | # 药品分组 167 | 'DRUG_GROUP': "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药", 168 | # 药物剂量 169 | 'DRUG_DOSAGE': "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。", 170 | # 药物性味 171 | 'DRUG_TASTE': "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。", 172 | # 中药功效 173 | 'DRUG_EFFICACY': "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀" 174 | } 175 | 176 | with open(os.path.join(data_dir, 'mrc_ent2id.json'), 'w', encoding='utf-8') as f: 177 | json.dump(ent2query, f, ensure_ascii=False, indent=2) 178 | 179 | 180 | if __name__ == '__main__': 181 | convert_data_to_json('../../data/raw_data', save_data=True, save_dict=True) 182 | build_ent2query('../../data/mid_data') 183 | 184 | -------------------------------------------------------------------------------- /src/preprocess/processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import logging 5 | from transformers import BertTokenizer 6 | from collections import defaultdict 7 | import random 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | ENTITY_TYPES = ['DRUG', 'DRUG_INGREDIENT', 'DISEASE', 'SYMPTOM', 'SYNDROME', 'DISEASE_GROUP', 12 | 'FOOD', 'FOOD_GROUP', 'PERSON_GROUP', 'DRUG_GROUP', 'DRUG_DOSAGE', 'DRUG_TASTE', 13 | 'DRUG_EFFICACY'] 14 | 15 | 16 | class InputExample: 17 | def __init__(self, 18 | set_type, 19 | text, 20 | labels=None, 21 | pseudo=None, 22 | distant_labels=None): 23 | self.set_type = set_type 24 | self.text = text 25 | self.labels = labels 26 | self.pseudo = pseudo 27 | self.distant_labels = distant_labels 28 | 29 | 30 | class BaseFeature: 31 | def __init__(self, 32 | token_ids, 33 | attention_masks, 34 | token_type_ids): 35 | # BERT 输入 36 | self.token_ids = token_ids 37 | self.attention_masks = attention_masks 38 | self.token_type_ids = token_type_ids 39 | 40 | 41 | class CRFFeature(BaseFeature): 42 | def __init__(self, 43 | token_ids, 44 | attention_masks, 45 | token_type_ids, 46 | labels=None, 47 | pseudo=None, 48 | distant_labels=None): 49 | super(CRFFeature, self).__init__(token_ids=token_ids, 50 | attention_masks=attention_masks, 51 | token_type_ids=token_type_ids) 52 | # labels 53 | self.labels = labels 54 | 55 | # pseudo 56 | self.pseudo = pseudo 57 | 58 | # distant labels 59 | self.distant_labels = distant_labels 60 | 61 | 62 | class SpanFeature(BaseFeature): 63 | def __init__(self, 64 | token_ids, 65 | attention_masks, 66 | token_type_ids, 67 | start_ids=None, 68 | end_ids=None, 69 | pseudo=None): 70 | super(SpanFeature, self).__init__(token_ids=token_ids, 71 | attention_masks=attention_masks, 72 | token_type_ids=token_type_ids) 73 | self.start_ids = start_ids 74 | self.end_ids = end_ids 75 | # pseudo 76 | self.pseudo = pseudo 77 | 78 | class MRCFeature(BaseFeature): 79 | def __init__(self, 80 | token_ids, 81 | attention_masks, 82 | token_type_ids, 83 | ent_type=None, 84 | start_ids=None, 85 | end_ids=None, 86 | pseudo=None): 87 | super(MRCFeature, self).__init__(token_ids=token_ids, 88 | attention_masks=attention_masks, 89 | token_type_ids=token_type_ids) 90 | self.ent_type = ent_type 91 | self.start_ids = start_ids 92 | self.end_ids = end_ids 93 | 94 | # pseudo 95 | self.pseudo = pseudo 96 | 97 | 98 | class NERProcessor: 99 | def __init__(self, cut_sent_len=256): 100 | self.cut_sent_len = cut_sent_len 101 | 102 | @staticmethod 103 | def read_json(file_path): 104 | with open(file_path, encoding='utf-8') as f: 105 | raw_examples = json.load(f) 106 | return raw_examples 107 | 108 | @staticmethod 109 | def _refactor_labels(sent, labels, distant_labels, start_index): 110 | """ 111 | 分句后需要重构 labels 的 offset 112 | :param sent: 切分并重新合并后的句子 113 | :param labels: 原始文档级的 labels 114 | :param distant_labels: 远程监督 label 115 | :param start_index: 该句子在文档中的起始 offset 116 | :return (type, entity, offset) 117 | """ 118 | new_labels, new_distant_labels = [], [] 119 | end_index = start_index + len(sent) 120 | 121 | for _label in labels: 122 | if start_index <= _label[2] <= _label[3] <= end_index: 123 | new_offset = _label[2] - start_index 124 | 125 | assert sent[new_offset: new_offset + len(_label[-1])] == _label[-1] 126 | 127 | new_labels.append((_label[1], _label[-1], new_offset)) 128 | # label 被截断的情况 129 | elif _label[2] < end_index < _label[3]: 130 | raise RuntimeError(f'{sent}, {_label}') 131 | 132 | for _label in distant_labels: 133 | if _label in sent: 134 | new_distant_labels.append(_label) 135 | 136 | return new_labels, new_distant_labels 137 | 138 | def get_examples(self, raw_examples, set_type): 139 | examples = [] 140 | 141 | for i, item in enumerate(raw_examples): 142 | text = item['text'] 143 | distant_labels = item['candidate_entities'] 144 | pseudo = item['pseudo'] 145 | 146 | sentences = cut_sent(text, self.cut_sent_len) 147 | start_index = 0 148 | 149 | for sent in sentences: 150 | labels, tmp_distant_labels = self._refactor_labels(sent, item['labels'], distant_labels, start_index) 151 | 152 | start_index += len(sent) 153 | 154 | examples.append(InputExample(set_type=set_type, 155 | text=sent, 156 | labels=labels, 157 | pseudo=pseudo, 158 | distant_labels=tmp_distant_labels)) 159 | 160 | return examples 161 | 162 | 163 | def fine_grade_tokenize(raw_text, tokenizer): 164 | """ 165 | 序列标注任务 BERT 分词器可能会导致标注偏移, 166 | 用 char-level 来 tokenize 167 | """ 168 | tokens = [] 169 | 170 | for _ch in raw_text: 171 | if _ch in [' ', '\t', '\n']: 172 | tokens.append('[BLANK]') 173 | else: 174 | if not len(tokenizer.tokenize(_ch)): 175 | tokens.append('[INV]') 176 | else: 177 | tokens.append(_ch) 178 | 179 | return tokens 180 | 181 | 182 | def cut_sentences_v1(sent): 183 | """ 184 | the first rank of sentence cut 185 | """ 186 | sent = re.sub('([。!?\?])([^”’])', r"\1\n\2", sent) # 单字符断句符 187 | sent = re.sub('(\.{6})([^”’])', r"\1\n\2", sent) # 英文省略号 188 | sent = re.sub('(\…{2})([^”’])', r"\1\n\2", sent) # 中文省略号 189 | sent = re.sub('([。!?\?][”’])([^,。!?\?])', r"\1\n\2", sent) 190 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后 191 | return sent.split("\n") 192 | 193 | 194 | def cut_sentences_v2(sent): 195 | """ 196 | the second rank of spilt sentence, split ';' | ';' 197 | """ 198 | sent = re.sub('([;;])([^”’])', r"\1\n\2", sent) 199 | return sent.split("\n") 200 | 201 | 202 | def cut_sent(text, max_seq_len): 203 | # 将句子分句,细粒度分句后再重新合并 204 | sentences = [] 205 | 206 | # 细粒度划分 207 | sentences_v1 = cut_sentences_v1(text) 208 | for sent_v1 in sentences_v1: 209 | if len(sent_v1) > max_seq_len - 2: 210 | sentences_v2 = cut_sentences_v2(sent_v1) 211 | sentences.extend(sentences_v2) 212 | else: 213 | sentences.append(sent_v1) 214 | 215 | assert ''.join(sentences) == text 216 | 217 | # 合并 218 | merged_sentences = [] 219 | start_index_ = 0 220 | 221 | while start_index_ < len(sentences): 222 | tmp_text = sentences[start_index_] 223 | 224 | end_index_ = start_index_ + 1 225 | 226 | while end_index_ < len(sentences) and \ 227 | len(tmp_text) + len(sentences[end_index_]) <= max_seq_len - 2: 228 | tmp_text += sentences[end_index_] 229 | end_index_ += 1 230 | 231 | start_index_ = end_index_ 232 | 233 | merged_sentences.append(tmp_text) 234 | 235 | return merged_sentences 236 | 237 | def sent_mask(sent, stop_mask_range_list, mask_prob=0.15): 238 | """ 239 | 将句子中的词以 mask prob 的概率随机 mask, 240 | 其中 85% 概率被置为 [mask] 15% 的概率不变。 241 | :param sent: list of segment words 242 | :param stop_mask_range_list: 不能 mask 的区域 243 | :param mask_prob: max mask nums: len(sent) * max_mask_prob 244 | :return: 245 | """ 246 | max_mask_token_nums = int(len(sent) * mask_prob) 247 | mask_nums = 0 248 | mask_sent = [] 249 | 250 | for i in range(len(sent)): 251 | 252 | flag = False 253 | for _stop_range in stop_mask_range_list: 254 | if _stop_range[0] <= i <= _stop_range[1]: 255 | flag = True 256 | break 257 | 258 | if flag: 259 | mask_sent.append(sent[i]) 260 | continue 261 | 262 | if mask_nums < max_mask_token_nums: 263 | # mask_prob 的概率进行 mask, 80% 概率被置为 [mask],10% 概率被替换, 10% 的概率不变 264 | if random.random() < mask_prob: 265 | mask_sent.append('[MASK]') 266 | mask_nums += 1 267 | else: 268 | mask_sent.append(sent[i]) 269 | else: 270 | mask_sent.append(sent[i]) 271 | 272 | return mask_sent 273 | 274 | 275 | def convert_crf_example(ex_idx, example: InputExample, tokenizer: BertTokenizer, 276 | max_seq_len, ent2id): 277 | set_type = example.set_type 278 | raw_text = example.text 279 | entities = example.labels 280 | pseudo = example.pseudo 281 | 282 | callback_info = (raw_text,) 283 | callback_labels = {x: [] for x in ENTITY_TYPES} 284 | 285 | for _label in entities: 286 | callback_labels[_label[0]].append((_label[1], _label[2])) 287 | 288 | callback_info += (callback_labels,) 289 | 290 | tokens = fine_grade_tokenize(raw_text, tokenizer) 291 | assert len(tokens) == len(raw_text) 292 | 293 | label_ids = None 294 | 295 | if set_type == 'train': 296 | # information for dev callback 297 | label_ids = [0] * len(tokens) 298 | 299 | # tag labels ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸) 300 | for ent in entities: 301 | ent_type = ent[0] 302 | 303 | ent_start = ent[-1] 304 | ent_end = ent_start + len(ent[1]) - 1 305 | 306 | if ent_start == ent_end: 307 | label_ids[ent_start] = ent2id['S-' + ent_type] 308 | else: 309 | label_ids[ent_start] = ent2id['B-' + ent_type] 310 | label_ids[ent_end] = ent2id['E-' + ent_type] 311 | for i in range(ent_start + 1, ent_end): 312 | label_ids[i] = ent2id['I-' + ent_type] 313 | 314 | if len(label_ids) > max_seq_len - 2: 315 | label_ids = label_ids[:max_seq_len - 2] 316 | 317 | label_ids = [0] + label_ids + [0] 318 | 319 | # pad 320 | if len(label_ids) < max_seq_len: 321 | pad_length = max_seq_len - len(label_ids) 322 | label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O 323 | 324 | assert len(label_ids) == max_seq_len, f'{len(label_ids)}' 325 | 326 | encode_dict = tokenizer.encode_plus(text=tokens, 327 | max_length=max_seq_len, 328 | pad_to_max_length=True, 329 | is_pretokenized=True, 330 | return_token_type_ids=True, 331 | return_attention_mask=True) 332 | 333 | token_ids = encode_dict['input_ids'] 334 | attention_masks = encode_dict['attention_mask'] 335 | token_type_ids = encode_dict['token_type_ids'] 336 | 337 | # if ex_idx < 3: 338 | # logger.info(f"*** {set_type}_example-{ex_idx} ***") 339 | # logger.info(f'text: {" ".join(tokens)}') 340 | # logger.info(f"token_ids: {token_ids}") 341 | # logger.info(f"attention_masks: {attention_masks}") 342 | # logger.info(f"token_type_ids: {token_type_ids}") 343 | # logger.info(f"labels: {label_ids}") 344 | 345 | feature = CRFFeature( 346 | # bert inputs 347 | token_ids=token_ids, 348 | attention_masks=attention_masks, 349 | token_type_ids=token_type_ids, 350 | labels=label_ids, 351 | pseudo=pseudo 352 | ) 353 | 354 | return feature, callback_info 355 | 356 | 357 | def convert_span_example(ex_idx, example: InputExample, tokenizer: BertTokenizer, 358 | max_seq_len, ent2id): 359 | set_type = example.set_type 360 | raw_text = example.text 361 | entities = example.labels 362 | pseudo = example.pseudo 363 | 364 | tokens = fine_grade_tokenize(raw_text, tokenizer) 365 | assert len(tokens) == len(raw_text) 366 | 367 | callback_labels = {x: [] for x in ENTITY_TYPES} 368 | 369 | for _label in entities: 370 | callback_labels[_label[0]].append((_label[1], _label[2])) 371 | 372 | callback_info = (raw_text, callback_labels,) 373 | 374 | start_ids, end_ids = None, None 375 | 376 | if set_type == 'train': 377 | start_ids = [0] * len(tokens) 378 | end_ids = [0] * len(tokens) 379 | 380 | for _ent in entities: 381 | 382 | ent_type = ent2id[_ent[0]] 383 | ent_start = _ent[-1] 384 | ent_end = ent_start + len(_ent[1]) - 1 385 | 386 | start_ids[ent_start] = ent_type 387 | end_ids[ent_end] = ent_type 388 | 389 | if len(start_ids) > max_seq_len - 2: 390 | start_ids = start_ids[:max_seq_len - 2] 391 | end_ids = end_ids[:max_seq_len - 2] 392 | 393 | start_ids = [0] + start_ids + [0] 394 | end_ids = [0] + end_ids + [0] 395 | 396 | # pad 397 | if len(start_ids) < max_seq_len: 398 | pad_length = max_seq_len - len(start_ids) 399 | 400 | start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O 401 | end_ids = end_ids + [0] * pad_length 402 | 403 | assert len(start_ids) == max_seq_len 404 | assert len(end_ids) == max_seq_len 405 | 406 | encode_dict = tokenizer.encode_plus(text=tokens, 407 | max_length=max_seq_len, 408 | pad_to_max_length=True, 409 | is_pretokenized=True, 410 | return_token_type_ids=True, 411 | return_attention_mask=True) 412 | 413 | token_ids = encode_dict['input_ids'] 414 | attention_masks = encode_dict['attention_mask'] 415 | token_type_ids = encode_dict['token_type_ids'] 416 | 417 | # if ex_idx < 3: 418 | # logger.info(f"*** {set_type}_example-{ex_idx} ***") 419 | # logger.info(f'text: {" ".join(tokens)}') 420 | # logger.info(f"token_ids: {token_ids}") 421 | # logger.info(f"attention_masks: {attention_masks}") 422 | # logger.info(f"token_type_ids: {token_type_ids}") 423 | # if start_ids and end_ids: 424 | # logger.info(f"start_ids: {start_ids}") 425 | # logger.info(f"end_ids: {end_ids}") 426 | 427 | feature = SpanFeature(token_ids=token_ids, 428 | attention_masks=attention_masks, 429 | token_type_ids=token_type_ids, 430 | start_ids=start_ids, 431 | end_ids=end_ids, 432 | pseudo=pseudo) 433 | 434 | return feature, callback_info 435 | 436 | def convert_mrc_example(ex_idx, example: InputExample, tokenizer: BertTokenizer, 437 | max_seq_len, ent2id, ent2query, mask_prob=None): 438 | set_type = example.set_type 439 | text_b = example.text 440 | entities = example.labels 441 | pseudo = example.pseudo 442 | 443 | features = [] 444 | callback_info = [] 445 | 446 | tokens_b = fine_grade_tokenize(text_b, tokenizer) 447 | assert len(tokens_b) == len(text_b) 448 | 449 | label_dict = defaultdict(list) 450 | 451 | for ent in entities: 452 | ent_type = ent[0] 453 | ent_start = ent[-1] 454 | ent_end = ent_start + len(ent[1]) - 1 455 | label_dict[ent_type].append((ent_start, ent_end, ent[1])) 456 | 457 | # 训练数据中构造 458 | if set_type == 'train': 459 | 460 | # 每一类为一个 example 461 | # for _type in label_dict.keys(): 462 | for _type in ENTITY_TYPES: 463 | start_ids = [0] * len(tokens_b) 464 | end_ids = [0] * len(tokens_b) 465 | 466 | stop_mask_ranges = [] 467 | 468 | text_a = ent2query[_type] 469 | tokens_a = fine_grade_tokenize(text_a, tokenizer) 470 | 471 | for _label in label_dict[_type]: 472 | start_ids[_label[0]] = 1 473 | end_ids[_label[1]] = 1 474 | 475 | stop_mask_ranges.append((_label[0], _label[1])) 476 | 477 | if len(start_ids) > max_seq_len - len(tokens_a) - 3: 478 | start_ids = start_ids[:max_seq_len - len(tokens_a) - 3] 479 | end_ids = end_ids[:max_seq_len - len(tokens_a) - 3] 480 | print('产生了不该有的截断') 481 | 482 | start_ids = [0] + [0] * len(tokens_a) + [0] + start_ids + [0] 483 | end_ids = [0] + [0] * len(tokens_a) + [0] + end_ids + [0] 484 | 485 | # pad 486 | if len(start_ids) < max_seq_len: 487 | pad_length = max_seq_len - len(start_ids) 488 | 489 | start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O 490 | end_ids = end_ids + [0] * pad_length 491 | 492 | assert len(start_ids) == max_seq_len 493 | assert len(end_ids) == max_seq_len 494 | 495 | # 随机mask 496 | if mask_prob: 497 | tokens_b = sent_mask(tokens_b, stop_mask_ranges, mask_prob=mask_prob) 498 | 499 | encode_dict = tokenizer.encode_plus(text=tokens_a, 500 | text_pair=tokens_b, 501 | max_length=max_seq_len, 502 | pad_to_max_length=True, 503 | truncation_strategy='only_second', 504 | is_pretokenized=True, 505 | return_token_type_ids=True, 506 | return_attention_mask=True) 507 | 508 | token_ids = encode_dict['input_ids'] 509 | attention_masks = encode_dict['attention_mask'] 510 | token_type_ids = encode_dict['token_type_ids'] 511 | 512 | # if ex_idx < 3: 513 | # logger.info(f"*** {set_type}_example-{ex_idx} ***") 514 | # logger.info(f'text: {" ".join(tokens_b)}') 515 | # logger.info(f"token_ids: {token_ids}") 516 | # logger.info(f"attention_masks: {attention_masks}") 517 | # logger.info(f"token_type_ids: {token_type_ids}") 518 | # logger.info(f'entity type: {_type}') 519 | # logger.info(f"start_ids: {start_ids}") 520 | # logger.info(f"end_ids: {end_ids}") 521 | 522 | feature = MRCFeature(token_ids=token_ids, 523 | attention_masks=attention_masks, 524 | token_type_ids=token_type_ids, 525 | ent_type=ent2id[_type], 526 | start_ids=start_ids, 527 | end_ids=end_ids, 528 | pseudo=pseudo 529 | ) 530 | 531 | features.append(feature) 532 | 533 | # 测试数据构造,为每一类单独构造一个 example 534 | else: 535 | for _type in ENTITY_TYPES: 536 | text_a = ent2query[_type] 537 | tokens_a = fine_grade_tokenize(text_a, tokenizer) 538 | 539 | encode_dict = tokenizer.encode_plus(text=tokens_a, 540 | text_pair=tokens_b, 541 | max_length=max_seq_len, 542 | pad_to_max_length=True, 543 | truncation_strategy='only_second', 544 | is_pretokenized=True, 545 | return_token_type_ids=True, 546 | return_attention_mask=True) 547 | 548 | token_ids = encode_dict['input_ids'] 549 | attention_masks = encode_dict['attention_mask'] 550 | token_type_ids = encode_dict['token_type_ids'] 551 | 552 | tmp_callback = (text_b, len(tokens_a) + 2, _type) # (text, text_offset, type, labels) 553 | tmp_callback_labels = [] 554 | 555 | for _label in label_dict[_type]: 556 | tmp_callback_labels.append((_label[2], _label[0])) 557 | 558 | tmp_callback += (tmp_callback_labels, ) 559 | 560 | callback_info.append(tmp_callback) 561 | 562 | feature = MRCFeature(token_ids=token_ids, 563 | attention_masks=attention_masks, 564 | token_type_ids=token_type_ids, 565 | ent_type=ent2id[_type]) 566 | 567 | features.append(feature) 568 | 569 | return features, callback_info 570 | 571 | 572 | 573 | def convert_examples_to_features(task_type, examples, max_seq_len, bert_dir, ent2id): 574 | assert task_type in ['crf', 'span', 'mrc'] 575 | 576 | tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt')) 577 | 578 | features = [] 579 | 580 | callback_info = [] 581 | 582 | logger.info(f'Convert {len(examples)} examples to features') 583 | type2id = {x: i for i, x in enumerate(ENTITY_TYPES)} 584 | 585 | for i, example in enumerate(examples): 586 | if task_type == 'crf': 587 | feature, tmp_callback = convert_crf_example( 588 | ex_idx=i, 589 | example=example, 590 | max_seq_len=max_seq_len, 591 | ent2id=ent2id, 592 | tokenizer=tokenizer 593 | ) 594 | elif task_type == 'mrc': 595 | feature, tmp_callback = convert_mrc_example( 596 | ex_idx=i, 597 | example=example, 598 | max_seq_len=max_seq_len, 599 | ent2id=type2id, 600 | ent2query=ent2id, 601 | tokenizer=tokenizer 602 | ) 603 | else: 604 | feature, tmp_callback = convert_span_example( 605 | ex_idx=i, 606 | example=example, 607 | max_seq_len=max_seq_len, 608 | ent2id=ent2id, 609 | tokenizer=tokenizer 610 | ) 611 | 612 | if feature is None: 613 | continue 614 | 615 | if task_type == 'mrc': 616 | features.extend(feature) 617 | callback_info.extend(tmp_callback) 618 | else: 619 | features.append(feature) 620 | callback_info.append(tmp_callback) 621 | 622 | logger.info(f'Build {len(features)} features') 623 | 624 | out = (features, ) 625 | 626 | if not len(callback_info): 627 | return out 628 | 629 | type_weight = {} # 统计每一类的比例,用于计算 micro-f1 630 | for _type in ENTITY_TYPES: 631 | type_weight[_type] = 0. 632 | 633 | count = 0. 634 | 635 | if task_type == 'mrc': 636 | for _callback in callback_info: 637 | type_weight[_callback[-2]] += len(_callback[-1]) 638 | count += len(_callback[-1]) 639 | else: 640 | for _callback in callback_info: 641 | for _type in _callback[1]: 642 | type_weight[_type] += len(_callback[1][_type]) 643 | count += len(_callback[1][_type]) 644 | 645 | for key in type_weight: 646 | type_weight[key] /= count 647 | 648 | out += ((callback_info, type_weight), ) 649 | 650 | return out 651 | 652 | 653 | if __name__ == '__main__': 654 | pass 655 | -------------------------------------------------------------------------------- /src/utils/__pycache__/attack_train_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/attack_train_utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/dataset_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/dataset_utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/evaluator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/evaluator.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/functions_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/functions_utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/model_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/model_utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /src/utils/attack_train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # FGM 6 | class FGM: 7 | def __init__(self, model: nn.Module, eps=1.): 8 | self.model = ( 9 | model.module if hasattr(model, "module") else model 10 | ) 11 | self.eps = eps 12 | self.backup = {} 13 | 14 | # only attack word embedding 15 | def attack(self, emb_name='word_embeddings'): 16 | for name, param in self.model.named_parameters(): 17 | if param.requires_grad and emb_name in name: 18 | self.backup[name] = param.data.clone() 19 | norm = torch.norm(param.grad) 20 | if norm and not torch.isnan(norm): 21 | r_at = self.eps * param.grad / norm 22 | param.data.add_(r_at) 23 | 24 | def restore(self, emb_name='word_embeddings'): 25 | for name, para in self.model.named_parameters(): 26 | if para.requires_grad and emb_name in name: 27 | assert name in self.backup 28 | para.data = self.backup[name] 29 | 30 | self.backup = {} 31 | 32 | 33 | # PGD 34 | class PGD: 35 | def __init__(self, model, eps=1., alpha=0.3): 36 | self.model = ( 37 | model.module if hasattr(model, "module") else model 38 | ) 39 | self.eps = eps 40 | self.alpha = alpha 41 | self.emb_backup = {} 42 | self.grad_backup = {} 43 | 44 | def attack(self, emb_name='word_embeddings', is_first_attack=False): 45 | for name, param in self.model.named_parameters(): 46 | if param.requires_grad and emb_name in name: 47 | if is_first_attack: 48 | self.emb_backup[name] = param.data.clone() 49 | norm = torch.norm(param.grad) 50 | if norm != 0 and not torch.isnan(norm): 51 | r_at = self.alpha * param.grad / norm 52 | param.data.add_(r_at) 53 | param.data = self.project(name, param.data) 54 | 55 | def restore(self, emb_name='word_embeddings'): 56 | for name, param in self.model.named_parameters(): 57 | if param.requires_grad and emb_name in name: 58 | assert name in self.emb_backup 59 | param.data = self.emb_backup[name] 60 | self.emb_backup = {} 61 | 62 | def project(self, param_name, param_data): 63 | r = param_data - self.emb_backup[param_name] 64 | if torch.norm(r) > self.eps: 65 | r = self.eps * r / torch.norm(r) 66 | return self.emb_backup[param_name] + r 67 | 68 | def backup_grad(self): 69 | for name, param in self.model.named_parameters(): 70 | if param.requires_grad and param.grad is not None: 71 | self.grad_backup[name] = param.grad.clone() 72 | 73 | def restore_grad(self): 74 | for name, param in self.model.named_parameters(): 75 | if param.requires_grad and param.grad is not None: 76 | param.grad = self.grad_backup[name] 77 | -------------------------------------------------------------------------------- /src/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class NERDataset(Dataset): 6 | def __init__(self, task_type, features, mode, **kwargs): 7 | 8 | self.nums = len(features) 9 | 10 | self.token_ids = [torch.tensor(example.token_ids).long() for example in features] 11 | self.attention_masks = [torch.tensor(example.attention_masks).float() for example in features] 12 | self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in features] 13 | 14 | self.labels = None 15 | self.start_ids, self.end_ids = None, None 16 | self.ent_type = None 17 | self.pseudo = None 18 | if mode == 'train': 19 | self.pseudo = [torch.tensor(example.pseudo).long() for example in features] 20 | if task_type == 'crf': 21 | self.labels = [torch.tensor(example.labels) for example in features] 22 | else: 23 | self.start_ids = [torch.tensor(example.start_ids).long() for example in features] 24 | self.end_ids = [torch.tensor(example.end_ids).long() for example in features] 25 | 26 | if kwargs.pop('use_type_embed', False): 27 | self.ent_type = [torch.tensor(example.ent_type) for example in features] 28 | 29 | def __len__(self): 30 | return self.nums 31 | 32 | def __getitem__(self, index): 33 | data = {'token_ids': self.token_ids[index], 34 | 'attention_masks': self.attention_masks[index], 35 | 'token_type_ids': self.token_type_ids[index]} 36 | 37 | if self.ent_type is not None: 38 | data['ent_type'] = self.ent_type[index] 39 | 40 | if self.labels is not None: 41 | data['labels'] = self.labels[index] 42 | 43 | if self.pseudo is not None: 44 | data['pseudo'] = self.pseudo[index] 45 | 46 | if self.start_ids is not None: 47 | data['start_ids'] = self.start_ids[index] 48 | data['end_ids'] = self.end_ids[index] 49 | 50 | return data 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /src/utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from collections import defaultdict 5 | from src.preprocess.processor import ENTITY_TYPES 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def get_base_out(model, loader, device): 11 | """ 12 | 每一个任务的 forward 都一样,封装起来 13 | """ 14 | model.eval() 15 | 16 | with torch.no_grad(): 17 | for idx, _batch in enumerate(loader): 18 | 19 | for key in _batch.keys(): 20 | _batch[key] = _batch[key].to(device) 21 | 22 | tmp_out = model(**_batch) 23 | 24 | yield tmp_out 25 | 26 | 27 | def crf_decode(decode_tokens, raw_text, id2ent): 28 | """ 29 | CRF 解码,用于解码 time loc 的提取 30 | """ 31 | predict_entities = {} 32 | 33 | decode_tokens = decode_tokens[1:-1] # 除去 CLS SEP token 34 | 35 | index_ = 0 36 | 37 | while index_ < len(decode_tokens): 38 | 39 | token_label = id2ent[decode_tokens[index_]].split('-') 40 | 41 | if token_label[0].startswith('S'): 42 | token_type = token_label[1] 43 | tmp_ent = raw_text[index_] 44 | 45 | if token_type not in predict_entities: 46 | predict_entities[token_type] = [(tmp_ent, index_)] 47 | else: 48 | predict_entities[token_type].append((tmp_ent, int(index_))) 49 | 50 | index_ += 1 51 | 52 | elif token_label[0].startswith('B'): 53 | token_type = token_label[1] 54 | start_index = index_ 55 | 56 | index_ += 1 57 | while index_ < len(decode_tokens): 58 | temp_token_label = id2ent[decode_tokens[index_]].split('-') 59 | 60 | if temp_token_label[0].startswith('I') and token_type == temp_token_label[1]: 61 | index_ += 1 62 | elif temp_token_label[0].startswith('E') and token_type == temp_token_label[1]: 63 | end_index = index_ 64 | index_ += 1 65 | 66 | tmp_ent = raw_text[start_index: end_index + 1] 67 | 68 | if token_type not in predict_entities: 69 | predict_entities[token_type] = [(tmp_ent, start_index)] 70 | else: 71 | predict_entities[token_type].append((tmp_ent, int(start_index))) 72 | 73 | break 74 | else: 75 | break 76 | else: 77 | index_ += 1 78 | 79 | return predict_entities 80 | 81 | 82 | # 严格解码 baseline 83 | def span_decode(start_logits, end_logits, raw_text, id2ent): 84 | predict_entities = defaultdict(list) 85 | 86 | start_pred = np.argmax(start_logits, -1) 87 | end_pred = np.argmax(end_logits, -1) 88 | 89 | for i, s_type in enumerate(start_pred): 90 | if s_type == 0: 91 | continue 92 | for j, e_type in enumerate(end_pred[i:]): 93 | if s_type == e_type: 94 | tmp_ent = raw_text[i:i + j + 1] 95 | predict_entities[id2ent[s_type]].append((tmp_ent, i)) 96 | break 97 | 98 | return predict_entities 99 | 100 | # 严格解码 baseline 101 | def mrc_decode(start_logits, end_logits, raw_text): 102 | predict_entities = [] 103 | start_pred = np.argmax(start_logits, -1) 104 | end_pred = np.argmax(end_logits, -1) 105 | 106 | for i, s_type in enumerate(start_pred): 107 | if s_type == 0: 108 | continue 109 | for j, e_type in enumerate(end_pred[i:]): 110 | if s_type == e_type: 111 | tmp_ent = raw_text[i:i+j+1] 112 | predict_entities.append((tmp_ent, i)) 113 | break 114 | 115 | return predict_entities 116 | 117 | 118 | def calculate_metric(gt, predict): 119 | """ 120 | 计算 tp fp fn 121 | """ 122 | tp, fp, fn = 0, 0, 0 123 | for entity_predict in predict: 124 | flag = 0 125 | for entity_gt in gt: 126 | if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]: 127 | flag = 1 128 | tp += 1 129 | break 130 | if flag == 0: 131 | fp += 1 132 | 133 | fn = len(gt) - tp 134 | 135 | return np.array([tp, fp, fn]) 136 | 137 | 138 | def get_p_r_f(tp, fp, fn): 139 | p = tp / (tp + fp) if tp + fp != 0 else 0 140 | r = tp / (tp + fn) if tp + fn != 0 else 0 141 | f1 = 2 * p * r / (p + r) if p + r != 0 else 0 142 | return np.array([p, r, f1]) 143 | 144 | 145 | def crf_evaluation(model, dev_info, device, ent2id): 146 | dev_loader, (dev_callback_info, type_weight) = dev_info 147 | 148 | pred_tokens = [] 149 | 150 | for tmp_pred in get_base_out(model, dev_loader, device): 151 | pred_tokens.extend(tmp_pred[0]) 152 | 153 | assert len(pred_tokens) == len(dev_callback_info) 154 | 155 | id2ent = {ent2id[key]: key for key in ent2id.keys()} 156 | 157 | role_metric = np.zeros([13, 3]) 158 | 159 | mirco_metrics = np.zeros(3) 160 | 161 | for tmp_tokens, tmp_callback in zip(pred_tokens, dev_callback_info): 162 | 163 | text, gt_entities = tmp_callback 164 | 165 | tmp_metric = np.zeros([13, 3]) 166 | 167 | pred_entities = crf_decode(tmp_tokens, text, id2ent) 168 | 169 | for idx, _type in enumerate(ENTITY_TYPES): 170 | if _type not in pred_entities: 171 | pred_entities[_type] = [] 172 | 173 | tmp_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type]) 174 | 175 | role_metric += tmp_metric 176 | 177 | for idx, _type in enumerate(ENTITY_TYPES): 178 | temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2]) 179 | 180 | mirco_metrics += temp_metric * type_weight[_type] 181 | 182 | metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \ 183 | f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}' 184 | 185 | return metric_str, mirco_metrics[2] 186 | 187 | 188 | def span_evaluation(model, dev_info, device, ent2id): 189 | dev_loader, (dev_callback_info, type_weight) = dev_info 190 | 191 | start_logits, end_logits = None, None 192 | 193 | model.eval() 194 | 195 | for tmp_pred in get_base_out(model, dev_loader, device): 196 | tmp_start_logits = tmp_pred[0].cpu().numpy() 197 | tmp_end_logits = tmp_pred[1].cpu().numpy() 198 | 199 | if start_logits is None: 200 | start_logits = tmp_start_logits 201 | end_logits = tmp_end_logits 202 | else: 203 | start_logits = np.append(start_logits, tmp_start_logits, axis=0) 204 | end_logits = np.append(end_logits, tmp_end_logits, axis=0) 205 | 206 | assert len(start_logits) == len(end_logits) == len(dev_callback_info) 207 | 208 | role_metric = np.zeros([13, 3]) 209 | 210 | mirco_metrics = np.zeros(3) 211 | 212 | id2ent = {ent2id[key]: key for key in ent2id.keys()} 213 | 214 | for tmp_start_logits, tmp_end_logits, tmp_callback \ 215 | in zip(start_logits, end_logits, dev_callback_info): 216 | 217 | text, gt_entities = tmp_callback 218 | 219 | tmp_start_logits = tmp_start_logits[1:1 + len(text)] 220 | tmp_end_logits = tmp_end_logits[1:1 + len(text)] 221 | 222 | pred_entities = span_decode(tmp_start_logits, tmp_end_logits, text, id2ent) 223 | 224 | for idx, _type in enumerate(ENTITY_TYPES): 225 | if _type not in pred_entities: 226 | pred_entities[_type] = [] 227 | 228 | role_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type]) 229 | 230 | for idx, _type in enumerate(ENTITY_TYPES): 231 | temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2]) 232 | 233 | mirco_metrics += temp_metric * type_weight[_type] 234 | 235 | metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \ 236 | f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}' 237 | 238 | return metric_str, mirco_metrics[2] 239 | 240 | def mrc_evaluation(model, dev_info, device): 241 | dev_loader, (dev_callback_info, type_weight) = dev_info 242 | 243 | start_logits, end_logits = None, None 244 | 245 | model.eval() 246 | 247 | for tmp_pred in get_base_out(model, dev_loader, device): 248 | tmp_start_logits = tmp_pred[0].cpu().numpy() 249 | tmp_end_logits = tmp_pred[1].cpu().numpy() 250 | 251 | if start_logits is None: 252 | start_logits = tmp_start_logits 253 | end_logits = tmp_end_logits 254 | else: 255 | start_logits = np.append(start_logits, tmp_start_logits, axis=0) 256 | end_logits = np.append(end_logits, tmp_end_logits, axis=0) 257 | 258 | assert len(start_logits) == len(end_logits) == len(dev_callback_info) 259 | 260 | role_metric = np.zeros([13, 3]) 261 | 262 | mirco_metrics = np.zeros(3) 263 | 264 | id2ent = {x: i for i, x in enumerate(ENTITY_TYPES)} 265 | 266 | for tmp_start_logits, tmp_end_logits, tmp_callback \ 267 | in zip(start_logits, end_logits, dev_callback_info): 268 | 269 | text, text_offset, ent_type, gt_entities = tmp_callback 270 | 271 | tmp_start_logits = tmp_start_logits[text_offset:text_offset+len(text)] 272 | tmp_end_logits = tmp_end_logits[text_offset:text_offset+len(text)] 273 | 274 | pred_entities = mrc_decode(tmp_start_logits, tmp_end_logits, text) 275 | 276 | role_metric[id2ent[ent_type]] += calculate_metric(gt_entities, pred_entities) 277 | 278 | for idx, _type in enumerate(ENTITY_TYPES): 279 | temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2]) 280 | 281 | mirco_metrics += temp_metric * type_weight[_type] 282 | 283 | metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \ 284 | f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}' 285 | 286 | return metric_str, mirco_metrics[2] 287 | -------------------------------------------------------------------------------- /src/utils/functions_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import random 5 | import numpy as np 6 | from collections import defaultdict 7 | from datetime import timedelta 8 | import time 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def get_time_dif(start_time): 15 | """ 16 | 获取已经使用的时间 17 | :param start_time: 18 | :return: 19 | """ 20 | end_time = time.time() 21 | time_dif = end_time - start_time 22 | return timedelta(seconds=int(round(time_dif))) 23 | 24 | 25 | def set_seed(seed): 26 | """ 27 | 设置随机种子 28 | :param seed: 29 | :return: 30 | """ 31 | random.seed(seed) 32 | torch.manual_seed(seed) 33 | np.random.seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | 37 | def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True): 38 | """ 39 | 加载模型 & 放置到 GPU 中(单卡 / 多卡) 40 | """ 41 | gpu_ids = gpu_ids.split(',') 42 | 43 | # set to device to the first cuda 44 | device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0]) 45 | 46 | if ckpt_path is not None: 47 | logger.info(f'Load ckpt from {ckpt_path}') 48 | model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict) 49 | 50 | model.to(device) 51 | 52 | if len(gpu_ids) > 1: 53 | logger.info(f'Use multi gpus in: {gpu_ids}') 54 | gpu_ids = [int(x) for x in gpu_ids] 55 | model = torch.nn.DataParallel(model, device_ids=gpu_ids) 56 | else: 57 | logger.info(f'Use single gpu in: {gpu_ids}') 58 | 59 | return model, device 60 | 61 | 62 | def get_model_path_list(base_dir): 63 | """ 64 | 从文件夹中获取 model.pt 的路径 65 | """ 66 | model_lists = [] 67 | 68 | for root, dirs, files in os.walk(base_dir): 69 | for _file in files: 70 | if 'model.pt' == _file: 71 | model_lists.append(os.path.join(root, _file)) 72 | 73 | model_lists = sorted(model_lists, 74 | key=lambda x: (x.split('/')[-3], int(x.split('/')[-2].split('-')[-1]))) 75 | 76 | return model_lists 77 | 78 | 79 | def swa(model, model_dir, swa_start=1): 80 | """ 81 | swa 滑动平均模型,一般在训练平稳阶段再使用 SWA 82 | """ 83 | model_path_list = get_model_path_list(model_dir) 84 | 85 | assert 1 <= swa_start < len(model_path_list) - 1, \ 86 | f'Using swa, swa start should smaller than {len(model_path_list) - 1} and bigger than 0' 87 | 88 | swa_model = copy.deepcopy(model) 89 | swa_n = 0. 90 | 91 | with torch.no_grad(): 92 | for _ckpt in model_path_list[swa_start:]: 93 | logger.info(f'Load model from {_ckpt}') 94 | model.load_state_dict(torch.load(_ckpt, map_location=torch.device('cpu'))) 95 | tmp_para_dict = dict(model.named_parameters()) 96 | 97 | alpha = 1. / (swa_n + 1.) 98 | 99 | for name, para in swa_model.named_parameters(): 100 | para.copy_(tmp_para_dict[name].data.clone() * alpha + para.data.clone() * (1. - alpha)) 101 | 102 | swa_n += 1 103 | 104 | # use 100000 to represent swa to avoid clash 105 | swa_model_dir = os.path.join(model_dir, f'checkpoint-100000') 106 | if not os.path.exists(swa_model_dir): 107 | os.mkdir(swa_model_dir) 108 | 109 | logger.info(f'Save swa model in: {swa_model_dir}') 110 | 111 | swa_model_path = os.path.join(swa_model_dir, 'model.pt') 112 | 113 | torch.save(swa_model.state_dict(), swa_model_path) 114 | 115 | return swa_model 116 | 117 | 118 | def vote(entities_list, threshold=0.9): 119 | """ 120 | 实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text) 121 | :param entities_list: 所有模型预测出的一个文件的实体 122 | :param threshold:大于70%模型预测出来的实体才能被选中 123 | :return: 124 | """ 125 | threshold_nums = int(len(entities_list)*threshold) 126 | entities_dict = defaultdict(int) 127 | entities = defaultdict(list) 128 | 129 | for _entities in entities_list: 130 | for _type in _entities: 131 | for _ent in _entities[_type]: 132 | entities_dict[(_type, _ent[0], _ent[1])] += 1 133 | 134 | for key in entities_dict: 135 | if entities_dict[key] >= threshold_nums: 136 | entities[key[0]].append((key[1], key[2])) 137 | 138 | return entities 139 | 140 | def ensemble_vote(entities_list, threshold=0.9): 141 | """ 142 | 针对 ensemble model 进行的 vote 143 | 实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text) 144 | """ 145 | threshold_nums = int(len(entities_list)*threshold) 146 | entities_dict = defaultdict(int) 147 | 148 | entities = defaultdict(list) 149 | 150 | for _entities in entities_list: 151 | for _id in _entities: 152 | for _ent in _entities[_id]: 153 | entities_dict[(_id, ) + _ent] += 1 154 | 155 | for key in entities_dict: 156 | if entities_dict[key] >= threshold_nums: 157 | entities[key[0]].append(key[1:]) 158 | 159 | return entities 160 | -------------------------------------------------------------------------------- /src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torchcrf import CRF 6 | from itertools import repeat 7 | from transformers import BertModel 8 | from src.utils.functions_utils import vote 9 | from src.utils.evaluator import crf_decode, span_decode 10 | 11 | 12 | class LabelSmoothingCrossEntropy(nn.Module): 13 | def __init__(self, eps=0.1, reduction='mean', ignore_index=-100): 14 | super(LabelSmoothingCrossEntropy, self).__init__() 15 | self.eps = eps 16 | self.reduction = reduction 17 | self.ignore_index = ignore_index 18 | 19 | def forward(self, output, target): 20 | c = output.size()[-1] 21 | log_pred = torch.log_softmax(output, dim=-1) 22 | if self.reduction == 'sum': 23 | loss = -log_pred.sum() 24 | else: 25 | loss = -log_pred.sum(dim=-1) 26 | if self.reduction == 'mean': 27 | loss = loss.mean() 28 | 29 | 30 | return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target, 31 | reduction=self.reduction, 32 | ignore_index=self.ignore_index) 33 | 34 | class FocalLoss(nn.Module): 35 | """Multi-class Focal loss implementation""" 36 | def __init__(self, gamma=2, weight=None, reduction='mean', ignore_index=-100): 37 | super(FocalLoss, self).__init__() 38 | self.gamma = gamma 39 | self.weight = weight 40 | self.ignore_index = ignore_index 41 | self.reduction = reduction 42 | 43 | def forward(self, input, target): 44 | """ 45 | input: [N, C] 46 | target: [N, ] 47 | """ 48 | log_pt = torch.log_softmax(input, dim=1) 49 | pt = torch.exp(log_pt) 50 | log_pt = (1 - pt) ** self.gamma * log_pt 51 | loss = torch.nn.functional.nll_loss(log_pt, target, self.weight, reduction=self.reduction, ignore_index=self.ignore_index) 52 | return loss 53 | 54 | class SpatialDropout(nn.Module): 55 | """ 56 | 对字级别的向量进行丢弃 57 | """ 58 | def __init__(self, drop_prob): 59 | super(SpatialDropout, self).__init__() 60 | self.drop_prob = drop_prob 61 | 62 | @staticmethod 63 | def _make_noise(input): 64 | return input.new().resize_(input.size(0), *repeat(1, input.dim() - 2), input.size(2)) 65 | 66 | def forward(self, inputs): 67 | output = inputs.clone() 68 | if not self.training or self.drop_prob == 0: 69 | return inputs 70 | else: 71 | noise = self._make_noise(inputs) 72 | if self.drop_prob == 1: 73 | noise.fill_(0) 74 | else: 75 | noise.bernoulli_(1 - self.drop_prob).div_(1 - self.drop_prob) 76 | noise = noise.expand_as(inputs) 77 | output.mul_(noise) 78 | return output 79 | 80 | class ConditionalLayerNorm(nn.Module): 81 | def __init__(self, 82 | normalized_shape, 83 | cond_shape, 84 | eps=1e-12): 85 | super().__init__() 86 | 87 | self.eps = eps 88 | 89 | self.weight = nn.Parameter(torch.Tensor(normalized_shape)) 90 | self.bias = nn.Parameter(torch.Tensor(normalized_shape)) 91 | 92 | self.weight_dense = nn.Linear(cond_shape, normalized_shape, bias=False) 93 | self.bias_dense = nn.Linear(cond_shape, normalized_shape, bias=False) 94 | 95 | self.reset_weight_and_bias() 96 | 97 | def reset_weight_and_bias(self): 98 | """ 99 | 此处初始化的作用是在训练开始阶段不让 conditional layer norm 起作用 100 | """ 101 | nn.init.ones_(self.weight) 102 | nn.init.zeros_(self.bias) 103 | 104 | nn.init.zeros_(self.weight_dense.weight) 105 | nn.init.zeros_(self.bias_dense.weight) 106 | 107 | def forward(self, inputs, cond=None): 108 | assert cond is not None, 'Conditional tensor need to input when use conditional layer norm' 109 | cond = torch.unsqueeze(cond, 1) # (b, 1, h*2) 110 | 111 | weight = self.weight_dense(cond) + self.weight # (b, 1, h) 112 | bias = self.bias_dense(cond) + self.bias # (b, 1, h) 113 | 114 | mean = torch.mean(inputs, dim=-1, keepdim=True) # (b, s, 1) 115 | outputs = inputs - mean # (b, s, h) 116 | 117 | variance = torch.mean(outputs ** 2, dim=-1, keepdim=True) 118 | std = torch.sqrt(variance + self.eps) # (b, s, 1) 119 | 120 | outputs = outputs / std # (b, s, h) 121 | 122 | outputs = outputs * weight + bias 123 | 124 | return outputs 125 | 126 | class BaseModel(nn.Module): 127 | def __init__(self, 128 | bert_dir, 129 | dropout_prob): 130 | super(BaseModel, self).__init__() 131 | config_path = os.path.join(bert_dir, 'config.json') 132 | 133 | assert os.path.exists(bert_dir) and os.path.exists(config_path), \ 134 | 'pretrained bert file does not exist' 135 | 136 | self.bert_module = BertModel.from_pretrained(bert_dir, 137 | output_hidden_states=True, 138 | hidden_dropout_prob=dropout_prob) 139 | 140 | self.bert_config = self.bert_module.config 141 | 142 | @staticmethod 143 | def _init_weights(blocks, **kwargs): 144 | """ 145 | 参数初始化,将 Linear / Embedding / LayerNorm 与 Bert 进行一样的初始化 146 | """ 147 | for block in blocks: 148 | for module in block.modules(): 149 | if isinstance(module, nn.Linear): 150 | if module.bias is not None: 151 | nn.init.zeros_(module.bias) 152 | elif isinstance(module, nn.Embedding): 153 | nn.init.normal_(module.weight, mean=0, std=kwargs.pop('initializer_range', 0.02)) 154 | elif isinstance(module, nn.LayerNorm): 155 | nn.init.ones_(module.weight) 156 | nn.init.zeros_(module.bias) 157 | 158 | 159 | # baseline 160 | class CRFModel(BaseModel): 161 | def __init__(self, 162 | bert_dir, 163 | num_tags, 164 | dropout_prob=0.1, 165 | **kwargs): 166 | super(CRFModel, self).__init__(bert_dir=bert_dir, dropout_prob=dropout_prob) 167 | 168 | out_dims = self.bert_config.hidden_size 169 | 170 | mid_linear_dims = kwargs.pop('mid_linear_dims', 128) 171 | 172 | self.mid_linear = nn.Sequential( 173 | nn.Linear(out_dims, mid_linear_dims), 174 | nn.ReLU(), 175 | nn.Dropout(dropout_prob) 176 | ) 177 | 178 | out_dims = mid_linear_dims 179 | 180 | self.classifier = nn.Linear(out_dims, num_tags) 181 | 182 | self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True) 183 | self.loss_weight.data.fill_(-0.2) 184 | 185 | self.crf_module = CRF(num_tags=num_tags, batch_first=True) 186 | 187 | init_blocks = [self.mid_linear, self.classifier] 188 | 189 | self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range) 190 | 191 | def forward(self, 192 | token_ids, 193 | attention_masks, 194 | token_type_ids, 195 | labels=None, 196 | pseudo=None): 197 | 198 | bert_outputs = self.bert_module( 199 | input_ids=token_ids, 200 | attention_mask=attention_masks, 201 | token_type_ids=token_type_ids 202 | ) 203 | 204 | # 常规 205 | seq_out = bert_outputs[0] 206 | 207 | seq_out = self.mid_linear(seq_out) 208 | 209 | emissions = self.classifier(seq_out) 210 | 211 | if labels is not None: 212 | if pseudo is not None: 213 | # (batch,) 214 | tokens_loss = -1. * self.crf_module(emissions=emissions, 215 | tags=labels.long(), 216 | mask=attention_masks.byte(), 217 | reduction='none') 218 | 219 | # nums of pseudo data 220 | pseudo_nums = pseudo.sum().item() 221 | total_nums = token_ids.shape[0] 222 | 223 | # learning parameter 224 | rate = torch.sigmoid(self.loss_weight) 225 | if pseudo_nums == 0: 226 | loss_0 = tokens_loss.mean() 227 | loss_1 = (rate*pseudo*tokens_loss).sum() 228 | else: 229 | if total_nums == pseudo_nums: 230 | loss_0 = 0 231 | else: 232 | loss_0 = ((1 - rate) * (1 - pseudo) * tokens_loss).sum() / (total_nums - pseudo_nums) 233 | loss_1 = (rate*pseudo*tokens_loss).sum() / pseudo_nums 234 | 235 | tokens_loss = loss_0 + loss_1 236 | 237 | else: 238 | tokens_loss = -1. * self.crf_module(emissions=emissions, 239 | tags=labels.long(), 240 | mask=attention_masks.byte(), 241 | reduction='mean') 242 | 243 | out = (tokens_loss,) 244 | 245 | else: 246 | tokens_out = self.crf_module.decode(emissions=emissions, mask=attention_masks.byte()) 247 | 248 | out = (tokens_out, emissions) 249 | 250 | return out 251 | 252 | 253 | class SpanModel(BaseModel): 254 | def __init__(self, 255 | bert_dir, 256 | num_tags, 257 | dropout_prob=0.1, 258 | loss_type='ce', 259 | **kwargs): 260 | """ 261 | tag the subject and object corresponding to the predicate 262 | :param loss_type: train loss type in ['ce', 'ls_ce', 'focal'] 263 | """ 264 | super(SpanModel, self).__init__(bert_dir, dropout_prob=dropout_prob) 265 | 266 | out_dims = self.bert_config.hidden_size 267 | 268 | mid_linear_dims = kwargs.pop('mid_linear_dims', 128) 269 | 270 | self.num_tags = num_tags 271 | 272 | self.mid_linear = nn.Sequential( 273 | nn.Linear(out_dims, mid_linear_dims), 274 | nn.ReLU(), 275 | nn.Dropout(dropout_prob) 276 | ) 277 | 278 | out_dims = mid_linear_dims 279 | 280 | self.start_fc = nn.Linear(out_dims, num_tags) 281 | self.end_fc = nn.Linear(out_dims, num_tags) 282 | 283 | reduction = 'none' 284 | if loss_type == 'ce': 285 | self.criterion = nn.CrossEntropyLoss(reduction=reduction) 286 | elif loss_type == 'ls_ce': 287 | self.criterion = LabelSmoothingCrossEntropy(reduction=reduction) 288 | else: 289 | self.criterion = FocalLoss(reduction=reduction) 290 | 291 | self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True) 292 | self.loss_weight.data.fill_(-0.2) 293 | 294 | init_blocks = [self.mid_linear, self.start_fc, self.end_fc] 295 | 296 | self._init_weights(init_blocks) 297 | 298 | def forward(self, 299 | token_ids, 300 | attention_masks, 301 | token_type_ids, 302 | start_ids=None, 303 | end_ids=None, 304 | pseudo=None): 305 | 306 | bert_outputs = self.bert_module( 307 | input_ids=token_ids, 308 | attention_mask=attention_masks, 309 | token_type_ids=token_type_ids 310 | ) 311 | 312 | seq_out = bert_outputs[0] 313 | 314 | seq_out = self.mid_linear(seq_out) 315 | 316 | start_logits = self.start_fc(seq_out) 317 | end_logits = self.end_fc(seq_out) 318 | 319 | out = (start_logits, end_logits, ) 320 | 321 | if start_ids is not None and end_ids is not None and self.training: 322 | 323 | start_logits = start_logits.view(-1, self.num_tags) 324 | end_logits = end_logits.view(-1, self.num_tags) 325 | 326 | # 去掉 padding 部分的标签,计算真实 loss 327 | active_loss = attention_masks.view(-1) == 1 328 | active_start_logits = start_logits[active_loss] 329 | active_end_logits = end_logits[active_loss] 330 | 331 | active_start_labels = start_ids.view(-1)[active_loss] 332 | active_end_labels = end_ids.view(-1)[active_loss] 333 | 334 | 335 | if pseudo is not None: 336 | # (batch,) 337 | start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1) 338 | end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1) 339 | 340 | # nums of pseudo data 341 | pseudo_nums = pseudo.sum().item() 342 | total_nums = token_ids.shape[0] 343 | 344 | # learning parameter 345 | rate = torch.sigmoid(self.loss_weight) 346 | if pseudo_nums == 0: 347 | start_loss = start_loss.mean() 348 | end_loss = end_loss.mean() 349 | else: 350 | if total_nums == pseudo_nums: 351 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums 352 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums 353 | else: 354 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \ 355 | + ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums) 356 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \ 357 | + ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums) 358 | else: 359 | start_loss = self.criterion(active_start_logits, active_start_labels) 360 | end_loss = self.criterion(active_end_logits, active_end_labels) 361 | 362 | loss = start_loss + end_loss 363 | 364 | out = (loss, ) + out 365 | 366 | return out 367 | 368 | class MRCModel(BaseModel): 369 | def __init__(self, 370 | bert_dir, 371 | dropout_prob=0.1, 372 | use_type_embed=False, 373 | loss_type='ce', 374 | **kwargs): 375 | """ 376 | tag the subject and object corresponding to the predicate 377 | :param use_type_embed: type embedding for the sentence 378 | :param loss_type: train loss type in ['ce', 'ls_ce', 'focal'] 379 | """ 380 | super(MRCModel, self).__init__(bert_dir, dropout_prob=dropout_prob) 381 | 382 | self.use_type_embed = use_type_embed 383 | self.use_smooth = loss_type 384 | 385 | out_dims = self.bert_config.hidden_size 386 | 387 | if self.use_type_embed: 388 | embed_dims = kwargs.pop('predicate_embed_dims', self.bert_config.hidden_size) 389 | self.type_embedding = nn.Embedding(13, embed_dims) 390 | 391 | self.conditional_layer_norm = ConditionalLayerNorm(out_dims, embed_dims, 392 | eps=self.bert_config.layer_norm_eps) 393 | 394 | mid_linear_dims = kwargs.pop('mid_linear_dims', 128) 395 | 396 | self.mid_linear = nn.Sequential( 397 | nn.Linear(out_dims, mid_linear_dims), 398 | nn.ReLU(), 399 | nn.Dropout(dropout_prob) 400 | ) 401 | 402 | out_dims = mid_linear_dims 403 | 404 | self.start_fc = nn.Linear(out_dims, 2) 405 | self.end_fc = nn.Linear(out_dims, 2) 406 | 407 | reduction = 'none' 408 | if loss_type == 'ce': 409 | self.criterion = nn.CrossEntropyLoss(reduction=reduction) 410 | elif loss_type == 'ls_ce': 411 | self.criterion = LabelSmoothingCrossEntropy(reduction=reduction) 412 | else: 413 | self.criterion = FocalLoss(reduction=reduction) 414 | 415 | self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True) 416 | self.loss_weight.data.fill_(-0.2) 417 | 418 | init_blocks = [self.mid_linear, self.start_fc, self.end_fc] 419 | 420 | if self.use_type_embed: 421 | init_blocks.append(self.type_embedding) 422 | 423 | self._init_weights(init_blocks) 424 | 425 | def forward(self, 426 | token_ids, 427 | attention_masks, 428 | token_type_ids, 429 | ent_type=None, 430 | start_ids=None, 431 | end_ids=None, 432 | pseudo=None): 433 | 434 | bert_outputs = self.bert_module( 435 | input_ids=token_ids, 436 | attention_mask=attention_masks, 437 | token_type_ids=token_type_ids 438 | ) 439 | 440 | seq_out = bert_outputs[0] 441 | 442 | if self.use_type_embed: 443 | assert ent_type is not None, \ 444 | 'Using predicate embedding, predicate should be implemented' 445 | 446 | predicate_feature = self.type_embedding(ent_type) 447 | seq_out = self.conditional_layer_norm(seq_out, predicate_feature) 448 | 449 | seq_out = self.mid_linear(seq_out) 450 | 451 | start_logits = self.start_fc(seq_out) 452 | end_logits = self.end_fc(seq_out) 453 | 454 | out = (start_logits, end_logits, ) 455 | 456 | if start_ids is not None and end_ids is not None: 457 | start_logits = start_logits.view(-1, 2) 458 | end_logits = end_logits.view(-1, 2) 459 | 460 | # 去掉 text_a 和 padding 部分的标签,计算真实 loss 461 | active_loss = token_type_ids.view(-1) == 1 462 | active_start_logits = start_logits[active_loss] 463 | active_end_logits = end_logits[active_loss] 464 | 465 | active_start_labels = start_ids.view(-1)[active_loss] 466 | active_end_labels = end_ids.view(-1)[active_loss] 467 | 468 | if pseudo is not None: 469 | # (batch,) 470 | start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1) 471 | end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1) 472 | 473 | # nums of pseudo data 474 | pseudo_nums = pseudo.sum().item() 475 | total_nums = token_ids.shape[0] 476 | 477 | # learning parameter 478 | rate = torch.sigmoid(self.loss_weight) 479 | if pseudo_nums == 0: 480 | start_loss = start_loss.mean() 481 | end_loss = end_loss.mean() 482 | else: 483 | if total_nums == pseudo_nums: 484 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums 485 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums 486 | else: 487 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \ 488 | + ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums) 489 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \ 490 | + ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums) 491 | else: 492 | start_loss = self.criterion(active_start_logits, active_start_labels) 493 | end_loss = self.criterion(active_end_logits, active_end_labels) 494 | 495 | loss = start_loss + end_loss 496 | 497 | out = (loss, ) + out 498 | 499 | return out 500 | 501 | 502 | class EnsembleCRFModel: 503 | def __init__(self, model_path_list, bert_dir_list, num_tags, device, lamb=1/3): 504 | 505 | self.models = [] 506 | self.crf_module = CRF(num_tags=num_tags, batch_first=True) 507 | self.lamb = lamb 508 | 509 | for idx, _path in enumerate(model_path_list): 510 | print(f'Load model from {_path}') 511 | 512 | 513 | print(f'Load model type: {bert_dir_list[0]}') 514 | model = CRFModel(bert_dir=bert_dir_list[0], num_tags=num_tags) 515 | 516 | 517 | model.load_state_dict(torch.load(_path, map_location=torch.device('cpu'))) 518 | 519 | model.eval() 520 | model.to(device) 521 | 522 | self.models.append(model) 523 | if idx == 0: 524 | print(f'Load CRF weight from {_path}') 525 | self.crf_module.load_state_dict(model.crf_module.state_dict()) 526 | self.crf_module.to(device) 527 | 528 | def weight(self, t): 529 | """ 530 | 牛顿冷却定律加权融合 531 | """ 532 | return math.exp(-self.lamb*t) 533 | 534 | def predict(self, model_inputs): 535 | weight_sum = 0. 536 | logits = None 537 | attention_masks = model_inputs['attention_masks'] 538 | 539 | for idx, model in enumerate(self.models): 540 | # 使用牛顿冷却概率融合 541 | weight = self.weight(idx) 542 | 543 | # 使用概率平均融合 544 | # weight = 1 / len(self.models) 545 | 546 | tmp_logits = model(**model_inputs)[1] * weight 547 | weight_sum += weight 548 | 549 | if logits is None: 550 | logits = tmp_logits 551 | else: 552 | logits += tmp_logits 553 | 554 | logits = logits / weight_sum 555 | 556 | tokens_out = self.crf_module.decode(emissions=logits, mask=attention_masks.byte()) 557 | 558 | return tokens_out 559 | 560 | def vote_entities(self, model_inputs, sent, id2ent, threshold): 561 | entities_ls = [] 562 | for idx, model in enumerate(self.models): 563 | tmp_tokens = model(**model_inputs)[0][0] 564 | tmp_entities = crf_decode(tmp_tokens, sent, id2ent) 565 | entities_ls.append(tmp_entities) 566 | 567 | return vote(entities_ls, threshold) 568 | 569 | 570 | class EnsembleSpanModel: 571 | def __init__(self, model_path_list, bert_dir_list, num_tags, device): 572 | 573 | self.models = [] 574 | 575 | for idx, _path in enumerate(model_path_list): 576 | print(f'Load model from {_path}') 577 | 578 | print(f'Load model type: {bert_dir_list[0]}') 579 | model = SpanModel(bert_dir=bert_dir_list[0], num_tags=num_tags) 580 | 581 | model.load_state_dict(torch.load(_path, map_location=torch.device('cpu'))) 582 | 583 | model.eval() 584 | model.to(device) 585 | 586 | self.models.append(model) 587 | 588 | def predict(self, model_inputs): 589 | start_logits, end_logits = None, None 590 | 591 | for idx, model in enumerate(self.models): 592 | 593 | # 使用概率平均融合 594 | weight = 1 / len(self.models) 595 | 596 | tmp_start_logits, tmp_end_logits = model(**model_inputs) 597 | 598 | tmp_start_logits = tmp_start_logits * weight 599 | tmp_end_logits = tmp_end_logits * weight 600 | 601 | if start_logits is None: 602 | start_logits = tmp_start_logits 603 | end_logits = tmp_end_logits 604 | else: 605 | start_logits += tmp_start_logits 606 | end_logits += tmp_end_logits 607 | 608 | return start_logits, end_logits 609 | 610 | def vote_entities(self, model_inputs, sent, id2ent, threshold): 611 | entities_ls = [] 612 | 613 | for idx, model in enumerate(self.models): 614 | 615 | start_logits, end_logits = model(**model_inputs) 616 | start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)] 617 | end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)] 618 | 619 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent) 620 | 621 | entities_ls.append(decode_entities) 622 | 623 | return vote(entities_ls, threshold) 624 | 625 | 626 | def build_model(task_type, bert_dir, **kwargs): 627 | assert task_type in ['crf', 'span', 'mrc'] 628 | 629 | if task_type == 'crf': 630 | model = CRFModel(bert_dir=bert_dir, 631 | num_tags=kwargs.pop('num_tags'), 632 | dropout_prob=kwargs.pop('dropout_prob', 0.1)) 633 | 634 | elif task_type == 'mrc': 635 | model = MRCModel(bert_dir=bert_dir, 636 | dropout_prob=kwargs.pop('dropout_prob', 0.1), 637 | use_type_embed=kwargs.pop('use_type_embed'), 638 | loss_type=kwargs.pop('loss_type', 'ce')) 639 | 640 | else: 641 | model = SpanModel(bert_dir=bert_dir, 642 | num_tags=kwargs.pop('num_tags'), 643 | dropout_prob=kwargs.pop('dropout_prob', 0.1), 644 | loss_type=kwargs.pop('loss_type', 'ce')) 645 | 646 | return model 647 | -------------------------------------------------------------------------------- /src/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class Args: 4 | @staticmethod 5 | def parse(): 6 | parser = argparse.ArgumentParser() 7 | return parser 8 | 9 | @staticmethod 10 | def initialize(parser: argparse.ArgumentParser): 11 | # args for path 12 | parser.add_argument('--raw_data_dir', default='./data/raw_data', 13 | help='the data dir of raw data') 14 | 15 | parser.add_argument('--mid_data_dir', default='./data/mid_data', 16 | help='the mid data dir') 17 | 18 | parser.add_argument('--output_dir', default='./out/', 19 | help='the output dir for model checkpoints') 20 | 21 | parser.add_argument('--bert_dir', default='../bert/torch_roberta_wwm', 22 | help='bert dir for ernie / roberta-wwm / uer') 23 | 24 | parser.add_argument('--bert_type', default='roberta_wwm', 25 | help='roberta_wwm / ernie_1 / uer_large') 26 | 27 | parser.add_argument('--task_type', default='crf', 28 | help='crf / span / mrc') 29 | 30 | parser.add_argument('--loss_type', default='ls_ce', 31 | help='loss type for span bert') 32 | 33 | parser.add_argument('--use_type_embed', default=False, action='store_true', 34 | help='weather to use soft label in span loss') 35 | 36 | parser.add_argument('--use_fp16', default=False, action='store_true', 37 | help='weather to use fp16 during training') 38 | 39 | # other args 40 | parser.add_argument('--seed', type=int, default=123, help='random seed') 41 | 42 | parser.add_argument('--gpu_ids', type=str, default='0', 43 | help='gpu ids to use, -1 for cpu, "0,1" for multi gpu') 44 | 45 | parser.add_argument('--mode', type=str, default='train', 46 | help='train / stack') 47 | 48 | parser.add_argument('--max_seq_len', default=512, type=int) 49 | 50 | parser.add_argument('--eval_batch_size', default=64, type=int) 51 | 52 | parser.add_argument('--swa_start', default=3, type=int, 53 | help='the epoch when swa start') 54 | 55 | # train args 56 | parser.add_argument('--train_epochs', default=10, type=int, 57 | help='Max training epoch') 58 | 59 | parser.add_argument('--dropout_prob', default=0.1, type=float, 60 | help='drop out probability') 61 | 62 | parser.add_argument('--lr', default=2e-5, type=float, 63 | help='learning rate for the bert module') 64 | 65 | parser.add_argument('--other_lr', default=2e-3, type=float, 66 | help='learning rate for the module except bert') 67 | 68 | parser.add_argument('--max_grad_norm', default=1.0, type=float, 69 | help='max grad clip') 70 | 71 | parser.add_argument('--warmup_proportion', default=0.1, type=float) 72 | 73 | parser.add_argument('--weight_decay', default=0.00, type=float) 74 | 75 | parser.add_argument('--adam_epsilon', default=1e-8, type=float) 76 | 77 | parser.add_argument('--train_batch_size', default=24, type=int) 78 | 79 | parser.add_argument('--eval_model', default=True, action='store_true', 80 | help='whether to eval model after training') 81 | 82 | parser.add_argument('--attack_train', default='', type=str, 83 | help='fgm / pgd attack train when training') 84 | 85 | # test args 86 | parser.add_argument('--version', default='v0', type=str, 87 | help='submit version') 88 | 89 | parser.add_argument('--submit_dir', default='./submit', type=str) 90 | 91 | parser.add_argument('--ckpt_dir', default='', type=str) 92 | 93 | return parser 94 | 95 | def get_parser(self): 96 | parser = self.parse() 97 | parser = self.initialize(parser) 98 | return parser.parse_args() 99 | -------------------------------------------------------------------------------- /src/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import logging 5 | from torch.cuda.amp import autocast as ac 6 | from torch.utils.data import DataLoader, RandomSampler 7 | from transformers import AdamW, get_linear_schedule_with_warmup 8 | from src.utils.attack_train_utils import FGM, PGD 9 | from src.utils.functions_utils import load_model_and_parallel, swa 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def save_model(opt, model, global_step): 16 | output_dir = os.path.join(opt.output_dir, 'checkpoint-{}'.format(global_step)) 17 | if not os.path.exists(output_dir): 18 | os.makedirs(output_dir, exist_ok=True) 19 | 20 | # take care of model distributed / parallel training 21 | model_to_save = ( 22 | model.module if hasattr(model, "module") else model 23 | ) 24 | logger.info(f'Saving model & optimizer & scheduler checkpoint to {output_dir}') 25 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt')) 26 | 27 | 28 | def build_optimizer_and_scheduler(opt, model, t_total): 29 | module = ( 30 | model.module if hasattr(model, "module") else model 31 | ) 32 | 33 | # 差分学习率 34 | no_decay = ["bias", "LayerNorm.weight"] 35 | model_param = list(module.named_parameters()) 36 | 37 | bert_param_optimizer = [] 38 | other_param_optimizer = [] 39 | 40 | for name, para in model_param: 41 | space = name.split('.') 42 | if space[0] == 'bert_module': 43 | bert_param_optimizer.append((name, para)) 44 | else: 45 | other_param_optimizer.append((name, para)) 46 | 47 | optimizer_grouped_parameters = [ 48 | # bert other module 49 | {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)], 50 | "weight_decay": opt.weight_decay, 'lr': opt.lr}, 51 | {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)], 52 | "weight_decay": 0.0, 'lr': opt.lr}, 53 | 54 | # 其他模块,差分学习率 55 | {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)], 56 | "weight_decay": opt.weight_decay, 'lr': opt.other_lr}, 57 | {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)], 58 | "weight_decay": 0.0, 'lr': opt.other_lr}, 59 | ] 60 | 61 | optimizer = AdamW(optimizer_grouped_parameters, lr=opt.lr, eps=opt.adam_epsilon) 62 | scheduler = get_linear_schedule_with_warmup( 63 | optimizer, num_warmup_steps=int(opt.warmup_proportion * t_total), num_training_steps=t_total 64 | ) 65 | 66 | return optimizer, scheduler 67 | 68 | 69 | def train(opt, model, train_dataset): 70 | swa_raw_model = copy.deepcopy(model) 71 | 72 | train_sampler = RandomSampler(train_dataset) 73 | 74 | train_loader = DataLoader(dataset=train_dataset, 75 | batch_size=opt.train_batch_size, 76 | sampler=train_sampler, 77 | num_workers=0) 78 | 79 | scaler = None 80 | if opt.use_fp16: 81 | scaler = torch.cuda.amp.GradScaler() 82 | 83 | model, device = load_model_and_parallel(model, opt.gpu_ids) 84 | 85 | use_n_gpus = False 86 | if hasattr(model, "module"): 87 | use_n_gpus = True 88 | 89 | t_total = len(train_loader) * opt.train_epochs 90 | 91 | optimizer, scheduler = build_optimizer_and_scheduler(opt, model, t_total) 92 | 93 | # Train 94 | logger.info("***** Running training *****") 95 | logger.info(f" Num Examples = {len(train_dataset)}") 96 | logger.info(f" Num Epochs = {opt.train_epochs}") 97 | logger.info(f" Total training batch size = {opt.train_batch_size}") 98 | logger.info(f" Total optimization steps = {t_total}") 99 | 100 | global_step = 0 101 | 102 | model.zero_grad() 103 | 104 | fgm, pgd = None, None 105 | 106 | attack_train_mode = opt.attack_train.lower() 107 | if attack_train_mode == 'fgm': 108 | fgm = FGM(model=model) 109 | elif attack_train_mode == 'pgd': 110 | pgd = PGD(model=model) 111 | 112 | pgd_k = 3 113 | 114 | save_steps = t_total // opt.train_epochs 115 | eval_steps = save_steps 116 | 117 | logger.info(f'Save model in {save_steps} steps; Eval model in {eval_steps} steps') 118 | 119 | log_loss_steps = 20 120 | 121 | avg_loss = 0. 122 | 123 | for epoch in range(opt.train_epochs): 124 | 125 | for step, batch_data in enumerate(train_loader): 126 | 127 | model.train() 128 | 129 | for key in batch_data.keys(): 130 | batch_data[key] = batch_data[key].to(device) 131 | 132 | if opt.use_fp16: 133 | with ac(): 134 | loss = model(**batch_data)[0] 135 | else: 136 | loss = model(**batch_data)[0] 137 | 138 | if use_n_gpus: 139 | loss = loss.mean() 140 | 141 | if opt.use_fp16: 142 | scaler.scale(loss).backward() 143 | else: 144 | loss.backward() 145 | 146 | if fgm is not None: 147 | fgm.attack() 148 | 149 | if opt.use_fp16: 150 | with ac(): 151 | loss_adv = model(**batch_data)[0] 152 | else: 153 | loss_adv = model(**batch_data)[0] 154 | 155 | if use_n_gpus: 156 | loss_adv = loss_adv.mean() 157 | 158 | if opt.use_fp16: 159 | scaler.scale(loss_adv).backward() 160 | else: 161 | loss_adv.backward() 162 | 163 | fgm.restore() 164 | 165 | elif pgd is not None: 166 | pgd.backup_grad() 167 | 168 | for _t in range(pgd_k): 169 | pgd.attack(is_first_attack=(_t == 0)) 170 | 171 | if _t != pgd_k - 1: 172 | model.zero_grad() 173 | else: 174 | pgd.restore_grad() 175 | 176 | if opt.use_fp16: 177 | with ac(): 178 | loss_adv = model(**batch_data)[0] 179 | else: 180 | loss_adv = model(**batch_data)[0] 181 | 182 | if use_n_gpus: 183 | loss_adv = loss_adv.mean() 184 | 185 | if opt.use_fp16: 186 | scaler.scale(loss_adv).backward() 187 | else: 188 | loss_adv.backward() 189 | 190 | pgd.restore() 191 | 192 | if opt.use_fp16: 193 | scaler.unscale_(optimizer) 194 | 195 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm) 196 | 197 | # optimizer.step() 198 | if opt.use_fp16: 199 | scaler.step(optimizer) 200 | scaler.update() 201 | else: 202 | optimizer.step() 203 | 204 | scheduler.step() 205 | model.zero_grad() 206 | 207 | global_step += 1 208 | 209 | if global_step % log_loss_steps == 0: 210 | avg_loss /= log_loss_steps 211 | logger.info('Step: %d / %d ----> total loss: %.5f' % (global_step, t_total, avg_loss)) 212 | avg_loss = 0. 213 | else: 214 | avg_loss += loss.item() 215 | 216 | if global_step % save_steps == 0: 217 | save_model(opt, model, global_step) 218 | 219 | swa(swa_raw_model, opt.output_dir, swa_start=opt.swa_start) 220 | 221 | # clear cuda cache to avoid OOM 222 | torch.cuda.empty_cache() 223 | logger.info('Train done') 224 | --------------------------------------------------------------------------------