├── .vscode ├── settings.json ├── tasks.json └── launch.json ├── test.sh ├── test.py ├── run_predict.sh ├── run_predict2.sh ├── run_eval.sh ├── test.json ├── run_ner.sh ├── run_joint.sh ├── run_role_segment_bin.sh ├── run_role_bin.sh ├── run_role_segment.sh ├── sh ├── run_role_bin0.sh ├── run_role_bin1.sh ├── run_role_bin2.sh ├── run_role_bin3.sh ├── run_role_bin4.sh ├── run_role_bin5.sh └── run_role_bin6.sh ├── run_classify.sh ├── run_distributed.sh ├── run_multi_class.sh ├── run_role_bin_event_type.sh ├── run_pl.sh ├── .gitignore ├── README.md ├── loss.py ├── utils_classify.py ├── main.py ├── utils_ner.py ├── utils_ner_segment.py ├── run_pl_ner.py ├── utils.py ├── transformer_base.py ├── utils_bi_ner.py ├── utils_bi_ner_segment.py ├── utils_bi_ner_segment_event_type.py ├── postprocess.py ├── utils_bi_ner_joint.py └── preprocess.py /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/mhxia/anaconda3/envs/whou-transformers/bin/python" 3 | } -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | zsh sh/run_role_bin0.sh ;zsh sh/run_role_bin1.sh; zsh sh/run_role_bin2.sh; zsh sh/run_role_bin3.sh 4 | zsh sh/run_role_bin4.sh; zsh sh/run_role_bin5.sh; zsh sh/run_role_bin6.sh -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "echo", 8 | "type": "shell", 9 | "command": "echo Hello" 10 | } 11 | ] 12 | } -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import sys 3 | #str = input() 4 | def helper(aList, left, right): 5 | pivot = aList[start] 6 | while left < right: 7 | while left < right and nums[right]>= pivot: 8 | right -= 1 9 | aList[left]= aList[right] 10 | while left < right and nums[right]<= pivot: 11 | left += 1 12 | aList[right] = aList[left] 13 | aList[left] = pivot 14 | return left 15 | 16 | def quickSort(nums, left , right): 17 | if left train.txt.tmp 7 | curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attredirects=0&d=1' \ 8 | | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp 9 | curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \ 10 | | grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp 11 | wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py" 12 | export MAX_LENGTH=128 13 | export BERT_MODEL=bert-base-multilingual-cased 14 | python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt 15 | python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt 16 | python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt 17 | cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt 18 | export OUTPUT_DIR=germeval-model 19 | export BATCH_SIZE=32 20 | export NUM_EPOCHS=3 21 | export SAVE_STEPS=750 22 | export SEED=1 23 | 24 | python3 run_pl_ner.py --data_dir ./ \ 25 | --model_type bert \ 26 | --labels ./labels.txt \ 27 | --model_name_or_path $BERT_MODEL \ 28 | --output_dir $OUTPUT_DIR \ 29 | --max_seq_length $MAX_LENGTH \ 30 | --num_train_epochs $NUM_EPOCHS \ 31 | --train_batch_size 32 \ 32 | --save_steps $SAVE_STEPS \ 33 | --seed $SEED \ 34 | --do_train \ 35 | --do_predict 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # data model 133 | 134 | output/ 135 | data/ 136 | results/ 137 | runs/ 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lic2020-ee 2 | 3 | ## 任务简介 4 | 事件抽取可以分成两步: 5 | 1. 触发词词提取/事件类型:通常用NER的方式识别触发词,同时也判断出事件类型 6 | 2. 事件元素抽取:提取与触发词相关的事件元素 7 | 8 | 一般来说,触发词不存在重叠问题,但是可能存在多个相同类型的触发词(事件);同一事件中事件元素可能存在重叠问题。 9 | 10 | ## 比赛思路 11 | 12 | ### baseline 13 | pipeline方式:NER方式提取所有的触发词,NER方式提取所有的事件元素,然后根据event_schema匹配触发词和事件元素。 14 | 15 | 实验结果: 16 | dev-trigger: f1=0.86427 precision=0.85655 recall=0.87921 17 | dev-role: f1=0.62321 precision=0.58821 recall=0.67439 18 | test: f1=0.808 precision=0.836 recall=0.782 19 | 20 | ### 事件类型分类 21 | 由于此次比赛只关注事件类型,并且同一文本中相同类型的事件合并不会影响评测结果;对于任务1,并不需要提取触发词。因此除了NER方式,还可以采用**多标签分类**的方法判断事件类型(采用sigmoid激活函数)。 22 | 23 | 验证集上的实验结果表明,NER和多标签分类对于事件类型分类结果差不多。 24 | | | Precision-macro | Recall-macro | F1-macro | 25 | | -------------------- | ------------------ | ------------------ | ------------------ | 26 | | NER | 0.9418052256532067 | 0.9571514785757392 | 0.9494163424124513 | 27 | | multi-label-classify | 0.9548229548229549 | 0.943874471937236 | 0.9493171471927163 | 28 | | multi-label-classify+ trick | 0.9515445184736523 | 0.9480989740494871 | 0.9498186215235792 | 29 | 30 | 这里有一个小的trick:当多标签分类结果若为空时,则选择score最大的标签,带来了微小的提升。 31 | 32 | 我们还尝试了重采样 WeightedRandomSampler,但是 recall 很低。最终没有采用。 33 | macro_F1 = 0.930539826349566 34 | micro_f1 = 0.928857823783912 35 | precision = 1.0 36 | recall = 0.8823529411764706 37 | 38 | 39 | 模型融合时:我们也测试了两种方法,logits平均和labels投票,结果相差不大。 40 | 41 | 基于多标签分类的模型融合结果: 42 | 5-merge-labels: {'precision': 0.9555555555555556, 'recall': 0.96016898008449, 'f1': 0.9578567128236003} 43 | 5-merge-logits: {'precision': 0.9577804583835947, 'recall': 0.9583584791792396, 'f1': 0.9580693815987934} 44 | 45 | 46 | ### 事件元素提取 47 | 事件元素提取有两种思路: 48 | 1. 提取触发词相关的事件元素 49 | 2. 先提取所有事件元素,然后把事件元素和触发词对应起来 50 | 51 | #### 第一种思路 52 | 这里采用 start-end 标注方式,为了解决重叠问题。 53 | 54 | 有几种特征输入的方式: 55 | 1. 加入触发词特征:更改触发词对应的 segemnt-id,这里主要参考 PLMEE[2] 。 56 | 2. 加入触发词特征:将触发词的 embedding 平均后,和 句子的 embedding 相加,这里主要参考 CASREL [1]。 57 | 3. 1 和 2 相结合。 58 | 4. 在2的基础上,加入event-type特征。 59 | 5. 在3的基础上,加入event-type特征。 60 | 61 | 验证集上的实验结果,我们只做了3和4: 62 | | | Precision | Recall | F1 | 63 | | -------------------- | ------------------ | ------------------ | ------------------ | 64 | | 3 | 0.7114285714285714 | 0.7410714285714286 | 0.7259475218658891| 65 | | 4 | 0.7352192362093353 | 0.7031926406926406 | 0.7188493984234545 | 66 | 67 | 3 在test1上的结果:f1=0.789 precision=0.828 recall=0.754 68 | 69 | recall很低,最后放弃了第一种思路。 70 | 71 | #### 第二种思路 72 | 73 | 我们测试了 BIO 标注和 start-end 标注方式,发现 start-end 标注方式更优。 74 | 75 | 验证集上的实验结果: 76 | | | Precision-span | Recall-span | F1-sapn | 77 | | -------------------- | ------------------ | ------------------ | ------------------ | 78 | | BIO | 0.6721880844242586 | 0.7517179563788468 | 0.7097320169252468 | 79 | | start-end | 0.7690982194141298 | 0.7147050974112623 | 0.7409046894452898 | 80 | | start-end + 5-merge | 0.8174037089871612 | 0.7751623376623377 | 0.7957228162755173 | 81 | 82 | 值得一提的是,模型融合能够带来较大的提升。 83 | 84 | ### 最终方案 85 | 86 | 采用pipeline方式:多标签分类得到事件类型,采用 start-end 标注方式提取所有的事件元素,然后根据event_schema匹配触发词和事件元素。 87 | 88 | test1结果:precision=0.847 recall=0.842 f1=0.844 89 | test1排名:28 90 | test2结果: 91 | test2排名:20 92 | 93 | ## 参考文献 94 | 95 | [1] Wei Z, Su J, Wang Y, et al. A Novel Hierarchical Binary Tagging Framework for Joint Extraction of Entities and Relations[J]. arXiv: Computation and Language, 2019. 96 | [2] Yang S, Feng D, Qiao L, et al. Exploring Pre-trained Language Models for Event Extraction and Generation[C]. meeting of the association for computational linguistics, 2019: 5284-5294. 97 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Predict", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "cwd": "${fileDirname}", 13 | "console": "integratedTerminal", 14 | "env": { 15 | "CUDA_VISIBLE_DEVICES": "2" 16 | }, 17 | "args": [ 18 | "--task", 19 | "role", 20 | "--model_type", 21 | "bert", 22 | "--model_name_or_path", 23 | "/home/mhxia/whou/workspace/pretrained_models/chinese_roberta_wwm_large_ext_pytorch", 24 | "--do_predict", 25 | "--data_dir", 26 | "./data/role_bin/test/", 27 | "--do_lower_case", 28 | "--keep_accents", 29 | "--schema", 30 | "./data/event_schema/event_schema.json", 31 | "--output_dir", 32 | "./output/role_bin_train_dev/0/", 33 | "--max_seq_length", 34 | "384", 35 | "--per_gpu_eval_batch_size", 36 | "32", 37 | "--seed", 38 | "1" 39 | ] 40 | }, 41 | { 42 | "name": "Python: Current File", 43 | "type": "python", 44 | "request": "launch", 45 | "program": "${file}", 46 | "cwd": "${fileDirname}", 47 | "console": "integratedTerminal" 48 | }, 49 | { 50 | "name": "Train", 51 | "type": "python", 52 | "request": "launch", 53 | "program": "${file}", 54 | "cwd": "${fileDirname}", 55 | "console": "integratedTerminal", 56 | "env": { 57 | "CUDA_VISIBLE_DEVICES": "2" 58 | }, 59 | "args": [ 60 | "--task", 61 | "trigger", 62 | "--model_type", 63 | "bert", 64 | "--model_name_or_path", 65 | "/home/mhxia/whou/workspace/pretrained_models/chinese_roberta_wwm_large_ext_pytorch", 66 | "--do_train", 67 | "--do_eval", 68 | "--evaluate_during_training", 69 | "--eval_all_checkpoints", 70 | "--data_dir", 71 | "./data/trigger_classify_weighted/0/", 72 | // "--overwrite_cache", 73 | "--do_lower_case", 74 | "--keep_accents", 75 | "--schema", 76 | "./data/event_schema/event_schema.json", 77 | "--output_dir", 78 | "./output/trigger_classify_weighted/0/", 79 | "--overwrite_output_dir", 80 | "--max_seq_length", 81 | "256", 82 | "--per_gpu_train_batch_size", 83 | "2", 84 | "--per_gpu_eval_batch_size", 85 | "64", 86 | "--gradient_accumulation_steps", 87 | "1", 88 | "--save_steps", 89 | "4", 90 | "--logging_steps", 91 | "4", 92 | "--num_train_epochs", 93 | "7", 94 | "--early_stop", 95 | "4", 96 | "--learning_rate", 97 | "3e-5", 98 | "--weight_decay", 99 | "0", 100 | "--warmup_steps", 101 | "1000", 102 | "--seed", 103 | "1" 104 | ] 105 | } 106 | ] 107 | } -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | 10 | class FocalLoss(nn.Module): 11 | '''Multi-class Focal loss implementation''' 12 | def __init__(self, gamma=2, weight=None, ignore_index=-100): 13 | super(FocalLoss, self).__init__() 14 | self.gamma = gamma 15 | self.weight = weight 16 | self.ignore_index=ignore_index 17 | 18 | def forward(self, input, target): 19 | """ 20 | input: [N, C] 21 | target: [N, ] 22 | """ 23 | 24 | not_ignore = target.data.cpu() != self.ignore_index 25 | target= target[not_ignore] 26 | input = input[not_ignore] 27 | 28 | logpt = F.log_softmax(input, dim=1) 29 | pt = torch.exp(logpt) 30 | logpt = (1-pt)**self.gamma * logpt 31 | loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) 32 | return loss 33 | 34 | 35 | class LabelSmoothingCrossEntropy(nn.Module): 36 | def __init__(self, eps=0.1, reduction='mean',ignore_index=-100): 37 | super(LabelSmoothingCrossEntropy, self).__init__() 38 | self.eps = eps 39 | self.reduction = reduction 40 | self.ignore_index = ignore_index 41 | 42 | def forward(self, output, target): 43 | c = output.size(-1) 44 | log_preds = F.log_softmax(output, dim=-1) 45 | if self.reduction=='sum': 46 | loss = -log_preds.sum() 47 | else: 48 | loss = -log_preds.sum(dim=-1) 49 | if self.reduction=='mean': 50 | loss = loss.mean() 51 | return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction, 52 | ignore_index=self.ignore_index) 53 | 54 | 55 | class DSCLoss(nn.Module): 56 | """DSCLoss: Dice Loss for Data-imbalanced NLP Tasks (Multi-Classification) 57 | Args: 58 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 59 | prediction: Output of Network, a tensor of shape [batch, class_num] 60 | target: Label of classification, a tensor of shape [batch, ] 61 | Returns: 62 | Loss tensor according to args reduction 63 | Comments: 64 | Suitable for imbalanced data. 65 | """ 66 | def __init__(self, smooth=0, ignore_index=-100): 67 | super(DSCLoss, self).__init__() 68 | self.smooth = smooth 69 | self.ignore_index = ignore_index 70 | 71 | 72 | def forward(self, prediction, target): 73 | # add by houwei 74 | prediction = F.softmax(prediction, dim=1) 75 | 76 | not_ignore = target.data.cpu() != self.ignore_index 77 | target= target[not_ignore] 78 | prediction = prediction[not_ignore] 79 | target = F.one_hot(target, num_classes=prediction.size(1)) 80 | 81 | num = (1.0 - prediction) * prediction * target + self.smooth 82 | den = (1.0 - prediction) * prediction + target + self.smooth 83 | dice = 1.0 - num / den 84 | # print(dice) 85 | 86 | loss = torch.mean(dice, dim=0) 87 | # print(loss) 88 | # 去掉 标签 O 的loss 89 | loss = loss[1:] 90 | return loss.mean() 91 | 92 | 93 | class DiceLoss(nn.Module): 94 | """DiceLoss: A kind of Dice Loss (Multi-Classification) 95 | Args: 96 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 97 | prediction: Output of Network, a tensor of shape [batch, class_num] 98 | target: Label of classification, a tensor of shape [batch, ] 99 | Returns: 100 | Loss tensor according to args reduction 101 | """ 102 | def __init__(self, smooth=1 , ignore_index=-100): 103 | super(DiceLoss, self).__init__() 104 | self.smooth = smooth 105 | self.ignore_index = ignore_index 106 | 107 | def forward(self, prediction, target): 108 | prediction = F.softmax(prediction, dim=1) 109 | 110 | not_ignore = target.data.cpu() != self.ignore_index 111 | target= target[not_ignore] 112 | prediction = prediction[not_ignore] 113 | target = F.one_hot(target, num_classes=prediction.size(1)) 114 | 115 | num = 2 * prediction * target + self.smooth 116 | den = prediction.pow(2) + target.pow(2) + self.smooth 117 | loss = torch.mean(1.0 - num / den, dim=0) 118 | # print(loss) 119 | # 去掉 标签 O 的loss 120 | loss = loss[1:] 121 | return loss.mean() 122 | 123 | ''' 124 | 第一轮 验证 125 | loss* 1: 0.001245 126 | loss* 10000: 0.01597 127 | ''' 128 | -------------------------------------------------------------------------------- /utils_classify.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ GLUE processors and helpers """ 17 | 18 | import logging 19 | import os 20 | import json 21 | 22 | from utils import get_labels 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class InputExample(object): 28 | """A single training/test example for token classification.""" 29 | 30 | def __init__(self, guid, text_a, text_b=None, label=None): 31 | """Constructs a InputExample. 32 | 33 | Args: 34 | guid: Unique id for the example. 35 | words: list. The words of the sequence. 36 | labels: (Optional) list. The labels for each word of the sequence. This should be 37 | specified for train and dev examples, but not for test examples. 38 | """ 39 | self.guid = guid 40 | self.text_a = text_a 41 | self.text_b = text_b 42 | self.label = label 43 | 44 | 45 | class InputFeatures(object): 46 | """A single set of features of data.""" 47 | 48 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 49 | self.input_ids = input_ids 50 | self.input_mask = attention_mask 51 | self.segment_ids = token_type_ids 52 | self.label = label 53 | 54 | 55 | def read_examples_from_file(data_dir, mode): 56 | label_list= get_labels(task='trigger', mode="classificstion") 57 | 58 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 59 | guid_index = 1 60 | examples = [] 61 | with open(file_path, encoding="utf-8") as f: 62 | for line in f: 63 | if line=='\n' or line=='': 64 | continue 65 | line_json = json.loads(line) 66 | # print(line_json) 67 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), text_a= line_json['text'], label= line_json['labels'])) 68 | guid_index += 1 69 | 70 | return examples 71 | 72 | 73 | def convert_examples_to_features(examples, tokenizer, 74 | max_length=512, 75 | label_list=None, 76 | pad_on_left=False, 77 | pad_token=0, 78 | pad_token_segment_id=0, 79 | mask_padding_with_zero=True): 80 | """ 81 | Loads a data file into a list of ``InputFeatures`` 82 | Args: 83 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 84 | tokenizer: Instance of a tokenizer that will tokenize the examples 85 | max_length: Maximum example length 86 | task: GLUE task 87 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method 88 | output_mode: String indicating the output mode. Either ``regression`` or ``classification`` 89 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 90 | pad_token: Padding token 91 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4) 92 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 93 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 94 | actual values) 95 | Returns: 96 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 97 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 98 | a list of task-specific ``InputFeatures`` which can be fed to the model. 99 | """ 100 | 101 | label_map = {label: i for i, label in enumerate(label_list)} 102 | 103 | features = [] 104 | for (ex_index, example) in enumerate(examples): 105 | if ex_index % 10000 == 0: 106 | logger.info("Writing example %d" % (ex_index)) 107 | 108 | inputs = tokenizer.encode_plus( 109 | example.text_a, 110 | example.text_b, 111 | add_special_tokens=True, 112 | max_length=max_length, 113 | ) 114 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 115 | 116 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 117 | # tokens are attended to. 118 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 119 | 120 | # Zero-pad up to the sequence length. 121 | padding_length = max_length - len(input_ids) 122 | if pad_on_left: 123 | input_ids = ([pad_token] * padding_length) + input_ids 124 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 125 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 126 | else: 127 | input_ids = input_ids + ([pad_token] * padding_length) 128 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 129 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 130 | 131 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) 132 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) 133 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length) 134 | 135 | label = [label_map[sub_label] for sub_label in example.label] 136 | 137 | if ex_index < 5: 138 | logger.info("*** Example ***") 139 | logger.info("guid: %s" % (example.guid)) 140 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 141 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 142 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 143 | logger.info("label: %s (id = %s)" % (str(example.label), str(label))) 144 | 145 | features.append( 146 | InputFeatures(input_ids=input_ids, 147 | attention_mask=attention_mask, 148 | token_type_ids=token_type_ids, 149 | label=label)) 150 | 151 | return features 152 | 153 | 154 | def convert_label_ids_to_onehot(label_ids, label_list): 155 | one_hot_labels= [0]*len(label_list) 156 | label_map = {label: i for i, label in enumerate(label_list)} 157 | ignore_index= -100 158 | non_index= -1 159 | for label_id in label_ids: 160 | if label_id not in [ignore_index, non_index]: 161 | one_hot_labels[label_id]= 1 162 | return one_hot_labels 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tqdm 3 | from preprocess import write_file 4 | import collections 5 | import numpy as np 6 | from utils import get_labels 7 | 8 | def role_bin_ner_labels_merge(input_file_list, output_file): 9 | num_fold = len(input_file_list) 10 | labels_file_list= [] 11 | for input_file in input_file_list: 12 | rows = open(input_file, encoding='utf-8').read().splitlines() 13 | labels_file = [json.loads(row)["labels"] for row in rows] 14 | labels_file_list.append(labels_file) 15 | 16 | num_samples = len(labels_file_list[0]) 17 | labels = [] 18 | for i in range(num_samples): 19 | cur_all_labels = [] 20 | for labels_file in labels_file_list: 21 | cur_labels = [] 22 | for label in labels_file[i]: 23 | cur_labels.append(" ".join(list(map(str, label)))) 24 | cur_all_labels.extend(cur_labels) 25 | 26 | cur_labels =[] 27 | obj = collections.Counter(cur_all_labels) 28 | for k,v in obj.items(): 29 | if v>= 3: 30 | cur_labels.append(list(map(int, k.split()))) 31 | if cur_labels==[] and len(obj)!=0: 32 | k = obj.most_common(1)[0][0] 33 | cur_labels.append(list(map(int, k.split()))) 34 | labels.append({"labels": cur_labels}) 35 | 36 | write_file(labels, output_file) 37 | 38 | 39 | def trigger_classify_labels_merge(input_file_list, output_file): 40 | num_fold = len(input_file_list) 41 | labels_file_list= [] 42 | for input_file in input_file_list: 43 | rows = open(input_file, encoding='utf-8').read().splitlines() 44 | labels_file = [json.loads(row)["labels"] for row in rows] 45 | labels_file_list.append(labels_file) 46 | 47 | num_samples = len(labels_file_list[0]) 48 | labels = [] 49 | for i in range(num_samples): 50 | cur_all_labels = [] 51 | for labels_file in labels_file_list: 52 | cur_all_labels.extend(labels_file[i]) 53 | 54 | cur_labels =[] 55 | obj = collections.Counter(cur_all_labels) 56 | for k,v in obj.items(): 57 | # (num_fold+1)//2: 58 | if v>=2: 59 | cur_labels.append(k) 60 | if cur_labels ==[]: 61 | # print(obj.most_common(1)[0][0]) 62 | cur_labels.append(obj.most_common(1)[0][0]) 63 | labels.append({"labels": cur_labels}) 64 | 65 | write_file(labels, output_file) 66 | 67 | def trigger_classify_logits_merge_and_eval(input_file_list, output_file, label_file): 68 | labels = get_labels(task="trigger", mode="classify") 69 | label_map = {i: label for i, label in enumerate(labels)} 70 | 71 | num_fold = len(input_file_list) 72 | logits_file_list= [] 73 | for input_file in input_file_list: 74 | rows = open(input_file, encoding='utf-8').read().splitlines() 75 | logits_file = [json.loads(row)["logits"] for row in rows] 76 | logits_file_list.append(logits_file) 77 | 78 | threshold = 0.5 79 | logits = np.array(logits_file_list).mean(axis=0) 80 | preds = logits > threshold # 1498*65 81 | 82 | # 若所有类别对应的 logit 都没有超过阈值,则将 logit 最大的类别作为 label 83 | for i in range(preds.shape[0]): 84 | if sum(preds[i])==0: 85 | preds[i][np.argmax(logits[i])]=True 86 | 87 | preds_list = [] 88 | batch_preds_list = [[] for _ in range(preds.shape[0])] 89 | 90 | for i in range(logits.shape[0]): 91 | for j in range(logits.shape[1]): 92 | if preds[i, j]: 93 | preds_list.append([i, label_map[j]]) 94 | batch_preds_list[i].append(label_map[j]) 95 | 96 | ############ labels 97 | label_rows = open(label_file, encoding='utf-8').read().splitlines() 98 | out_labels = [json.loads(row)["labels"] for row in label_rows] 99 | 100 | out_label_list = [] 101 | 102 | for i, row in enumerate(label_rows): 103 | json_line = json.loads(row) 104 | for label in json_line["labels"]: 105 | out_label_list.append([i, label]) 106 | 107 | print(compute_f1(preds_list, out_label_list)) 108 | 109 | write_file(labels, output_file) 110 | 111 | def eval_trigger(pred_file, label_file): 112 | preds_list= [] 113 | labels_list = [] 114 | pred_rows = open(pred_file, encoding='utf-8').read().splitlines() 115 | label_rows = open(label_file, encoding='utf-8').read().splitlines() 116 | 117 | for i, row in enumerate(pred_rows): 118 | json_line = json.loads(row) 119 | for label in json_line["labels"]: 120 | preds_list.append([i, label]) 121 | 122 | for i, row in enumerate(label_rows): 123 | json_line = json.loads(row) 124 | for label in json_line["labels"]: 125 | labels_list.append([i, label]) 126 | 127 | print(compute_f1(preds_list, labels_list )) 128 | 129 | def eval_trigger(pred_file, label_file): 130 | preds_list= [] 131 | labels_list = [] 132 | pred_rows = open(pred_file, encoding='utf-8').read().splitlines() 133 | label_rows = open(label_file, encoding='utf-8').read().splitlines() 134 | 135 | for i, row in enumerate(pred_rows): 136 | json_line = json.loads(row) 137 | arguments = json_line["arguments"] 138 | preds_list.extend(arguments) 139 | 140 | for i, row in enumerate(label_rows): 141 | json_line = json.loads(row) 142 | arguments = json_line["arguments"] 143 | labels_list.extend(arguments) 144 | 145 | print(compute_f1(preds_list, labels_list )) 146 | 147 | def compute_f1(preds_list, labels_list): 148 | nb_correct = 0 149 | for out_label in labels_list: 150 | # for pred in preds_list: 151 | # if out_label==pred: 152 | # nb_correct+=1 153 | if out_label in preds_list: 154 | nb_correct += 1 155 | continue 156 | nb_pred = len(preds_list) 157 | nb_true = len(labels_list) 158 | # print(nb_correct, nb_pred, nb_true) 159 | 160 | p = nb_correct / nb_pred if nb_pred > 0 else 0 161 | r = nb_correct / nb_true if nb_true > 0 else 0 162 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0 163 | 164 | results = { 165 | "precision": p, 166 | "recall": r, 167 | "f1": f1, 168 | } 169 | return results 170 | 171 | if __name__ == "__main__": 172 | trigger_classify_labels_merge(input_file_list=[ 173 | "./output/trigger_classify/0/checkpoint-best/test2_predictions.json", 174 | "./output/trigger_classify/1/checkpoint-best/test2_predictions.json", 175 | "./output/trigger_classify/2/checkpoint-best/test2_predictions.json", 176 | "./output/trigger_classify/3/checkpoint-best/test2_predictions.json"], 177 | output_file="./output/trigger_classify/merge/test2_predictions_labels.json" 178 | ) 179 | # eval_trigger(pred_file="./output/trigger_classify/merge/eval_predictions_labels.json",\ 180 | # label_file="./data/trigger_classify/dev.json", 181 | # ) 182 | 183 | # trigger_classify_logits_merge_and_eval(input_file_list=[ 184 | # "./output/trigger_classify/0/checkpoint-best/eval_logits.json", 185 | # "./output/trigger_classify/1/checkpoint-best/eval_logits.json", 186 | # "./output/trigger_classify/2/checkpoint-best/eval_logits.json", 187 | # "./output/trigger_classify/3/checkpoint-best/eval_logits.json", 188 | # "./output/trigger_classify/4/checkpoint-best/eval_logits.json"], 189 | # output_file="./output/trigger_classify/merge/eval_predictions_logits.json", 190 | # label_file="./data/trigger_classify/dev.json" 191 | # ) 192 | 193 | role_bin_ner_labels_merge(input_file_list=[ 194 | "./output/role_bin_train_dev/0/checkpoint-best/test_predictions2.json", 195 | "./output/role_bin_train_dev/2/checkpoint-best/test_predictions2.json", 196 | "./output/role_bin_train_dev/3/checkpoint-best/test_predictions2.json", 197 | "./output/role_bin_train_dev/4/checkpoint-best/test_predictions2.json", 198 | "./output/role_bin_train_dev/5/checkpoint-best/test_predictions2.json", 199 | "./output/role_bin_train_dev/6/checkpoint-best/test_predictions2.json"], 200 | output_file="./output/role_bin/merge/test2_predictions_labels.json" 201 | ) 202 | # eval_trigger(pred_file="./output/role_bin/merge/eval_predictions_indexed_labels.json",\ 203 | # label_file="./data/role_bin/dev.json", 204 | # ) 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /utils_ner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | 19 | import logging 20 | import os 21 | import json 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class InputExample(object): 27 | """A single training/test example for token classification.""" 28 | 29 | def __init__(self, guid, words, labels): 30 | """Constructs a InputExample. 31 | 32 | Args: 33 | guid: Unique id for the example. 34 | words: list. The words of the sequence. 35 | labels: (Optional) list. The labels for each word of the sequence. This should be 36 | specified for train and dev examples, but not for test examples. 37 | """ 38 | self.guid = guid 39 | self.words = words 40 | self.labels = labels 41 | 42 | 43 | class InputFeatures(object): 44 | """A single set of features of data.""" 45 | 46 | def __init__(self, input_ids, input_mask, segment_ids, label_ids): 47 | self.input_ids = input_ids 48 | self.input_mask = input_mask 49 | self.segment_ids = segment_ids 50 | self.label_ids = label_ids 51 | 52 | 53 | def read_examples_from_file(data_dir, mode): 54 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 55 | guid_index = 1 56 | examples = [] 57 | with open(file_path, encoding="utf-8") as f: 58 | words = [] 59 | labels = [] 60 | for line in f: 61 | if line=='\n' or line=='': 62 | continue 63 | line_json = json.loads(line) 64 | words = line_json['tokens'] 65 | if mode=='test': labels=['O']*len(words) 66 | else: labels = line_json['labels'] 67 | if len(words)!= len(labels): 68 | print(words, labels," length misMatch") 69 | continue 70 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words, labels=labels)) 71 | guid_index += 1 72 | 73 | return examples 74 | 75 | 76 | def convert_examples_to_features( 77 | examples, 78 | label_list, 79 | max_seq_length, 80 | tokenizer, 81 | cls_token_at_end=False, 82 | cls_token="[CLS]", 83 | cls_token_segment_id=1, 84 | sep_token="[SEP]", 85 | sep_token_extra=False, 86 | pad_on_left=False, 87 | pad_token=0, 88 | pad_token_segment_id=0, 89 | pad_token_label_id=-100, 90 | sequence_a_segment_id=0, 91 | mask_padding_with_zero=True, 92 | ): 93 | """ Loads a data file into a list of `InputBatch`s 94 | `cls_token_at_end` define the location of the CLS token: 95 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 96 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 97 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 98 | """ 99 | 100 | label_map = {label: i for i, label in enumerate(label_list)} 101 | # print(label_map) 102 | 103 | features = [] 104 | for (ex_index, example) in enumerate(examples): 105 | if ex_index % 10000 == 0: 106 | logger.info("Writing example %d of %d", ex_index, len(examples)) 107 | # print(example.words, example.labels) 108 | # print(len(example.words), len(example.labels)) 109 | tokens = [] 110 | label_ids = [] 111 | for word, label in zip(example.words, example.labels): 112 | word_tokens = tokenizer.tokenize(word) 113 | tokens.extend(word_tokens) 114 | 115 | if len(word_tokens)>1: 116 | # print(word,">1") # 没有 117 | pass 118 | if len(word_tokens)<1: 119 | # print(word,"<1") 基本都是空格 120 | tokens.extend(["[unused1]"]) 121 | # continue 122 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 123 | label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)) 124 | # if len(tokens)!= len(label_ids): 125 | # print(word, word_tokens, tokens, label_ids) 126 | assert len(tokens) == len(label_ids) 127 | # print(len(tokens),len(label_ids)) 128 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 129 | special_tokens_count = 3 if sep_token_extra else 2 130 | if len(tokens) > max_seq_length - special_tokens_count: 131 | tokens = tokens[: (max_seq_length - special_tokens_count)] 132 | label_ids = label_ids[: (max_seq_length - special_tokens_count)] 133 | 134 | # The convention in BERT is: 135 | # (a) For sequence pairs: 136 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 137 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 138 | # (b) For single sequences: 139 | # tokens: [CLS] the dog is hairy . [SEP] 140 | # type_ids: 0 0 0 0 0 0 0 141 | # 142 | # Where "type_ids" are used to indicate whether this is the first 143 | # sequence or the second sequence. The embedding vectors for `type=0` and 144 | # `type=1` were learned during pre-training and are added to the wordpiece 145 | # embedding vector (and position vector). This is not *strictly* necessary 146 | # since the [SEP] token unambiguously separates the sequences, but it makes 147 | # it easier for the model to learn the concept of sequences. 148 | # 149 | # For classification tasks, the first vector (corresponding to [CLS]) is 150 | # used as as the "sentence vector". Note that this only makes sense because 151 | # the entire model is fine-tuned. 152 | tokens += [sep_token] 153 | label_ids += [pad_token_label_id] 154 | if sep_token_extra: 155 | # roberta uses an extra separator b/w pairs of sentences 156 | tokens += [sep_token] 157 | label_ids += [pad_token_label_id] 158 | segment_ids = [sequence_a_segment_id] * len(tokens) 159 | 160 | if cls_token_at_end: 161 | tokens += [cls_token] 162 | label_ids += [pad_token_label_id] 163 | segment_ids += [cls_token_segment_id] 164 | else: 165 | tokens = [cls_token] + tokens 166 | label_ids = [pad_token_label_id] + label_ids 167 | segment_ids = [cls_token_segment_id] + segment_ids 168 | 169 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 170 | # print(len(tokens), len(input_ids), len(label_ids)) 171 | 172 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 173 | # tokens are attended to. 174 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 175 | 176 | # Zero-pad up to the sequence length. 177 | padding_length = max_seq_length - len(input_ids) 178 | if pad_on_left: 179 | input_ids = ([pad_token] * padding_length) + input_ids 180 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 181 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 182 | label_ids = ([pad_token_label_id] * padding_length) + label_ids 183 | else: 184 | input_ids += [pad_token] * padding_length 185 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 186 | segment_ids += [pad_token_segment_id] * padding_length 187 | label_ids += [pad_token_label_id] * padding_length 188 | 189 | # print(len(label_ids), max_seq_length) 190 | 191 | assert len(input_ids) == max_seq_length 192 | assert len(input_mask) == max_seq_length 193 | assert len(segment_ids) == max_seq_length 194 | assert len(label_ids) == max_seq_length 195 | 196 | if ex_index < 5: 197 | logger.info("*** Example ***") 198 | logger.info("guid: %s", example.guid) 199 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 200 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 201 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 202 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 203 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 204 | 205 | features.append( 206 | InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids) 207 | ) 208 | return features 209 | -------------------------------------------------------------------------------- /utils_ner_segment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | 19 | import logging 20 | import os 21 | import json 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class InputExample(object): 27 | """A single training/test example for token classification.""" 28 | 29 | def __init__(self, guid, words, labels, segment_ids): 30 | """Constructs a InputExample. 31 | 32 | Args: 33 | guid: Unique id for the example. 34 | words: list. The words of the sequence. 35 | labels: (Optional) list. The labels for each word of the sequence. This should be 36 | specified for train and dev examples, but not for test examples. 37 | """ 38 | self.guid = guid 39 | self.words = words 40 | self.labels = labels 41 | self.segment_ids = segment_ids 42 | 43 | 44 | class InputFeatures(object): 45 | """A single set of features of data.""" 46 | 47 | def __init__(self, input_ids, input_mask, segment_ids, label_ids): 48 | self.input_ids = input_ids 49 | self.input_mask = input_mask 50 | self.segment_ids = segment_ids 51 | self.label_ids = label_ids 52 | 53 | 54 | def read_examples_from_file(data_dir, mode): 55 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 56 | guid_index = 1 57 | examples = [] 58 | with open(file_path, encoding="utf-8") as f: 59 | for line in f: 60 | if line=='\n' or line=='': 61 | continue 62 | line_json = json.loads(line) 63 | # id = line_json['id'] 64 | words = line_json['tokens'] 65 | if mode=='test': labels=['O']*len(words) 66 | else: labels = line_json['labels'] 67 | segment_ids= line_json["segment_ids"] 68 | 69 | if len(words)!= len(labels): 70 | print(words, labels," length misMatch") 71 | continue 72 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words, labels=labels, \ 73 | segment_ids= segment_ids)) 74 | guid_index += 1 75 | 76 | return examples 77 | 78 | 79 | def convert_examples_to_features( 80 | examples, 81 | label_list, 82 | max_seq_length, 83 | tokenizer, 84 | trigger_token_segment_id = 1, 85 | cls_token_at_end=False, 86 | cls_token="[CLS]", 87 | cls_token_segment_id=1, 88 | sep_token="[SEP]", 89 | sep_token_extra=False, 90 | pad_on_left=False, 91 | pad_token=0, 92 | pad_token_segment_id=0, 93 | pad_token_label_id=-100, 94 | sequence_a_segment_id=0, 95 | mask_padding_with_zero=True, 96 | ): 97 | """ Loads a data file into a list of `InputBatch`s 98 | `cls_token_at_end` define the location of the CLS token: 99 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 100 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 101 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 102 | """ 103 | 104 | label_map = {label: i for i, label in enumerate(label_list)} 105 | # print(label_map) 106 | 107 | features = [] 108 | for (ex_index, example) in enumerate(examples): 109 | if ex_index % 10000 == 0: 110 | logger.info("Writing example %d of %d", ex_index, len(examples)) 111 | # print(example.words, example.labels) 112 | # print(len(example.words), len(example.labels)) 113 | tokens = [] 114 | label_ids = [] 115 | segment_ids = [] 116 | for word, label, segment_id in zip(example.words, example.labels, example.segment_ids): 117 | word_tokens = tokenizer.tokenize(word) 118 | tokens.extend(word_tokens) 119 | if len(word_tokens)>1: 120 | # print(word,">1") # 没有 121 | pass 122 | if len(word_tokens)<1: 123 | # print(word,"<1") 基本都是空格 124 | tokens.extend(["[unused1]"]) 125 | # continue 126 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 127 | label_ids.extend([label_map[label]] ) 128 | segment_ids.extend( [sequence_a_segment_id if not segment_id else trigger_token_segment_id]) 129 | # if len(tokens)!= len(label_ids): 130 | # print(word, word_tokens, tokens, label_ids) 131 | # print(len(tokens),len(label_ids)) 132 | 133 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 134 | special_tokens_count = 3 if sep_token_extra else 2 135 | if len(tokens) > max_seq_length - special_tokens_count: 136 | tokens = tokens[: (max_seq_length - special_tokens_count)] 137 | label_ids = label_ids[: (max_seq_length - special_tokens_count)] 138 | segment_ids = segment_ids[: (max_seq_length - special_tokens_count)] 139 | 140 | # The convention in BERT is: 141 | # (a) For sequence pairs: 142 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 143 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 144 | # (b) For single sequences: 145 | # tokens: [CLS] the dog is hairy . [SEP] 146 | # type_ids: 0 0 0 0 0 0 0 147 | # 148 | # Where "type_ids" are used to indicate whether this is the first 149 | # sequence or the second sequence. The embedding vectors for `type=0` and 150 | # `type=1` were learned during pre-training and are added to the wordpiece 151 | # embedding vector (and position vector). This is not *strictly* necessary 152 | # since the [SEP] token unambiguously separates the sequences, but it makes 153 | # it easier for the model to learn the concept of sequences. 154 | # 155 | # For classification tasks, the first vector (corresponding to [CLS]) is 156 | # used as as the "sentence vector". Note that this only makes sense because 157 | # the entire model is fine-tuned. 158 | tokens += [sep_token] 159 | label_ids += [pad_token_label_id] 160 | segment_ids += [sequence_a_segment_id] 161 | if sep_token_extra: 162 | # roberta uses an extra separator b/w pairs of sentences 163 | tokens += [sep_token] 164 | label_ids += [pad_token_label_id] 165 | segment_ids += [sequence_a_segment_id] 166 | 167 | # segment_ids = [sequence_a_segment_id] * len(tokens) 168 | # # 改变 trigger_token_segment_id 169 | # id = example.id 170 | # trigger = example.trigger 171 | # trigger_start_index = example.trigger_start_index 172 | # print(id ,len(segment_ids), trigger_start_index, trigger) 173 | # for i in range(trigger_start_index-1, trigger_start_index-1 + len(trigger) ): 174 | # segment_ids[i] = trigger_token_segment_id 175 | 176 | if cls_token_at_end: 177 | tokens += [cls_token] 178 | label_ids += [pad_token_label_id] 179 | segment_ids += [cls_token_segment_id] 180 | else: 181 | tokens = [cls_token] + tokens 182 | label_ids = [pad_token_label_id] + label_ids 183 | segment_ids = [cls_token_segment_id] + segment_ids 184 | 185 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 186 | # print(len(tokens), len(input_ids), len(label_ids)) 187 | 188 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 189 | # tokens are attended to. 190 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 191 | 192 | # Zero-pad up to the sequence length. 193 | padding_length = max_seq_length - len(input_ids) 194 | if pad_on_left: 195 | input_ids = ([pad_token] * padding_length) + input_ids 196 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 197 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 198 | label_ids = ([pad_token_label_id] * padding_length) + label_ids 199 | else: 200 | input_ids += [pad_token] * padding_length 201 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 202 | segment_ids += [pad_token_segment_id] * padding_length 203 | label_ids += [pad_token_label_id] * padding_length 204 | 205 | # print(len(label_ids), max_seq_length) 206 | 207 | assert len(input_ids) == max_seq_length 208 | assert len(input_mask) == max_seq_length 209 | assert len(segment_ids) == max_seq_length 210 | assert len(label_ids) == max_seq_length 211 | 212 | if ex_index < 5: 213 | logger.info("*** Example ***") 214 | logger.info("guid: %s", example.guid) 215 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 216 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 217 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 218 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 219 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 220 | 221 | features.append( 222 | InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids) 223 | ) 224 | return features 225 | -------------------------------------------------------------------------------- /run_pl_ner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from seqeval.metrics import f1_score, precision_score, recall_score 9 | from torch.nn import CrossEntropyLoss 10 | from torch.utils.data import DataLoader, TensorDataset 11 | 12 | from transformer_base import BaseTransformer, add_generic_args, generic_train 13 | from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class NERTransformer(BaseTransformer): 20 | """ 21 | A training module for NER. See BaseTransformer for the core options. 22 | """ 23 | 24 | def __init__(self, hparams): 25 | self.labels = get_labels(hparams.labels) 26 | num_labels = len(self.labels) 27 | self.pad_token_label_id = CrossEntropyLoss().ignore_index 28 | super(NERTransformer, self).__init__(hparams, num_labels) 29 | 30 | def forward(self, **inputs): 31 | return self.model(**inputs) 32 | 33 | def training_step(self, batch, batch_num): 34 | "Compute loss and log." 35 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 36 | if self.hparams.model_type != "distilbert": 37 | inputs["token_type_ids"] = ( 38 | batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None 39 | ) # XLM and RoBERTa don"t use segment_ids 40 | 41 | outputs = self.forward(**inputs) 42 | loss = outputs[0] 43 | tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]} 44 | return {"loss": loss, "log": tensorboard_logs} 45 | 46 | def _feature_file(self, mode): 47 | return os.path.join( 48 | self.hparams.data_dir, 49 | "cached_{}_{}_{}".format( 50 | mode, 51 | list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), 52 | str(self.hparams.max_seq_length), 53 | ), 54 | ) 55 | 56 | def prepare_data(self): 57 | "Called to initialize data. Use the call to construct features" 58 | args = self.hparams 59 | for mode in ["train", "dev", "test"]: 60 | cached_features_file = self._feature_file(mode) 61 | if not os.path.exists(cached_features_file): 62 | logger.info("Creating features from dataset file at %s", args.data_dir) 63 | examples = read_examples_from_file(args.data_dir, mode) 64 | features = convert_examples_to_features( 65 | examples, 66 | self.labels, 67 | args.max_seq_length, 68 | self.tokenizer, 69 | cls_token_at_end=bool(args.model_type in ["xlnet"]), 70 | cls_token=self.tokenizer.cls_token, 71 | cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0, 72 | sep_token=self.tokenizer.sep_token, 73 | sep_token_extra=bool(args.model_type in ["roberta"]), 74 | pad_on_left=bool(args.model_type in ["xlnet"]), 75 | pad_token=self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0], 76 | pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, 77 | pad_token_label_id=self.pad_token_label_id, 78 | ) 79 | logger.info("Saving features into cached file %s", cached_features_file) 80 | torch.save(features, cached_features_file) 81 | 82 | def load_dataset(self, mode, batch_size): 83 | "Load datasets. Called after prepare data." 84 | cached_features_file = self._feature_file(mode) 85 | logger.info("Loading features from cached file %s", cached_features_file) 86 | features = torch.load(cached_features_file) 87 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 88 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 89 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 90 | all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) 91 | return DataLoader( 92 | TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids), batch_size=batch_size 93 | ) 94 | 95 | def validation_step(self, batch, batch_nb): 96 | "Compute validation" 97 | 98 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} 99 | if self.hparams.model_type != "distilbert": 100 | inputs["token_type_ids"] = ( 101 | batch[2] if self.hparams.model_type in ["bert", "xlnet"] else None 102 | ) # XLM and RoBERTa don"t use segment_ids 103 | outputs = self.forward(**inputs) 104 | tmp_eval_loss, logits = outputs[:2] 105 | preds = logits.detach().cpu().numpy() 106 | out_label_ids = inputs["labels"].detach().cpu().numpy() 107 | return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} 108 | 109 | def _eval_end(self, outputs): 110 | "Evaluation called for both Val and Test" 111 | val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean() 112 | preds = np.concatenate([x["pred"] for x in outputs], axis=0) 113 | preds = np.argmax(preds, axis=2) 114 | out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0) 115 | 116 | label_map = {i: label for i, label in enumerate(self.labels)} 117 | out_label_list = [[] for _ in range(out_label_ids.shape[0])] 118 | preds_list = [[] for _ in range(out_label_ids.shape[0])] 119 | 120 | for i in range(out_label_ids.shape[0]): 121 | for j in range(out_label_ids.shape[1]): 122 | if out_label_ids[i, j] != self.pad_token_label_id: 123 | out_label_list[i].append(label_map[out_label_ids[i][j]]) 124 | preds_list[i].append(label_map[preds[i][j]]) 125 | 126 | results = { 127 | "val_loss": val_loss_mean, 128 | "precision": precision_score(out_label_list, preds_list), 129 | "recall": recall_score(out_label_list, preds_list), 130 | "f1": f1_score(out_label_list, preds_list), 131 | } 132 | 133 | if self.is_logger(): 134 | logger.info("***** Eval results *****") 135 | for key in sorted(results.keys()): 136 | logger.info(" %s = %s", key, str(results[key])) 137 | 138 | tensorboard_logs = results 139 | ret = {k: v for k, v in results.items()} 140 | ret["log"] = tensorboard_logs 141 | return ret, preds_list, out_label_list 142 | 143 | def validation_end(self, outputs): 144 | ret, preds, targets = self._eval_end(outputs) 145 | return ret 146 | 147 | def test_end(self, outputs): 148 | ret, predictions, targets = self._eval_end(outputs) 149 | 150 | if self.is_logger(): 151 | # Write output to a file: 152 | # Save results 153 | output_test_results_file = os.path.join(self.hparams.output_dir, "test_results.txt") 154 | with open(output_test_results_file, "w") as writer: 155 | for key in sorted(ret.keys()): 156 | if key != "log": 157 | writer.write("{} = {}\n".format(key, str(ret[key]))) 158 | # Save predictions 159 | output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") 160 | with open(output_test_predictions_file, "w") as writer: 161 | with open(os.path.join(self.hparams.data_dir, "test.txt"), "r") as f: 162 | example_id = 0 163 | for line in f: 164 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 165 | writer.write(line) 166 | if not predictions[example_id]: 167 | example_id += 1 168 | elif predictions[example_id]: 169 | output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n" 170 | writer.write(output_line) 171 | else: 172 | logger.warning( 173 | "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0] 174 | ) 175 | return ret 176 | 177 | @staticmethod 178 | def add_model_specific_args(parser, root_dir): 179 | # Add NER specific options 180 | BaseTransformer.add_model_specific_args(parser, root_dir) 181 | parser.add_argument( 182 | "--max_seq_length", 183 | default=128, 184 | type=int, 185 | help="The maximum total input sequence length after tokenization. Sequences longer " 186 | "than this will be truncated, sequences shorter will be padded.", 187 | ) 188 | 189 | parser.add_argument( 190 | "--labels", 191 | default="", 192 | type=str, 193 | help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.", 194 | ) 195 | 196 | parser.add_argument( 197 | "--data_dir", 198 | default=None, 199 | type=str, 200 | required=True, 201 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", 202 | ) 203 | 204 | parser.add_argument( 205 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 206 | ) 207 | 208 | return parser 209 | 210 | 211 | if __name__ == "__main__": 212 | parser = argparse.ArgumentParser() 213 | add_generic_args(parser, os.getcwd()) 214 | parser = NERTransformer.add_model_specific_args(parser, os.getcwd()) 215 | args = parser.parse_args() 216 | model = NERTransformer(args) 217 | trainer = generic_train(model, args) 218 | 219 | if args.do_predict: 220 | checkpoints = list(sorted(glob.glob(args.output_dir + "/checkpoint_*.ckpt", recursive=True))) 221 | NERTransformer.load_from_checkpoint(checkpoints[-1]) 222 | trainer.test(model) 223 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def write_file(datas, output_file): 4 | with open(output_file, 'w', encoding='utf-8') as f: 5 | for obj in datas: 6 | json.dump(obj, f, ensure_ascii=False, sort_keys=True) 7 | f.write("\n") 8 | 9 | def remove_duplication(alist): 10 | res = [] 11 | for item in alist: 12 | if item not in res: 13 | res.append(item) 14 | return res 15 | 16 | 17 | def get_labels(path="./data/event_schema/event_schema.json", task='trigger', mode="ner"): 18 | if not path: 19 | if mode=='ner': 20 | return ["O", "B-ENTITY", "I-ENTITY"] 21 | else: 22 | return ["O"] 23 | 24 | elif task=='trigger': 25 | labels = [] 26 | rows = open(path, encoding='utf-8').read().splitlines() 27 | if mode == "ner": labels.append('O') 28 | for row in rows: 29 | row = json.loads(row) 30 | event_type = row["event_type"] 31 | if mode == "ner": 32 | labels.append("B-{}".format(event_type)) 33 | labels.append("I-{}".format(event_type)) 34 | else: 35 | labels.append(event_type) 36 | return remove_duplication(labels) 37 | 38 | elif task=='role': 39 | labels = [] 40 | rows = open(path, encoding='utf-8').read().splitlines() 41 | if mode == "ner": labels.append('O') 42 | for row in rows: 43 | row = json.loads(row) 44 | for role in row["role_list"]: 45 | role_type = role['role'] 46 | if mode == "ner": 47 | labels.append("B-{}".format(role_type)) 48 | labels.append("I-{}".format(role_type)) 49 | else: 50 | labels.append(role_type) 51 | return remove_duplication(labels) 52 | 53 | else: 54 | labels = [] 55 | rows = open(path, encoding='utf-8').read().splitlines() 56 | if mode == "ner": labels.append('O') 57 | for row in rows: 58 | row = json.loads(row) 59 | if row['class']!=task: 60 | continue 61 | for role in row["role_list"]: 62 | role_type = role['role'] 63 | if mode == "ner": 64 | labels.append("B-{}".format(role_type)) 65 | labels.append("I-{}".format(role_type)) 66 | else: 67 | labels.append(role_type) 68 | return remove_duplication(labels) 69 | 70 | def data_val(input_file): 71 | rows = open(input_file, encoding='utf-8').read().splitlines() 72 | 73 | event_class_count = 0 74 | role_count = 0 75 | arg_count = 0 76 | arg_role_count = 0 77 | arg_role_one_event_count = 0 78 | trigger_count = 0 79 | argument_len_list =[] 80 | 81 | for row in rows: 82 | if len(row)==1: print(row) 83 | row = json.loads(row) 84 | 85 | arg_start_index_list=[] 86 | arg_start_index_map={} 87 | event_class_list = [] 88 | trigger_start_index_list = [] 89 | 90 | event_class_flag = False 91 | arg_start_index_flag= False 92 | role_flag = False 93 | arg_role_flag= False 94 | arg_role_one_event_flag= False 95 | trigger_flag = False 96 | 97 | for event in row["event_list"]: 98 | event_class = event["class"] 99 | if event_class_list==[]: 100 | event_class_list.append(event_class) 101 | elif event_class not in event_class_list: 102 | # event_class_count += 1 103 | event_class_flag = True 104 | # print(row) 105 | 106 | trigger_start_index= event["trigger_start_index"] 107 | if trigger_start_index not in trigger_start_index_list: 108 | trigger_start_index_list.append(trigger_start_index) 109 | else: 110 | trigger_flag = True 111 | print(row) 112 | 113 | role_list = [] 114 | arg_start_index_map_in_one_event = {} 115 | for arg in event["arguments"]: 116 | role = arg['role'] 117 | argument = arg['argument'] 118 | argument_start_index = arg["argument_start_index"] 119 | argument_len_list.append([len(argument),argument]) 120 | if role not in role_list: 121 | role_list.append(role) 122 | else: 123 | # role_count += 1 124 | arg_start_index_flag = True 125 | # print(row) 126 | 127 | if argument_start_index not in arg_start_index_map_in_one_event: 128 | arg_start_index_map_in_one_event[argument_start_index]= role 129 | else: 130 | if role!= arg_start_index_map_in_one_event[argument_start_index]: 131 | arg_role_one_event_flag = True 132 | # print(row) 133 | 134 | 135 | if argument_start_index not in arg_start_index_list: 136 | arg_start_index_list.append(argument_start_index) 137 | arg_start_index_map[argument_start_index]= role 138 | else: 139 | # arg_count+= 1 140 | role_flag = True 141 | if role!= arg_start_index_map[argument_start_index]: 142 | arg_role_flag = True 143 | # print(row) 144 | 145 | if role_flag: 146 | role_count += 1 147 | # print(row) 148 | if event_class_flag: 149 | event_class_count += 1 150 | # print(row) 151 | if arg_start_index_flag: 152 | arg_count += 1 153 | # print(row) 154 | if arg_role_flag: 155 | arg_role_count += 1 156 | if arg_role_one_event_flag: 157 | arg_role_one_event_count += 1 158 | if trigger_flag: 159 | trigger_count += 1 160 | 161 | print(event_class_count, role_count, arg_count, arg_role_count, arg_role_one_event_count, trigger_count) 162 | argument_len_list.sort(key=lambda x:x[0], reverse= True) 163 | print(argument_len_list[:10]) 164 | 165 | def position_val(input_file): 166 | rows = open(input_file, encoding='utf-8').read().splitlines() 167 | trigger_count = 0 168 | arg_count = 0 169 | 170 | for row in rows: 171 | # position_flag = False 172 | 173 | if len(row)==1: print(row) 174 | row = json.loads(row) 175 | text = row['text'] 176 | for event in row["event_list"]: 177 | event_class = event["class"] 178 | trigger = event["trigger"] 179 | event_type = event["event_type"] 180 | trigger_start_index = event["trigger_start_index"] 181 | 182 | if text[trigger_start_index: trigger_start_index+len(trigger)]!= trigger: 183 | print("trigger position mismatch") 184 | trigger_count += 1 185 | 186 | for arg in event["arguments"]: 187 | role = arg['role'] 188 | argument = arg['argument'] 189 | argument_start_index = arg["argument_start_index"] 190 | 191 | if text[argument_start_index: argument_start_index+len(argument)]!= argument: 192 | print("argument position mismatch") 193 | arg_count+=1 194 | 195 | print(trigger_count, arg_count) 196 | 197 | 198 | # 统计 event_type 分布 199 | def data_analysis(input_file): 200 | rows = open(input_file, encoding='utf-8').read().splitlines() 201 | label_list= get_labels(task='trigger', mode="classification") 202 | label_map = {label: i for i, label in enumerate(label_list)} 203 | label_count = [0 for i in range(len(label_list))] 204 | for row in rows: 205 | row = json.loads(row) 206 | for event in row["event_list"]: 207 | event_type = event["event_type"] 208 | label_count[label_map[event_type]] += 1 209 | print(label_count) 210 | 211 | def get_num_of_arguments(input_file): 212 | lines = open(input_file, encoding='utf-8').read().splitlines() 213 | arg_count = 0 214 | for line in lines: 215 | line = json.loads(line) 216 | for event in line["event_list"]: 217 | arg_count += len(event["arguments"]) 218 | print(arg_count) 219 | 220 | def read_write(input_file, output_file): 221 | rows = open(input_file, encoding='utf-8').read().splitlines() 222 | results = [] 223 | for row in rows: 224 | row = json.loads(row) 225 | id = row.pop('id') 226 | text = row.pop('text') 227 | # labels = row.pop('labels') 228 | event_list = row.pop('event_list') 229 | row['text'] = text 230 | row['id'] = id 231 | # row['labels'] = labels 232 | row['event_list'] = event_list 233 | results.append(row) 234 | write_file(results, output_file) 235 | 236 | def schema_analysis(path="./data/event_schema/event_schema.json"): 237 | rows = open(path, encoding='utf-8').read().splitlines() 238 | argument_map = {} 239 | for row in rows: 240 | d_json = json.loads(row) 241 | event_type = d_json["event_type"] 242 | for r in d_json["role_list"]: 243 | role = r["role"] 244 | if role in argument_map: 245 | argument_map[role].append(event_type) 246 | else: 247 | argument_map[role]= [event_type] 248 | argument_unique = [] 249 | argument_duplicate = [] 250 | for argument, event_type_list in argument_map.items(): 251 | if len(event_type_list)==1: 252 | argument_unique.append(argument) 253 | else: 254 | argument_duplicate.append(argument) 255 | 256 | print(argument_unique, argument_duplicate) 257 | for argument in argument_duplicate: 258 | print(argument_map[argument]) 259 | 260 | return argument_map 261 | 262 | 263 | if __name__ == '__main__': 264 | # labels = get_labels(path="./data/event_schema/event_schema.json", task='trigger', mode="classification") 265 | # print(len(labels), labels[50:60]) 266 | 267 | # data_val("./data/train_data/train.json") 268 | # data_val("./data/dev_data/dev.json") 269 | 270 | # data_analysis("./data/train_data/train.json") 271 | 272 | # 无异常 273 | # position_val("./data/train_data/train.json") 274 | # position_val("./data/dev_data/dev.json") 275 | 276 | # get_num_of_arguments("./results/test_pred_bin_segment.json") 277 | 278 | # read_write("./output/eval_pred.json", "./results/eval_pred.json") 279 | # read_write("./results/test1.trigger.pred.json", "./results/paddle.trigger.json") 280 | 281 | schema_analysis() 282 | 283 | 284 | -------------------------------------------------------------------------------- /transformer_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | 9 | from transformers import ( 10 | AdamW, 11 | BertConfig, 12 | BertForTokenClassification, 13 | BertTokenizer, 14 | CamembertConfig, 15 | CamembertForTokenClassification, 16 | CamembertTokenizer, 17 | DistilBertConfig, 18 | DistilBertForTokenClassification, 19 | DistilBertTokenizer, 20 | RobertaConfig, 21 | RobertaForTokenClassification, 22 | RobertaTokenizer, 23 | XLMRobertaConfig, 24 | XLMRobertaForTokenClassification, 25 | XLMRobertaTokenizer, 26 | get_linear_schedule_with_warmup, 27 | ) 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | ALL_MODELS = sum( 34 | ( 35 | tuple(conf.pretrained_config_archive_map.keys()) 36 | for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig) 37 | ), 38 | (), 39 | ) 40 | 41 | MODEL_CLASSES = { 42 | "bert": (BertConfig, BertForTokenClassification, BertTokenizer), 43 | "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), 44 | "distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer), 45 | "camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer), 46 | "xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer), 47 | } 48 | 49 | 50 | def set_seed(args): 51 | random.seed(args.seed) 52 | np.random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | if args.n_gpu > 0: 55 | torch.cuda.manual_seed_all(args.seed) 56 | 57 | 58 | class BaseTransformer(pl.LightningModule): 59 | def __init__(self, hparams, num_labels=None): 60 | "Initialize a model." 61 | 62 | super(BaseTransformer, self).__init__() 63 | self.hparams = hparams 64 | self.hparams.model_type = self.hparams.model_type.lower() 65 | 66 | config_class, model_class, tokenizer_class = MODEL_CLASSES[self.hparams.model_type] 67 | config = config_class.from_pretrained( 68 | self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, 69 | num_labels=num_labels, 70 | cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, 71 | ) 72 | tokenizer = tokenizer_class.from_pretrained( 73 | self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, 74 | do_lower_case=self.hparams.do_lower_case, 75 | cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, 76 | ) 77 | model = model_class.from_pretrained( 78 | self.hparams.model_name_or_path, 79 | from_tf=bool(".ckpt" in self.hparams.model_name_or_path), 80 | config=config, 81 | cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, 82 | ) 83 | self.config, self.tokenizer, self.model = config, tokenizer, model 84 | 85 | def is_logger(self): 86 | return self.trainer.proc_rank <= 0 87 | 88 | def configure_optimizers(self): 89 | "Prepare optimizer and schedule (linear warmup and decay)" 90 | 91 | model = self.model 92 | no_decay = ["bias", "LayerNorm.weight"] 93 | optimizer_grouped_parameters = [ 94 | { 95 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 96 | "weight_decay": self.hparams.weight_decay, 97 | }, 98 | { 99 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 100 | "weight_decay": 0.0, 101 | }, 102 | ] 103 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 104 | self.opt = optimizer 105 | return [optimizer] 106 | 107 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): 108 | if self.trainer.use_tpu: 109 | xm.optimizer_step(optimizer) 110 | else: 111 | optimizer.step() 112 | optimizer.zero_grad() 113 | self.lr_scheduler.step() 114 | 115 | def get_tqdm_dict(self): 116 | tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} 117 | 118 | return tqdm_dict 119 | 120 | def test_step(self, batch, batch_nb): 121 | return self.validation_step(batch, batch_nb) 122 | 123 | def test_end(self, outputs): 124 | return self.validation_end(outputs) 125 | 126 | def train_dataloader(self): 127 | train_batch_size = self.hparams.train_batch_size 128 | dataloader = self.load_dataset("train", train_batch_size) 129 | 130 | t_total = ( 131 | (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu))) 132 | // self.hparams.gradient_accumulation_steps 133 | * float(self.hparams.num_train_epochs) 134 | ) 135 | scheduler = get_linear_schedule_with_warmup( 136 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 137 | ) 138 | self.lr_scheduler = scheduler 139 | return dataloader 140 | 141 | def val_dataloader(self): 142 | return self.load_dataset("dev", self.hparams.eval_batch_size) 143 | 144 | def test_dataloader(self): 145 | return self.load_dataset("test", self.hparams.eval_batch_size) 146 | 147 | @staticmethod 148 | def add_model_specific_args(parser, root_dir): 149 | parser.add_argument( 150 | "--model_type", 151 | default=None, 152 | type=str, 153 | required=True, 154 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 155 | ) 156 | parser.add_argument( 157 | "--model_name_or_path", 158 | default=None, 159 | type=str, 160 | required=True, 161 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), 162 | ) 163 | parser.add_argument( 164 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 165 | ) 166 | parser.add_argument( 167 | "--tokenizer_name", 168 | default="", 169 | type=str, 170 | help="Pretrained tokenizer name or path if not the same as model_name", 171 | ) 172 | parser.add_argument( 173 | "--cache_dir", 174 | default="", 175 | type=str, 176 | help="Where do you want to store the pre-trained models downloaded from s3", 177 | ) 178 | parser.add_argument( 179 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 180 | ) 181 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 182 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 183 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 184 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 185 | parser.add_argument( 186 | "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." 187 | ) 188 | 189 | parser.add_argument("--train_batch_size", default=32, type=int) 190 | parser.add_argument("--eval_batch_size", default=32, type=int) 191 | 192 | 193 | def add_generic_args(parser, root_dir): 194 | parser.add_argument( 195 | "--output_dir", 196 | default=None, 197 | type=str, 198 | required=True, 199 | help="The output directory where the model predictions and checkpoints will be written.", 200 | ) 201 | 202 | parser.add_argument( 203 | "--fp16", 204 | action="store_true", 205 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 206 | ) 207 | 208 | parser.add_argument( 209 | "--fp16_opt_level", 210 | type=str, 211 | default="O1", 212 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 213 | "See details at https://nvidia.github.io/apex/amp.html", 214 | ) 215 | 216 | parser.add_argument("--n_gpu", type=int, default=1) 217 | parser.add_argument("--n_tpu_cores", type=int, default=0) 218 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 219 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 220 | parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") 221 | parser.add_argument( 222 | "--gradient_accumulation_steps", 223 | type=int, 224 | default=1, 225 | help="Number of updates steps to accumulate before performing a backward/update pass.", 226 | ) 227 | 228 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 229 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 230 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 231 | 232 | 233 | def generic_train(model, args): 234 | # init model 235 | set_seed(args) 236 | 237 | # Setup distant debugging if needed 238 | if args.server_ip and args.server_port: 239 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 240 | import ptvsd 241 | 242 | print("Waiting for debugger attach") 243 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 244 | ptvsd.wait_for_attach() 245 | 246 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 247 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 248 | 249 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 250 | filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5 251 | ) 252 | 253 | train_params = dict( 254 | accumulate_grad_batches=args.gradient_accumulation_steps, 255 | gpus=args.n_gpu, 256 | max_epochs=args.num_train_epochs, 257 | early_stop_callback=False, 258 | gradient_clip_val=args.max_grad_norm, 259 | checkpoint_callback=checkpoint_callback, 260 | ) 261 | 262 | if args.fp16: 263 | train_params["use_amp"] = args.fp16 264 | train_params["amp_level"] = args.fp16_opt_level 265 | 266 | if args.n_tpu_cores > 0: 267 | global xm 268 | import torch_xla.core.xla_model as xm 269 | 270 | train_params["num_tpu_cores"] = args.n_tpu_cores 271 | train_params["gpus"] = 0 272 | 273 | if args.n_gpu > 1: 274 | train_params["distributed_backend"] = "ddp" 275 | 276 | trainer = pl.Trainer(**train_params) 277 | 278 | if args.do_train: 279 | trainer.fit(model) 280 | 281 | return trainer 282 | -------------------------------------------------------------------------------- /utils_bi_ner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | 19 | import logging 20 | import os 21 | import json 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class InputExample(object): 27 | """A single training/test example for token classification.""" 28 | 29 | def __init__(self, guid, words, start_labels, end_labels): 30 | """Constructs a InputExample. 31 | 32 | Args: 33 | guid: Unique id for the example. 34 | words: list. The words of the sequence. 35 | labels: (Optional) list. The labels for each word of the sequence. This should be 36 | specified for train and dev examples, but not for test examples. 37 | """ 38 | self.guid = guid 39 | self.words = words 40 | self.start_labels = start_labels 41 | self.end_labels = end_labels 42 | 43 | 44 | class InputFeatures(object): 45 | """A single set of features of data.""" 46 | 47 | def __init__(self, input_ids, input_mask, segment_ids, start_label_ids, end_label_ids): 48 | self.input_ids = input_ids 49 | self.input_mask = input_mask 50 | self.segment_ids = segment_ids 51 | self.start_label_ids = start_label_ids 52 | self.end_label_ids = end_label_ids 53 | 54 | 55 | def read_examples_from_file(data_dir, mode): 56 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 57 | guid_index = 1 58 | examples = [] 59 | with open(file_path, encoding="utf-8") as f: 60 | words = [] 61 | labels = [] 62 | for line in f: 63 | if line=='\n' or line=='': 64 | continue 65 | line_json = json.loads(line) 66 | words = line_json['tokens'] 67 | if mode=='test': 68 | start_labels=['O']*len(words) 69 | end_labels = ['O']*len(words) 70 | else: 71 | start_labels = line_json['start_labels'] 72 | end_labels = line_json['end_labels'] 73 | if len(words)!= len(start_labels) : 74 | print(words, start_labels," length misMatch") 75 | continue 76 | if len(words)!= len(end_labels) : 77 | print(words, end_labels," length misMatch") 78 | continue 79 | 80 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words,\ 81 | start_labels=start_labels, end_labels = end_labels)) 82 | guid_index += 1 83 | 84 | return examples 85 | 86 | 87 | def convert_examples_to_features( 88 | examples, 89 | label_list, 90 | max_seq_length, 91 | tokenizer, 92 | cls_token_at_end=False, 93 | cls_token="[CLS]", 94 | cls_token_segment_id=1, 95 | sep_token="[SEP]", 96 | sep_token_extra=False, 97 | pad_on_left=False, 98 | pad_token=0, 99 | pad_token_segment_id=0, 100 | pad_token_label_id=-100, 101 | sequence_a_segment_id=0, 102 | mask_padding_with_zero=True, 103 | ): 104 | """ Loads a data file into a list of `InputBatch`s 105 | `cls_token_at_end` define the location of the CLS token: 106 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 107 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 108 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 109 | """ 110 | 111 | label_map = {label: i for i, label in enumerate(label_list)} 112 | label_map['O'] = -1 113 | # print(label_map) 114 | 115 | features = [] 116 | for (ex_index, example) in enumerate(examples): 117 | if ex_index % 10000 == 0: 118 | logger.info("Writing example %d of %d", ex_index, len(examples)) 119 | # print(example.words, example.labels) 120 | # print(len(example.words), len(example.labels)) 121 | tokens = [] 122 | start_label_ids = [] 123 | end_label_ids = [] 124 | for word, start_label, end_label in zip(example.words, example.start_labels, example.end_labels): 125 | word_tokens = tokenizer.tokenize(word) 126 | if len(word_tokens)==1: 127 | tokens.extend(word_tokens) 128 | if len(word_tokens)>1: 129 | print(word,">1") 130 | tokens.extend(word_tokens[:1]) 131 | pass 132 | if len(word_tokens)<1: 133 | # print(word,"<1") 基本都是空格 134 | tokens.extend(["[unused1]"]) 135 | # continue 136 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 137 | cur_start_labels = start_label.split() 138 | cur_start_label_ids = [] 139 | for cur_start_label in cur_start_labels: 140 | cur_start_label_ids.append(label_map[cur_start_label]) 141 | start_label_ids.append(cur_start_label_ids) 142 | 143 | cur_end_labels = end_label.split() 144 | cur_end_label_ids = [] 145 | for cur_end_label in cur_end_labels: 146 | cur_end_label_ids.append(label_map[cur_end_label]) 147 | end_label_ids.append(cur_end_label_ids) 148 | 149 | # if len(tokens)!= len(label_ids): 150 | # print(word, word_tokens, tokens, label_ids) 151 | # print(len(tokens),len(label_ids)) 152 | 153 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 154 | special_tokens_count = 3 if sep_token_extra else 2 155 | if len(tokens) > max_seq_length - special_tokens_count: 156 | tokens = tokens[: (max_seq_length - special_tokens_count)] 157 | start_label_ids = start_label_ids[: (max_seq_length - special_tokens_count)] 158 | end_label_ids = end_label_ids[: (max_seq_length - special_tokens_count)] 159 | 160 | 161 | # The convention in BERT is: 162 | # (a) For sequence pairs: 163 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 164 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 165 | # (b) For single sequences: 166 | # tokens: [CLS] the dog is hairy . [SEP] 167 | # type_ids: 0 0 0 0 0 0 0 168 | # 169 | # Where "type_ids" are used to indicate whether this is the first 170 | # sequence or the second sequence. The embedding vectors for `type=0` and 171 | # `type=1` were learned during pre-training and are added to the wordpiece 172 | # embedding vector (and position vector). This is not *strictly* necessary 173 | # since the [SEP] token unambiguously separates the sequences, but it makes 174 | # it easier for the model to learn the concept of sequences. 175 | # 176 | # For classification tasks, the first vector (corresponding to [CLS]) is 177 | # used as as the "sentence vector". Note that this only makes sense because 178 | # the entire model is fine-tuned. 179 | tokens += [sep_token] 180 | start_label_ids += [[pad_token_label_id]] 181 | end_label_ids += [[pad_token_label_id]] 182 | 183 | if sep_token_extra: 184 | # roberta uses an extra separator b/w pairs of sentences 185 | tokens += [sep_token] 186 | start_label_ids += [[pad_token_label_id]] 187 | end_label_ids += [[pad_token_label_id]] 188 | segment_ids = [sequence_a_segment_id] * len(tokens) 189 | 190 | if cls_token_at_end: 191 | tokens += [cls_token] 192 | start_label_ids += [[pad_token_label_id]] 193 | end_label_ids += [[pad_token_label_id]] 194 | segment_ids += [cls_token_segment_id] 195 | else: 196 | tokens = [cls_token] + tokens 197 | start_label_ids = [[pad_token_label_id]] + start_label_ids 198 | end_label_ids = [[pad_token_label_id]] + end_label_ids 199 | segment_ids = [cls_token_segment_id] + segment_ids 200 | 201 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 202 | # print(len(tokens), len(input_ids), len(label_ids)) 203 | 204 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 205 | # tokens are attended to. 206 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 207 | 208 | # Zero-pad up to the sequence length. 209 | padding_length = max_seq_length - len(input_ids) 210 | if pad_on_left: 211 | input_ids = ([pad_token] * padding_length) + input_ids 212 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 213 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 214 | start_label_ids = ([[pad_token_label_id]] * padding_length) + start_label_ids 215 | end_label_ids = ([[pad_token_label_id]] * padding_length) + end_label_ids 216 | else: 217 | input_ids += [pad_token] * padding_length 218 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 219 | segment_ids += [pad_token_segment_id] * padding_length 220 | start_label_ids += [[pad_token_label_id]] * padding_length 221 | end_label_ids += [[pad_token_label_id]] * padding_length 222 | 223 | # print(len(label_ids), max_seq_length) 224 | 225 | assert len(input_ids) == max_seq_length 226 | assert len(input_mask) == max_seq_length 227 | assert len(segment_ids) == max_seq_length 228 | assert len(start_label_ids) == max_seq_length 229 | assert len(end_label_ids) == max_seq_length 230 | 231 | if ex_index < 5: 232 | logger.info("*** Example ***") 233 | logger.info("guid: %s", example.guid) 234 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 235 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 236 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 237 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 238 | logger.info("start_label_ids: %s", " ".join([str(x) for x in start_label_ids])) 239 | logger.info("end_label_ids: %s", " ".join([str(x) for x in end_label_ids])) 240 | 241 | features.append( 242 | InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, \ 243 | start_label_ids=start_label_ids, end_label_ids= end_label_ids) 244 | ) 245 | return features 246 | 247 | def convert_label_ids_to_onehot(label_ids,label_list): 248 | one_hot_labels= [[0]*len(label_list) for _ in range(len(label_ids))] 249 | label_map = {label: i for i, label in enumerate(label_list)} 250 | ignore_index= -100 251 | non_index= -1 252 | for i, label_id in enumerate(label_ids): 253 | for sub_label_id in label_id: 254 | if sub_label_id not in [ignore_index, non_index]: 255 | one_hot_labels[i][sub_label_id]= 1 256 | return one_hot_labels 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /utils_bi_ner_segment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | 19 | import logging 20 | import os 21 | import json 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class InputExample(object): 27 | """A single training/test example for token classification.""" 28 | 29 | def __init__(self, guid, words, segment_ids, start_labels, end_labels): 30 | """Constructs a InputExample. 31 | 32 | Args: 33 | guid: Unique id for the example. 34 | words: list. The words of the sequence. 35 | labels: (Optional) list. The labels for each word of the sequence. This should be 36 | specified for train and dev examples, but not for test examples. 37 | """ 38 | self.guid = guid 39 | self.words = words 40 | self.segment_ids = segment_ids 41 | self.start_labels = start_labels 42 | self.end_labels = end_labels 43 | 44 | 45 | class InputFeatures(object): 46 | """A single set of features of data.""" 47 | 48 | def __init__(self, input_ids, input_mask, segment_ids, start_label_ids, end_label_ids): 49 | self.input_ids = input_ids 50 | self.input_mask = input_mask 51 | self.segment_ids = segment_ids 52 | self.start_label_ids = start_label_ids 53 | self.end_label_ids = end_label_ids 54 | 55 | 56 | def read_examples_from_file(data_dir, mode): 57 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 58 | guid_index = 1 59 | examples = [] 60 | with open(file_path, encoding="utf-8") as f: 61 | words = [] 62 | labels = [] 63 | for line in f: 64 | if line=='\n' or line=='': 65 | continue 66 | line_json = json.loads(line) 67 | words = line_json['tokens'] 68 | if mode=='test': 69 | start_labels=['O']*len(words) 70 | end_labels = ['O']*len(words) 71 | else: 72 | start_labels = line_json['start_labels'] 73 | end_labels = line_json['end_labels'] 74 | if len(words)!= len(start_labels) : 75 | print(words, start_labels," length misMatch") 76 | continue 77 | if len(words)!= len(end_labels) : 78 | print(words, end_labels," length misMatch") 79 | continue 80 | segment_ids= line_json["segment_ids"] 81 | 82 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words,\ 83 | start_labels=start_labels, end_labels = end_labels, segment_ids= segment_ids)) 84 | guid_index += 1 85 | 86 | return examples 87 | 88 | 89 | def convert_examples_to_features( 90 | examples, 91 | label_list, 92 | max_seq_length, 93 | tokenizer, 94 | trigger_token_segment_id =1, 95 | cls_token_at_end=False, 96 | cls_token="[CLS]", 97 | cls_token_segment_id=1, 98 | sep_token="[SEP]", 99 | sep_token_extra=False, 100 | pad_on_left=False, 101 | pad_token=0, 102 | pad_token_segment_id=0, 103 | pad_token_label_id=-100, 104 | sequence_a_segment_id=0, 105 | mask_padding_with_zero=True, 106 | ): 107 | """ Loads a data file into a list of `InputBatch`s 108 | `cls_token_at_end` define the location of the CLS token: 109 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 110 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 111 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 112 | """ 113 | 114 | label_map = {label: i for i, label in enumerate(label_list)} 115 | label_map['O'] = -1 116 | # print(label_map) 117 | 118 | features = [] 119 | for (ex_index, example) in enumerate(examples): 120 | if ex_index % 10000 == 0: 121 | logger.info("Writing example %d of %d", ex_index, len(examples)) 122 | # print(example.words, example.labels) 123 | # print(len(example.words), len(example.labels)) 124 | tokens = [] 125 | start_label_ids = [] 126 | end_label_ids = [] 127 | segment_ids = [] 128 | for word, start_label, end_label, segment_id in zip(example.words, example.start_labels, example.end_labels, example.segment_ids): 129 | word_tokens = tokenizer.tokenize(word) 130 | tokens.extend(word_tokens) 131 | if len(word_tokens)==1: 132 | tokens.extend(word_tokens) 133 | if len(word_tokens)>1: 134 | print(word,">1") 135 | tokens.extend(word_tokens[:1]) 136 | pass 137 | if len(word_tokens)<1: 138 | # print(word,"<1") 基本都是空格 139 | tokens.extend(["[unused1]"]) 140 | # continue 141 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 142 | cur_start_labels = start_label.split() 143 | cur_start_label_ids = [] 144 | for cur_start_label in cur_start_labels: 145 | cur_start_label_ids.append(label_map[cur_start_label]) 146 | start_label_ids.append(cur_start_label_ids) 147 | 148 | cur_end_labels = end_label.split() 149 | cur_end_label_ids = [] 150 | for cur_end_label in cur_end_labels: 151 | cur_end_label_ids.append(label_map[cur_end_label]) 152 | end_label_ids.append(cur_end_label_ids) 153 | 154 | segment_ids.extend( [sequence_a_segment_id if not segment_id else trigger_token_segment_id] * 1) 155 | 156 | # if len(tokens)!= len(label_ids): 157 | # print(word, word_tokens, tokens, label_ids) 158 | # print(len(tokens),len(label_ids)) 159 | 160 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 161 | special_tokens_count = 3 if sep_token_extra else 2 162 | if len(tokens) > max_seq_length - special_tokens_count: 163 | tokens = tokens[: (max_seq_length - special_tokens_count)] 164 | start_label_ids = start_label_ids[: (max_seq_length - special_tokens_count)] 165 | end_label_ids = end_label_ids[: (max_seq_length - special_tokens_count)] 166 | segment_ids = segment_ids[: (max_seq_length - special_tokens_count)] 167 | 168 | 169 | # The convention in BERT is: 170 | # (a) For sequence pairs: 171 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 172 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 173 | # (b) For single sequences: 174 | # tokens: [CLS] the dog is hairy . [SEP] 175 | # type_ids: 0 0 0 0 0 0 0 176 | # 177 | # Where "type_ids" are used to indicate whether this is the first 178 | # sequence or the second sequence. The embedding vectors for `type=0` and 179 | # `type=1` were learned during pre-training and are added to the wordpiece 180 | # embedding vector (and position vector). This is not *strictly* necessary 181 | # since the [SEP] token unambiguously separates the sequences, but it makes 182 | # it easier for the model to learn the concept of sequences. 183 | # 184 | # For classification tasks, the first vector (corresponding to [CLS]) is 185 | # used as as the "sentence vector". Note that this only makes sense because 186 | # the entire model is fine-tuned. 187 | tokens += [sep_token] 188 | start_label_ids += [[pad_token_label_id]] 189 | end_label_ids += [[pad_token_label_id]] 190 | segment_ids += [sequence_a_segment_id] 191 | 192 | if sep_token_extra: 193 | # roberta uses an extra separator b/w pairs of sentences 194 | tokens += [sep_token] 195 | start_label_ids += [[pad_token_label_id]] 196 | end_label_ids += [[pad_token_label_id]] 197 | segment_ids += [sequence_a_segment_id] 198 | # segment_ids = [sequence_a_segment_id] * len(tokens) 199 | 200 | if cls_token_at_end: 201 | tokens += [cls_token] 202 | start_label_ids += [[pad_token_label_id]] 203 | end_label_ids += [[pad_token_label_id]] 204 | segment_ids += [cls_token_segment_id] 205 | else: 206 | tokens = [cls_token] + tokens 207 | start_label_ids = [[pad_token_label_id]] + start_label_ids 208 | end_label_ids = [[pad_token_label_id]] + end_label_ids 209 | segment_ids = [cls_token_segment_id] + segment_ids 210 | 211 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 212 | # print(len(tokens), len(input_ids), len(label_ids)) 213 | 214 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 215 | # tokens are attended to. 216 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 217 | 218 | # Zero-pad up to the sequence length. 219 | padding_length = max_seq_length - len(input_ids) 220 | if pad_on_left: 221 | input_ids = ([pad_token] * padding_length) + input_ids 222 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 223 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 224 | start_label_ids = ([[pad_token_label_id]] * padding_length) + start_label_ids 225 | end_label_ids = ([[pad_token_label_id]] * padding_length) + end_label_ids 226 | else: 227 | input_ids += [pad_token] * padding_length 228 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 229 | segment_ids += [pad_token_segment_id] * padding_length 230 | start_label_ids += [[pad_token_label_id]] * padding_length 231 | end_label_ids += [[pad_token_label_id]] * padding_length 232 | 233 | # print(len(label_ids), max_seq_length) 234 | 235 | assert len(input_ids) == max_seq_length 236 | assert len(input_mask) == max_seq_length 237 | assert len(segment_ids) == max_seq_length 238 | assert len(start_label_ids) == max_seq_length 239 | assert len(end_label_ids) == max_seq_length 240 | 241 | if ex_index < 5: 242 | logger.info("*** Example ***") 243 | logger.info("guid: %s", example.guid) 244 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 245 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 246 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 247 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 248 | logger.info("start_label_ids: %s", " ".join([str(x) for x in start_label_ids])) 249 | logger.info("end_label_ids: %s", " ".join([str(x) for x in end_label_ids])) 250 | 251 | if sum(segment_ids)==0: 252 | print(ex_index, "segment_id == None") 253 | continue 254 | features.append( 255 | InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, \ 256 | start_label_ids=start_label_ids, end_label_ids= end_label_ids) 257 | ) 258 | return features 259 | 260 | def convert_label_ids_to_onehot(label_ids, label_list): 261 | one_hot_labels= [[False]*len(label_list) for _ in range(len(label_ids))] 262 | label_map = {label: i for i, label in enumerate(label_list)} 263 | ignore_index= -100 264 | non_index= -1 265 | for i, label_id in enumerate(label_ids): 266 | for sub_label_id in label_id: 267 | if sub_label_id not in [ignore_index, non_index]: 268 | one_hot_labels[i][sub_label_id]= 1 269 | return one_hot_labels 270 | 271 | 272 | -------------------------------------------------------------------------------- /utils_bi_ner_segment_event_type.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | 19 | import logging 20 | import os 21 | import json 22 | from utils import get_labels 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class InputExample(object): 28 | """A single training/test example for token classification.""" 29 | 30 | def __init__(self, guid, words, segment_ids, start_labels, end_labels, event_type_id): 31 | """Constructs a InputExample. 32 | 33 | Args: 34 | guid: Unique id for the example. 35 | words: list. The words of the sequence. 36 | labels: (Optional) list. The labels for each word of the sequence. This should be 37 | specified for train and dev examples, but not for test examples. 38 | """ 39 | self.guid = guid 40 | self.words = words 41 | self.segment_ids = segment_ids 42 | self.start_labels = start_labels 43 | self.end_labels = end_labels 44 | self.event_type_id = event_type_id 45 | 46 | 47 | class InputFeatures(object): 48 | """A single set of features of data.""" 49 | 50 | def __init__(self, input_ids, input_mask, segment_ids, start_label_ids, end_label_ids, event_type_id): 51 | self.input_ids = input_ids 52 | self.input_mask = input_mask 53 | self.segment_ids = segment_ids 54 | self.start_label_ids = start_label_ids 55 | self.end_label_ids = end_label_ids 56 | self.event_type_id = event_type_id 57 | 58 | 59 | def read_examples_from_file(data_dir, mode): 60 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 61 | label_list = get_labels(task="trigger", mode="classification") 62 | label_map = {label: i for i, label in enumerate(label_list)} 63 | 64 | guid_index = 1 65 | examples = [] 66 | with open(file_path, encoding="utf-8") as f: 67 | words = [] 68 | labels = [] 69 | for line in f: 70 | if line=='\n' or line=='': 71 | continue 72 | line_json = json.loads(line) 73 | words = line_json['tokens'] 74 | if mode=='test': 75 | start_labels=['O']*len(words) 76 | end_labels = ['O']*len(words) 77 | event_type = 0 78 | else: 79 | start_labels = line_json['start_labels'] 80 | end_labels = line_json['end_labels'] 81 | event_type = label_map[line_json["event_type"]] 82 | if len(words)!= len(start_labels) : 83 | print(words, start_labels," length misMatch") 84 | continue 85 | if len(words)!= len(end_labels) : 86 | print(words, end_labels," length misMatch") 87 | continue 88 | segment_ids= line_json["segment_ids"] 89 | 90 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words,\ 91 | start_labels=start_labels, end_labels = end_labels, segment_ids= segment_ids, event_type_id=event_type)) 92 | guid_index += 1 93 | 94 | return examples 95 | 96 | 97 | def convert_examples_to_features( 98 | examples, 99 | label_list, 100 | max_seq_length, 101 | tokenizer, 102 | trigger_token_segment_id =1, 103 | cls_token_at_end=False, 104 | cls_token="[CLS]", 105 | cls_token_segment_id=1, 106 | sep_token="[SEP]", 107 | sep_token_extra=False, 108 | pad_on_left=False, 109 | pad_token=0, 110 | pad_token_segment_id=0, 111 | pad_token_label_id=-100, 112 | sequence_a_segment_id=0, 113 | mask_padding_with_zero=True, 114 | ): 115 | """ Loads a data file into a list of `InputBatch`s 116 | `cls_token_at_end` define the location of the CLS token: 117 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 118 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 119 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 120 | """ 121 | 122 | label_map = {label: i for i, label in enumerate(label_list)} 123 | label_map['O'] = -1 124 | # print(label_map) 125 | 126 | features = [] 127 | for (ex_index, example) in enumerate(examples): 128 | if ex_index % 10000 == 0: 129 | logger.info("Writing example %d of %d", ex_index, len(examples)) 130 | # print(example.words, example.labels) 131 | # print(len(example.words), len(example.labels)) 132 | tokens = [] 133 | start_label_ids = [] 134 | end_label_ids = [] 135 | segment_ids = [] 136 | for word, start_label, end_label, segment_id in zip(example.words, example.start_labels, example.end_labels, example.segment_ids): 137 | word_tokens = tokenizer.tokenize(word) 138 | tokens.extend(word_tokens) 139 | if len(word_tokens)==1: 140 | tokens.extend(word_tokens) 141 | if len(word_tokens)>1: 142 | print(word,">1") 143 | tokens.extend(word_tokens[:1]) 144 | pass 145 | if len(word_tokens)<1: 146 | # print(word,"<1") 基本都是空格 147 | tokens.extend(["[unused1]"]) 148 | # continue 149 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 150 | cur_start_labels = start_label.split() 151 | cur_start_label_ids = [] 152 | for cur_start_label in cur_start_labels: 153 | cur_start_label_ids.append(label_map[cur_start_label]) 154 | start_label_ids.append(cur_start_label_ids) 155 | 156 | cur_end_labels = end_label.split() 157 | cur_end_label_ids = [] 158 | for cur_end_label in cur_end_labels: 159 | cur_end_label_ids.append(label_map[cur_end_label]) 160 | end_label_ids.append(cur_end_label_ids) 161 | 162 | segment_ids.extend( [sequence_a_segment_id if not segment_id else trigger_token_segment_id] * 1) 163 | 164 | # if len(tokens)!= len(label_ids): 165 | # print(word, word_tokens, tokens, label_ids) 166 | # print(len(tokens),len(label_ids)) 167 | 168 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 169 | special_tokens_count = 3 if sep_token_extra else 2 170 | if len(tokens) > max_seq_length - special_tokens_count: 171 | tokens = tokens[: (max_seq_length - special_tokens_count)] 172 | start_label_ids = start_label_ids[: (max_seq_length - special_tokens_count)] 173 | end_label_ids = end_label_ids[: (max_seq_length - special_tokens_count)] 174 | segment_ids = segment_ids[: (max_seq_length - special_tokens_count)] 175 | 176 | 177 | # The convention in BERT is: 178 | # (a) For sequence pairs: 179 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 180 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 181 | # (b) For single sequences: 182 | # tokens: [CLS] the dog is hairy . [SEP] 183 | # type_ids: 0 0 0 0 0 0 0 184 | # 185 | # Where "type_ids" are used to indicate whether this is the first 186 | # sequence or the second sequence. The embedding vectors for `type=0` and 187 | # `type=1` were learned during pre-training and are added to the wordpiece 188 | # embedding vector (and position vector). This is not *strictly* necessary 189 | # since the [SEP] token unambiguously separates the sequences, but it makes 190 | # it easier for the model to learn the concept of sequences. 191 | # 192 | # For classification tasks, the first vector (corresponding to [CLS]) is 193 | # used as as the "sentence vector". Note that this only makes sense because 194 | # the entire model is fine-tuned. 195 | tokens += [sep_token] 196 | start_label_ids += [[pad_token_label_id]] 197 | end_label_ids += [[pad_token_label_id]] 198 | segment_ids += [sequence_a_segment_id] 199 | 200 | if sep_token_extra: 201 | # roberta uses an extra separator b/w pairs of sentences 202 | tokens += [sep_token] 203 | start_label_ids += [[pad_token_label_id]] 204 | end_label_ids += [[pad_token_label_id]] 205 | segment_ids += [sequence_a_segment_id] 206 | # segment_ids = [sequence_a_segment_id] * len(tokens) 207 | 208 | if cls_token_at_end: 209 | tokens += [cls_token] 210 | start_label_ids += [[pad_token_label_id]] 211 | end_label_ids += [[pad_token_label_id]] 212 | segment_ids += [cls_token_segment_id] 213 | else: 214 | tokens = [cls_token] + tokens 215 | start_label_ids = [[pad_token_label_id]] + start_label_ids 216 | end_label_ids = [[pad_token_label_id]] + end_label_ids 217 | segment_ids = [cls_token_segment_id] + segment_ids 218 | 219 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 220 | # print(len(tokens), len(input_ids), len(label_ids)) 221 | 222 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 223 | # tokens are attended to. 224 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 225 | 226 | # Zero-pad up to the sequence length. 227 | padding_length = max_seq_length - len(input_ids) 228 | if pad_on_left: 229 | input_ids = ([pad_token] * padding_length) + input_ids 230 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 231 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 232 | start_label_ids = ([[pad_token_label_id]] * padding_length) + start_label_ids 233 | end_label_ids = ([[pad_token_label_id]] * padding_length) + end_label_ids 234 | else: 235 | input_ids += [pad_token] * padding_length 236 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 237 | segment_ids += [pad_token_segment_id] * padding_length 238 | start_label_ids += [[pad_token_label_id]] * padding_length 239 | end_label_ids += [[pad_token_label_id]] * padding_length 240 | 241 | # print(len(label_ids), max_seq_length) 242 | 243 | assert len(input_ids) == max_seq_length 244 | assert len(input_mask) == max_seq_length 245 | assert len(segment_ids) == max_seq_length 246 | assert len(start_label_ids) == max_seq_length 247 | assert len(end_label_ids) == max_seq_length 248 | 249 | event_type_id = example.event_type_id 250 | 251 | if ex_index < 5: 252 | logger.info("*** Example ***") 253 | logger.info("guid: %s", example.guid) 254 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 255 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 256 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 257 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 258 | logger.info("start_label_ids: %s", " ".join([str(x) for x in start_label_ids])) 259 | logger.info("end_label_ids: %s", " ".join([str(x) for x in end_label_ids])) 260 | logger.info("event_type_id: %s", str(event_type_id)) 261 | 262 | if sum(segment_ids)==0: 263 | print(ex_index, "segment_id == None") 264 | continue 265 | features.append( 266 | InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, \ 267 | start_label_ids=start_label_ids, end_label_ids= end_label_ids, event_type_id=event_type_id) 268 | ) 269 | return features 270 | 271 | def convert_label_ids_to_onehot(label_ids, label_list): 272 | one_hot_labels= [[False]*len(label_list) for _ in range(len(label_ids))] 273 | label_map = {label: i for i, label in enumerate(label_list)} 274 | ignore_index= -100 275 | non_index= -1 276 | for i, label_id in enumerate(label_ids): 277 | for sub_label_id in label_id: 278 | if sub_label_id not in [ignore_index, non_index]: 279 | one_hot_labels[i][sub_label_id]= 1 280 | return one_hot_labels 281 | 282 | 283 | -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """hello world""" 17 | import os 18 | import sys 19 | import json 20 | import argparse 21 | 22 | 23 | def read_by_lines(path, encoding="utf-8"): 24 | """read the data by line""" 25 | result = list() 26 | with open(path, "r") as infile: 27 | for line in infile: 28 | result.append(line.strip()) 29 | return result 30 | 31 | 32 | def write_by_lines(path, data, t_code="utf-8"): 33 | """write the data""" 34 | with open(path, "w") as outfile: 35 | [outfile.write(d + "\n") for d in data] 36 | 37 | 38 | def data_process(path, model="trigger", is_predict=False): 39 | """data_process""" 40 | 41 | def label_data(data, start, l, _type): 42 | """label_data""" 43 | for i in range(start, start + l): 44 | suffix = u"B-" if i == start else u"I-" 45 | data[i] = u"{}{}".format(suffix, _type) 46 | return data 47 | 48 | sentences = [] 49 | output = [u"text_a"] if is_predict else [u"text_a\tlabel"] 50 | with open(path) as f: 51 | for line in f: 52 | d_json = json.loads(line.strip().decode("utf-8")) 53 | _id = d_json["id"] 54 | text_a = [ 55 | u"," if t == u" " or t == u"\n" or t == u"\t" else t 56 | for t in list(d_json["text"].lower()) 57 | ] 58 | if is_predict: 59 | sentences.append({"text": d_json["text"], "id": _id}) 60 | output.append(u'\002'.join(text_a)) 61 | else: 62 | if model == u"trigger": 63 | labels = [u"O"] * len(text_a) 64 | for event in d_json["event_list"]: 65 | event_type = event["event_type"] 66 | start = event["trigger_start_index"] 67 | trigger = event["trigger"] 68 | labels = label_data(labels, start, 69 | len(trigger), event_type) 70 | output.append(u"{}\t{}".format(u'\002'.join(text_a), 71 | u'\002'.join(labels))) 72 | elif model == u"role": 73 | for event in d_json["event_list"]: 74 | labels = [u"O"] * len(text_a) 75 | for arg in event["arguments"]: 76 | role_type = arg["role"] 77 | argument = arg["argument"] 78 | start = arg["argument_start_index"] 79 | labels = label_data(labels, start, 80 | len(argument), role_type) 81 | output.append(u"{}\t{}".format(u'\002'.join(text_a), 82 | u'\002'.join(labels))) 83 | if is_predict: 84 | return sentences, output 85 | else: 86 | return output 87 | 88 | 89 | def schema_process(path, model="trigger"): 90 | """schema_process""" 91 | 92 | def label_add(labels, _type): 93 | """label_add""" 94 | if u"B-{}".format(_type) not in labels: 95 | labels.extend([u"B-{}".format(_type), u"I-{}".format(_type)]) 96 | return labels 97 | 98 | labels = [] 99 | with open(path) as f: 100 | for line in f: 101 | d_json = json.loads(line.strip().decode("utf-8")) 102 | if model == u"trigger": 103 | labels = label_add(labels, d_json["event_type"]) 104 | elif model == u"role": 105 | for role in d_json["role_list"]: 106 | labels = label_add(labels, role["role"]) 107 | labels.append(u"O") 108 | return labels 109 | 110 | 111 | def extract_result(text, labels): 112 | """extract_result""" 113 | ret, is_start, cur_type = [], False, None 114 | for i, label in enumerate(labels): 115 | if label != u"O": 116 | _type = label[2:] 117 | if label.startswith(u"B-"): 118 | is_start = True 119 | cur_type = _type 120 | ret.append({"start": i, "text": [text[i]], "type": _type}) 121 | elif _type != cur_type: 122 | """ 123 | # 如果是没有B-开头的,则不要这部分数据 124 | cur_type = None 125 | is_start = False 126 | """ 127 | cur_type = _type 128 | is_start = True 129 | ret.append({"start": i, "text": [text[i]], "type": _type}) 130 | elif is_start: 131 | ret[-1]["text"].append(text[i]) 132 | else: 133 | cur_type = None 134 | is_start = False 135 | else: 136 | cur_type = None 137 | is_start = False 138 | return ret 139 | 140 | ## trigger-ner + role-ner 141 | def predict_data_process_ner(trigger_file, role_file, schema_file, save_path): 142 | """predict_data_process""" 143 | pred_ret = [] 144 | trigger_datas = read_by_lines(trigger_file) 145 | role_datas = read_by_lines(role_file) 146 | schema_datas = read_by_lines(schema_file) 147 | schema = {} 148 | for s in schema_datas: 149 | d_json = json.loads(s) 150 | schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]] 151 | # 将role数据进行处理 152 | sent_role_mapping = {} 153 | for d in role_datas: 154 | d_json = json.loads(d) 155 | r_ret = extract_result(d_json["text"], d_json["labels"]) 156 | role_ret = {} 157 | for r in r_ret: 158 | role_type = r["type"] 159 | if role_type not in role_ret: 160 | role_ret[role_type] = [] 161 | role_ret[role_type].append(u"".join(r["text"])) 162 | sent_role_mapping[d_json["id"]] = role_ret 163 | 164 | for d in trigger_datas: 165 | d_json = json.loads(d) 166 | t_ret = extract_result(d_json["text"], d_json["labels"]) 167 | pred_event_types = list(set([t["type"] for t in t_ret])) 168 | event_list = [] 169 | for event_type in pred_event_types: 170 | role_list = schema[event_type] 171 | arguments = [] 172 | for role_type, ags in sent_role_mapping[d_json["id"]].items(): 173 | if role_type not in role_list: 174 | continue 175 | for arg in ags: 176 | if len(arg) == 1: 177 | # 一点小trick 178 | continue 179 | arguments.append({"role": role_type, "argument": arg}) 180 | event = {"event_type": event_type, "arguments": arguments} 181 | event_list.append(event) 182 | pred_ret.append({ 183 | "id": d_json["id"], 184 | "text": d_json["text"], 185 | "event_list": event_list 186 | }) 187 | pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret] 188 | write_by_lines(save_path, pred_ret) 189 | 190 | ## trigger-ner + role-bin 191 | def predict_data_process_ner_bin(trigger_file, role_file, schema_file, save_path): 192 | """predict_data_process""" 193 | pred_ret = [] 194 | trigger_datas = read_by_lines(trigger_file) 195 | role_datas = read_by_lines(role_file) 196 | schema_datas = read_by_lines(schema_file) 197 | schema = {} 198 | for s in schema_datas: 199 | d_json = json.loads(s) 200 | schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]] 201 | # 将role数据进行处理 202 | sent_role_mapping = {} 203 | for d in role_datas: 204 | d_json = json.loads(d) 205 | arguments =d_json["arguments"] 206 | role_ret = {} 207 | for r in arguments: 208 | role_type = r["role"] 209 | if role_type not in role_ret: 210 | role_ret[role_type] = [] 211 | role_ret[role_type].append(u"".join(r["argument"])) 212 | sent_role_mapping[d_json["id"]] = role_ret 213 | 214 | for d in trigger_datas: 215 | d_json = json.loads(d) 216 | t_ret = extract_result(d_json["text"], d_json["labels"]) 217 | pred_event_types = list(set([t["type"] for t in t_ret])) 218 | event_list = [] 219 | for event_type in pred_event_types: 220 | role_list = schema[event_type] 221 | arguments = [] 222 | for role_type, ags in sent_role_mapping[d_json["id"]].items(): 223 | if role_type not in role_list: 224 | continue 225 | for arg in ags: 226 | if len(arg) == 1: 227 | # 一点小trick 228 | continue 229 | arguments.append({"role": role_type, "argument": arg}) 230 | event = {"event_type": event_type, "arguments": arguments} 231 | event_list.append(event) 232 | pred_ret.append({ 233 | "id": d_json["id"], 234 | "text": d_json["text"], 235 | "event_list": event_list 236 | }) 237 | pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret] 238 | write_by_lines(save_path, pred_ret) 239 | 240 | ## trigger-bin + role-bin 241 | def predict_data_process_bin(trigger_file, role_file, schema_file, save_path): 242 | """predict_data_process""" 243 | pred_ret = [] 244 | trigger_datas = read_by_lines(trigger_file) 245 | role_datas = read_by_lines(role_file) 246 | schema_datas = read_by_lines(schema_file) 247 | schema = {} 248 | schema_reverse = {} 249 | for s in schema_datas: 250 | d_json = json.loads(s) 251 | schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]] 252 | schema_reverse=[] 253 | # 将role数据进行处理 254 | sent_role_mapping = {} 255 | for d in role_datas: 256 | d_json = json.loads(d) 257 | arguments =d_json["arguments"] 258 | role_ret = {} 259 | for r in arguments: 260 | role_type = r["role"] 261 | if role_type not in role_ret: 262 | role_ret[role_type] = [] 263 | role_ret[role_type].append(u"".join(r["argument"])) 264 | sent_role_mapping[d_json["id"]] = role_ret 265 | 266 | for d in trigger_datas: 267 | d_json = json.loads(d) 268 | pred_event_types = d_json["labels"] 269 | event_list = [] 270 | for event_type in pred_event_types: 271 | role_list = schema[event_type] 272 | arguments = [] 273 | for role_type, ags in sent_role_mapping[d_json["id"]].items(): 274 | if role_type not in role_list: 275 | continue 276 | for arg in ags: 277 | if len(arg) == 1: 278 | # 一点小trick 279 | continue 280 | arguments.append({"role": role_type, "argument": arg}) 281 | event = {"event_type": event_type, "arguments": arguments} 282 | event_list.append(event) 283 | pred_ret.append({ 284 | "id": d_json["id"], 285 | "text": d_json["text"], 286 | "event_list": event_list 287 | }) 288 | pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret] 289 | write_by_lines(save_path, pred_ret) 290 | 291 | ## trigger-bin + role-bin + 以role为准 292 | def predict_data_process_bin2(trigger_file, role_file, schema_file, save_path): 293 | """predict_data_process""" 294 | from utils import schema_analysis 295 | argument_map= schema_analysis() 296 | 297 | pred_ret = [] 298 | trigger_datas = read_by_lines(trigger_file) 299 | role_datas = read_by_lines(role_file) 300 | schema_datas = read_by_lines(schema_file) 301 | schema = {} 302 | schema_reverse = {} 303 | for s in schema_datas: 304 | d_json = json.loads(s) 305 | schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]] 306 | schema_reverse=[] 307 | # 将role数据进行处理 308 | sent_role_mapping = {} 309 | for d in role_datas: 310 | d_json = json.loads(d) 311 | arguments =d_json["arguments"] 312 | role_ret = {} 313 | for r in arguments: 314 | role_type = r["role"] 315 | if role_type not in role_ret: 316 | role_ret[role_type] = [] 317 | role_ret[role_type].append(u"".join(r["argument"])) 318 | sent_role_mapping[d_json["id"]] = role_ret 319 | 320 | for d in trigger_datas: 321 | d_json = json.loads(d) 322 | pred_event_types = d_json["labels"] 323 | event_list = [] 324 | 325 | for role_type, ags in sent_role_mapping[d_json["id"]].items(): 326 | for event_type in pred_event_types: 327 | role_list = schema[event_type] 328 | if role_type not in role_list: 329 | event_type = argument_map 330 | pred_event_types.append 331 | continue 332 | for arg in ags: 333 | # 一点小trick 334 | if len(arg) == 1: continue 335 | arguments.append({"role": role_type, "argument": arg}) 336 | 337 | 338 | event = {"event_type": event_type, "arguments": arguments} 339 | event_list.append(event) 340 | 341 | 342 | pred_ret.append({ 343 | "id": d_json["id"], 344 | "text": d_json["text"], 345 | "event_list": event_list 346 | }) 347 | pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret] 348 | write_by_lines(save_path, pred_ret) 349 | 350 | def merge(input_file, output_file): 351 | lines = open(input_file, encoding='utf-8').read().splitlines() 352 | res =[] 353 | pre_line = {"id":""} 354 | flag= False 355 | for line in lines: 356 | json_line = json.loads(line) 357 | cur_id = json_line["id"] 358 | pre_id = pre_line["id"] 359 | if cur_id != pre_id: 360 | res.append(json_line) 361 | pre_id = cur_id 362 | pre_line = json_line 363 | flag= True 364 | else: 365 | json_line["event_list"].extend(pre_line["event_list"]) 366 | pre_line = json_line 367 | flag= False 368 | if not flag: 369 | res.append(json_line) 370 | 371 | from preprocess import write_file 372 | write_file(res, output_file) 373 | 374 | 375 | if __name__ == "__main__": 376 | 377 | # predict_data_process_ner( 378 | # trigger_file= "./output/trigger/checkpoint-best/test_predictions_indexed.json", \ 379 | # role_file = "./output/role2/checkpoint-best/test_predictions_indexed.json", \ 380 | # schema_file = "./data/event_schema/event_schema.json", \ 381 | # save_path = "./results/test_pred2.json") 382 | 383 | # predict_data_process_ner_bin( 384 | # trigger_file= "./output/trigger/checkpoint-best/test_predictions_indexed.json", \ 385 | # role_file = "./output/role_bin/merge/test_predictions_indexed_labels.json", \ 386 | # schema_file = "./data/event_schema/event_schema.json", \ 387 | # save_path = "./results/test_pred_role_bin_merge.json") 388 | 389 | predict_data_process_bin( 390 | trigger_file= "./output/trigger_classify/merge/test2_predictions_indexed_labels.json", \ 391 | role_file = "./output/role_bin/merge/test2_predictions_indexed_labels.json", \ 392 | schema_file = "./data/event_schema/event_schema.json", \ 393 | save_path = "./results/test2_pred_trigger_bin_role_bin_merge_2.json") 394 | 395 | # merge("./output/role_segment_bin/checkpoint-best/eval_predictions_indexed.json",\ 396 | # "./results/eval_pred_bi_segment.json") 397 | # merge("./output/role_segment_bin/checkpoint-best/test_predictions_indexed.json",\ 398 | # "./results/test_pred_bi_segment.json") 399 | 400 | -------------------------------------------------------------------------------- /utils_bi_ner_joint.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 17 | 18 | 19 | import logging 20 | import os 21 | import json 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class InputExample(object): 27 | """A single training/test example for token classification.""" 28 | 29 | def __init__(self, guid, words, segment_ids, \ 30 | trigger_start_labels,trigger_end_labels,role_start_labels,role_end_labels): 31 | """Constructs a InputExample. 32 | 33 | Args: 34 | guid: Unique id for the example. 35 | words: list. The words of the sequence. 36 | labels: (Optional) list. The labels for each word of the sequence. This should be 37 | specified for train and dev examples, but not for test examples. 38 | """ 39 | self.guid = guid 40 | self.words = words 41 | self.segment_ids = segment_ids 42 | self.trigger_start_labels = trigger_start_labels 43 | self.trigger_end_labels = trigger_end_labels 44 | self.role_start_labels = role_start_labels 45 | self.role_end_labels = role_end_labels 46 | 47 | 48 | 49 | class InputFeatures(object): 50 | """A single set of features of data.""" 51 | 52 | def __init__(self, input_ids, input_mask, segment_ids, \ 53 | trigger_start_label_ids, trigger_end_label_ids,\ 54 | role_start_label_ids, role_end_label_ids): 55 | self.input_ids = input_ids 56 | self.input_mask = input_mask 57 | self.segment_ids = segment_ids 58 | self.trigger_start_label_ids = trigger_start_label_ids 59 | self.trigger_end_label_ids = trigger_end_label_ids 60 | self.role_start_label_ids = role_start_label_ids 61 | self.role_end_label_ids = role_end_label_ids 62 | 63 | 64 | def read_examples_from_file(data_dir, mode): 65 | file_path = os.path.join(data_dir, "{}.json".format(mode)) 66 | guid_index = 1 67 | examples = [] 68 | with open(file_path, encoding="utf-8") as f: 69 | words = [] 70 | labels = [] 71 | for line in f: 72 | if line=='\n' or line=='': 73 | continue 74 | line_json = json.loads(line) 75 | # id = line_json['id'] 76 | words = line_json['tokens'] 77 | segment_ids= line_json["segment_ids"] 78 | 79 | if mode=='test': 80 | trigger_start_labels=['O']*len(words) 81 | trigger_end_labels = ['O']*len(words) 82 | role_start_labels=['O']*len(words) 83 | role_end_labels = ['O']*len(words) 84 | 85 | else: 86 | trigger_start_labels = line_json['trigger_start_labels'] 87 | trigger_end_labels = line_json['trigger_end_labels'] 88 | role_start_labels = line_json['role_start_labels'] 89 | role_end_labels = line_json['role_end_labels'] 90 | 91 | if len(words)!= len(trigger_start_labels) : 92 | print(words, trigger_start_labels," length misMatch") 93 | continue 94 | if len(words)!= len(trigger_end_labels) : 95 | print(words, trigger_end_labels," length misMatch") 96 | continue 97 | if len(words)!= len(role_start_labels) : 98 | print(words, role_start_labels," length misMatch") 99 | continue 100 | if len(words)!= len(role_end_labels) : 101 | print(words, role_end_labels," length misMatch") 102 | continue 103 | 104 | examples.append(InputExample(guid="{}-{}".format(mode, guid_index), \ 105 | words=words, segment_ids= segment_ids, \ 106 | trigger_start_labels=trigger_start_labels, trigger_end_labels = trigger_end_labels, \ 107 | role_start_labels=role_start_labels, role_end_labels = role_end_labels )) 108 | guid_index += 1 109 | 110 | return examples 111 | 112 | 113 | def convert_examples_to_features( 114 | examples, 115 | trigger_label_list, 116 | role_label_list, 117 | max_seq_length, 118 | tokenizer, 119 | trigger_token_segment_id = 1, 120 | cls_token_at_end=False, 121 | cls_token="[CLS]", 122 | cls_token_segment_id=1, 123 | sep_token="[SEP]", 124 | sep_token_extra=False, 125 | pad_on_left=False, 126 | pad_token=0, 127 | pad_token_segment_id=0, 128 | pad_token_label_id=-100, 129 | sequence_a_segment_id=0, 130 | mask_padding_with_zero=True, 131 | ): 132 | """ Loads a data file into a list of `InputBatch`s 133 | `cls_token_at_end` define the location of the CLS token: 134 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 135 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 136 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 137 | """ 138 | 139 | trigger_label_map = {label: i for i, label in enumerate(trigger_label_list)} 140 | trigger_label_map['O'] = -1 141 | role_label_map = {label: i for i, label in enumerate(role_label_list)} 142 | role_label_map['O'] = -1 143 | # print(label_map) 144 | 145 | features = [] 146 | for (ex_index, example) in enumerate(examples): 147 | if ex_index % 10000 == 0: 148 | logger.info("Writing example %d of %d", ex_index, len(examples)) 149 | # print(example.words, example.labels) 150 | # print(len(example.words), len(example.labels)) 151 | tokens = [] 152 | trigger_start_label_ids = [] 153 | trigger_end_label_ids = [] 154 | role_start_label_ids = [] 155 | role_end_label_ids = [] 156 | segment_ids = [] 157 | for word, segment_id, trigger_start_label, trigger_end_label, role_start_label, role_end_label \ 158 | in zip(example.words, example.segment_ids, \ 159 | example.trigger_start_labels, example.trigger_end_labels, \ 160 | example.role_start_labels, example.role_end_labels): 161 | 162 | word_tokens = tokenizer.tokenize(word) 163 | 164 | tokens.extend(word_tokens) 165 | if len(word_tokens)==1: 166 | tokens.extend(word_tokens) 167 | if len(word_tokens)>1: 168 | print(word,">1") 169 | tokens.extend(word_tokens[:1]) 170 | pass 171 | if len(word_tokens)<1: 172 | # print(word,"<1") 基本都是空格 173 | tokens.extend(["[unused1]"]) 174 | # continue 175 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 176 | ################################################## 177 | # trigger 178 | cur_start_labels = trigger_start_label.split() 179 | cur_start_label_ids = [] 180 | for cur_start_label in cur_start_labels: 181 | cur_start_label_ids.append(trigger_label_map[cur_start_label]) 182 | trigger_start_label_ids.append(cur_start_label_ids) 183 | 184 | cur_end_labels = trigger_end_label.split() 185 | cur_end_label_ids = [] 186 | for cur_end_label in cur_end_labels: 187 | cur_end_label_ids.append(trigger_label_map[cur_end_label]) 188 | trigger_end_label_ids.append(cur_end_label_ids) 189 | 190 | ################################################## 191 | # role 192 | cur_start_labels = role_start_label.split() 193 | cur_start_label_ids = [] 194 | for cur_start_label in cur_start_labels: 195 | cur_start_label_ids.append(role_label_map[cur_start_label]) 196 | role_start_label_ids.append(cur_start_label_ids) 197 | 198 | cur_end_labels = role_end_label.split() 199 | cur_end_label_ids = [] 200 | for cur_end_label in cur_end_labels: 201 | cur_end_label_ids.append(role_label_map[cur_end_label]) 202 | role_end_label_ids.append(cur_end_label_ids) 203 | 204 | 205 | segment_ids.extend( [sequence_a_segment_id if not segment_id else trigger_token_segment_id] ) 206 | 207 | 208 | # if len(tokens)!= len(label_ids): 209 | # print(word, word_tokens, tokens, label_ids) 210 | # print(len(tokens),len(label_ids)) 211 | 212 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 213 | special_tokens_count = 3 if sep_token_extra else 2 214 | if len(tokens) > max_seq_length - special_tokens_count: 215 | tokens = tokens[: (max_seq_length - special_tokens_count)] 216 | trigger_start_label_ids = trigger_start_label_ids[: (max_seq_length - special_tokens_count)] 217 | trigger_end_label_ids = trigger_end_label_ids[: (max_seq_length - special_tokens_count)] 218 | role_start_label_ids = role_start_label_ids[: (max_seq_length - special_tokens_count)] 219 | role_end_label_ids = role_end_label_ids[: (max_seq_length - special_tokens_count)] 220 | segment_ids = segment_ids[: (max_seq_length - special_tokens_count)] 221 | 222 | 223 | # The convention in BERT is: 224 | # (a) For sequence pairs: 225 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 226 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 227 | # (b) For single sequences: 228 | # tokens: [CLS] the dog is hairy . [SEP] 229 | # type_ids: 0 0 0 0 0 0 0 230 | # 231 | # Where "type_ids" are used to indicate whether this is the first 232 | # sequence or the second sequence. The embedding vectors for `type=0` and 233 | # `type=1` were learned during pre-training and are added to the wordpiece 234 | # embedding vector (and position vector). This is not *strictly* necessary 235 | # since the [SEP] token unambiguously separates the sequences, but it makes 236 | # it easier for the model to learn the concept of sequences. 237 | # 238 | # For classification tasks, the first vector (corresponding to [CLS]) is 239 | # used as as the "sentence vector". Note that this only makes sense because 240 | # the entire model is fine-tuned. 241 | tokens += [sep_token] 242 | trigger_start_label_ids += [[pad_token_label_id]] 243 | trigger_end_label_ids += [[pad_token_label_id]] 244 | role_start_label_ids += [[pad_token_label_id]] 245 | role_end_label_ids += [[pad_token_label_id]] 246 | segment_ids += [sequence_a_segment_id] 247 | 248 | if sep_token_extra: 249 | # roberta uses an extra separator b/w pairs of sentences 250 | tokens += [sep_token] 251 | trigger_start_label_ids += [[pad_token_label_id]] 252 | trigger_end_label_ids += [[pad_token_label_id]] 253 | role_start_label_ids += [[pad_token_label_id]] 254 | role_end_label_ids += [[pad_token_label_id]] 255 | segment_ids += [sequence_a_segment_id] 256 | 257 | if cls_token_at_end: 258 | tokens += [cls_token] 259 | trigger_start_label_ids += [[pad_token_label_id]] 260 | trigger_end_label_ids += [[pad_token_label_id]] 261 | role_start_label_ids += [[pad_token_label_id]] 262 | role_end_label_ids += [[pad_token_label_id]] 263 | segment_ids += [cls_token_segment_id] 264 | else: 265 | tokens = [cls_token] + tokens 266 | trigger_start_label_ids = [[pad_token_label_id]] + trigger_start_label_ids 267 | trigger_end_label_ids = [[pad_token_label_id]] + trigger_end_label_ids 268 | role_start_label_ids = [[pad_token_label_id]] + role_start_label_ids 269 | role_end_label_ids = [[pad_token_label_id]] + role_end_label_ids 270 | segment_ids = [cls_token_segment_id] + segment_ids 271 | 272 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 273 | # print(len(tokens), len(input_ids), len(label_ids)) 274 | 275 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 276 | # tokens are attended to. 277 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 278 | 279 | # Zero-pad up to the sequence length. 280 | padding_length = max_seq_length - len(input_ids) 281 | if pad_on_left: 282 | input_ids = ([pad_token] * padding_length) + input_ids 283 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 284 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 285 | trigger_start_label_ids = ([[pad_token_label_id]] * padding_length) + trigger_start_label_ids 286 | trigger_end_label_ids = ([[pad_token_label_id]] * padding_length) + trigger_end_label_ids 287 | role_start_label_ids = ([[pad_token_label_id]] * padding_length) + role_start_label_ids 288 | role_end_label_ids = ([[pad_token_label_id]] * padding_length) + role_end_label_ids 289 | else: 290 | input_ids += [pad_token] * padding_length 291 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 292 | segment_ids += [pad_token_segment_id] * padding_length 293 | trigger_start_label_ids += [[pad_token_label_id]] * padding_length 294 | trigger_end_label_ids += [[pad_token_label_id]] * padding_length 295 | role_start_label_ids += [[pad_token_label_id]] * padding_length 296 | role_end_label_ids += [[pad_token_label_id]] * padding_length 297 | 298 | # print(len(label_ids), max_seq_length) 299 | 300 | assert len(input_ids) == max_seq_length 301 | assert len(input_mask) == max_seq_length 302 | assert len(segment_ids) == max_seq_length 303 | assert len(trigger_start_label_ids) == max_seq_length 304 | assert len(trigger_end_label_ids) == max_seq_length 305 | assert len(role_start_label_ids) == max_seq_length 306 | assert len(role_end_label_ids) == max_seq_length 307 | 308 | if ex_index < 5: 309 | logger.info("*** Example ***") 310 | logger.info("guid: %s", example.guid) 311 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 312 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 313 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 314 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 315 | logger.info("trigger_start_label_ids: %s", " ".join([str(x) for x in trigger_start_label_ids])) 316 | logger.info("trigger_end_label_ids: %s", " ".join([str(x) for x in trigger_end_label_ids])) 317 | logger.info("role_start_label_ids: %s", " ".join([str(x) for x in role_start_label_ids])) 318 | logger.info("role_end_label_ids: %s", " ".join([str(x) for x in role_end_label_ids])) 319 | 320 | features.append( 321 | InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, \ 322 | trigger_start_label_ids=trigger_start_label_ids, trigger_end_label_ids= trigger_end_label_ids, \ 323 | role_start_label_ids=role_start_label_ids, role_end_label_ids= role_end_label_ids ) 324 | ) 325 | return features 326 | 327 | def convert_label_ids_to_onehot(label_ids, label_list): 328 | one_hot_labels= [[False]*len(label_list) for _ in range(len(label_ids))] 329 | label_map = {label: i for i, label in enumerate(label_list)} 330 | ignore_index= -100 331 | non_index= -1 332 | for i, label_id in enumerate(label_ids): 333 | for sub_label_id in label_id: 334 | if sub_label_id not in [ignore_index, non_index]: 335 | one_hot_labels[i][sub_label_id]= 1 336 | return one_hot_labels 337 | 338 | 339 | 340 | 341 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from utils import get_labels, write_file 4 | from postprocess import extract_result 5 | 6 | 7 | def trigger_classify_file_remove_id(input_file, output_file): 8 | rows = open(input_file, encoding='utf-8').read().splitlines() 9 | results = [] 10 | for row in rows: 11 | if len(row)==1: print(row) 12 | row = json.loads(row) 13 | row.pop("id") 14 | row.pop("text") 15 | results.append(row) 16 | write_file(results,output_file) 17 | 18 | def trigger_classify_process(input_file, output_file, is_predict=False): 19 | rows = open(input_file, encoding='utf-8').read().splitlines() 20 | results = [] 21 | count = 0 22 | for row in rows: 23 | if len(row)==1: print(row) 24 | row = json.loads(row) 25 | count += 1 26 | if "id" not in row: 27 | row["id"]=count 28 | labels = [] 29 | if is_predict: 30 | results.append({"id":row["id"], "text":row["text"], "labels":labels}) 31 | continue 32 | for event in row["event_list"]: 33 | event_type = event["event_type"] 34 | labels.append(event_type) 35 | labels = list(set(labels)) 36 | results.append({"id":row["id"], "text":row["text"], "labels":labels}) 37 | write_file(results,output_file) 38 | 39 | def trigger_process_bio(input_file, output_file, is_predict=False): 40 | rows = open(input_file, encoding='utf-8').read().splitlines() 41 | results = [] 42 | for row in rows: 43 | if len(row)==1: print(row) 44 | row = json.loads(row) 45 | labels = ['O']*len(row["text"]) 46 | if is_predict: 47 | results.append({"id":row["id"], "tokens":list(row["text"]), "labels":labels}) 48 | continue 49 | for event in row["event_list"]: 50 | trigger = event["trigger"] 51 | event_type = event["event_type"] 52 | trigger_start_index = event["trigger_start_index"] 53 | labels[trigger_start_index]= "B-{}".format(event_type) 54 | for i in range(1, len(trigger)): 55 | labels[trigger_start_index+i]= "I-{}".format(event_type) 56 | results.append({"id":row["id"], "tokens":list(row["text"]), "labels":labels}) 57 | write_file(results,output_file) 58 | 59 | def trigger_process_binary(input_file, output_file, is_predict=False): 60 | rows = open(input_file, encoding='utf-8').read().splitlines() 61 | results = [] 62 | for row in rows: 63 | if len(row)==1: print(row) 64 | row = json.loads(row) 65 | start_labels = ['O']*len(row["text"]) 66 | end_labels = ['O']*len(row["text"]) 67 | if is_predict: 68 | results.append({"id":row["id"], "tokens":list(row["text"]), "start_labels":start_labels, "end_labels":end_labels}) 69 | continue 70 | for event in row["event_list"]: 71 | trigger = event["trigger"] 72 | event_type = event["event_type"] 73 | trigger_start_index = event["trigger_start_index"] 74 | trigger_end_index = trigger_start_index + len(trigger) - 1 75 | start_labels[trigger_start_index]= event_type 76 | end_labels[trigger_end_index]= event_type 77 | results.append({"id":row["id"], "tokens":list(row["text"]), "start_labels":start_labels, "end_labels":end_labels}) 78 | write_file(results,output_file) 79 | 80 | def role_process(input_file, output_file, is_predict=False): 81 | rows = open(input_file, encoding='utf-8').read().splitlines() 82 | results = [] 83 | for row in rows: 84 | if len(row)==1: print(row) 85 | row = json.loads(row) 86 | labels = ['O']*len(row["text"]) 87 | if is_predict: 88 | results.append({"id":row["id"], "tokens":list(row["text"]), "labels":labels}) 89 | continue 90 | for event in row["event_list"]: 91 | event_type = event["event_type"] 92 | for arg in event["arguments"]: 93 | role = arg['role'] 94 | argument = arg['argument'] 95 | argument_start_index = arg["argument_start_index"] 96 | labels[argument_start_index]= "B-{}".format(role) 97 | for i in range(1, len(argument)): 98 | labels[argument_start_index+i]= "I-{}".format(role) 99 | if arg['alias']!=[]: print(arg['alias']) 100 | results.append({"id":row["id"], "tokens":list(row["text"]), "labels":labels}) 101 | write_file(results,output_file) 102 | 103 | def role_process_segment(input_file, output_file, is_predict=False): 104 | rows = open(input_file, encoding='utf-8').read().splitlines() 105 | results = [] 106 | len_text = [] 107 | for row in rows: 108 | if len(row)==1: print(row) 109 | row = json.loads(row) 110 | len_text.append(len(row["text"])) 111 | if len(list(row["text"]))!= len(row["text"]): 112 | print("list and text mismatched") 113 | labels = ['O']*len(row["text"]) 114 | if is_predict: 115 | results.append({"id":row["id"], "tokens":list(row["text"]), "labels":labels}) 116 | continue 117 | for event in row["event_list"]: 118 | event_type = event["event_type"] 119 | trigger = event["trigger"] 120 | trigger_start_index = event["trigger_start_index"] 121 | segment_ids= [0] * len(row["text"]) 122 | for i in range(trigger_start_index, trigger_start_index+ len(trigger) ): 123 | segment_ids[i] = 1 124 | 125 | for arg in event["arguments"]: 126 | role = arg['role'] 127 | argument = arg['argument'] 128 | argument_start_index = arg["argument_start_index"] 129 | labels[argument_start_index]= "B-{}".format(role) 130 | for i in range(1, len(argument)): 131 | labels[argument_start_index+i]= "I-{}".format(role) 132 | if arg['alias']!=[]: print(arg['alias']) 133 | 134 | results.append({"id":row["id"], "event_type":event_type, "segment_ids":segment_ids,\ 135 | "tokens":list(row["text"]), "labels":labels}) 136 | write_file(results,output_file) 137 | print(min(len_text), max(len_text), mean(len_text)) 138 | 139 | def role_process_binary(input_file, output_file, is_predict=False): 140 | label_list = get_labels(task= "role", mode="classification") 141 | label_map = {label: i for i, label in enumerate(label_list)} 142 | rows = open(input_file, encoding='utf-8').read().splitlines() 143 | results = [] 144 | count = 0 145 | for row in rows: 146 | if len(row)==1: print(row) 147 | row = json.loads(row) 148 | count += 1 149 | if "id" not in row: 150 | row["id"]=count 151 | start_labels = ['O']*len(row["text"]) 152 | end_labels = ['O']*len(row["text"]) 153 | arguments = [] 154 | if is_predict: 155 | results.append({"id":row["id"], "tokens":list(row["text"]), "start_labels":start_labels, "end_labels":end_labels, "arguments":arguments}) 156 | continue 157 | for event in row["event_list"]: 158 | event_type = event["event_type"] 159 | for arg in event["arguments"]: 160 | role = arg['role'] 161 | role_id = label_map[role] 162 | argument = arg['argument'] 163 | argument_start_index = arg["argument_start_index"] 164 | argument_end_index = argument_start_index + len(argument) -1 165 | 166 | if start_labels[argument_start_index]=="O": 167 | start_labels[argument_start_index] = role 168 | else: 169 | start_labels[argument_start_index] += (" "+ role) 170 | if end_labels[argument_end_index]=="O": 171 | end_labels[argument_end_index] = role 172 | else: 173 | end_labels[argument_end_index] += (" "+ role) 174 | 175 | if arg['alias']!=[]: print(arg['alias']) 176 | 177 | arg.pop('alias') 178 | arguments.append(arg) 179 | 180 | results.append({"id":row["id"], "tokens":list(row["text"]), "start_labels":start_labels, "end_labels":end_labels, "arguments":arguments}) 181 | write_file(results,output_file) 182 | 183 | def role_process_segment_binary(input_file, output_file, is_predict=False): 184 | label_list = get_labels(task= "role", mode="classification") 185 | label_map = {label: i for i, label in enumerate(label_list)} 186 | rows = open(input_file, encoding='utf-8').read().splitlines() 187 | results = [] 188 | for row in rows: 189 | if len(row)==1: print(row) 190 | row = json.loads(row) 191 | if is_predict: 192 | results.append({"id":row["id"], "tokens":list(row["text"]), \ 193 | "start_labels":['O']*len(row["text"]), "end_labels":['O']*len(row["text"])}) 194 | continue 195 | for event in row["event_list"]: 196 | event_type = event["event_type"] 197 | trigger = event["trigger"] 198 | trigger_start_index = event["trigger_start_index"] 199 | segment_ids= [0] * len(row["text"]) 200 | for i in range(trigger_start_index, trigger_start_index+ len(trigger) ): 201 | segment_ids[i] = 1 202 | start_labels = ['O']*len(row["text"]) 203 | end_labels = ['O']*len(row["text"]) 204 | 205 | for arg in event["arguments"]: 206 | role = arg['role'] 207 | role_id = label_map[role] 208 | argument = arg['argument'] 209 | argument_start_index = arg["argument_start_index"] 210 | argument_end_index = argument_start_index + len(argument) -1 211 | 212 | if start_labels[argument_start_index]=="O": 213 | start_labels[argument_start_index] = role 214 | else: 215 | start_labels[argument_start_index] += (" "+ role) 216 | if end_labels[argument_end_index]=="O": 217 | end_labels[argument_end_index] = role 218 | else: 219 | end_labels[argument_end_index] += (" "+ role) 220 | 221 | if arg['alias']!=[]: print(arg['alias']) 222 | results.append({"id":row["id"], "tokens":list(row["text"]), "event_type":event_type, \ 223 | "segment_ids":segment_ids,"start_labels":start_labels, "end_labels":end_labels}) 224 | write_file(results,output_file) 225 | 226 | 227 | 228 | def joint_process_binary(input_file, output_file, is_predict=False): 229 | label_list = get_labels(task= "role", mode="classification") 230 | label_map = {label: i for i, label in enumerate(label_list)} 231 | rows = open(input_file, encoding='utf-8').read().splitlines() 232 | results = [] 233 | for row in rows: 234 | if len(row)==1: print(row) 235 | row = json.loads(row) 236 | 237 | if is_predict: 238 | results.append({"id":row["id"], "tokens":list(row["text"]), \ 239 | "trigger_start_labels":['O']*len(row["text"]), "role_end_labels":['O']*len(row["text"]), \ 240 | "role_start_labels":['O']*len(row["text"]), "role_end_labels":['O']*len(row["text"])}) 241 | continue 242 | 243 | trigger_start_labels = ['O']*len(row["text"]) 244 | trigger_end_labels = ['O']*len(row["text"]) 245 | 246 | # trigger 247 | for event in row["event_list"]: 248 | event_type = event["event_type"] 249 | trigger = event["trigger"] 250 | trigger_start_index = event['trigger_start_index'] 251 | trigger_end_index = trigger_start_index + len(trigger) -1 252 | if trigger_start_labels[trigger_start_index]=="O": 253 | trigger_start_labels[trigger_start_index] = event_type 254 | else: 255 | trigger_start_labels[trigger_start_index] += (" "+ event_type) 256 | if trigger_end_labels[trigger_end_index]=="O": 257 | trigger_end_labels[trigger_end_index] = event_type 258 | else: 259 | trigger_end_labels[trigger_end_index] += (" "+ event_type) 260 | 261 | # role 262 | for event in row["event_list"]: 263 | event_type = event["event_type"] 264 | trigger = event["trigger"] 265 | trigger_start_index = event['trigger_start_index'] 266 | segment_ids= [0] * len(row["text"]) 267 | for i in range(trigger_start_index, trigger_start_index+ len(trigger) ): 268 | segment_ids[i] = 1 269 | 270 | role_start_labels = ['O']*len(row["text"]) 271 | role_end_labels = ['O']*len(row["text"]) 272 | 273 | for arg in event["arguments"]: 274 | role = arg['role'] 275 | role_id = label_map[role] 276 | argument = arg['argument'] 277 | argument_start_index = arg["argument_start_index"] 278 | argument_end_index = argument_start_index + len(argument) -1 279 | 280 | if role_start_labels[argument_start_index]=="O": 281 | role_start_labels[argument_start_index] = role 282 | else: 283 | role_start_labels[argument_start_index] += (" "+ role) 284 | 285 | if role_end_labels[argument_end_index]=="O": 286 | role_end_labels[argument_end_index] = role 287 | else: 288 | role_end_labels[argument_end_index] += (" "+ role) 289 | 290 | if arg['alias']!=[]: print(arg['alias']) 291 | results.append({"id":row["id"], "tokens":list(row["text"]), "segment_ids":segment_ids, \ 292 | "trigger_start_labels":trigger_start_labels, "trigger_end_labels":trigger_end_labels, \ 293 | "role_start_labels":role_start_labels, "role_end_labels":role_end_labels}) 294 | 295 | write_file(results,output_file) 296 | 297 | 298 | 299 | def role_process_filter(event_class, input_file, output_file, is_predict=False): 300 | rows = open(input_file, encoding='utf-8').read().splitlines() 301 | results = [] 302 | for row in rows: 303 | if len(row)==1: print(row) 304 | row = json.loads(row) 305 | labels = ['O']*len(row["text"]) 306 | if is_predict: continue 307 | flag = False 308 | for event in row["event_list"]: 309 | event_type = event["event_type"] 310 | if event_class != event["class"]: 311 | continue 312 | flag = True 313 | for arg in event["arguments"]: 314 | role = arg['role'] 315 | argument = arg['argument'] 316 | argument_start_index = arg["argument_start_index"] 317 | labels[argument_start_index]= "B-{}".format(role) 318 | for i in range(1, len(argument)): 319 | labels[argument_start_index+i]= "I-{}".format(role) 320 | if not flag: continue 321 | results.append({"id":row["id"], "tokens":list(row["text"]), "labels":labels}) 322 | write_file(results,output_file) 323 | 324 | def get_event_class(schema_file): 325 | rows = open(schema_file, encoding='utf-8').read().splitlines() 326 | labels=[] 327 | for row in rows: 328 | row = json.loads(row) 329 | event_class = row["class"] 330 | if event_class in labels: 331 | continue 332 | labels.append(event_class) 333 | return labels 334 | 335 | def index_output_bio_trigger(test_file, prediction_file, output_file): 336 | tests = open(test_file, encoding='utf-8').read().splitlines() 337 | predictions = open(prediction_file, encoding='utf-8').read().splitlines() 338 | results = [] 339 | index = 0 340 | max_length = 256-2 341 | for test, prediction in zip(tests, predictions): 342 | index += 1 343 | test = json.loads(test) 344 | tokens = test.pop('tokens') 345 | test['text'] = ''.join(tokens) 346 | 347 | prediction = json.loads(prediction) 348 | labels = prediction["labels"] 349 | if len(labels)!=len(tokens) and len(labels) != max_length: 350 | print(labels, tokens) 351 | print(len(labels), len(tokens), index) 352 | break 353 | test["labels"] = labels 354 | 355 | results.append(test) 356 | write_file(results, output_file) 357 | 358 | def index_output_bin_trigger(test_file, prediction_file, output_file): 359 | tests = open(test_file, encoding='utf-8').read().splitlines() 360 | predictions = open(prediction_file, encoding='utf-8').read().splitlines() 361 | results = [] 362 | index = 0 363 | for test, prediction in zip(tests, predictions): 364 | index += 1 365 | test = json.loads(test) 366 | 367 | prediction = json.loads(prediction) 368 | labels = prediction["labels"] 369 | test["labels"] = labels 370 | 371 | results.append(test) 372 | write_file(results, output_file) 373 | 374 | def index_output_bio_arg(test_file, prediction_file, output_file): 375 | tests = open(test_file, encoding='utf-8').read().splitlines() 376 | predictions = open(prediction_file, encoding='utf-8').read().splitlines() 377 | results = [] 378 | index = 0 379 | max_length = 256-2 380 | for test, prediction in zip(tests, predictions): 381 | index += 1 382 | test = json.loads(test) 383 | tokens = test.pop('tokens') 384 | test['text'] = ''.join(tokens) 385 | 386 | prediction = json.loads(prediction) 387 | labels = prediction["labels"] 388 | if len(labels)!=len(tokens) and len(labels) != max_length: 389 | print(labels, tokens) 390 | print(len(labels), len(tokens), index) 391 | break 392 | 393 | args = extract_result(test["text"], labels) 394 | arguments = [] 395 | for arg in args: 396 | argument = {} 397 | argument["role"] = arg["type"] 398 | argument["argument_start_index"] = arg['start'] 399 | argument["argument"] =''.join(arg['text']) 400 | arguments.append(argument) 401 | 402 | test.pop("labels") 403 | test["arguments"] = arguments 404 | results.append(test) 405 | write_file(results, output_file) 406 | 407 | 408 | def index_output_segment_bin(test_file, prediction_file, output_file): 409 | label_list = get_labels(task='role', mode="classification") 410 | label_map = {i: label for i, label in enumerate(label_list)} 411 | 412 | tests = open(test_file, encoding='utf-8').read().splitlines() 413 | predictions = open(prediction_file, encoding='utf-8').read().splitlines() 414 | results = [] 415 | index = 0 416 | max_length = 256-2 417 | for test, prediction in zip(tests, predictions): 418 | index += 1 419 | test = json.loads(test) 420 | start_labels = test.pop('start_labels') 421 | end_labels = test.pop('end_labels') 422 | 423 | tokens = test.pop('tokens') 424 | text = ''.join(tokens) 425 | test['text'] = text 426 | 427 | segment_ids = test.pop('segment_ids') 428 | trigger = ''.join([tokens[i] for i in range(len(tokens)) if segment_ids[i]]) 429 | for i in range(len(tokens)): 430 | if segment_ids[i]: 431 | trigger_start_index = i 432 | break 433 | 434 | event = {} 435 | # event['trigger'] = trigger 436 | # event['trigger_start_index']= trigger_start_index 437 | event_type = test.pop("event_type") 438 | event["event_type"]=event_type 439 | 440 | prediction = json.loads(prediction) 441 | arg_list = prediction["labels"] 442 | arguments =[] 443 | for arg in arg_list: 444 | sub_dict = {} 445 | argument_start_index = arg[1] -1 446 | argument_end_index = arg[2] -1 447 | argument = text[argument_start_index:argument_end_index+1] 448 | role = label_map[arg[3]] 449 | sub_dict["role"]=role 450 | sub_dict["argument"]=argument 451 | # sub_dict["argument_start_index"] = argument_start_index 452 | arguments.append(sub_dict) 453 | 454 | event["arguments"]= arguments 455 | 456 | test['event_list']= [event] 457 | results.append(test) 458 | write_file(results, output_file) 459 | 460 | def index_output_bin_arg(test_file, prediction_file, output_file): 461 | label_list = get_labels(task='role', mode="classification") 462 | label_map = {i: label for i, label in enumerate(label_list)} 463 | 464 | tests = open(test_file, encoding='utf-8').read().splitlines() 465 | predictions = open(prediction_file, encoding='utf-8').read().splitlines() 466 | results = [] 467 | index = 0 468 | max_length = 256-2 469 | for test, prediction in zip(tests, predictions): 470 | index += 1 471 | test = json.loads(test) 472 | start_labels = test.pop('start_labels') 473 | end_labels = test.pop('end_labels') 474 | 475 | tokens = test.pop('tokens') 476 | text = ''.join(tokens) 477 | test['text'] = text 478 | 479 | prediction = json.loads(prediction) 480 | arg_list = prediction["labels"] 481 | arguments =[] 482 | for arg in arg_list: 483 | sub_dict = {} 484 | argument_start_index = arg[1] -1 485 | argument_end_index = arg[2] -1 486 | argument = text[argument_start_index:argument_end_index+1] 487 | role = label_map[arg[3]] 488 | sub_dict["role"]=role 489 | sub_dict["argument"]=argument 490 | sub_dict["argument_start_index"] = argument_start_index 491 | arguments.append(sub_dict) 492 | 493 | test["arguments"]= arguments 494 | results.append(test) 495 | write_file(results, output_file) 496 | 497 | 498 | # un-finished 499 | # def binary_to_bio(test_file, prediction_file, output_file): 500 | # tests = open(test_file, encoding='utf-8').read().splitlines() 501 | # predictions = open(prediction_file, encoding='utf-8').read().splitlines() 502 | # results = [] 503 | # for test,prediction in zip(tests, predictions): 504 | # test = json.loads(test) 505 | # tokens = test.pop('tokens') 506 | # test['text'] = ''.join(tokens) 507 | 508 | # row_preds_list = json.loads(prediction) 509 | 510 | # labels= ['O']*len(tokens) 511 | # # for pred in row_preds_list: 512 | 513 | # test.update(prediction) 514 | 515 | # results.append(test) 516 | # write_file(results, output_file) 517 | 518 | # ner_segment_bi 输入的预处理函数 519 | 520 | def convert_bio_to_segment(input_file, output_file): 521 | lines = open(input_file, encoding='utf-8').read().splitlines() 522 | res = [] 523 | for line in lines: 524 | line = json.loads(line) 525 | text = line["text"] 526 | labels = line["labels"] 527 | tokens = list(text) 528 | if len(labels)!=len(tokens): 529 | print(len(labels), len(tokens)) 530 | 531 | triggers = extract_result(text, labels) 532 | if len(triggers)==0: 533 | print("detect no trigger") 534 | for trigger in triggers: 535 | event_type= trigger["type"] 536 | segment_ids = [0]*(len(tokens)) 537 | trigger_start_index = trigger['start'] 538 | trigger_end_index = trigger['start'] + len(trigger['text']) 539 | for i in range(trigger_start_index, trigger_end_index): 540 | segment_ids[i] = 1 541 | start_labels = ['O']*(len(tokens)) 542 | end_labels = ['O']*(len(tokens)) 543 | 544 | cur_line = {} 545 | cur_line["id"] = line["id"] 546 | cur_line["tokens"] = tokens 547 | cur_line["event_type"] = event_type 548 | cur_line["segment_ids"] = segment_ids 549 | cur_line["start_labels"] = start_labels 550 | cur_line["end_labels"] = end_labels 551 | res.append(cur_line) 552 | write_file(res, output_file) 553 | 554 | 555 | def convert_bio_to_label(input_file, output_file): 556 | lines = open(input_file, encoding='utf-8').read().splitlines() 557 | res = [] 558 | for line in lines: 559 | line_json = json.loads(line) 560 | labels = [] 561 | for label in line_json["labels"]: 562 | if label.startswith("B-") and label[2:] not in labels: 563 | labels.append(label[2:]) 564 | res.append({"labels":labels}) 565 | write_file(res, output_file) 566 | 567 | def compute_matric(label_file, pred_file): 568 | label_lines = open(label_file, encoding='utf-8').read().splitlines() 569 | pred_lines = open(pred_file, encoding='utf-8').read().splitlines() 570 | 571 | labels = [] 572 | for i, line in enumerate(label_lines): 573 | json_line = json.loads(line) 574 | for label in json_line['labels']: 575 | labels.append([i, label]) 576 | 577 | preds = [] 578 | for i, line in enumerate(pred_lines): 579 | json_line = json.loads(line) 580 | for label in json_line['labels']: 581 | preds.append([i, label]) 582 | 583 | nb_correct = 0 584 | for out_label in labels: 585 | if out_label in preds: 586 | nb_correct += 1 587 | continue 588 | nb_pred = len(preds) 589 | nb_true = len(labels) 590 | # print(nb_correct, nb_pred, nb_true) 591 | 592 | p = nb_correct / nb_pred if nb_pred > 0 else 0 593 | r = nb_correct / nb_true if nb_true > 0 else 0 594 | f1 = 2 * p * r / (p + r) if p + r > 0 else 0 595 | 596 | print(p, r, f1) 597 | 598 | 599 | 600 | def split_data(input_file, output_dir, num_split=5): 601 | datas = open(input_file, encoding='utf-8').read().splitlines() 602 | for i in range(num_split): 603 | globals()["train_data"+str(i+1)] = [] 604 | globals()["dev_data"+str(i+1)] = [] 605 | for i, data in enumerate(datas): 606 | cur = i % num_split + 1 607 | for j in range(num_split): 608 | if cur == j+1: 609 | globals()["dev_data" + str(j + 1)].append(json.loads(data)) 610 | else: 611 | globals()["train_data"+str(j + 1)].append(json.loads(data)) 612 | for i in range(num_split): 613 | cur_dir = os.path.join(output_dir, str(i)) 614 | if not os.path.exists(cur_dir): 615 | os.makedirs(cur_dir) 616 | write_file(globals()["train_data"+str(i + 1)], os.path.join(cur_dir, "train.json")) 617 | write_file(globals()["dev_data"+str(i + 1)], os.path.join(cur_dir, "dev.json")) 618 | 619 | 620 | if __name__ == '__main__': 621 | 622 | # trigger_classify_file_remove_id("./data/trigger_classify/dev.json", "./data/trigger_classify/dev_without_id.json") 623 | 624 | # split_data("./data/trigger_classify/train.json", "./data/trigger_classify", num_split=5) 625 | # split_data("./data/role_bin_train_dev/train_dev.json", "./data/role_bin_train_dev", num_split=7) 626 | 627 | # trigger_process_bio("./data/train_data/train.json", "./data/trigger/train.json") 628 | 629 | # trigger_classify_process("./data/train_data/train.json", "./data/trigger_classify/train.json") 630 | # trigger_classify_process("./data/dev_data/dev.json", "./data/trigger_classify/dev.json") 631 | # trigger_classify_process("./data/test2_data/test2.json", "./data/trigger_classify/test/test.json", is_predict=True) 632 | 633 | # trigger_process_binary("./data/train_data/train.json", "./data/trigger_bin/train.json") 634 | # trigger_process_binary("./data/dev_data/dev.json","./data/trigger_bin/dev.json") 635 | # trigger_process_binary("./data/test1_data/test1.json", "./data/trigger_bin/test.json",is_predict=True) 636 | 637 | # role_process_binary("./data/train_data/train.json", "./data/role_bin/train.json") 638 | # role_process_binary("./data/dev_data/dev.json","./data/role_bin/dev.json") 639 | # role_process_binary("./data/test2_data/test2.json", "./data/role_bin/test/test.json",is_predict=True) 640 | 641 | # role_process_segment("./data/train_data/train.json", "./data/role_segment/train.json") 642 | # role_process_segment("./data/dev_data/dev.json","./data/role_segment/dev.json") 643 | # role_process_segment("./data/test1_data/test1.json", "./data/role_segment/test.json",is_predict=True) 644 | 645 | # role_process_segment_binary("./data/train_data/train.json", "./data/role_segment_bin/train.json") 646 | # role_process_segment_binary("./data/dev_data/dev.json","./data/role_segment_bin/dev.json") 647 | 648 | # joint_process_binary("./data/train_data/train.json", "./data/joint_bin/train.json") 649 | # joint_process_binary("./data/dev_data/dev.json","./data/joint_bin/dev.json") 650 | # joint_process_binary("./data/test1_data/test1.json", "./data/joint_bin/test.json",is_predict=True) 651 | 652 | # event_class_list = get_event_class("./data/event_schema/event_schema.json") 653 | # for event_class in event_class_list: 654 | # if not os.path.exists("./data/role/{}".format(event_class)): 655 | # os.makedirs("./data/role/{}".format(event_class)) 656 | # role_process_filter(event_class, "./data/train_data/train.json", "./data/role/{}/train.json".format(event_class)) 657 | # role_process_filter(event_class, "./data/dev_data/dev.json","./data/role/{}/dev.json".format(event_class)) 658 | 659 | # index_output_bio_trigger("./data/trigger/dev.json" , "./output/trigger/checkpoint-best/eval_predictions.json","./output/trigger/checkpoint-best/eval_predictions_indexed.json" ) 660 | # index_output_bio_trigger("./data/trigger/test.json" , "./output/trigger/checkpoint-best/test_predictions.json","./output/trigger/checkpoint-best/test_predictions_indexed.json" ) 661 | 662 | # index_output_bin_trigger("./data/trigger_classify/dev.json" , "./output/trigger_classify/merge/eval_predictions_labels.json","./output/trigger_classify/merge/eval_predictions_indexed_labels.json" ) 663 | index_output_bin_trigger("./data/trigger_classify/test/test.json" , "./output/trigger_classify/merge/test2_predictions_labels.json","./output/trigger_classify/merge/test2_predictions_indexed_labels.json" ) 664 | 665 | # index_output_bio_arg("./data/role/dev.json" , "./output/role/checkpoint-best/eval_predictions.json","./output/role/checkpoint-best/eval_predictions_labels.json" ) 666 | # index_output_bio_arg("./data/role/test.json" , "./output/role/checkpoint-best/test_predictions.json","./output/role/checkpoint-best/test_predictions_indexed.json" ) 667 | 668 | # index_output_segment_bin("./data/role_segment_bin/dev.json" , "./output/role_segment_bin/checkpoint-best/eval_predictions.json","./output/role_segment_bin/checkpoint-best/eval_predictions_indexed.json" ) 669 | # index_output_segment_bin("./data/role_segment_bin/test.json" , "./output/role_segment_bin/checkpoint-best/test_predictions.json","./output/role_segment_bin/checkpoint-best/test_predictions_indexed.json" ) 670 | 671 | # index_output_bin_arg("./data/role_bin/dev.json" , "./output/role_bin/merge/eval_predictions_labels.json","./output/role_bin/merge/eval_predictions_indexed_labels.json" ) 672 | index_output_bin_arg("./data/role_bin/test/test.json" , "./output/role_bin/merge/test2_predictions_labels.json","./output/role_bin/merge/test2_predictions_indexed_labels.json" ) 673 | 674 | # convert_bio_to_segment("./output/trigger/checkpoint-best/test_predictions_indexed.json",\ 675 | # "./output/trigger/checkpoint-best/test_predictions_indexed_semgent_id.json") 676 | 677 | # convert_bio_to_label("./output/trigger/checkpoint-best/eval_predictions.json",\ 678 | # "./output/trigger/checkpoint-best/eval_predictions_labels.json") 679 | # compute_matric("./data/trigger_classify/dev.json", "./output/trigger/checkpoint-best/eval_predictions_labels.json") 680 | 681 | pass 682 | --------------------------------------------------------------------------------