├── README.md
├── competition_predict.py
├── convert_test_data.py
├── data
├── mid_data
│ ├── crf_ent2id.json
│ ├── mrc_ent2id.json
│ └── span_ent2id.json
└── raw_data
│ ├── dev.json
│ ├── pseudo.json
│ ├── stack.json
│ ├── test.json
│ └── train.json
├── main.py
├── md_files
├── 1.png
├── 10.png
├── 11.png
├── 12.png
├── 13.png
├── 2.png
├── 3.png
├── 4.png
├── 5.png
├── 6.png
├── 7.png
├── 8.png
└── 9.png
├── run.sh
└── src
├── preprocess
├── __pycache__
│ └── processor.cpython-36.pyc
├── convert_raw_data.py
└── processor.py
└── utils
├── __pycache__
├── attack_train_utils.cpython-36.pyc
├── dataset_utils.cpython-36.pyc
├── evaluator.cpython-36.pyc
├── functions_utils.cpython-36.pyc
├── model_utils.cpython-36.pyc
├── options.cpython-36.pyc
└── trainer.cpython-36.pyc
├── attack_train_utils.py
├── dataset_utils.py
├── evaluator.py
├── functions_utils.py
├── model_utils.py
├── options.py
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 | # Chinese-DeepNER-Pytorch
2 |
3 | ## 天池中药说明书实体识别挑战冠军方案开源
4 |
5 | ### 贡献者:
6 | zxx飞翔的鱼: https://github.com/z814081807
7 |
8 | 我是蛋糕王:https://github.com/WuHuRestaurant
9 |
10 | 数青峰:https://github.com/zchaizju
11 |
12 | ### 后续官方开放数据集后DeepNER项目会进行优化升级,包含完整的数据处理、训练、验证、测试、部署流程,提供详细的代码注释、模型介绍、实验结果,提供更普适的基于预训练的中文命名实体识别方案,开箱即用,欢迎Star!
13 |
14 | (代码框架基于**pytorch and transformers**, 框架**复用性、解耦性、易读性**较高,很容易修改迁移至其他NLP任务中)
15 |
16 | ## 环境
17 |
18 | ```python
19 | python3.7
20 | pytorch==1.6.0 +
21 | transformers==2.10.0
22 | pytorch-crf==0.7.2
23 | ```
24 |
25 | ## 项目目录说明
26 |
27 | ```shell
28 | DeepNER
29 | │
30 | ├── data # 数据文件夹
31 | │ ├── mid_data # 存放一些中间数据
32 | │ │ ├── crf_ent2id.json # crf 模型的 schema
33 | │ │ └── span_ent2id.json # span 模型的 schema
34 | │ │ └── mrc_ent2id.json # mrc 模型的 schema
35 | │
36 | │ ├── raw_data # 转换后的数据
37 | │ │ ├── dev.json # 转换后的验证集
38 | │ │ ├── test.json # 转换后的初赛测试集
39 | │ │ ├── pseudo.json # 转换后的半监督数据
40 | │ │ ├── stack.json # 转换后的全体数据
41 | │ └── └── train.json # 转换后的训练集
42 | │
43 | ├── out # 存放训练好的模型
44 | │ ├── ...
45 | │ └── ...
46 | │
47 | ├── src
48 | │ ├── preprocess
49 | │ │ ├── convert_raw_data.py # 处理转换原始数据
50 | │ │ └── processor.py # 转换数据为 Bert 模型的输入
51 | │ ├── utils
52 | │ │ ├── attack_train_utils.py # 对抗训练 FGM / PGD
53 | │ │ ├── dataset_utils.py # torch Dataset
54 | │ │ ├── evaluator.py # 模型评估
55 | │ │ ├── functions_utils.py # 跨文件调用的一些 functions
56 | │ │ ├── model_utils.py # Span & CRF & MRC model (pytorch)
57 | │ │ ├── options.py # 命令行参数
58 | │ | └── trainer.py # 训练器
59 | |
60 | ├── competition_predict.py # 复赛数据推理并提交
61 | ├── README.md # ...
62 | ├── convert_test_data.py # 将复赛 test 转化成 json 格式
63 | ├── run.sh # 运行脚本
64 | └── main.py # main 函数 (主要用于训练/评估)
65 | ```
66 |
67 | ## 使用说明
68 |
69 | ### 预训练使用说明
70 |
71 | * 腾讯预训练模型 Uer-large(24层): https://github.com/dbiir/UER-py/wiki/Modelzoo
72 |
73 | * 哈工大预训练模型 :https://github.com/ymcui/Chinese-BERT-wwm
74 |
75 | 百度云下载链接:
76 |
77 | 链接:https://pan.baidu.com/s/1axdkovbzGaszl8bXIn4sPw
78 | 提取码:jjba
79 |
80 | (注意:需人工将 vocab.txt 中两个 [unused] 转换成 [INV] 和 [BLANK])
81 |
82 | tips: 推荐使用 uer、roberta-wwm、robert-wwm-large
83 |
84 | ### 数据转换
85 |
86 | **注:已提供转换好的数据 无需运行**
87 |
88 | ```python
89 | python src/preprocessing/convert_raw_data.py
90 | ```
91 |
92 | ### 训练阶段
93 |
94 | ```shell
95 | bash run.sh
96 | ```
97 |
98 | **注:脚本中指定的 BERT_DIR 指BERT所在文件夹,需要把 BERT 下载到指定文件夹中**
99 |
100 | ##### BERT-CRF模型训练
101 |
102 | ```python
103 | task_type='crf'
104 | mode='train' or 'stack' train:单模训练与验证 ; stack:5折训练与验证
105 |
106 | swa_start: swa 模型权重平均开始的 epoch
107 | attack_train: 'pgd' / 'fgm' / '' 对抗训练 fgm 训练速度慢一倍, pgd 慢两倍,pgd 本次数据集效果明显
108 | ```
109 |
110 | ##### BERT-SPAN模型训练
111 |
112 | ```python
113 | task_type='span'
114 | mode:同上
115 | attack_train: 同上
116 | loss_type: 'ce':交叉熵; 'ls_ce':label_smooth; 'focal': focal loss
117 | ```
118 |
119 | ##### BERT-MRC模型训练
120 |
121 | ```python
122 | task_type='mrc'
123 | mode:同上
124 | attack_train: 同上
125 | loss_type: 同上
126 | ```
127 |
128 | ### 预测复赛 test 文件 (上述模型训练完成后)
129 |
130 | **注:暂无数据运行,等待官方数据开源后可运行**
131 |
132 | ```shell
133 | # convert_test_data
134 | python convert_test_data.py
135 | # predict
136 | python competition_predict.py
137 | ```
138 |
139 | # 赛题背景
140 | ## 任务描述
141 | 人工智能加速了中医药领域的传承创新发展,其中中医药文本的信息抽取部分是构建中医药知识图谱的核心部分,为上层应用如临床辅助诊疗系统的构建(CDSS)等奠定了基础。本次NER挑战需要抽取中药药品说明书中的关键信息,包括药品、药物成分、疾病、症状、证候等13类实体,构建中医药药品知识库。
142 |
143 | ## 数据探索分析
144 | 本次竞赛训练数据有三个特点:
145 |
146 | - 中药药品说明书以长文本居多
147 |
148 | - 医疗场景下的标注样本不足
149 |
150 | - 标签分布不平衡
151 |
152 |
153 | # 核心思路
154 | ## 数据预处理
155 | 首先对说明书文本进行预清洗与长文本切分。预清洗部分对无效字符进行过滤。针对长文本问题,采用两级文本切分的策略。切分后的句子可能过短,将短文本归并,使得归并后的文本长度不超过设置的最大长度。此外,利用全部标注数据构造实体知识库,作为领域先验词典。
156 |
157 | ## Baseline: BERT-CRF
158 |
159 |
160 | - Baseline 细节
161 | - 预训练模型:选用 UER-large-24 layer[1],UER在RoBerta-wwm 框架下采用大规模优质中文语料继续训练,CLUE 任务中单模第一
162 | - 差分学习率:BERT层学习率2e-5;其他层学习率2e-3
163 | - 参数初始化:模型其他模块与BERT采用相同的初始化方式
164 | - 滑动参数平均:加权平均最后几个epoch模型的权重,得到更加平滑和表现更优的模型
165 | - Baseline bad-case分析
166 |
167 |
168 | ## 优化1:对抗训练
169 | - 动机:采用对抗训练缓解模型鲁棒性差的问题,提升模型泛化能力
170 | - 对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力
171 | - Fast Gradient Method (FGM):对embedding层在梯度方向添加扰动
172 | - Projected Gradient Descent (PGD) [2]:迭代扰动,每次扰动被投影到规定范围内
173 |
174 | ## 优化2:混合精度训练(FP16)
175 | - 动机:对抗训练降低了计算效率,使用混合精度训练优化训练耗时
176 | - 混合精度训练
177 | - 在内存中用FP16做存储和乘法来加速
178 | - 用FP32做累加避免舍入误差
179 | - 损失放大
180 | - 反向传播前扩大2^k倍loss,防止loss下溢出
181 | - 反向传播后将权重梯度还原
182 |
183 | ## 优化3:多模型融合
184 | - 动机:baseline 错误集中于歧义性错误,采用多级医学命名实体识别系统以消除歧义性
185 | - 方法:差异化多级模型融合系统
186 | - 模型框架差异化:BERT-CRF & BERT-SPAN & BERT-MRC
187 | - 训练数据差异化:更换随机种子、更换句子切分长度(256、512)
188 | - 多级模型融合策略
189 |
190 | - 融合模型1——BERT-SPAN
191 | - 采用SPAN指针的形式替代CRF模块,加快训练速度
192 | - 以半指针-半标注的结构预测实体的起始位置,同时标注过程中给出实体类别
193 | - 采用严格解码形式,重叠实体选取logits最大的一个,保证准确率
194 | - 使用label smooth缓解过拟合问题
195 |
196 |
197 | - 融合模型2——BERT-MRC
198 | - 基于阅读理解的方式处理NER任务
199 | - query:实体类型的描述来作为query
200 | - doc:分句后的原始文本作为doc
201 | - 针对每一种类型构造一个样本,训练时有大量负样本,可以随机选取30%加入训练,其余丢弃,保证效率
202 | - 预测时对每一类都需构造一次样本,对解码输出不做限制,保证召回率
203 | - 使用label smooth缓解过拟合问题
204 | - MRC在本次数据集上精度表现不佳,且训练和推理效率较低,仅作为提升召回率的方案,提供代码仅供学习,不推荐日常使用
205 |
206 |
207 | - 多级融合策略
208 | - CRF/SPAN/MRC 5折交叉验证得到的模型进行第一级概率融合,将 logits 平均后解码实体
209 | - CRF/SPAN/MRC 概率融合后的模型进行第二级投票融合,获取最终结果
210 |
211 |
212 | ## 优化4:半监督学习
213 | - 动机:为了缓解医疗场景下的标注语料稀缺的问题, 我们使用半监督学习(伪标签)充分利用未标注的500条初赛测试集
214 | - 策略:动态伪标签
215 | - 首先使用原始标注数据训练一个基准模型M
216 | - 使用基准模型M对初赛测试集进行预测得到伪标签
217 | - 将伪标签加入训练集,赋予伪标签一个动态可学习权重(图中alpha),加入真实标签数据中共同训练得到模型M’
218 |
219 | - tips:使用多模融合的基准模型减少伪标签的噪音;权重也可以固定,选取需多尝试哪个效果好,本质上是降低伪标签的loss权重,是缓解伪标签噪音的一种方法。
220 |
221 | ## 其他无明显提升的尝试方案
222 | - 取BERT后四层动态加权输出,无明显提升
223 | - BERT 输出后加上BiLSTM / IDCNN 模块,过拟合严重,训练速度大大降低
224 | - 数据增强,对同类实体词进行随机替换,以扩充训练数据
225 | - BERT-SPAN / MRC 模型采用focal loss / dice loss 等缓解标签不平衡
226 | - 利用构造的领域词典修正模型输出
227 |
228 | ## 最终线上成绩72.90%,复赛Rank 1,决赛Rank 1
229 |
230 |
231 | # Ref
232 | [1] Zhao et al., UER: An Open-Source Toolkit for Pre-training Models, EMNLP-IJCNLP, 2019.
233 | [2] Madry et al., Towards Deep Learning Models Resistant to Adversarial Attacks, ICLR, 2018.
234 |
--------------------------------------------------------------------------------
/competition_predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from collections import defaultdict
5 | from transformers import BertTokenizer
6 | from src.utils.model_utils import CRFModel, SpanModel, EnsembleCRFModel, EnsembleSpanModel
7 | from src.utils.evaluator import crf_decode, span_decode
8 | from src.utils.functions_utils import load_model_and_parallel, ensemble_vote
9 | from src.preprocess.processor import cut_sent, fine_grade_tokenize
10 |
11 | MID_DATA_DIR = "./data/mid_data"
12 | RAW_DATA_DIR = "./data/raw_data_random"
13 | SUBMIT_DIR = "./result"
14 | GPU_IDS = "0"
15 |
16 | LAMBDA = 0.3
17 | THRESHOLD = 0.9
18 | MAX_SEQ_LEN = 512
19 |
20 | TASK_TYPE = "crf" # choose crf or span
21 | VOTE = True # choose True or False
22 | VERSION = "mixed" # choose single or ensemble or mixed ; if mixed VOTE and TAST_TYPE is useless.
23 |
24 | # single_predict
25 | BERT_TYPE = "uer_large" # roberta_wwm / ernie_1 / uer_large
26 |
27 | BERT_DIR = f"./bert/torch_{BERT_TYPE}"
28 | with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
29 | CKPT_PATH = f.read().strip()
30 |
31 | # ensemble_predict
32 | BERT_DIR_LIST = ["./bert/torch_uer_large", "./bert/torch_roberta_wwm"]
33 |
34 | with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
35 | ENSEMBLE_DIR_LIST = f.readlines()
36 | print('ENSEMBLE_DIR_LIST:{}'.format(ENSEMBLE_DIR_LIST))
37 |
38 |
39 | # mixed_predict
40 | MIX_BERT_DIR = "./bert/torch_uer_large"
41 |
42 | with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
43 | MIX_DIR_LIST = f.readlines()
44 | print('MIX_DIR_LIST:{}'.format(MIX_DIR_LIST))
45 |
46 |
47 | def prepare_info():
48 | info_dict = {}
49 | with open(os.path.join(MID_DATA_DIR, f'{TASK_TYPE}_ent2id.json'), encoding='utf-8') as f:
50 | ent2id = json.load(f)
51 |
52 | with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f:
53 | info_dict['examples'] = json.load(f)
54 |
55 | info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()}
56 |
57 | info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt'))
58 |
59 | return info_dict
60 |
61 |
62 | def mixed_prepare_info(mixed='crf'):
63 | info_dict = {}
64 | with open(os.path.join(MID_DATA_DIR, f'{mixed}_ent2id.json'), encoding='utf-8') as f:
65 | ent2id = json.load(f)
66 |
67 | with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f:
68 | info_dict['examples'] = json.load(f)
69 |
70 | info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()}
71 |
72 | info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt'))
73 |
74 | return info_dict
75 |
76 |
77 | def base_predict(model, device, info_dict, ensemble=False, mixed=''):
78 | labels = defaultdict(list)
79 |
80 | tokenizer = info_dict['tokenizer']
81 | id2ent = info_dict['id2ent']
82 |
83 | with torch.no_grad():
84 | for _ex in info_dict['examples']:
85 | ex_idx = _ex['id']
86 | raw_text = _ex['text']
87 |
88 | if not len(raw_text):
89 | labels[ex_idx] = []
90 | print('{}为空'.format(ex_idx))
91 | continue
92 |
93 | sentences = cut_sent(raw_text, MAX_SEQ_LEN)
94 |
95 | start_index = 0
96 |
97 | for sent in sentences:
98 |
99 | sent_tokens = fine_grade_tokenize(sent, tokenizer)
100 |
101 | encode_dict = tokenizer.encode_plus(text=sent_tokens,
102 | max_length=MAX_SEQ_LEN,
103 | is_pretokenized=True,
104 | pad_to_max_length=False,
105 | return_tensors='pt',
106 | return_token_type_ids=True,
107 | return_attention_mask=True)
108 |
109 | model_inputs = {'token_ids': encode_dict['input_ids'],
110 | 'attention_masks': encode_dict['attention_mask'],
111 | 'token_type_ids': encode_dict['token_type_ids']}
112 |
113 | for key in model_inputs:
114 | model_inputs[key] = model_inputs[key].to(device)
115 |
116 | if ensemble:
117 | if TASK_TYPE == 'crf':
118 | if VOTE:
119 | decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD)
120 | else:
121 | pred_tokens = model.predict(model_inputs)[0]
122 | decode_entities = crf_decode(pred_tokens, sent, id2ent)
123 | else:
124 | if VOTE:
125 | decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD)
126 | else:
127 | start_logits, end_logits = model.predict(model_inputs)
128 | start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
129 | end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
130 |
131 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
132 |
133 | else:
134 |
135 | if mixed:
136 | if mixed == 'crf':
137 | pred_tokens = model(**model_inputs)[0][0]
138 | decode_entities = crf_decode(pred_tokens, sent, id2ent)
139 | else:
140 | start_logits, end_logits = model(**model_inputs)
141 |
142 | start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
143 | end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
144 |
145 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
146 |
147 | else:
148 | if TASK_TYPE == 'crf':
149 | pred_tokens = model(**model_inputs)[0][0]
150 | decode_entities = crf_decode(pred_tokens, sent, id2ent)
151 | else:
152 | start_logits, end_logits = model(**model_inputs)
153 |
154 | start_logits = start_logits[0].cpu().numpy()[1:1+len(sent)]
155 | end_logits = end_logits[0].cpu().numpy()[1:1+len(sent)]
156 |
157 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
158 |
159 |
160 | for _ent_type in decode_entities:
161 | for _ent in decode_entities[_ent_type]:
162 | tmp_start = _ent[1] + start_index
163 | tmp_end = tmp_start + len(_ent[0])
164 |
165 | assert raw_text[tmp_start: tmp_end] == _ent[0]
166 |
167 | labels[ex_idx].append((_ent_type, tmp_start, tmp_end, _ent[0]))
168 |
169 | start_index += len(sent)
170 |
171 | if not len(labels[ex_idx]):
172 | labels[ex_idx] = []
173 |
174 | return labels
175 |
176 |
177 | def single_predict():
178 | save_dir = os.path.join(SUBMIT_DIR, VERSION)
179 |
180 | if not os.path.exists(save_dir):
181 | os.makedirs(save_dir, exist_ok=True)
182 |
183 | info_dict = prepare_info()
184 |
185 | if TASK_TYPE == 'crf':
186 | model = CRFModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent']))
187 | else:
188 | model = SpanModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent'])+1)
189 |
190 | print(f'Load model from {CKPT_PATH}')
191 | model, device = load_model_and_parallel(model, GPU_IDS, CKPT_PATH)
192 | model.eval()
193 |
194 | labels = base_predict(model, device, info_dict)
195 |
196 | for key in labels.keys():
197 | with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
198 | if not len(labels[key]):
199 | print(key)
200 | f.write("")
201 | else:
202 | for idx, _label in enumerate(labels[key]):
203 | f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
204 |
205 |
206 | def ensemble_predict():
207 | save_dir = os.path.join(SUBMIT_DIR, VERSION)
208 | if not os.path.exists(save_dir):
209 | os.makedirs(save_dir, exist_ok=True)
210 |
211 | info_dict = prepare_info()
212 |
213 | model_path_list = [x.strip() for x in ENSEMBLE_DIR_LIST]
214 | print('model_path_list:{}'.format(model_path_list))
215 |
216 | device = torch.device(f'cuda:{GPU_IDS[0]}')
217 |
218 |
219 | if TASK_TYPE == 'crf':
220 | model = EnsembleCRFModel(model_path_list=model_path_list,
221 | bert_dir_list=BERT_DIR_LIST,
222 | num_tags=len(info_dict['id2ent']),
223 | device=device,
224 | lamb=LAMBDA)
225 | else:
226 | model = EnsembleSpanModel(model_path_list=model_path_list,
227 | bert_dir_list=BERT_DIR_LIST,
228 | num_tags=len(info_dict['id2ent'])+1,
229 | device=device)
230 |
231 |
232 | labels = base_predict(model, device, info_dict, ensemble=True)
233 |
234 |
235 | for key in labels.keys():
236 | with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
237 | if not len(labels[key]):
238 | print(key)
239 | f.write("")
240 | else:
241 | for idx, _label in enumerate(labels[key]):
242 | f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
243 |
244 | def mixed_predict():
245 | save_dir = os.path.join(SUBMIT_DIR, VERSION)
246 |
247 | if not os.path.exists(save_dir):
248 | os.makedirs(save_dir, exist_ok=True)
249 |
250 | model_path_list = [x.strip() for x in MIX_DIR_LIST]
251 | print('model_path_list:{}'.format(model_path_list))
252 |
253 | all_labels = []
254 |
255 | for i, model_path in enumerate(model_path_list):
256 | if i <= 4:
257 | info_dict = mixed_prepare_info(mixed='span')
258 |
259 | model = SpanModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']) + 1)
260 | print(f'Load model from {model_path}')
261 | model, device = load_model_and_parallel(model, GPU_IDS, model_path)
262 | model.eval()
263 | labels = base_predict(model, device, info_dict, ensemble=False, mixed='span')
264 |
265 | else:
266 | info_dict = mixed_prepare_info(mixed='crf')
267 |
268 | model = CRFModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']))
269 | print(f'Load model from {model_path}')
270 | model, device = load_model_and_parallel(model, GPU_IDS, model_path)
271 | model.eval()
272 | labels = base_predict(model, device, info_dict, ensemble=False, mixed='crf')
273 |
274 | all_labels.append(labels)
275 |
276 | labels = ensemble_vote(all_labels, THRESHOLD)
277 |
278 | # for key in labels.keys():
279 |
280 | for key in range(1500, 1997):
281 | with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
282 | if not len(labels[key]):
283 | print(key)
284 | f.write("")
285 | else:
286 | for idx, _label in enumerate(labels[key]):
287 | f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
288 |
289 |
290 |
291 | if __name__ == '__main__':
292 | assert VERSION in ['single', 'ensemble', 'mixed'], 'VERSION mismatch'
293 |
294 | if VERSION == 'single':
295 | single_predict()
296 | elif VERSION == 'ensemble':
297 | if VOTE:
298 | print("————————开始投票:————————")
299 | ensemble_predict()
300 |
301 | elif VERSION == 'mixed':
302 | print("————————开始混合投票:————————")
303 | mixed_predict()
304 |
305 | # 压缩result.zip
306 | import zipfile
307 |
308 | def zip_file(src_dir):
309 | zip_name = src_dir + '.zip'
310 | z = zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED)
311 | for dirpath, dirnames, filenames in os.walk(src_dir):
312 | fpath = dirpath.replace(src_dir, '')
313 | fpath = fpath and fpath + os.sep or ''
314 | for filename in filenames:
315 | z.write(os.path.join(dirpath, filename), fpath + filename)
316 | print('==压缩成功==')
317 | z.close()
318 |
319 | zip_file('./result')
320 |
321 |
--------------------------------------------------------------------------------
/convert_test_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import trange
4 |
5 |
6 | def save_info(data_dir, data, desc):
7 | with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f:
8 | json.dump(data, f, ensure_ascii=False, indent=2)
9 |
10 |
11 | def convert_test_data_to_json(test_dir, save_dir):
12 |
13 | test_examples = []
14 |
15 |
16 | # process test examples
17 | for i in trange(1500, 1997):
18 | with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f:
19 | text = f.read()
20 |
21 | test_examples.append({'id': i,
22 | 'text': text})
23 |
24 | save_info(save_dir, test_examples, 'test')
25 |
26 |
27 | if __name__ == '__main__':
28 | test_dir = './tcdata/juesai'
29 | save_dir = './data/raw_data_random'
30 | convert_test_data_to_json(test_dir, save_dir)
31 | print('测试数据转换完成')
32 |
33 |
--------------------------------------------------------------------------------
/data/mid_data/crf_ent2id.json:
--------------------------------------------------------------------------------
1 | {
2 | "O": 0,
3 | "B-DRUG_GROUP": 1,
4 | "B-DRUG_DOSAGE": 2,
5 | "B-FOOD": 3,
6 | "B-DRUG_EFFICACY": 4,
7 | "B-FOOD_GROUP": 5,
8 | "B-SYMPTOM": 6,
9 | "B-DISEASE_GROUP": 7,
10 | "B-SYNDROME": 8,
11 | "B-PERSON_GROUP": 9,
12 | "B-DRUG_TASTE": 10,
13 | "B-DRUG_INGREDIENT": 11,
14 | "B-DRUG": 12,
15 | "B-DISEASE": 13,
16 | "I-DRUG_GROUP": 14,
17 | "I-DRUG_DOSAGE": 15,
18 | "I-FOOD": 16,
19 | "I-DRUG_EFFICACY": 17,
20 | "I-FOOD_GROUP": 18,
21 | "I-SYMPTOM": 19,
22 | "I-DISEASE_GROUP": 20,
23 | "I-SYNDROME": 21,
24 | "I-PERSON_GROUP": 22,
25 | "I-DRUG_TASTE": 23,
26 | "I-DRUG_INGREDIENT": 24,
27 | "I-DRUG": 25,
28 | "I-DISEASE": 26,
29 | "E-DRUG_GROUP": 27,
30 | "E-DRUG_DOSAGE": 28,
31 | "E-FOOD": 29,
32 | "E-DRUG_EFFICACY": 30,
33 | "E-FOOD_GROUP": 31,
34 | "E-SYMPTOM": 32,
35 | "E-DISEASE_GROUP": 33,
36 | "E-SYNDROME": 34,
37 | "E-PERSON_GROUP": 35,
38 | "E-DRUG_TASTE": 36,
39 | "E-DRUG_INGREDIENT": 37,
40 | "E-DRUG": 38,
41 | "E-DISEASE": 39,
42 | "S-DRUG_GROUP": 40,
43 | "S-DRUG_DOSAGE": 41,
44 | "S-FOOD": 42,
45 | "S-DRUG_EFFICACY": 43,
46 | "S-FOOD_GROUP": 44,
47 | "S-SYMPTOM": 45,
48 | "S-DISEASE_GROUP": 46,
49 | "S-SYNDROME": 47,
50 | "S-PERSON_GROUP": 48,
51 | "S-DRUG_TASTE": 49,
52 | "S-DRUG_INGREDIENT": 50,
53 | "S-DRUG": 51,
54 | "S-DISEASE": 52
55 | }
--------------------------------------------------------------------------------
/data/mid_data/mrc_ent2id.json:
--------------------------------------------------------------------------------
1 | {
2 | "DRUG": "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。",
3 | "DRUG_INGREDIENT": "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。",
4 | "DISEASE": "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。",
5 | "SYMPTOM": "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。",
6 | "SYNDROME": "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。",
7 | "DISEASE_GROUP": "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。",
8 | "FOOD": "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。",
9 | "FOOD_GROUP": "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。",
10 | "PERSON_GROUP": "找出人群:中医药的适用及禁忌范围内相关特定人群。",
11 | "DRUG_GROUP": "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药",
12 | "DRUG_DOSAGE": "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。",
13 | "DRUG_TASTE": "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。",
14 | "DRUG_EFFICACY": "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀"
15 | }
--------------------------------------------------------------------------------
/data/mid_data/span_ent2id.json:
--------------------------------------------------------------------------------
1 | {
2 | "DRUG_GROUP": 1,
3 | "DRUG_DOSAGE": 2,
4 | "FOOD": 3,
5 | "DRUG_EFFICACY": 4,
6 | "FOOD_GROUP": 5,
7 | "SYMPTOM": 6,
8 | "DISEASE_GROUP": 7,
9 | "SYNDROME": 8,
10 | "PERSON_GROUP": 9,
11 | "DRUG_TASTE": 10,
12 | "DRUG_INGREDIENT": 11,
13 | "DRUG": 12,
14 | "DISEASE": 13
15 | }
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import json
4 | import logging
5 | from torch.utils.data import DataLoader
6 | from sklearn.model_selection import KFold
7 | from src.utils.trainer import train
8 | from src.utils.options import Args
9 | from src.utils.model_utils import build_model
10 | from src.utils.dataset_utils import NERDataset
11 | from src.utils.evaluator import crf_evaluation, span_evaluation, mrc_evaluation
12 | from src.utils.functions_utils import set_seed, get_model_path_list, load_model_and_parallel, get_time_dif
13 | from src.preprocess.processor import NERProcessor, convert_examples_to_features
14 |
15 | logger = logging.getLogger(__name__)
16 | logging.basicConfig(
17 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
18 | datefmt="%m/%d/%Y %H:%M:%S",
19 | level=logging.INFO
20 | )
21 |
22 | def train_base(opt, train_examples, dev_examples=None):
23 | with open(os.path.join(opt.mid_data_dir, f'{opt.task_type}_ent2id.json'), encoding='utf-8') as f:
24 | ent2id = json.load(f)
25 |
26 | train_features = convert_examples_to_features(opt.task_type, train_examples,
27 | opt.max_seq_len, opt.bert_dir, ent2id)[0]
28 |
29 | train_dataset = NERDataset(opt.task_type, train_features, 'train', use_type_embed=opt.use_type_embed)
30 |
31 | if opt.task_type == 'crf':
32 | model = build_model('crf', opt.bert_dir, num_tags=len(ent2id),
33 | dropout_prob=opt.dropout_prob)
34 | elif opt.task_type == 'mrc':
35 | model = build_model('mrc', opt.bert_dir,
36 | dropout_prob=opt.dropout_prob,
37 | use_type_embed=opt.use_type_embed,
38 | loss_type=opt.loss_type)
39 | else:
40 | model = build_model('span', opt.bert_dir, num_tags=len(ent2id)+1,
41 | dropout_prob=opt.dropout_prob,
42 | loss_type=opt.loss_type)
43 |
44 | train(opt, model, train_dataset)
45 |
46 | if dev_examples is not None:
47 |
48 | dev_features, dev_callback_info = convert_examples_to_features(opt.task_type, dev_examples,
49 | opt.max_seq_len, opt.bert_dir, ent2id)
50 |
51 | dev_dataset = NERDataset(opt.task_type, dev_features, 'dev', use_type_embed=opt.use_type_embed)
52 |
53 | dev_loader = DataLoader(dev_dataset, batch_size=opt.eval_batch_size,
54 | shuffle=False, num_workers=0)
55 |
56 | dev_info = (dev_loader, dev_callback_info)
57 |
58 | model_path_list = get_model_path_list(opt.output_dir)
59 |
60 | metric_str = ''
61 |
62 | max_f1 = 0.
63 | max_f1_step = 0
64 |
65 | max_f1_path = ''
66 |
67 | for idx, model_path in enumerate(model_path_list):
68 |
69 | tmp_step = model_path.split('/')[-2].split('-')[-1]
70 |
71 |
72 | model, device = load_model_and_parallel(model, opt.gpu_ids[0],
73 | ckpt_path=model_path)
74 |
75 | if opt.task_type == 'crf':
76 | tmp_metric_str, tmp_f1 = crf_evaluation(model, dev_info, device, ent2id)
77 | elif opt.task_type == 'mrc':
78 | tmp_metric_str, tmp_f1 = mrc_evaluation(model, dev_info, device)
79 | else:
80 | tmp_metric_str, tmp_f1 = span_evaluation(model, dev_info, device, ent2id)
81 |
82 | logger.info(f'In step {tmp_step}:\n {tmp_metric_str}')
83 |
84 | metric_str += f'In step {tmp_step}:\n {tmp_metric_str}' + '\n\n'
85 |
86 | if tmp_f1 > max_f1:
87 | max_f1 = tmp_f1
88 | max_f1_step = tmp_step
89 | max_f1_path = model_path
90 |
91 | max_metric_str = f'Max f1 is: {max_f1}, in step {max_f1_step}'
92 |
93 | logger.info(max_metric_str)
94 |
95 | metric_str += max_metric_str + '\n'
96 |
97 | eval_save_path = os.path.join(opt.output_dir, 'eval_metric.txt')
98 |
99 | with open(eval_save_path, 'a', encoding='utf-8') as f1:
100 | f1.write(metric_str)
101 |
102 | with open('./best_ckpt_path.txt', 'a', encoding='utf-8') as f2:
103 | f2.write(max_f1_path + '\n')
104 |
105 | del_dir_list = [os.path.join(opt.output_dir, path.split('/')[-2])
106 | for path in model_path_list if path != max_f1_path]
107 |
108 | import shutil
109 | for x in del_dir_list:
110 | shutil.rmtree(x)
111 | logger.info('{}已删除'.format(x))
112 |
113 |
114 | def training(opt):
115 | if args.task_type == 'mrc':
116 | # 62 for mrc query
117 | processor = NERProcessor(opt.max_seq_len-62)
118 | else:
119 | processor = NERProcessor(opt.max_seq_len)
120 |
121 | train_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'train.json'))
122 |
123 | # add pseudo data to train data
124 | pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
125 | train_raw_examples = train_raw_examples + pseudo_raw_examples
126 |
127 | train_examples = processor.get_examples(train_raw_examples, 'train')
128 |
129 | dev_examples = None
130 | if opt.eval_model:
131 | dev_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'dev.json'))
132 | dev_examples = processor.get_examples(dev_raw_examples, 'dev')
133 |
134 | train_base(opt, train_examples, dev_examples)
135 |
136 |
137 | def stacking(opt):
138 | logger.info('Start to KFold stack attribution model')
139 |
140 | if args.task_type == 'mrc':
141 | # 62 for mrc query
142 | processor = NERProcessor(opt.max_seq_len-62)
143 | else:
144 | processor = NERProcessor(opt.max_seq_len)
145 |
146 | kf = KFold(5, shuffle=True, random_state=42)
147 |
148 | stack_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'stack.json'))
149 |
150 | pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
151 |
152 | base_output_dir = opt.output_dir
153 |
154 | for i, (train_ids, dev_ids) in enumerate(kf.split(stack_raw_examples)):
155 | logger.info(f'Start to train the {i} fold')
156 | train_raw_examples = [stack_raw_examples[_idx] for _idx in train_ids]
157 |
158 | # add pseudo data to train data
159 | train_raw_examples = train_raw_examples + pseudo_raw_examples
160 | train_examples = processor.get_examples(train_raw_examples, 'train')
161 |
162 | dev_raw_examples = [stack_raw_examples[_idx] for _idx in dev_ids]
163 | dev_info = processor.get_examples(dev_raw_examples, 'dev')
164 |
165 | tmp_output_dir = os.path.join(base_output_dir, f'v{i}')
166 |
167 | opt.output_dir = tmp_output_dir
168 |
169 | train_base(opt, train_examples, dev_info)
170 |
171 | if __name__ == '__main__':
172 | start_time = time.time()
173 | logging.info('----------------开始计时----------------')
174 | logging.info('----------------------------------------')
175 |
176 | args = Args().get_parser()
177 |
178 | assert args.mode in ['train', 'stack'], 'mode mismatch'
179 | assert args.task_type in ['crf', 'span', 'mrc']
180 |
181 | args.output_dir = os.path.join(args.output_dir, args.bert_type)
182 |
183 | set_seed(args.seed)
184 |
185 | if args.attack_train != '':
186 | args.output_dir += f'_{args.attack_train}'
187 |
188 | if args.weight_decay:
189 | args.output_dir += '_wd'
190 |
191 | if args.use_fp16:
192 | args.output_dir += '_fp16'
193 |
194 | if args.task_type == 'span':
195 | args.output_dir += f'_{args.loss_type}'
196 |
197 | if args.task_type == 'mrc':
198 | if args.use_type_embed:
199 | args.output_dir += f'_embed'
200 | args.output_dir += f'_{args.loss_type}'
201 |
202 | args.output_dir += f'_{args.task_type}'
203 |
204 | if args.mode == 'stack':
205 | args.output_dir += '_stack'
206 |
207 | if not os.path.exists(args.output_dir):
208 | os.makedirs(args.output_dir, exist_ok=True)
209 |
210 | logger.info(f'{args.mode} {args.task_type} in max_seq_len {args.max_seq_len}')
211 |
212 | if args.mode == 'train':
213 | training(args)
214 | else:
215 | stacking(args)
216 |
217 | time_dif = get_time_dif(start_time)
218 | logging.info("----------本次容器运行时长:{}-----------".format(time_dif))
219 |
--------------------------------------------------------------------------------
/md_files/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/1.png
--------------------------------------------------------------------------------
/md_files/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/10.png
--------------------------------------------------------------------------------
/md_files/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/11.png
--------------------------------------------------------------------------------
/md_files/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/12.png
--------------------------------------------------------------------------------
/md_files/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/13.png
--------------------------------------------------------------------------------
/md_files/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/2.png
--------------------------------------------------------------------------------
/md_files/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/3.png
--------------------------------------------------------------------------------
/md_files/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/4.png
--------------------------------------------------------------------------------
/md_files/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/5.png
--------------------------------------------------------------------------------
/md_files/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/6.png
--------------------------------------------------------------------------------
/md_files/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/7.png
--------------------------------------------------------------------------------
/md_files/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/8.png
--------------------------------------------------------------------------------
/md_files/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/md_files/9.png
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export MID_DATA_DIR="./data/mid_data"
4 | export RAW_DATA_DIR="./data/raw_data"
5 | export OUTPUT_DIR="./out"
6 |
7 | export GPU_IDS="0"
8 | export BERT_TYPE="roberta_wwm" # roberta_wwm / roberta_wwm_large / uer_large
9 | export BERT_DIR="../bert/torch_$BERT_TYPE"
10 |
11 | export MODE="train"
12 | export TASK_TYPE="crf"
13 |
14 | python main.py \
15 | --gpu_ids=$GPU_IDS \
16 | --output_dir=$OUTPUT_DIR \
17 | --mid_data_dir=$MID_DATA_DIR \
18 | --mode=$MODE \
19 | --task_type=$TASK_TYPE \
20 | --raw_data_dir=$RAW_DATA_DIR \
21 | --bert_dir=$BERT_DIR \
22 | --bert_type=$BERT_TYPE \
23 | --train_epochs=10 \
24 | --swa_start=5 \
25 | --attack_train="" \
26 | --train_batch_size=24 \
27 | --dropout_prob=0.1 \
28 | --max_seq_len=512 \
29 | --lr=2e-5 \
30 | --other_lr=2e-3 \
31 | --seed=123 \
32 | --weight_decay=0.01 \
33 | --loss_type='ls_ce' \
34 | --eval_model \
35 | #--use_fp16
--------------------------------------------------------------------------------
/src/preprocess/__pycache__/processor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/preprocess/__pycache__/processor.cpython-36.pyc
--------------------------------------------------------------------------------
/src/preprocess/convert_raw_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import trange
4 | from sklearn.model_selection import train_test_split, KFold
5 |
6 |
7 | def save_info(data_dir, data, desc):
8 | with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f:
9 | json.dump(data, f, ensure_ascii=False, indent=2)
10 |
11 |
12 | def convert_data_to_json(base_dir, save_data=False, save_dict=False):
13 | stack_examples = []
14 | pseudo_examples = []
15 | test_examples = []
16 |
17 | stack_dir = os.path.join(base_dir, 'train')
18 | pseudo_dir = os.path.join(base_dir, 'pseudo')
19 | test_dir = os.path.join(base_dir, 'test')
20 |
21 |
22 | # process train examples
23 | for i in trange(1000):
24 | with open(os.path.join(stack_dir, f'{i}.txt'), encoding='utf-8') as f:
25 | text = f.read()
26 |
27 | labels = []
28 | with open(os.path.join(stack_dir, f'{i}.ann'), encoding='utf-8') as f:
29 | for line in f.readlines():
30 | tmp_label = line.strip().split('\t')
31 | assert len(tmp_label) == 3
32 | tmp_mid = tmp_label[1].split()
33 | tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]]
34 |
35 | labels.append(tmp_label)
36 | tmp_label[2] = int(tmp_label[2])
37 | tmp_label[3] = int(tmp_label[3])
38 |
39 | assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i)
40 |
41 | stack_examples.append({'id': i,
42 | 'text': text,
43 | 'labels': labels,
44 | 'pseudo': 0})
45 |
46 |
47 | # 构建实体知识库
48 | kf = KFold(10)
49 | entities = set()
50 | ent_types = set()
51 | for _now_id, _candidate_id in kf.split(stack_examples):
52 | now = [stack_examples[_id] for _id in _now_id]
53 | candidate = [stack_examples[_id] for _id in _candidate_id]
54 | now_entities = set()
55 |
56 | for _ex in now:
57 | for _label in _ex['labels']:
58 | ent_types.add(_label[1])
59 |
60 | if len(_label[-1]) > 1:
61 | now_entities.add(_label[-1])
62 | entities.add(_label[-1])
63 | # print(len(now_entities))
64 | for _ex in candidate:
65 | text = _ex['text']
66 | candidate_entities = []
67 |
68 | for _ent in now_entities:
69 | if _ent in text:
70 | candidate_entities.append(_ent)
71 |
72 | _ex['candidate_entities'] = candidate_entities
73 | assert len(ent_types) == 13
74 |
75 | # process test examples predicted by the preliminary model
76 | for i in trange(1000, 1500):
77 | with open(os.path.join(pseudo_dir, f'{i}.txt'), encoding='utf-8') as f:
78 | text = f.read()
79 |
80 | candidate_entities = []
81 | for _ent in entities:
82 | if _ent in text:
83 | candidate_entities.append(_ent)
84 |
85 | labels = []
86 | with open(os.path.join(pseudo_dir, f'{i}.ann'), encoding='utf-8') as f:
87 | for line in f.readlines():
88 | tmp_label = line.strip().split('\t')
89 | assert len(tmp_label) == 3
90 | tmp_mid = tmp_label[1].split()
91 | tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]]
92 |
93 | labels.append(tmp_label)
94 | tmp_label[2] = int(tmp_label[2])
95 | tmp_label[3] = int(tmp_label[3])
96 |
97 | assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i)
98 |
99 | pseudo_examples.append({'id': i,
100 | 'text': text,
101 | 'labels': labels,
102 | 'candidate_entities': candidate_entities,
103 | 'pseudo': 1})
104 |
105 | # process test examples
106 | for i in trange(1000, 1500):
107 | with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f:
108 | text = f.read()
109 |
110 | candidate_entities = []
111 | for _ent in entities:
112 | if _ent in text:
113 | candidate_entities.append(_ent)
114 |
115 | test_examples.append({'id': i,
116 | 'text': text,
117 | 'candidate_entities': candidate_entities})
118 |
119 | train, dev = train_test_split(stack_examples, shuffle=True, random_state=123, test_size=0.15)
120 |
121 | if save_data:
122 | save_info(base_dir, stack_examples, 'stack')
123 | save_info(base_dir, train, 'train')
124 | save_info(base_dir, dev, 'dev')
125 | save_info(base_dir, test_examples, 'test')
126 |
127 | save_info(base_dir, pseudo_examples, 'pseudo')
128 |
129 | if save_dict:
130 | ent_types = list(ent_types)
131 | span_ent2id = {_type: i+1 for i, _type in enumerate(ent_types)}
132 |
133 | ent_types = ['O'] + [p + '-' + _type for p in ['B', 'I', 'E', 'S'] for _type in list(ent_types)]
134 | crf_ent2id = {ent: i for i, ent in enumerate(ent_types)}
135 |
136 | mid_data_dir = os.path.join(os.path.split(base_dir)[0], 'mid_data')
137 | if not os.path.exists(mid_data_dir):
138 | os.mkdir(mid_data_dir)
139 |
140 | save_info(mid_data_dir, span_ent2id, 'span_ent2id')
141 | save_info(mid_data_dir, crf_ent2id, 'crf_ent2id')
142 |
143 |
144 | def build_ent2query(data_dir):
145 | # 利用比赛实体类型简介来描述 query
146 | ent2query = {
147 | # 药物
148 | 'DRUG': "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。",
149 | # 药物成分
150 | 'DRUG_INGREDIENT': "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。",
151 | # 疾病
152 | 'DISEASE': "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。",
153 | # 症状
154 | 'SYMPTOM': "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。",
155 | # 症候
156 | 'SYNDROME': "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。",
157 | # 疾病分组
158 | 'DISEASE_GROUP': "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。",
159 | # 食物
160 | 'FOOD': "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。",
161 | # 食物分组
162 | 'FOOD_GROUP': "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,"
163 | "同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。",
164 | # 人群
165 | 'PERSON_GROUP': "找出人群:中医药的适用及禁忌范围内相关特定人群。",
166 | # 药品分组
167 | 'DRUG_GROUP': "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药",
168 | # 药物剂量
169 | 'DRUG_DOSAGE': "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。",
170 | # 药物性味
171 | 'DRUG_TASTE': "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。",
172 | # 中药功效
173 | 'DRUG_EFFICACY': "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀"
174 | }
175 |
176 | with open(os.path.join(data_dir, 'mrc_ent2id.json'), 'w', encoding='utf-8') as f:
177 | json.dump(ent2query, f, ensure_ascii=False, indent=2)
178 |
179 |
180 | if __name__ == '__main__':
181 | convert_data_to_json('../../data/raw_data', save_data=True, save_dict=True)
182 | build_ent2query('../../data/mid_data')
183 |
184 |
--------------------------------------------------------------------------------
/src/preprocess/processor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import logging
5 | from transformers import BertTokenizer
6 | from collections import defaultdict
7 | import random
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | ENTITY_TYPES = ['DRUG', 'DRUG_INGREDIENT', 'DISEASE', 'SYMPTOM', 'SYNDROME', 'DISEASE_GROUP',
12 | 'FOOD', 'FOOD_GROUP', 'PERSON_GROUP', 'DRUG_GROUP', 'DRUG_DOSAGE', 'DRUG_TASTE',
13 | 'DRUG_EFFICACY']
14 |
15 |
16 | class InputExample:
17 | def __init__(self,
18 | set_type,
19 | text,
20 | labels=None,
21 | pseudo=None,
22 | distant_labels=None):
23 | self.set_type = set_type
24 | self.text = text
25 | self.labels = labels
26 | self.pseudo = pseudo
27 | self.distant_labels = distant_labels
28 |
29 |
30 | class BaseFeature:
31 | def __init__(self,
32 | token_ids,
33 | attention_masks,
34 | token_type_ids):
35 | # BERT 输入
36 | self.token_ids = token_ids
37 | self.attention_masks = attention_masks
38 | self.token_type_ids = token_type_ids
39 |
40 |
41 | class CRFFeature(BaseFeature):
42 | def __init__(self,
43 | token_ids,
44 | attention_masks,
45 | token_type_ids,
46 | labels=None,
47 | pseudo=None,
48 | distant_labels=None):
49 | super(CRFFeature, self).__init__(token_ids=token_ids,
50 | attention_masks=attention_masks,
51 | token_type_ids=token_type_ids)
52 | # labels
53 | self.labels = labels
54 |
55 | # pseudo
56 | self.pseudo = pseudo
57 |
58 | # distant labels
59 | self.distant_labels = distant_labels
60 |
61 |
62 | class SpanFeature(BaseFeature):
63 | def __init__(self,
64 | token_ids,
65 | attention_masks,
66 | token_type_ids,
67 | start_ids=None,
68 | end_ids=None,
69 | pseudo=None):
70 | super(SpanFeature, self).__init__(token_ids=token_ids,
71 | attention_masks=attention_masks,
72 | token_type_ids=token_type_ids)
73 | self.start_ids = start_ids
74 | self.end_ids = end_ids
75 | # pseudo
76 | self.pseudo = pseudo
77 |
78 | class MRCFeature(BaseFeature):
79 | def __init__(self,
80 | token_ids,
81 | attention_masks,
82 | token_type_ids,
83 | ent_type=None,
84 | start_ids=None,
85 | end_ids=None,
86 | pseudo=None):
87 | super(MRCFeature, self).__init__(token_ids=token_ids,
88 | attention_masks=attention_masks,
89 | token_type_ids=token_type_ids)
90 | self.ent_type = ent_type
91 | self.start_ids = start_ids
92 | self.end_ids = end_ids
93 |
94 | # pseudo
95 | self.pseudo = pseudo
96 |
97 |
98 | class NERProcessor:
99 | def __init__(self, cut_sent_len=256):
100 | self.cut_sent_len = cut_sent_len
101 |
102 | @staticmethod
103 | def read_json(file_path):
104 | with open(file_path, encoding='utf-8') as f:
105 | raw_examples = json.load(f)
106 | return raw_examples
107 |
108 | @staticmethod
109 | def _refactor_labels(sent, labels, distant_labels, start_index):
110 | """
111 | 分句后需要重构 labels 的 offset
112 | :param sent: 切分并重新合并后的句子
113 | :param labels: 原始文档级的 labels
114 | :param distant_labels: 远程监督 label
115 | :param start_index: 该句子在文档中的起始 offset
116 | :return (type, entity, offset)
117 | """
118 | new_labels, new_distant_labels = [], []
119 | end_index = start_index + len(sent)
120 |
121 | for _label in labels:
122 | if start_index <= _label[2] <= _label[3] <= end_index:
123 | new_offset = _label[2] - start_index
124 |
125 | assert sent[new_offset: new_offset + len(_label[-1])] == _label[-1]
126 |
127 | new_labels.append((_label[1], _label[-1], new_offset))
128 | # label 被截断的情况
129 | elif _label[2] < end_index < _label[3]:
130 | raise RuntimeError(f'{sent}, {_label}')
131 |
132 | for _label in distant_labels:
133 | if _label in sent:
134 | new_distant_labels.append(_label)
135 |
136 | return new_labels, new_distant_labels
137 |
138 | def get_examples(self, raw_examples, set_type):
139 | examples = []
140 |
141 | for i, item in enumerate(raw_examples):
142 | text = item['text']
143 | distant_labels = item['candidate_entities']
144 | pseudo = item['pseudo']
145 |
146 | sentences = cut_sent(text, self.cut_sent_len)
147 | start_index = 0
148 |
149 | for sent in sentences:
150 | labels, tmp_distant_labels = self._refactor_labels(sent, item['labels'], distant_labels, start_index)
151 |
152 | start_index += len(sent)
153 |
154 | examples.append(InputExample(set_type=set_type,
155 | text=sent,
156 | labels=labels,
157 | pseudo=pseudo,
158 | distant_labels=tmp_distant_labels))
159 |
160 | return examples
161 |
162 |
163 | def fine_grade_tokenize(raw_text, tokenizer):
164 | """
165 | 序列标注任务 BERT 分词器可能会导致标注偏移,
166 | 用 char-level 来 tokenize
167 | """
168 | tokens = []
169 |
170 | for _ch in raw_text:
171 | if _ch in [' ', '\t', '\n']:
172 | tokens.append('[BLANK]')
173 | else:
174 | if not len(tokenizer.tokenize(_ch)):
175 | tokens.append('[INV]')
176 | else:
177 | tokens.append(_ch)
178 |
179 | return tokens
180 |
181 |
182 | def cut_sentences_v1(sent):
183 | """
184 | the first rank of sentence cut
185 | """
186 | sent = re.sub('([。!?\?])([^”’])', r"\1\n\2", sent) # 单字符断句符
187 | sent = re.sub('(\.{6})([^”’])', r"\1\n\2", sent) # 英文省略号
188 | sent = re.sub('(\…{2})([^”’])', r"\1\n\2", sent) # 中文省略号
189 | sent = re.sub('([。!?\?][”’])([^,。!?\?])', r"\1\n\2", sent)
190 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后
191 | return sent.split("\n")
192 |
193 |
194 | def cut_sentences_v2(sent):
195 | """
196 | the second rank of spilt sentence, split ';' | ';'
197 | """
198 | sent = re.sub('([;;])([^”’])', r"\1\n\2", sent)
199 | return sent.split("\n")
200 |
201 |
202 | def cut_sent(text, max_seq_len):
203 | # 将句子分句,细粒度分句后再重新合并
204 | sentences = []
205 |
206 | # 细粒度划分
207 | sentences_v1 = cut_sentences_v1(text)
208 | for sent_v1 in sentences_v1:
209 | if len(sent_v1) > max_seq_len - 2:
210 | sentences_v2 = cut_sentences_v2(sent_v1)
211 | sentences.extend(sentences_v2)
212 | else:
213 | sentences.append(sent_v1)
214 |
215 | assert ''.join(sentences) == text
216 |
217 | # 合并
218 | merged_sentences = []
219 | start_index_ = 0
220 |
221 | while start_index_ < len(sentences):
222 | tmp_text = sentences[start_index_]
223 |
224 | end_index_ = start_index_ + 1
225 |
226 | while end_index_ < len(sentences) and \
227 | len(tmp_text) + len(sentences[end_index_]) <= max_seq_len - 2:
228 | tmp_text += sentences[end_index_]
229 | end_index_ += 1
230 |
231 | start_index_ = end_index_
232 |
233 | merged_sentences.append(tmp_text)
234 |
235 | return merged_sentences
236 |
237 | def sent_mask(sent, stop_mask_range_list, mask_prob=0.15):
238 | """
239 | 将句子中的词以 mask prob 的概率随机 mask,
240 | 其中 85% 概率被置为 [mask] 15% 的概率不变。
241 | :param sent: list of segment words
242 | :param stop_mask_range_list: 不能 mask 的区域
243 | :param mask_prob: max mask nums: len(sent) * max_mask_prob
244 | :return:
245 | """
246 | max_mask_token_nums = int(len(sent) * mask_prob)
247 | mask_nums = 0
248 | mask_sent = []
249 |
250 | for i in range(len(sent)):
251 |
252 | flag = False
253 | for _stop_range in stop_mask_range_list:
254 | if _stop_range[0] <= i <= _stop_range[1]:
255 | flag = True
256 | break
257 |
258 | if flag:
259 | mask_sent.append(sent[i])
260 | continue
261 |
262 | if mask_nums < max_mask_token_nums:
263 | # mask_prob 的概率进行 mask, 80% 概率被置为 [mask],10% 概率被替换, 10% 的概率不变
264 | if random.random() < mask_prob:
265 | mask_sent.append('[MASK]')
266 | mask_nums += 1
267 | else:
268 | mask_sent.append(sent[i])
269 | else:
270 | mask_sent.append(sent[i])
271 |
272 | return mask_sent
273 |
274 |
275 | def convert_crf_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
276 | max_seq_len, ent2id):
277 | set_type = example.set_type
278 | raw_text = example.text
279 | entities = example.labels
280 | pseudo = example.pseudo
281 |
282 | callback_info = (raw_text,)
283 | callback_labels = {x: [] for x in ENTITY_TYPES}
284 |
285 | for _label in entities:
286 | callback_labels[_label[0]].append((_label[1], _label[2]))
287 |
288 | callback_info += (callback_labels,)
289 |
290 | tokens = fine_grade_tokenize(raw_text, tokenizer)
291 | assert len(tokens) == len(raw_text)
292 |
293 | label_ids = None
294 |
295 | if set_type == 'train':
296 | # information for dev callback
297 | label_ids = [0] * len(tokens)
298 |
299 | # tag labels ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸)
300 | for ent in entities:
301 | ent_type = ent[0]
302 |
303 | ent_start = ent[-1]
304 | ent_end = ent_start + len(ent[1]) - 1
305 |
306 | if ent_start == ent_end:
307 | label_ids[ent_start] = ent2id['S-' + ent_type]
308 | else:
309 | label_ids[ent_start] = ent2id['B-' + ent_type]
310 | label_ids[ent_end] = ent2id['E-' + ent_type]
311 | for i in range(ent_start + 1, ent_end):
312 | label_ids[i] = ent2id['I-' + ent_type]
313 |
314 | if len(label_ids) > max_seq_len - 2:
315 | label_ids = label_ids[:max_seq_len - 2]
316 |
317 | label_ids = [0] + label_ids + [0]
318 |
319 | # pad
320 | if len(label_ids) < max_seq_len:
321 | pad_length = max_seq_len - len(label_ids)
322 | label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O
323 |
324 | assert len(label_ids) == max_seq_len, f'{len(label_ids)}'
325 |
326 | encode_dict = tokenizer.encode_plus(text=tokens,
327 | max_length=max_seq_len,
328 | pad_to_max_length=True,
329 | is_pretokenized=True,
330 | return_token_type_ids=True,
331 | return_attention_mask=True)
332 |
333 | token_ids = encode_dict['input_ids']
334 | attention_masks = encode_dict['attention_mask']
335 | token_type_ids = encode_dict['token_type_ids']
336 |
337 | # if ex_idx < 3:
338 | # logger.info(f"*** {set_type}_example-{ex_idx} ***")
339 | # logger.info(f'text: {" ".join(tokens)}')
340 | # logger.info(f"token_ids: {token_ids}")
341 | # logger.info(f"attention_masks: {attention_masks}")
342 | # logger.info(f"token_type_ids: {token_type_ids}")
343 | # logger.info(f"labels: {label_ids}")
344 |
345 | feature = CRFFeature(
346 | # bert inputs
347 | token_ids=token_ids,
348 | attention_masks=attention_masks,
349 | token_type_ids=token_type_ids,
350 | labels=label_ids,
351 | pseudo=pseudo
352 | )
353 |
354 | return feature, callback_info
355 |
356 |
357 | def convert_span_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
358 | max_seq_len, ent2id):
359 | set_type = example.set_type
360 | raw_text = example.text
361 | entities = example.labels
362 | pseudo = example.pseudo
363 |
364 | tokens = fine_grade_tokenize(raw_text, tokenizer)
365 | assert len(tokens) == len(raw_text)
366 |
367 | callback_labels = {x: [] for x in ENTITY_TYPES}
368 |
369 | for _label in entities:
370 | callback_labels[_label[0]].append((_label[1], _label[2]))
371 |
372 | callback_info = (raw_text, callback_labels,)
373 |
374 | start_ids, end_ids = None, None
375 |
376 | if set_type == 'train':
377 | start_ids = [0] * len(tokens)
378 | end_ids = [0] * len(tokens)
379 |
380 | for _ent in entities:
381 |
382 | ent_type = ent2id[_ent[0]]
383 | ent_start = _ent[-1]
384 | ent_end = ent_start + len(_ent[1]) - 1
385 |
386 | start_ids[ent_start] = ent_type
387 | end_ids[ent_end] = ent_type
388 |
389 | if len(start_ids) > max_seq_len - 2:
390 | start_ids = start_ids[:max_seq_len - 2]
391 | end_ids = end_ids[:max_seq_len - 2]
392 |
393 | start_ids = [0] + start_ids + [0]
394 | end_ids = [0] + end_ids + [0]
395 |
396 | # pad
397 | if len(start_ids) < max_seq_len:
398 | pad_length = max_seq_len - len(start_ids)
399 |
400 | start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O
401 | end_ids = end_ids + [0] * pad_length
402 |
403 | assert len(start_ids) == max_seq_len
404 | assert len(end_ids) == max_seq_len
405 |
406 | encode_dict = tokenizer.encode_plus(text=tokens,
407 | max_length=max_seq_len,
408 | pad_to_max_length=True,
409 | is_pretokenized=True,
410 | return_token_type_ids=True,
411 | return_attention_mask=True)
412 |
413 | token_ids = encode_dict['input_ids']
414 | attention_masks = encode_dict['attention_mask']
415 | token_type_ids = encode_dict['token_type_ids']
416 |
417 | # if ex_idx < 3:
418 | # logger.info(f"*** {set_type}_example-{ex_idx} ***")
419 | # logger.info(f'text: {" ".join(tokens)}')
420 | # logger.info(f"token_ids: {token_ids}")
421 | # logger.info(f"attention_masks: {attention_masks}")
422 | # logger.info(f"token_type_ids: {token_type_ids}")
423 | # if start_ids and end_ids:
424 | # logger.info(f"start_ids: {start_ids}")
425 | # logger.info(f"end_ids: {end_ids}")
426 |
427 | feature = SpanFeature(token_ids=token_ids,
428 | attention_masks=attention_masks,
429 | token_type_ids=token_type_ids,
430 | start_ids=start_ids,
431 | end_ids=end_ids,
432 | pseudo=pseudo)
433 |
434 | return feature, callback_info
435 |
436 | def convert_mrc_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
437 | max_seq_len, ent2id, ent2query, mask_prob=None):
438 | set_type = example.set_type
439 | text_b = example.text
440 | entities = example.labels
441 | pseudo = example.pseudo
442 |
443 | features = []
444 | callback_info = []
445 |
446 | tokens_b = fine_grade_tokenize(text_b, tokenizer)
447 | assert len(tokens_b) == len(text_b)
448 |
449 | label_dict = defaultdict(list)
450 |
451 | for ent in entities:
452 | ent_type = ent[0]
453 | ent_start = ent[-1]
454 | ent_end = ent_start + len(ent[1]) - 1
455 | label_dict[ent_type].append((ent_start, ent_end, ent[1]))
456 |
457 | # 训练数据中构造
458 | if set_type == 'train':
459 |
460 | # 每一类为一个 example
461 | # for _type in label_dict.keys():
462 | for _type in ENTITY_TYPES:
463 | start_ids = [0] * len(tokens_b)
464 | end_ids = [0] * len(tokens_b)
465 |
466 | stop_mask_ranges = []
467 |
468 | text_a = ent2query[_type]
469 | tokens_a = fine_grade_tokenize(text_a, tokenizer)
470 |
471 | for _label in label_dict[_type]:
472 | start_ids[_label[0]] = 1
473 | end_ids[_label[1]] = 1
474 |
475 | stop_mask_ranges.append((_label[0], _label[1]))
476 |
477 | if len(start_ids) > max_seq_len - len(tokens_a) - 3:
478 | start_ids = start_ids[:max_seq_len - len(tokens_a) - 3]
479 | end_ids = end_ids[:max_seq_len - len(tokens_a) - 3]
480 | print('产生了不该有的截断')
481 |
482 | start_ids = [0] + [0] * len(tokens_a) + [0] + start_ids + [0]
483 | end_ids = [0] + [0] * len(tokens_a) + [0] + end_ids + [0]
484 |
485 | # pad
486 | if len(start_ids) < max_seq_len:
487 | pad_length = max_seq_len - len(start_ids)
488 |
489 | start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O
490 | end_ids = end_ids + [0] * pad_length
491 |
492 | assert len(start_ids) == max_seq_len
493 | assert len(end_ids) == max_seq_len
494 |
495 | # 随机mask
496 | if mask_prob:
497 | tokens_b = sent_mask(tokens_b, stop_mask_ranges, mask_prob=mask_prob)
498 |
499 | encode_dict = tokenizer.encode_plus(text=tokens_a,
500 | text_pair=tokens_b,
501 | max_length=max_seq_len,
502 | pad_to_max_length=True,
503 | truncation_strategy='only_second',
504 | is_pretokenized=True,
505 | return_token_type_ids=True,
506 | return_attention_mask=True)
507 |
508 | token_ids = encode_dict['input_ids']
509 | attention_masks = encode_dict['attention_mask']
510 | token_type_ids = encode_dict['token_type_ids']
511 |
512 | # if ex_idx < 3:
513 | # logger.info(f"*** {set_type}_example-{ex_idx} ***")
514 | # logger.info(f'text: {" ".join(tokens_b)}')
515 | # logger.info(f"token_ids: {token_ids}")
516 | # logger.info(f"attention_masks: {attention_masks}")
517 | # logger.info(f"token_type_ids: {token_type_ids}")
518 | # logger.info(f'entity type: {_type}')
519 | # logger.info(f"start_ids: {start_ids}")
520 | # logger.info(f"end_ids: {end_ids}")
521 |
522 | feature = MRCFeature(token_ids=token_ids,
523 | attention_masks=attention_masks,
524 | token_type_ids=token_type_ids,
525 | ent_type=ent2id[_type],
526 | start_ids=start_ids,
527 | end_ids=end_ids,
528 | pseudo=pseudo
529 | )
530 |
531 | features.append(feature)
532 |
533 | # 测试数据构造,为每一类单独构造一个 example
534 | else:
535 | for _type in ENTITY_TYPES:
536 | text_a = ent2query[_type]
537 | tokens_a = fine_grade_tokenize(text_a, tokenizer)
538 |
539 | encode_dict = tokenizer.encode_plus(text=tokens_a,
540 | text_pair=tokens_b,
541 | max_length=max_seq_len,
542 | pad_to_max_length=True,
543 | truncation_strategy='only_second',
544 | is_pretokenized=True,
545 | return_token_type_ids=True,
546 | return_attention_mask=True)
547 |
548 | token_ids = encode_dict['input_ids']
549 | attention_masks = encode_dict['attention_mask']
550 | token_type_ids = encode_dict['token_type_ids']
551 |
552 | tmp_callback = (text_b, len(tokens_a) + 2, _type) # (text, text_offset, type, labels)
553 | tmp_callback_labels = []
554 |
555 | for _label in label_dict[_type]:
556 | tmp_callback_labels.append((_label[2], _label[0]))
557 |
558 | tmp_callback += (tmp_callback_labels, )
559 |
560 | callback_info.append(tmp_callback)
561 |
562 | feature = MRCFeature(token_ids=token_ids,
563 | attention_masks=attention_masks,
564 | token_type_ids=token_type_ids,
565 | ent_type=ent2id[_type])
566 |
567 | features.append(feature)
568 |
569 | return features, callback_info
570 |
571 |
572 |
573 | def convert_examples_to_features(task_type, examples, max_seq_len, bert_dir, ent2id):
574 | assert task_type in ['crf', 'span', 'mrc']
575 |
576 | tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
577 |
578 | features = []
579 |
580 | callback_info = []
581 |
582 | logger.info(f'Convert {len(examples)} examples to features')
583 | type2id = {x: i for i, x in enumerate(ENTITY_TYPES)}
584 |
585 | for i, example in enumerate(examples):
586 | if task_type == 'crf':
587 | feature, tmp_callback = convert_crf_example(
588 | ex_idx=i,
589 | example=example,
590 | max_seq_len=max_seq_len,
591 | ent2id=ent2id,
592 | tokenizer=tokenizer
593 | )
594 | elif task_type == 'mrc':
595 | feature, tmp_callback = convert_mrc_example(
596 | ex_idx=i,
597 | example=example,
598 | max_seq_len=max_seq_len,
599 | ent2id=type2id,
600 | ent2query=ent2id,
601 | tokenizer=tokenizer
602 | )
603 | else:
604 | feature, tmp_callback = convert_span_example(
605 | ex_idx=i,
606 | example=example,
607 | max_seq_len=max_seq_len,
608 | ent2id=ent2id,
609 | tokenizer=tokenizer
610 | )
611 |
612 | if feature is None:
613 | continue
614 |
615 | if task_type == 'mrc':
616 | features.extend(feature)
617 | callback_info.extend(tmp_callback)
618 | else:
619 | features.append(feature)
620 | callback_info.append(tmp_callback)
621 |
622 | logger.info(f'Build {len(features)} features')
623 |
624 | out = (features, )
625 |
626 | if not len(callback_info):
627 | return out
628 |
629 | type_weight = {} # 统计每一类的比例,用于计算 micro-f1
630 | for _type in ENTITY_TYPES:
631 | type_weight[_type] = 0.
632 |
633 | count = 0.
634 |
635 | if task_type == 'mrc':
636 | for _callback in callback_info:
637 | type_weight[_callback[-2]] += len(_callback[-1])
638 | count += len(_callback[-1])
639 | else:
640 | for _callback in callback_info:
641 | for _type in _callback[1]:
642 | type_weight[_type] += len(_callback[1][_type])
643 | count += len(_callback[1][_type])
644 |
645 | for key in type_weight:
646 | type_weight[key] /= count
647 |
648 | out += ((callback_info, type_weight), )
649 |
650 | return out
651 |
652 |
653 | if __name__ == '__main__':
654 | pass
655 |
--------------------------------------------------------------------------------
/src/utils/__pycache__/attack_train_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/attack_train_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/dataset_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/dataset_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/evaluator.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/evaluator.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/functions_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/functions_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/model_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/model_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/options.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chizhu/DeepNER/8c4abc21676af50ede29dce90bfac4892b36a1c5/src/utils/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/src/utils/attack_train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # FGM
6 | class FGM:
7 | def __init__(self, model: nn.Module, eps=1.):
8 | self.model = (
9 | model.module if hasattr(model, "module") else model
10 | )
11 | self.eps = eps
12 | self.backup = {}
13 |
14 | # only attack word embedding
15 | def attack(self, emb_name='word_embeddings'):
16 | for name, param in self.model.named_parameters():
17 | if param.requires_grad and emb_name in name:
18 | self.backup[name] = param.data.clone()
19 | norm = torch.norm(param.grad)
20 | if norm and not torch.isnan(norm):
21 | r_at = self.eps * param.grad / norm
22 | param.data.add_(r_at)
23 |
24 | def restore(self, emb_name='word_embeddings'):
25 | for name, para in self.model.named_parameters():
26 | if para.requires_grad and emb_name in name:
27 | assert name in self.backup
28 | para.data = self.backup[name]
29 |
30 | self.backup = {}
31 |
32 |
33 | # PGD
34 | class PGD:
35 | def __init__(self, model, eps=1., alpha=0.3):
36 | self.model = (
37 | model.module if hasattr(model, "module") else model
38 | )
39 | self.eps = eps
40 | self.alpha = alpha
41 | self.emb_backup = {}
42 | self.grad_backup = {}
43 |
44 | def attack(self, emb_name='word_embeddings', is_first_attack=False):
45 | for name, param in self.model.named_parameters():
46 | if param.requires_grad and emb_name in name:
47 | if is_first_attack:
48 | self.emb_backup[name] = param.data.clone()
49 | norm = torch.norm(param.grad)
50 | if norm != 0 and not torch.isnan(norm):
51 | r_at = self.alpha * param.grad / norm
52 | param.data.add_(r_at)
53 | param.data = self.project(name, param.data)
54 |
55 | def restore(self, emb_name='word_embeddings'):
56 | for name, param in self.model.named_parameters():
57 | if param.requires_grad and emb_name in name:
58 | assert name in self.emb_backup
59 | param.data = self.emb_backup[name]
60 | self.emb_backup = {}
61 |
62 | def project(self, param_name, param_data):
63 | r = param_data - self.emb_backup[param_name]
64 | if torch.norm(r) > self.eps:
65 | r = self.eps * r / torch.norm(r)
66 | return self.emb_backup[param_name] + r
67 |
68 | def backup_grad(self):
69 | for name, param in self.model.named_parameters():
70 | if param.requires_grad and param.grad is not None:
71 | self.grad_backup[name] = param.grad.clone()
72 |
73 | def restore_grad(self):
74 | for name, param in self.model.named_parameters():
75 | if param.requires_grad and param.grad is not None:
76 | param.grad = self.grad_backup[name]
77 |
--------------------------------------------------------------------------------
/src/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 |
5 | class NERDataset(Dataset):
6 | def __init__(self, task_type, features, mode, **kwargs):
7 |
8 | self.nums = len(features)
9 |
10 | self.token_ids = [torch.tensor(example.token_ids).long() for example in features]
11 | self.attention_masks = [torch.tensor(example.attention_masks).float() for example in features]
12 | self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in features]
13 |
14 | self.labels = None
15 | self.start_ids, self.end_ids = None, None
16 | self.ent_type = None
17 | self.pseudo = None
18 | if mode == 'train':
19 | self.pseudo = [torch.tensor(example.pseudo).long() for example in features]
20 | if task_type == 'crf':
21 | self.labels = [torch.tensor(example.labels) for example in features]
22 | else:
23 | self.start_ids = [torch.tensor(example.start_ids).long() for example in features]
24 | self.end_ids = [torch.tensor(example.end_ids).long() for example in features]
25 |
26 | if kwargs.pop('use_type_embed', False):
27 | self.ent_type = [torch.tensor(example.ent_type) for example in features]
28 |
29 | def __len__(self):
30 | return self.nums
31 |
32 | def __getitem__(self, index):
33 | data = {'token_ids': self.token_ids[index],
34 | 'attention_masks': self.attention_masks[index],
35 | 'token_type_ids': self.token_type_ids[index]}
36 |
37 | if self.ent_type is not None:
38 | data['ent_type'] = self.ent_type[index]
39 |
40 | if self.labels is not None:
41 | data['labels'] = self.labels[index]
42 |
43 | if self.pseudo is not None:
44 | data['pseudo'] = self.pseudo[index]
45 |
46 | if self.start_ids is not None:
47 | data['start_ids'] = self.start_ids[index]
48 | data['end_ids'] = self.end_ids[index]
49 |
50 | return data
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/src/utils/evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import numpy as np
4 | from collections import defaultdict
5 | from src.preprocess.processor import ENTITY_TYPES
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def get_base_out(model, loader, device):
11 | """
12 | 每一个任务的 forward 都一样,封装起来
13 | """
14 | model.eval()
15 |
16 | with torch.no_grad():
17 | for idx, _batch in enumerate(loader):
18 |
19 | for key in _batch.keys():
20 | _batch[key] = _batch[key].to(device)
21 |
22 | tmp_out = model(**_batch)
23 |
24 | yield tmp_out
25 |
26 |
27 | def crf_decode(decode_tokens, raw_text, id2ent):
28 | """
29 | CRF 解码,用于解码 time loc 的提取
30 | """
31 | predict_entities = {}
32 |
33 | decode_tokens = decode_tokens[1:-1] # 除去 CLS SEP token
34 |
35 | index_ = 0
36 |
37 | while index_ < len(decode_tokens):
38 |
39 | token_label = id2ent[decode_tokens[index_]].split('-')
40 |
41 | if token_label[0].startswith('S'):
42 | token_type = token_label[1]
43 | tmp_ent = raw_text[index_]
44 |
45 | if token_type not in predict_entities:
46 | predict_entities[token_type] = [(tmp_ent, index_)]
47 | else:
48 | predict_entities[token_type].append((tmp_ent, int(index_)))
49 |
50 | index_ += 1
51 |
52 | elif token_label[0].startswith('B'):
53 | token_type = token_label[1]
54 | start_index = index_
55 |
56 | index_ += 1
57 | while index_ < len(decode_tokens):
58 | temp_token_label = id2ent[decode_tokens[index_]].split('-')
59 |
60 | if temp_token_label[0].startswith('I') and token_type == temp_token_label[1]:
61 | index_ += 1
62 | elif temp_token_label[0].startswith('E') and token_type == temp_token_label[1]:
63 | end_index = index_
64 | index_ += 1
65 |
66 | tmp_ent = raw_text[start_index: end_index + 1]
67 |
68 | if token_type not in predict_entities:
69 | predict_entities[token_type] = [(tmp_ent, start_index)]
70 | else:
71 | predict_entities[token_type].append((tmp_ent, int(start_index)))
72 |
73 | break
74 | else:
75 | break
76 | else:
77 | index_ += 1
78 |
79 | return predict_entities
80 |
81 |
82 | # 严格解码 baseline
83 | def span_decode(start_logits, end_logits, raw_text, id2ent):
84 | predict_entities = defaultdict(list)
85 |
86 | start_pred = np.argmax(start_logits, -1)
87 | end_pred = np.argmax(end_logits, -1)
88 |
89 | for i, s_type in enumerate(start_pred):
90 | if s_type == 0:
91 | continue
92 | for j, e_type in enumerate(end_pred[i:]):
93 | if s_type == e_type:
94 | tmp_ent = raw_text[i:i + j + 1]
95 | predict_entities[id2ent[s_type]].append((tmp_ent, i))
96 | break
97 |
98 | return predict_entities
99 |
100 | # 严格解码 baseline
101 | def mrc_decode(start_logits, end_logits, raw_text):
102 | predict_entities = []
103 | start_pred = np.argmax(start_logits, -1)
104 | end_pred = np.argmax(end_logits, -1)
105 |
106 | for i, s_type in enumerate(start_pred):
107 | if s_type == 0:
108 | continue
109 | for j, e_type in enumerate(end_pred[i:]):
110 | if s_type == e_type:
111 | tmp_ent = raw_text[i:i+j+1]
112 | predict_entities.append((tmp_ent, i))
113 | break
114 |
115 | return predict_entities
116 |
117 |
118 | def calculate_metric(gt, predict):
119 | """
120 | 计算 tp fp fn
121 | """
122 | tp, fp, fn = 0, 0, 0
123 | for entity_predict in predict:
124 | flag = 0
125 | for entity_gt in gt:
126 | if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]:
127 | flag = 1
128 | tp += 1
129 | break
130 | if flag == 0:
131 | fp += 1
132 |
133 | fn = len(gt) - tp
134 |
135 | return np.array([tp, fp, fn])
136 |
137 |
138 | def get_p_r_f(tp, fp, fn):
139 | p = tp / (tp + fp) if tp + fp != 0 else 0
140 | r = tp / (tp + fn) if tp + fn != 0 else 0
141 | f1 = 2 * p * r / (p + r) if p + r != 0 else 0
142 | return np.array([p, r, f1])
143 |
144 |
145 | def crf_evaluation(model, dev_info, device, ent2id):
146 | dev_loader, (dev_callback_info, type_weight) = dev_info
147 |
148 | pred_tokens = []
149 |
150 | for tmp_pred in get_base_out(model, dev_loader, device):
151 | pred_tokens.extend(tmp_pred[0])
152 |
153 | assert len(pred_tokens) == len(dev_callback_info)
154 |
155 | id2ent = {ent2id[key]: key for key in ent2id.keys()}
156 |
157 | role_metric = np.zeros([13, 3])
158 |
159 | mirco_metrics = np.zeros(3)
160 |
161 | for tmp_tokens, tmp_callback in zip(pred_tokens, dev_callback_info):
162 |
163 | text, gt_entities = tmp_callback
164 |
165 | tmp_metric = np.zeros([13, 3])
166 |
167 | pred_entities = crf_decode(tmp_tokens, text, id2ent)
168 |
169 | for idx, _type in enumerate(ENTITY_TYPES):
170 | if _type not in pred_entities:
171 | pred_entities[_type] = []
172 |
173 | tmp_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type])
174 |
175 | role_metric += tmp_metric
176 |
177 | for idx, _type in enumerate(ENTITY_TYPES):
178 | temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
179 |
180 | mirco_metrics += temp_metric * type_weight[_type]
181 |
182 | metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
183 | f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
184 |
185 | return metric_str, mirco_metrics[2]
186 |
187 |
188 | def span_evaluation(model, dev_info, device, ent2id):
189 | dev_loader, (dev_callback_info, type_weight) = dev_info
190 |
191 | start_logits, end_logits = None, None
192 |
193 | model.eval()
194 |
195 | for tmp_pred in get_base_out(model, dev_loader, device):
196 | tmp_start_logits = tmp_pred[0].cpu().numpy()
197 | tmp_end_logits = tmp_pred[1].cpu().numpy()
198 |
199 | if start_logits is None:
200 | start_logits = tmp_start_logits
201 | end_logits = tmp_end_logits
202 | else:
203 | start_logits = np.append(start_logits, tmp_start_logits, axis=0)
204 | end_logits = np.append(end_logits, tmp_end_logits, axis=0)
205 |
206 | assert len(start_logits) == len(end_logits) == len(dev_callback_info)
207 |
208 | role_metric = np.zeros([13, 3])
209 |
210 | mirco_metrics = np.zeros(3)
211 |
212 | id2ent = {ent2id[key]: key for key in ent2id.keys()}
213 |
214 | for tmp_start_logits, tmp_end_logits, tmp_callback \
215 | in zip(start_logits, end_logits, dev_callback_info):
216 |
217 | text, gt_entities = tmp_callback
218 |
219 | tmp_start_logits = tmp_start_logits[1:1 + len(text)]
220 | tmp_end_logits = tmp_end_logits[1:1 + len(text)]
221 |
222 | pred_entities = span_decode(tmp_start_logits, tmp_end_logits, text, id2ent)
223 |
224 | for idx, _type in enumerate(ENTITY_TYPES):
225 | if _type not in pred_entities:
226 | pred_entities[_type] = []
227 |
228 | role_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type])
229 |
230 | for idx, _type in enumerate(ENTITY_TYPES):
231 | temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
232 |
233 | mirco_metrics += temp_metric * type_weight[_type]
234 |
235 | metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
236 | f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
237 |
238 | return metric_str, mirco_metrics[2]
239 |
240 | def mrc_evaluation(model, dev_info, device):
241 | dev_loader, (dev_callback_info, type_weight) = dev_info
242 |
243 | start_logits, end_logits = None, None
244 |
245 | model.eval()
246 |
247 | for tmp_pred in get_base_out(model, dev_loader, device):
248 | tmp_start_logits = tmp_pred[0].cpu().numpy()
249 | tmp_end_logits = tmp_pred[1].cpu().numpy()
250 |
251 | if start_logits is None:
252 | start_logits = tmp_start_logits
253 | end_logits = tmp_end_logits
254 | else:
255 | start_logits = np.append(start_logits, tmp_start_logits, axis=0)
256 | end_logits = np.append(end_logits, tmp_end_logits, axis=0)
257 |
258 | assert len(start_logits) == len(end_logits) == len(dev_callback_info)
259 |
260 | role_metric = np.zeros([13, 3])
261 |
262 | mirco_metrics = np.zeros(3)
263 |
264 | id2ent = {x: i for i, x in enumerate(ENTITY_TYPES)}
265 |
266 | for tmp_start_logits, tmp_end_logits, tmp_callback \
267 | in zip(start_logits, end_logits, dev_callback_info):
268 |
269 | text, text_offset, ent_type, gt_entities = tmp_callback
270 |
271 | tmp_start_logits = tmp_start_logits[text_offset:text_offset+len(text)]
272 | tmp_end_logits = tmp_end_logits[text_offset:text_offset+len(text)]
273 |
274 | pred_entities = mrc_decode(tmp_start_logits, tmp_end_logits, text)
275 |
276 | role_metric[id2ent[ent_type]] += calculate_metric(gt_entities, pred_entities)
277 |
278 | for idx, _type in enumerate(ENTITY_TYPES):
279 | temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
280 |
281 | mirco_metrics += temp_metric * type_weight[_type]
282 |
283 | metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
284 | f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
285 |
286 | return metric_str, mirco_metrics[2]
287 |
--------------------------------------------------------------------------------
/src/utils/functions_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 | import random
5 | import numpy as np
6 | from collections import defaultdict
7 | from datetime import timedelta
8 | import time
9 | import logging
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def get_time_dif(start_time):
15 | """
16 | 获取已经使用的时间
17 | :param start_time:
18 | :return:
19 | """
20 | end_time = time.time()
21 | time_dif = end_time - start_time
22 | return timedelta(seconds=int(round(time_dif)))
23 |
24 |
25 | def set_seed(seed):
26 | """
27 | 设置随机种子
28 | :param seed:
29 | :return:
30 | """
31 | random.seed(seed)
32 | torch.manual_seed(seed)
33 | np.random.seed(seed)
34 | torch.cuda.manual_seed_all(seed)
35 |
36 |
37 | def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True):
38 | """
39 | 加载模型 & 放置到 GPU 中(单卡 / 多卡)
40 | """
41 | gpu_ids = gpu_ids.split(',')
42 |
43 | # set to device to the first cuda
44 | device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
45 |
46 | if ckpt_path is not None:
47 | logger.info(f'Load ckpt from {ckpt_path}')
48 | model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict)
49 |
50 | model.to(device)
51 |
52 | if len(gpu_ids) > 1:
53 | logger.info(f'Use multi gpus in: {gpu_ids}')
54 | gpu_ids = [int(x) for x in gpu_ids]
55 | model = torch.nn.DataParallel(model, device_ids=gpu_ids)
56 | else:
57 | logger.info(f'Use single gpu in: {gpu_ids}')
58 |
59 | return model, device
60 |
61 |
62 | def get_model_path_list(base_dir):
63 | """
64 | 从文件夹中获取 model.pt 的路径
65 | """
66 | model_lists = []
67 |
68 | for root, dirs, files in os.walk(base_dir):
69 | for _file in files:
70 | if 'model.pt' == _file:
71 | model_lists.append(os.path.join(root, _file))
72 |
73 | model_lists = sorted(model_lists,
74 | key=lambda x: (x.split('/')[-3], int(x.split('/')[-2].split('-')[-1])))
75 |
76 | return model_lists
77 |
78 |
79 | def swa(model, model_dir, swa_start=1):
80 | """
81 | swa 滑动平均模型,一般在训练平稳阶段再使用 SWA
82 | """
83 | model_path_list = get_model_path_list(model_dir)
84 |
85 | assert 1 <= swa_start < len(model_path_list) - 1, \
86 | f'Using swa, swa start should smaller than {len(model_path_list) - 1} and bigger than 0'
87 |
88 | swa_model = copy.deepcopy(model)
89 | swa_n = 0.
90 |
91 | with torch.no_grad():
92 | for _ckpt in model_path_list[swa_start:]:
93 | logger.info(f'Load model from {_ckpt}')
94 | model.load_state_dict(torch.load(_ckpt, map_location=torch.device('cpu')))
95 | tmp_para_dict = dict(model.named_parameters())
96 |
97 | alpha = 1. / (swa_n + 1.)
98 |
99 | for name, para in swa_model.named_parameters():
100 | para.copy_(tmp_para_dict[name].data.clone() * alpha + para.data.clone() * (1. - alpha))
101 |
102 | swa_n += 1
103 |
104 | # use 100000 to represent swa to avoid clash
105 | swa_model_dir = os.path.join(model_dir, f'checkpoint-100000')
106 | if not os.path.exists(swa_model_dir):
107 | os.mkdir(swa_model_dir)
108 |
109 | logger.info(f'Save swa model in: {swa_model_dir}')
110 |
111 | swa_model_path = os.path.join(swa_model_dir, 'model.pt')
112 |
113 | torch.save(swa_model.state_dict(), swa_model_path)
114 |
115 | return swa_model
116 |
117 |
118 | def vote(entities_list, threshold=0.9):
119 | """
120 | 实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text)
121 | :param entities_list: 所有模型预测出的一个文件的实体
122 | :param threshold:大于70%模型预测出来的实体才能被选中
123 | :return:
124 | """
125 | threshold_nums = int(len(entities_list)*threshold)
126 | entities_dict = defaultdict(int)
127 | entities = defaultdict(list)
128 |
129 | for _entities in entities_list:
130 | for _type in _entities:
131 | for _ent in _entities[_type]:
132 | entities_dict[(_type, _ent[0], _ent[1])] += 1
133 |
134 | for key in entities_dict:
135 | if entities_dict[key] >= threshold_nums:
136 | entities[key[0]].append((key[1], key[2]))
137 |
138 | return entities
139 |
140 | def ensemble_vote(entities_list, threshold=0.9):
141 | """
142 | 针对 ensemble model 进行的 vote
143 | 实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text)
144 | """
145 | threshold_nums = int(len(entities_list)*threshold)
146 | entities_dict = defaultdict(int)
147 |
148 | entities = defaultdict(list)
149 |
150 | for _entities in entities_list:
151 | for _id in _entities:
152 | for _ent in _entities[_id]:
153 | entities_dict[(_id, ) + _ent] += 1
154 |
155 | for key in entities_dict:
156 | if entities_dict[key] >= threshold_nums:
157 | entities[key[0]].append(key[1:])
158 |
159 | return entities
160 |
--------------------------------------------------------------------------------
/src/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | from torchcrf import CRF
6 | from itertools import repeat
7 | from transformers import BertModel
8 | from src.utils.functions_utils import vote
9 | from src.utils.evaluator import crf_decode, span_decode
10 |
11 |
12 | class LabelSmoothingCrossEntropy(nn.Module):
13 | def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
14 | super(LabelSmoothingCrossEntropy, self).__init__()
15 | self.eps = eps
16 | self.reduction = reduction
17 | self.ignore_index = ignore_index
18 |
19 | def forward(self, output, target):
20 | c = output.size()[-1]
21 | log_pred = torch.log_softmax(output, dim=-1)
22 | if self.reduction == 'sum':
23 | loss = -log_pred.sum()
24 | else:
25 | loss = -log_pred.sum(dim=-1)
26 | if self.reduction == 'mean':
27 | loss = loss.mean()
28 |
29 |
30 | return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target,
31 | reduction=self.reduction,
32 | ignore_index=self.ignore_index)
33 |
34 | class FocalLoss(nn.Module):
35 | """Multi-class Focal loss implementation"""
36 | def __init__(self, gamma=2, weight=None, reduction='mean', ignore_index=-100):
37 | super(FocalLoss, self).__init__()
38 | self.gamma = gamma
39 | self.weight = weight
40 | self.ignore_index = ignore_index
41 | self.reduction = reduction
42 |
43 | def forward(self, input, target):
44 | """
45 | input: [N, C]
46 | target: [N, ]
47 | """
48 | log_pt = torch.log_softmax(input, dim=1)
49 | pt = torch.exp(log_pt)
50 | log_pt = (1 - pt) ** self.gamma * log_pt
51 | loss = torch.nn.functional.nll_loss(log_pt, target, self.weight, reduction=self.reduction, ignore_index=self.ignore_index)
52 | return loss
53 |
54 | class SpatialDropout(nn.Module):
55 | """
56 | 对字级别的向量进行丢弃
57 | """
58 | def __init__(self, drop_prob):
59 | super(SpatialDropout, self).__init__()
60 | self.drop_prob = drop_prob
61 |
62 | @staticmethod
63 | def _make_noise(input):
64 | return input.new().resize_(input.size(0), *repeat(1, input.dim() - 2), input.size(2))
65 |
66 | def forward(self, inputs):
67 | output = inputs.clone()
68 | if not self.training or self.drop_prob == 0:
69 | return inputs
70 | else:
71 | noise = self._make_noise(inputs)
72 | if self.drop_prob == 1:
73 | noise.fill_(0)
74 | else:
75 | noise.bernoulli_(1 - self.drop_prob).div_(1 - self.drop_prob)
76 | noise = noise.expand_as(inputs)
77 | output.mul_(noise)
78 | return output
79 |
80 | class ConditionalLayerNorm(nn.Module):
81 | def __init__(self,
82 | normalized_shape,
83 | cond_shape,
84 | eps=1e-12):
85 | super().__init__()
86 |
87 | self.eps = eps
88 |
89 | self.weight = nn.Parameter(torch.Tensor(normalized_shape))
90 | self.bias = nn.Parameter(torch.Tensor(normalized_shape))
91 |
92 | self.weight_dense = nn.Linear(cond_shape, normalized_shape, bias=False)
93 | self.bias_dense = nn.Linear(cond_shape, normalized_shape, bias=False)
94 |
95 | self.reset_weight_and_bias()
96 |
97 | def reset_weight_and_bias(self):
98 | """
99 | 此处初始化的作用是在训练开始阶段不让 conditional layer norm 起作用
100 | """
101 | nn.init.ones_(self.weight)
102 | nn.init.zeros_(self.bias)
103 |
104 | nn.init.zeros_(self.weight_dense.weight)
105 | nn.init.zeros_(self.bias_dense.weight)
106 |
107 | def forward(self, inputs, cond=None):
108 | assert cond is not None, 'Conditional tensor need to input when use conditional layer norm'
109 | cond = torch.unsqueeze(cond, 1) # (b, 1, h*2)
110 |
111 | weight = self.weight_dense(cond) + self.weight # (b, 1, h)
112 | bias = self.bias_dense(cond) + self.bias # (b, 1, h)
113 |
114 | mean = torch.mean(inputs, dim=-1, keepdim=True) # (b, s, 1)
115 | outputs = inputs - mean # (b, s, h)
116 |
117 | variance = torch.mean(outputs ** 2, dim=-1, keepdim=True)
118 | std = torch.sqrt(variance + self.eps) # (b, s, 1)
119 |
120 | outputs = outputs / std # (b, s, h)
121 |
122 | outputs = outputs * weight + bias
123 |
124 | return outputs
125 |
126 | class BaseModel(nn.Module):
127 | def __init__(self,
128 | bert_dir,
129 | dropout_prob):
130 | super(BaseModel, self).__init__()
131 | config_path = os.path.join(bert_dir, 'config.json')
132 |
133 | assert os.path.exists(bert_dir) and os.path.exists(config_path), \
134 | 'pretrained bert file does not exist'
135 |
136 | self.bert_module = BertModel.from_pretrained(bert_dir,
137 | output_hidden_states=True,
138 | hidden_dropout_prob=dropout_prob)
139 |
140 | self.bert_config = self.bert_module.config
141 |
142 | @staticmethod
143 | def _init_weights(blocks, **kwargs):
144 | """
145 | 参数初始化,将 Linear / Embedding / LayerNorm 与 Bert 进行一样的初始化
146 | """
147 | for block in blocks:
148 | for module in block.modules():
149 | if isinstance(module, nn.Linear):
150 | if module.bias is not None:
151 | nn.init.zeros_(module.bias)
152 | elif isinstance(module, nn.Embedding):
153 | nn.init.normal_(module.weight, mean=0, std=kwargs.pop('initializer_range', 0.02))
154 | elif isinstance(module, nn.LayerNorm):
155 | nn.init.ones_(module.weight)
156 | nn.init.zeros_(module.bias)
157 |
158 |
159 | # baseline
160 | class CRFModel(BaseModel):
161 | def __init__(self,
162 | bert_dir,
163 | num_tags,
164 | dropout_prob=0.1,
165 | **kwargs):
166 | super(CRFModel, self).__init__(bert_dir=bert_dir, dropout_prob=dropout_prob)
167 |
168 | out_dims = self.bert_config.hidden_size
169 |
170 | mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
171 |
172 | self.mid_linear = nn.Sequential(
173 | nn.Linear(out_dims, mid_linear_dims),
174 | nn.ReLU(),
175 | nn.Dropout(dropout_prob)
176 | )
177 |
178 | out_dims = mid_linear_dims
179 |
180 | self.classifier = nn.Linear(out_dims, num_tags)
181 |
182 | self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
183 | self.loss_weight.data.fill_(-0.2)
184 |
185 | self.crf_module = CRF(num_tags=num_tags, batch_first=True)
186 |
187 | init_blocks = [self.mid_linear, self.classifier]
188 |
189 | self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range)
190 |
191 | def forward(self,
192 | token_ids,
193 | attention_masks,
194 | token_type_ids,
195 | labels=None,
196 | pseudo=None):
197 |
198 | bert_outputs = self.bert_module(
199 | input_ids=token_ids,
200 | attention_mask=attention_masks,
201 | token_type_ids=token_type_ids
202 | )
203 |
204 | # 常规
205 | seq_out = bert_outputs[0]
206 |
207 | seq_out = self.mid_linear(seq_out)
208 |
209 | emissions = self.classifier(seq_out)
210 |
211 | if labels is not None:
212 | if pseudo is not None:
213 | # (batch,)
214 | tokens_loss = -1. * self.crf_module(emissions=emissions,
215 | tags=labels.long(),
216 | mask=attention_masks.byte(),
217 | reduction='none')
218 |
219 | # nums of pseudo data
220 | pseudo_nums = pseudo.sum().item()
221 | total_nums = token_ids.shape[0]
222 |
223 | # learning parameter
224 | rate = torch.sigmoid(self.loss_weight)
225 | if pseudo_nums == 0:
226 | loss_0 = tokens_loss.mean()
227 | loss_1 = (rate*pseudo*tokens_loss).sum()
228 | else:
229 | if total_nums == pseudo_nums:
230 | loss_0 = 0
231 | else:
232 | loss_0 = ((1 - rate) * (1 - pseudo) * tokens_loss).sum() / (total_nums - pseudo_nums)
233 | loss_1 = (rate*pseudo*tokens_loss).sum() / pseudo_nums
234 |
235 | tokens_loss = loss_0 + loss_1
236 |
237 | else:
238 | tokens_loss = -1. * self.crf_module(emissions=emissions,
239 | tags=labels.long(),
240 | mask=attention_masks.byte(),
241 | reduction='mean')
242 |
243 | out = (tokens_loss,)
244 |
245 | else:
246 | tokens_out = self.crf_module.decode(emissions=emissions, mask=attention_masks.byte())
247 |
248 | out = (tokens_out, emissions)
249 |
250 | return out
251 |
252 |
253 | class SpanModel(BaseModel):
254 | def __init__(self,
255 | bert_dir,
256 | num_tags,
257 | dropout_prob=0.1,
258 | loss_type='ce',
259 | **kwargs):
260 | """
261 | tag the subject and object corresponding to the predicate
262 | :param loss_type: train loss type in ['ce', 'ls_ce', 'focal']
263 | """
264 | super(SpanModel, self).__init__(bert_dir, dropout_prob=dropout_prob)
265 |
266 | out_dims = self.bert_config.hidden_size
267 |
268 | mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
269 |
270 | self.num_tags = num_tags
271 |
272 | self.mid_linear = nn.Sequential(
273 | nn.Linear(out_dims, mid_linear_dims),
274 | nn.ReLU(),
275 | nn.Dropout(dropout_prob)
276 | )
277 |
278 | out_dims = mid_linear_dims
279 |
280 | self.start_fc = nn.Linear(out_dims, num_tags)
281 | self.end_fc = nn.Linear(out_dims, num_tags)
282 |
283 | reduction = 'none'
284 | if loss_type == 'ce':
285 | self.criterion = nn.CrossEntropyLoss(reduction=reduction)
286 | elif loss_type == 'ls_ce':
287 | self.criterion = LabelSmoothingCrossEntropy(reduction=reduction)
288 | else:
289 | self.criterion = FocalLoss(reduction=reduction)
290 |
291 | self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
292 | self.loss_weight.data.fill_(-0.2)
293 |
294 | init_blocks = [self.mid_linear, self.start_fc, self.end_fc]
295 |
296 | self._init_weights(init_blocks)
297 |
298 | def forward(self,
299 | token_ids,
300 | attention_masks,
301 | token_type_ids,
302 | start_ids=None,
303 | end_ids=None,
304 | pseudo=None):
305 |
306 | bert_outputs = self.bert_module(
307 | input_ids=token_ids,
308 | attention_mask=attention_masks,
309 | token_type_ids=token_type_ids
310 | )
311 |
312 | seq_out = bert_outputs[0]
313 |
314 | seq_out = self.mid_linear(seq_out)
315 |
316 | start_logits = self.start_fc(seq_out)
317 | end_logits = self.end_fc(seq_out)
318 |
319 | out = (start_logits, end_logits, )
320 |
321 | if start_ids is not None and end_ids is not None and self.training:
322 |
323 | start_logits = start_logits.view(-1, self.num_tags)
324 | end_logits = end_logits.view(-1, self.num_tags)
325 |
326 | # 去掉 padding 部分的标签,计算真实 loss
327 | active_loss = attention_masks.view(-1) == 1
328 | active_start_logits = start_logits[active_loss]
329 | active_end_logits = end_logits[active_loss]
330 |
331 | active_start_labels = start_ids.view(-1)[active_loss]
332 | active_end_labels = end_ids.view(-1)[active_loss]
333 |
334 |
335 | if pseudo is not None:
336 | # (batch,)
337 | start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1)
338 | end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1)
339 |
340 | # nums of pseudo data
341 | pseudo_nums = pseudo.sum().item()
342 | total_nums = token_ids.shape[0]
343 |
344 | # learning parameter
345 | rate = torch.sigmoid(self.loss_weight)
346 | if pseudo_nums == 0:
347 | start_loss = start_loss.mean()
348 | end_loss = end_loss.mean()
349 | else:
350 | if total_nums == pseudo_nums:
351 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums
352 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums
353 | else:
354 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \
355 | + ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums)
356 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \
357 | + ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums)
358 | else:
359 | start_loss = self.criterion(active_start_logits, active_start_labels)
360 | end_loss = self.criterion(active_end_logits, active_end_labels)
361 |
362 | loss = start_loss + end_loss
363 |
364 | out = (loss, ) + out
365 |
366 | return out
367 |
368 | class MRCModel(BaseModel):
369 | def __init__(self,
370 | bert_dir,
371 | dropout_prob=0.1,
372 | use_type_embed=False,
373 | loss_type='ce',
374 | **kwargs):
375 | """
376 | tag the subject and object corresponding to the predicate
377 | :param use_type_embed: type embedding for the sentence
378 | :param loss_type: train loss type in ['ce', 'ls_ce', 'focal']
379 | """
380 | super(MRCModel, self).__init__(bert_dir, dropout_prob=dropout_prob)
381 |
382 | self.use_type_embed = use_type_embed
383 | self.use_smooth = loss_type
384 |
385 | out_dims = self.bert_config.hidden_size
386 |
387 | if self.use_type_embed:
388 | embed_dims = kwargs.pop('predicate_embed_dims', self.bert_config.hidden_size)
389 | self.type_embedding = nn.Embedding(13, embed_dims)
390 |
391 | self.conditional_layer_norm = ConditionalLayerNorm(out_dims, embed_dims,
392 | eps=self.bert_config.layer_norm_eps)
393 |
394 | mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
395 |
396 | self.mid_linear = nn.Sequential(
397 | nn.Linear(out_dims, mid_linear_dims),
398 | nn.ReLU(),
399 | nn.Dropout(dropout_prob)
400 | )
401 |
402 | out_dims = mid_linear_dims
403 |
404 | self.start_fc = nn.Linear(out_dims, 2)
405 | self.end_fc = nn.Linear(out_dims, 2)
406 |
407 | reduction = 'none'
408 | if loss_type == 'ce':
409 | self.criterion = nn.CrossEntropyLoss(reduction=reduction)
410 | elif loss_type == 'ls_ce':
411 | self.criterion = LabelSmoothingCrossEntropy(reduction=reduction)
412 | else:
413 | self.criterion = FocalLoss(reduction=reduction)
414 |
415 | self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
416 | self.loss_weight.data.fill_(-0.2)
417 |
418 | init_blocks = [self.mid_linear, self.start_fc, self.end_fc]
419 |
420 | if self.use_type_embed:
421 | init_blocks.append(self.type_embedding)
422 |
423 | self._init_weights(init_blocks)
424 |
425 | def forward(self,
426 | token_ids,
427 | attention_masks,
428 | token_type_ids,
429 | ent_type=None,
430 | start_ids=None,
431 | end_ids=None,
432 | pseudo=None):
433 |
434 | bert_outputs = self.bert_module(
435 | input_ids=token_ids,
436 | attention_mask=attention_masks,
437 | token_type_ids=token_type_ids
438 | )
439 |
440 | seq_out = bert_outputs[0]
441 |
442 | if self.use_type_embed:
443 | assert ent_type is not None, \
444 | 'Using predicate embedding, predicate should be implemented'
445 |
446 | predicate_feature = self.type_embedding(ent_type)
447 | seq_out = self.conditional_layer_norm(seq_out, predicate_feature)
448 |
449 | seq_out = self.mid_linear(seq_out)
450 |
451 | start_logits = self.start_fc(seq_out)
452 | end_logits = self.end_fc(seq_out)
453 |
454 | out = (start_logits, end_logits, )
455 |
456 | if start_ids is not None and end_ids is not None:
457 | start_logits = start_logits.view(-1, 2)
458 | end_logits = end_logits.view(-1, 2)
459 |
460 | # 去掉 text_a 和 padding 部分的标签,计算真实 loss
461 | active_loss = token_type_ids.view(-1) == 1
462 | active_start_logits = start_logits[active_loss]
463 | active_end_logits = end_logits[active_loss]
464 |
465 | active_start_labels = start_ids.view(-1)[active_loss]
466 | active_end_labels = end_ids.view(-1)[active_loss]
467 |
468 | if pseudo is not None:
469 | # (batch,)
470 | start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1)
471 | end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1)
472 |
473 | # nums of pseudo data
474 | pseudo_nums = pseudo.sum().item()
475 | total_nums = token_ids.shape[0]
476 |
477 | # learning parameter
478 | rate = torch.sigmoid(self.loss_weight)
479 | if pseudo_nums == 0:
480 | start_loss = start_loss.mean()
481 | end_loss = end_loss.mean()
482 | else:
483 | if total_nums == pseudo_nums:
484 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums
485 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums
486 | else:
487 | start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \
488 | + ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums)
489 | end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \
490 | + ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums)
491 | else:
492 | start_loss = self.criterion(active_start_logits, active_start_labels)
493 | end_loss = self.criterion(active_end_logits, active_end_labels)
494 |
495 | loss = start_loss + end_loss
496 |
497 | out = (loss, ) + out
498 |
499 | return out
500 |
501 |
502 | class EnsembleCRFModel:
503 | def __init__(self, model_path_list, bert_dir_list, num_tags, device, lamb=1/3):
504 |
505 | self.models = []
506 | self.crf_module = CRF(num_tags=num_tags, batch_first=True)
507 | self.lamb = lamb
508 |
509 | for idx, _path in enumerate(model_path_list):
510 | print(f'Load model from {_path}')
511 |
512 |
513 | print(f'Load model type: {bert_dir_list[0]}')
514 | model = CRFModel(bert_dir=bert_dir_list[0], num_tags=num_tags)
515 |
516 |
517 | model.load_state_dict(torch.load(_path, map_location=torch.device('cpu')))
518 |
519 | model.eval()
520 | model.to(device)
521 |
522 | self.models.append(model)
523 | if idx == 0:
524 | print(f'Load CRF weight from {_path}')
525 | self.crf_module.load_state_dict(model.crf_module.state_dict())
526 | self.crf_module.to(device)
527 |
528 | def weight(self, t):
529 | """
530 | 牛顿冷却定律加权融合
531 | """
532 | return math.exp(-self.lamb*t)
533 |
534 | def predict(self, model_inputs):
535 | weight_sum = 0.
536 | logits = None
537 | attention_masks = model_inputs['attention_masks']
538 |
539 | for idx, model in enumerate(self.models):
540 | # 使用牛顿冷却概率融合
541 | weight = self.weight(idx)
542 |
543 | # 使用概率平均融合
544 | # weight = 1 / len(self.models)
545 |
546 | tmp_logits = model(**model_inputs)[1] * weight
547 | weight_sum += weight
548 |
549 | if logits is None:
550 | logits = tmp_logits
551 | else:
552 | logits += tmp_logits
553 |
554 | logits = logits / weight_sum
555 |
556 | tokens_out = self.crf_module.decode(emissions=logits, mask=attention_masks.byte())
557 |
558 | return tokens_out
559 |
560 | def vote_entities(self, model_inputs, sent, id2ent, threshold):
561 | entities_ls = []
562 | for idx, model in enumerate(self.models):
563 | tmp_tokens = model(**model_inputs)[0][0]
564 | tmp_entities = crf_decode(tmp_tokens, sent, id2ent)
565 | entities_ls.append(tmp_entities)
566 |
567 | return vote(entities_ls, threshold)
568 |
569 |
570 | class EnsembleSpanModel:
571 | def __init__(self, model_path_list, bert_dir_list, num_tags, device):
572 |
573 | self.models = []
574 |
575 | for idx, _path in enumerate(model_path_list):
576 | print(f'Load model from {_path}')
577 |
578 | print(f'Load model type: {bert_dir_list[0]}')
579 | model = SpanModel(bert_dir=bert_dir_list[0], num_tags=num_tags)
580 |
581 | model.load_state_dict(torch.load(_path, map_location=torch.device('cpu')))
582 |
583 | model.eval()
584 | model.to(device)
585 |
586 | self.models.append(model)
587 |
588 | def predict(self, model_inputs):
589 | start_logits, end_logits = None, None
590 |
591 | for idx, model in enumerate(self.models):
592 |
593 | # 使用概率平均融合
594 | weight = 1 / len(self.models)
595 |
596 | tmp_start_logits, tmp_end_logits = model(**model_inputs)
597 |
598 | tmp_start_logits = tmp_start_logits * weight
599 | tmp_end_logits = tmp_end_logits * weight
600 |
601 | if start_logits is None:
602 | start_logits = tmp_start_logits
603 | end_logits = tmp_end_logits
604 | else:
605 | start_logits += tmp_start_logits
606 | end_logits += tmp_end_logits
607 |
608 | return start_logits, end_logits
609 |
610 | def vote_entities(self, model_inputs, sent, id2ent, threshold):
611 | entities_ls = []
612 |
613 | for idx, model in enumerate(self.models):
614 |
615 | start_logits, end_logits = model(**model_inputs)
616 | start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
617 | end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
618 |
619 | decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
620 |
621 | entities_ls.append(decode_entities)
622 |
623 | return vote(entities_ls, threshold)
624 |
625 |
626 | def build_model(task_type, bert_dir, **kwargs):
627 | assert task_type in ['crf', 'span', 'mrc']
628 |
629 | if task_type == 'crf':
630 | model = CRFModel(bert_dir=bert_dir,
631 | num_tags=kwargs.pop('num_tags'),
632 | dropout_prob=kwargs.pop('dropout_prob', 0.1))
633 |
634 | elif task_type == 'mrc':
635 | model = MRCModel(bert_dir=bert_dir,
636 | dropout_prob=kwargs.pop('dropout_prob', 0.1),
637 | use_type_embed=kwargs.pop('use_type_embed'),
638 | loss_type=kwargs.pop('loss_type', 'ce'))
639 |
640 | else:
641 | model = SpanModel(bert_dir=bert_dir,
642 | num_tags=kwargs.pop('num_tags'),
643 | dropout_prob=kwargs.pop('dropout_prob', 0.1),
644 | loss_type=kwargs.pop('loss_type', 'ce'))
645 |
646 | return model
647 |
--------------------------------------------------------------------------------
/src/utils/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | class Args:
4 | @staticmethod
5 | def parse():
6 | parser = argparse.ArgumentParser()
7 | return parser
8 |
9 | @staticmethod
10 | def initialize(parser: argparse.ArgumentParser):
11 | # args for path
12 | parser.add_argument('--raw_data_dir', default='./data/raw_data',
13 | help='the data dir of raw data')
14 |
15 | parser.add_argument('--mid_data_dir', default='./data/mid_data',
16 | help='the mid data dir')
17 |
18 | parser.add_argument('--output_dir', default='./out/',
19 | help='the output dir for model checkpoints')
20 |
21 | parser.add_argument('--bert_dir', default='../bert/torch_roberta_wwm',
22 | help='bert dir for ernie / roberta-wwm / uer')
23 |
24 | parser.add_argument('--bert_type', default='roberta_wwm',
25 | help='roberta_wwm / ernie_1 / uer_large')
26 |
27 | parser.add_argument('--task_type', default='crf',
28 | help='crf / span / mrc')
29 |
30 | parser.add_argument('--loss_type', default='ls_ce',
31 | help='loss type for span bert')
32 |
33 | parser.add_argument('--use_type_embed', default=False, action='store_true',
34 | help='weather to use soft label in span loss')
35 |
36 | parser.add_argument('--use_fp16', default=False, action='store_true',
37 | help='weather to use fp16 during training')
38 |
39 | # other args
40 | parser.add_argument('--seed', type=int, default=123, help='random seed')
41 |
42 | parser.add_argument('--gpu_ids', type=str, default='0',
43 | help='gpu ids to use, -1 for cpu, "0,1" for multi gpu')
44 |
45 | parser.add_argument('--mode', type=str, default='train',
46 | help='train / stack')
47 |
48 | parser.add_argument('--max_seq_len', default=512, type=int)
49 |
50 | parser.add_argument('--eval_batch_size', default=64, type=int)
51 |
52 | parser.add_argument('--swa_start', default=3, type=int,
53 | help='the epoch when swa start')
54 |
55 | # train args
56 | parser.add_argument('--train_epochs', default=10, type=int,
57 | help='Max training epoch')
58 |
59 | parser.add_argument('--dropout_prob', default=0.1, type=float,
60 | help='drop out probability')
61 |
62 | parser.add_argument('--lr', default=2e-5, type=float,
63 | help='learning rate for the bert module')
64 |
65 | parser.add_argument('--other_lr', default=2e-3, type=float,
66 | help='learning rate for the module except bert')
67 |
68 | parser.add_argument('--max_grad_norm', default=1.0, type=float,
69 | help='max grad clip')
70 |
71 | parser.add_argument('--warmup_proportion', default=0.1, type=float)
72 |
73 | parser.add_argument('--weight_decay', default=0.00, type=float)
74 |
75 | parser.add_argument('--adam_epsilon', default=1e-8, type=float)
76 |
77 | parser.add_argument('--train_batch_size', default=24, type=int)
78 |
79 | parser.add_argument('--eval_model', default=True, action='store_true',
80 | help='whether to eval model after training')
81 |
82 | parser.add_argument('--attack_train', default='', type=str,
83 | help='fgm / pgd attack train when training')
84 |
85 | # test args
86 | parser.add_argument('--version', default='v0', type=str,
87 | help='submit version')
88 |
89 | parser.add_argument('--submit_dir', default='./submit', type=str)
90 |
91 | parser.add_argument('--ckpt_dir', default='', type=str)
92 |
93 | return parser
94 |
95 | def get_parser(self):
96 | parser = self.parse()
97 | parser = self.initialize(parser)
98 | return parser.parse_args()
99 |
--------------------------------------------------------------------------------
/src/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 | import logging
5 | from torch.cuda.amp import autocast as ac
6 | from torch.utils.data import DataLoader, RandomSampler
7 | from transformers import AdamW, get_linear_schedule_with_warmup
8 | from src.utils.attack_train_utils import FGM, PGD
9 | from src.utils.functions_utils import load_model_and_parallel, swa
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def save_model(opt, model, global_step):
16 | output_dir = os.path.join(opt.output_dir, 'checkpoint-{}'.format(global_step))
17 | if not os.path.exists(output_dir):
18 | os.makedirs(output_dir, exist_ok=True)
19 |
20 | # take care of model distributed / parallel training
21 | model_to_save = (
22 | model.module if hasattr(model, "module") else model
23 | )
24 | logger.info(f'Saving model & optimizer & scheduler checkpoint to {output_dir}')
25 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt'))
26 |
27 |
28 | def build_optimizer_and_scheduler(opt, model, t_total):
29 | module = (
30 | model.module if hasattr(model, "module") else model
31 | )
32 |
33 | # 差分学习率
34 | no_decay = ["bias", "LayerNorm.weight"]
35 | model_param = list(module.named_parameters())
36 |
37 | bert_param_optimizer = []
38 | other_param_optimizer = []
39 |
40 | for name, para in model_param:
41 | space = name.split('.')
42 | if space[0] == 'bert_module':
43 | bert_param_optimizer.append((name, para))
44 | else:
45 | other_param_optimizer.append((name, para))
46 |
47 | optimizer_grouped_parameters = [
48 | # bert other module
49 | {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
50 | "weight_decay": opt.weight_decay, 'lr': opt.lr},
51 | {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
52 | "weight_decay": 0.0, 'lr': opt.lr},
53 |
54 | # 其他模块,差分学习率
55 | {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)],
56 | "weight_decay": opt.weight_decay, 'lr': opt.other_lr},
57 | {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)],
58 | "weight_decay": 0.0, 'lr': opt.other_lr},
59 | ]
60 |
61 | optimizer = AdamW(optimizer_grouped_parameters, lr=opt.lr, eps=opt.adam_epsilon)
62 | scheduler = get_linear_schedule_with_warmup(
63 | optimizer, num_warmup_steps=int(opt.warmup_proportion * t_total), num_training_steps=t_total
64 | )
65 |
66 | return optimizer, scheduler
67 |
68 |
69 | def train(opt, model, train_dataset):
70 | swa_raw_model = copy.deepcopy(model)
71 |
72 | train_sampler = RandomSampler(train_dataset)
73 |
74 | train_loader = DataLoader(dataset=train_dataset,
75 | batch_size=opt.train_batch_size,
76 | sampler=train_sampler,
77 | num_workers=0)
78 |
79 | scaler = None
80 | if opt.use_fp16:
81 | scaler = torch.cuda.amp.GradScaler()
82 |
83 | model, device = load_model_and_parallel(model, opt.gpu_ids)
84 |
85 | use_n_gpus = False
86 | if hasattr(model, "module"):
87 | use_n_gpus = True
88 |
89 | t_total = len(train_loader) * opt.train_epochs
90 |
91 | optimizer, scheduler = build_optimizer_and_scheduler(opt, model, t_total)
92 |
93 | # Train
94 | logger.info("***** Running training *****")
95 | logger.info(f" Num Examples = {len(train_dataset)}")
96 | logger.info(f" Num Epochs = {opt.train_epochs}")
97 | logger.info(f" Total training batch size = {opt.train_batch_size}")
98 | logger.info(f" Total optimization steps = {t_total}")
99 |
100 | global_step = 0
101 |
102 | model.zero_grad()
103 |
104 | fgm, pgd = None, None
105 |
106 | attack_train_mode = opt.attack_train.lower()
107 | if attack_train_mode == 'fgm':
108 | fgm = FGM(model=model)
109 | elif attack_train_mode == 'pgd':
110 | pgd = PGD(model=model)
111 |
112 | pgd_k = 3
113 |
114 | save_steps = t_total // opt.train_epochs
115 | eval_steps = save_steps
116 |
117 | logger.info(f'Save model in {save_steps} steps; Eval model in {eval_steps} steps')
118 |
119 | log_loss_steps = 20
120 |
121 | avg_loss = 0.
122 |
123 | for epoch in range(opt.train_epochs):
124 |
125 | for step, batch_data in enumerate(train_loader):
126 |
127 | model.train()
128 |
129 | for key in batch_data.keys():
130 | batch_data[key] = batch_data[key].to(device)
131 |
132 | if opt.use_fp16:
133 | with ac():
134 | loss = model(**batch_data)[0]
135 | else:
136 | loss = model(**batch_data)[0]
137 |
138 | if use_n_gpus:
139 | loss = loss.mean()
140 |
141 | if opt.use_fp16:
142 | scaler.scale(loss).backward()
143 | else:
144 | loss.backward()
145 |
146 | if fgm is not None:
147 | fgm.attack()
148 |
149 | if opt.use_fp16:
150 | with ac():
151 | loss_adv = model(**batch_data)[0]
152 | else:
153 | loss_adv = model(**batch_data)[0]
154 |
155 | if use_n_gpus:
156 | loss_adv = loss_adv.mean()
157 |
158 | if opt.use_fp16:
159 | scaler.scale(loss_adv).backward()
160 | else:
161 | loss_adv.backward()
162 |
163 | fgm.restore()
164 |
165 | elif pgd is not None:
166 | pgd.backup_grad()
167 |
168 | for _t in range(pgd_k):
169 | pgd.attack(is_first_attack=(_t == 0))
170 |
171 | if _t != pgd_k - 1:
172 | model.zero_grad()
173 | else:
174 | pgd.restore_grad()
175 |
176 | if opt.use_fp16:
177 | with ac():
178 | loss_adv = model(**batch_data)[0]
179 | else:
180 | loss_adv = model(**batch_data)[0]
181 |
182 | if use_n_gpus:
183 | loss_adv = loss_adv.mean()
184 |
185 | if opt.use_fp16:
186 | scaler.scale(loss_adv).backward()
187 | else:
188 | loss_adv.backward()
189 |
190 | pgd.restore()
191 |
192 | if opt.use_fp16:
193 | scaler.unscale_(optimizer)
194 |
195 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)
196 |
197 | # optimizer.step()
198 | if opt.use_fp16:
199 | scaler.step(optimizer)
200 | scaler.update()
201 | else:
202 | optimizer.step()
203 |
204 | scheduler.step()
205 | model.zero_grad()
206 |
207 | global_step += 1
208 |
209 | if global_step % log_loss_steps == 0:
210 | avg_loss /= log_loss_steps
211 | logger.info('Step: %d / %d ----> total loss: %.5f' % (global_step, t_total, avg_loss))
212 | avg_loss = 0.
213 | else:
214 | avg_loss += loss.item()
215 |
216 | if global_step % save_steps == 0:
217 | save_model(opt, model, global_step)
218 |
219 | swa(swa_raw_model, opt.output_dir, swa_start=opt.swa_start)
220 |
221 | # clear cuda cache to avoid OOM
222 | torch.cuda.empty_cache()
223 | logger.info('Train done')
224 |
--------------------------------------------------------------------------------