├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── collate_functions.py ├── mrc_ner_dataset.py ├── tagger_ner_dataset.py └── truncate_dataset.py ├── evaluate ├── mrc_ner_evaluate.py └── tagger_ner_evaluate.py ├── inference ├── mrc_ner_inference.py └── tagger_ner_inference.py ├── metrics ├── __init__.py ├── functional │ ├── __init__.py │ ├── query_span_f1.py │ └── tagger_span_f1.py ├── query_span_f1.py └── tagger_span_f1.py ├── models ├── __init__.py ├── bert_query_ner.py ├── bert_tagger.py ├── classifier.py └── model_config.py ├── ner2mrc ├── __init__.py ├── download.md ├── genia2mrc.py ├── msra2mrc.py └── queries │ ├── genia.json │ └── zh_msra.json ├── requirements.txt ├── scripts ├── bert_tagger │ ├── evaluate.sh │ ├── inference.sh │ └── reproduce │ │ ├── conll03.sh │ │ ├── msra.sh │ │ └── onto4.sh └── mrc_ner │ ├── evaluate.sh │ ├── flat_inference.sh │ ├── nested_inference.sh │ └── reproduce │ ├── ace04.sh │ ├── ace05.sh │ ├── conll03.sh │ ├── genia.sh │ ├── kbp17.sh │ ├── msra.sh │ ├── onto4.sh │ └── onto5.sh ├── tests ├── assert_correct_dataset.py ├── bert_tokenizer.py ├── collect_entity_labels.py ├── count_mrc_max_length.py ├── count_sequence_max_length.py ├── extract_entity_span.py └── illegal_entity_boundary.py ├── train ├── bert_tagger_trainer.py └── mrc_ner_trainer.py └── utils ├── __init__.py ├── bmes_decode.py ├── convert_tf2torch.sh ├── get_parser.py └── random_seed.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | *.DS_Store 3 | 4 | # Logs 5 | logs/* 6 | 7 | # 8 | experiments/* 9 | log/* 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # backup files 20 | bk 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | # Do Not push origin intermediate logging files 120 | *.log 121 | *.out 122 | 123 | # mac book 124 | .DS_Store 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Unified MRC Framework for Named Entity Recognition 2 | The repository contains the code of the recent research advances in [Shannon.AI](http://www.shannonai.com). 3 | 4 | **A Unified MRC Framework for Named Entity Recognition**
5 | Xiaoya Li, Jingrong Feng, Yuxian Meng, Qinghong Han, Fei Wu and Jiwei Li
6 | In ACL 2020. [paper](https://arxiv.org/abs/1910.11476)
7 | If you find this repo helpful, please cite the following: 8 | ```latex 9 | @article{li2019unified, 10 | title={A Unified MRC Framework for Named Entity Recognition}, 11 | author={Li, Xiaoya and Feng, Jingrong and Meng, Yuxian and Han, Qinghong and Wu, Fei and Li, Jiwei}, 12 | journal={arXiv preprint arXiv:1910.11476}, 13 | year={2019} 14 | } 15 | ``` 16 | For any question, please feel free to post Github issues.
17 | 18 | ## Install Requirements 19 | 20 | * The code requires Python 3.6+. 21 | 22 | * If you are working on a GPU machine with CUDA 10.1, please run `pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html` to install PyTorch. If not, please see the [PyTorch Official Website](https://pytorch.org/) for instructions. 23 | 24 | * Then run the following script to install the remaining dependenices: `pip install -r requirements.txt` 25 | 26 | We build our project on [pytorch-lightning.](https://github.com/PyTorchLightning/pytorch-lightning) 27 | If you want to know more about the arguments used in our training scripts, please 28 | refer to [pytorch-lightning documentation.](https://pytorch-lightning.readthedocs.io/en/latest/) 29 | 30 | ### Baseline: BERT-Tagger 31 | 32 | We release code, [scripts](./scripts/bert_tagger/reproduce) and [datafiles](./ner2mrc/download.md) for fine-tuning BERT and treating NER as a sequence labeling task.
33 | 34 | ### MRC-NER: Prepare Datasets 35 | 36 | You can [download](./ner2mrc/download.md) the preprocessed MRC-NER datasets used in our paper.
37 | For flat NER datasets, please use `ner2mrc/mrsa2mrc.py` to transform your BMES NER annotations to MRC-format.
38 | For nested NER datasets, please use `ner2mrc/genia2mrc.py` to transform your start-end NER annotations to MRC-format.
39 | 40 | ### MRC-NER: Training 41 | 42 | The main training procedure is in `train/mrc_ner_trainer.py` 43 | 44 | Scripts for reproducing our experimental results can be found in the `./scripts/mrc_ner/reproduce/` folder. 45 | Note that you need to change `DATA_DIR`, `BERT_DIR`, `OUTPUT_DIR` to your own dataset path, bert model path and log path, respectively.
46 | For example, run `./scripts/mrc_ner/reproduce/ace04.sh` will start training MRC-NER models and save intermediate log to `$OUTPUT_DIR/train_log.txt`.
47 | During training, the model trainer will automatically evaluate on the dev set every `val_check_interval` epochs, 48 | and save the topk checkpoints to `$OUTPUT_DIR`.
49 | 50 | ### MRC-NER: Evaluation 51 | 52 | After training, you can find the best checkpoint on the dev set according to the evaluation results in `$OUTPUT_DIR/train_log.txt`.
53 | Then run `python3 evaluate/mrc_ner_evaluate.py $OUTPUT_DIR/.ckpt $OUTPUT_DIR/lightning_logs/` to evaluate on the test set with the best checkpoint chosen on dev. 54 | 55 | ### MRC-NER: Inference 56 | 57 | Code for inference using the trained MRC-NER model can be found in `inference/mrc_ner_inference.py` file.
58 | For flat NER, we provide the inference script in [flat_inference.sh](./scripts/mrc_ner/flat_inference.sh)
59 | For nested NER, we provide the inference script in [nested_inference.sh](./scripts/mrc_ner/nested_inference.sh) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/collate_functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: collate_functions.py 5 | 6 | import torch 7 | from typing import List 8 | 9 | 10 | def tagger_collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]: 11 | """ 12 | pad to maximum length of this batch 13 | Args: 14 | batch: a batch of samples, each contains a list of field data(Tensor): 15 | tokens, token_type_ids, attention_mask, wordpiece_label_idx_lst 16 | Returns: 17 | output: list of field batched data, which shape is [batch, max_length] 18 | """ 19 | batch_size = len(batch) 20 | max_length = max(x[0].shape[0] for x in batch) 21 | output = [] 22 | 23 | for field_idx in range(3): 24 | # 0 -> tokens 25 | # 1 -> token_type_ids 26 | # 2 -> attention_mask 27 | pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype) 28 | for sample_idx in range(batch_size): 29 | data = batch[sample_idx][field_idx] 30 | pad_output[sample_idx][: data.shape[0]] = data 31 | output.append(pad_output) 32 | 33 | # 3 -> sequence_label 34 | # -100 is ignore_index in the cross-entropy loss function. 35 | pad_output = torch.full([batch_size, max_length], -100, dtype=batch[0][3].dtype) 36 | for sample_idx in range(batch_size): 37 | data = batch[sample_idx][3] 38 | pad_output[sample_idx][: data.shape[0]] = data 39 | output.append(pad_output) 40 | 41 | # 4 -> is word_piece_label 42 | pad_output = torch.full([batch_size, max_length], -100, dtype=batch[0][4].dtype) 43 | for sample_idx in range(batch_size): 44 | data = batch[sample_idx][4] 45 | pad_output[sample_idx][: data.shape[0]] = data 46 | output.append(pad_output) 47 | 48 | return output 49 | 50 | 51 | def collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]: 52 | """ 53 | pad to maximum length of this batch 54 | Args: 55 | batch: a batch of samples, each contains a list of field data(Tensor): 56 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx 57 | Returns: 58 | output: list of field batched data, which shape is [batch, max_length] 59 | """ 60 | batch_size = len(batch) 61 | max_length = max(x[0].shape[0] for x in batch) 62 | output = [] 63 | 64 | for field_idx in range(6): 65 | pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype) 66 | for sample_idx in range(batch_size): 67 | data = batch[sample_idx][field_idx] 68 | pad_output[sample_idx][: data.shape[0]] = data 69 | output.append(pad_output) 70 | 71 | pad_match_labels = torch.zeros([batch_size, max_length, max_length], dtype=torch.long) 72 | for sample_idx in range(batch_size): 73 | data = batch[sample_idx][6] 74 | pad_match_labels[sample_idx, : data.shape[1], : data.shape[1]] = data 75 | output.append(pad_match_labels) 76 | 77 | output.append(torch.stack([x[-2] for x in batch])) 78 | output.append(torch.stack([x[-1] for x in batch])) 79 | 80 | return output 81 | -------------------------------------------------------------------------------- /datasets/mrc_ner_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: mrc_ner_dataset.py 5 | 6 | import json 7 | import torch 8 | from tokenizers import BertWordPieceTokenizer 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class MRCNERDataset(Dataset): 13 | """ 14 | MRC NER Dataset 15 | Args: 16 | json_path: path to mrc-ner style json 17 | tokenizer: BertTokenizer 18 | max_length: int, max length of query+context 19 | possible_only: if True, only use possible samples that contain answer for the query/context 20 | is_chinese: is chinese dataset 21 | """ 22 | def __init__(self, json_path, tokenizer: BertWordPieceTokenizer, max_length: int = 512, possible_only=False, 23 | is_chinese=False, pad_to_maxlen=False): 24 | self.all_data = json.load(open(json_path, encoding="utf-8")) 25 | self.tokenizer = tokenizer 26 | self.max_length = max_length 27 | self.possible_only = possible_only 28 | if self.possible_only: 29 | self.all_data = [ 30 | x for x in self.all_data if x["start_position"] 31 | ] 32 | self.is_chinese = is_chinese 33 | self.pad_to_maxlen = pad_to_maxlen 34 | 35 | def __len__(self): 36 | return len(self.all_data) 37 | 38 | def __getitem__(self, item): 39 | """ 40 | Args: 41 | item: int, idx 42 | Returns: 43 | tokens: tokens of query + context, [seq_len] 44 | token_type_ids: token type ids, 0 for query, 1 for context, [seq_len] 45 | start_labels: start labels of NER in tokens, [seq_len] 46 | end_labels: end labelsof NER in tokens, [seq_len] 47 | label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len] 48 | match_labels: match labels, [seq_len, seq_len] 49 | sample_idx: sample id 50 | label_idx: label id 51 | """ 52 | data = self.all_data[item] 53 | tokenizer = self.tokenizer 54 | 55 | qas_id = data.get("qas_id", "0.0") 56 | sample_idx, label_idx = qas_id.split(".") 57 | sample_idx = torch.LongTensor([int(sample_idx)]) 58 | label_idx = torch.LongTensor([int(label_idx)]) 59 | 60 | query = data["query"] 61 | context = data["context"] 62 | start_positions = data["start_position"] 63 | end_positions = data["end_position"] 64 | 65 | if self.is_chinese: 66 | context = "".join(context.split()) 67 | end_positions = [x+1 for x in end_positions] 68 | else: 69 | # add space offsets 70 | words = context.split() 71 | start_positions = [x + sum([len(w) for w in words[:x]]) for x in start_positions] 72 | end_positions = [x + sum([len(w) for w in words[:x + 1]]) for x in end_positions] 73 | 74 | query_context_tokens = tokenizer.encode(query, context, add_special_tokens=True) 75 | tokens = query_context_tokens.ids 76 | type_ids = query_context_tokens.type_ids 77 | offsets = query_context_tokens.offsets 78 | 79 | # find new start_positions/end_positions, considering 80 | # 1. we add query tokens at the beginning 81 | # 2. word-piece tokenize 82 | origin_offset2token_idx_start = {} 83 | origin_offset2token_idx_end = {} 84 | for token_idx in range(len(tokens)): 85 | # skip query tokens 86 | if type_ids[token_idx] == 0: 87 | continue 88 | token_start, token_end = offsets[token_idx] 89 | # skip [CLS] or [SEP] 90 | if token_start == token_end == 0: 91 | continue 92 | origin_offset2token_idx_start[token_start] = token_idx 93 | origin_offset2token_idx_end[token_end] = token_idx 94 | 95 | new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions] 96 | new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions] 97 | 98 | label_mask = [ 99 | (0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1) 100 | for token_idx in range(len(tokens)) 101 | ] 102 | start_label_mask = label_mask.copy() 103 | end_label_mask = label_mask.copy() 104 | 105 | # the start/end position must be whole word 106 | if not self.is_chinese: 107 | for token_idx in range(len(tokens)): 108 | current_word_idx = query_context_tokens.words[token_idx] 109 | next_word_idx = query_context_tokens.words[token_idx+1] if token_idx+1 < len(tokens) else None 110 | prev_word_idx = query_context_tokens.words[token_idx-1] if token_idx-1 > 0 else None 111 | if prev_word_idx is not None and current_word_idx == prev_word_idx: 112 | start_label_mask[token_idx] = 0 113 | if next_word_idx is not None and current_word_idx == next_word_idx: 114 | end_label_mask[token_idx] = 0 115 | 116 | assert all(start_label_mask[p] != 0 for p in new_start_positions) 117 | assert all(end_label_mask[p] != 0 for p in new_end_positions) 118 | 119 | assert len(new_start_positions) == len(new_end_positions) == len(start_positions) 120 | assert len(label_mask) == len(tokens) 121 | start_labels = [(1 if idx in new_start_positions else 0) 122 | for idx in range(len(tokens))] 123 | end_labels = [(1 if idx in new_end_positions else 0) 124 | for idx in range(len(tokens))] 125 | 126 | # truncate 127 | tokens = tokens[: self.max_length] 128 | type_ids = type_ids[: self.max_length] 129 | start_labels = start_labels[: self.max_length] 130 | end_labels = end_labels[: self.max_length] 131 | start_label_mask = start_label_mask[: self.max_length] 132 | end_label_mask = end_label_mask[: self.max_length] 133 | 134 | # make sure last token is [SEP] 135 | sep_token = tokenizer.token_to_id("[SEP]") 136 | if tokens[-1] != sep_token: 137 | assert len(tokens) == self.max_length 138 | tokens = tokens[: -1] + [sep_token] 139 | start_labels[-1] = 0 140 | end_labels[-1] = 0 141 | start_label_mask[-1] = 0 142 | end_label_mask[-1] = 0 143 | 144 | if self.pad_to_maxlen: 145 | tokens = self.pad(tokens, 0) 146 | type_ids = self.pad(type_ids, 1) 147 | start_labels = self.pad(start_labels) 148 | end_labels = self.pad(end_labels) 149 | start_label_mask = self.pad(start_label_mask) 150 | end_label_mask = self.pad(end_label_mask) 151 | 152 | seq_len = len(tokens) 153 | match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long) 154 | for start, end in zip(new_start_positions, new_end_positions): 155 | if start >= seq_len or end >= seq_len: 156 | continue 157 | match_labels[start, end] = 1 158 | 159 | return [ 160 | torch.LongTensor(tokens), 161 | torch.LongTensor(type_ids), 162 | torch.LongTensor(start_labels), 163 | torch.LongTensor(end_labels), 164 | torch.LongTensor(start_label_mask), 165 | torch.LongTensor(end_label_mask), 166 | match_labels, 167 | sample_idx, 168 | label_idx 169 | ] 170 | 171 | def pad(self, lst, value=0, max_length=None): 172 | max_length = max_length or self.max_length 173 | while len(lst) < max_length: 174 | lst.append(value) 175 | return lst 176 | 177 | 178 | def run_dataset(): 179 | """test dataset""" 180 | import os 181 | from datasets.collate_functions import collate_to_max_length 182 | from torch.utils.data import DataLoader 183 | # zh datasets 184 | bert_path = "/data/nfsdata/nlp/BERT_BASE_DIR/chinese_L-12_H-768_A-12" 185 | vocab_file = os.path.join(bert_path, "vocab.txt") 186 | # json_path = "/mnt/mrc/zh_msra/mrc-ner.test" 187 | json_path = "/data/xiaoya/datasets/mrc_ner/zh_msra/mrc-ner.train" 188 | is_chinese = True 189 | 190 | # en datasets 191 | # bert_path = "/mnt/mrc/bert-base-uncased" 192 | # json_path = "/mnt/mrc/ace2004/mrc-ner.train" 193 | # json_path = "/mnt/mrc/genia/mrc-ner.train" 194 | # is_chinese = False 195 | 196 | vocab_file = os.path.join(bert_path, "vocab.txt") 197 | tokenizer = BertWordPieceTokenizer(vocab_file) 198 | dataset = MRCNERDataset(json_path=json_path, tokenizer=tokenizer, 199 | is_chinese=is_chinese) 200 | 201 | dataloader = DataLoader(dataset, batch_size=1, 202 | collate_fn=collate_to_max_length) 203 | 204 | for batch in dataloader: 205 | for tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx in zip(*batch): 206 | tokens = tokens.tolist() 207 | start_positions, end_positions = torch.where(match_labels > 0) 208 | start_positions = start_positions.tolist() 209 | end_positions = end_positions.tolist() 210 | print(start_labels.numpy().tolist()) 211 | 212 | tmp_start_position = [] 213 | for tmp_idx, tmp_label in enumerate(start_labels.numpy().tolist()): 214 | if tmp_label != 0: 215 | tmp_start_position.append(tmp_idx) 216 | 217 | tmp_end_position = [] 218 | for tmp_idx, tmp_label in enumerate(end_labels.numpy().tolist()): 219 | if tmp_label != 0: 220 | tmp_end_position.append(tmp_idx) 221 | 222 | if not start_positions: 223 | continue 224 | print("="*20) 225 | print(f"len: {len(tokens)}", tokenizer.decode(tokens, skip_special_tokens=False)) 226 | for start, end in zip(start_positions, end_positions): 227 | print(str(sample_idx.item()), str(label_idx.item()) + "\t" + tokenizer.decode(tokens[start: end+1])) 228 | 229 | print("!!!"*20) 230 | for start, end in zip(tmp_start_position, tmp_end_position): 231 | print(str(sample_idx.item()), str(label_idx.item()) + "\t" + tokenizer.decode(tokens[start: end + 1])) 232 | 233 | 234 | if __name__ == '__main__': 235 | run_dataset() 236 | -------------------------------------------------------------------------------- /datasets/tagger_ner_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tagger_ner_dataset.py 5 | 6 | import torch 7 | from transformers import AutoTokenizer 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def get_labels(data_sign): 12 | """gets the list of labels for this data set.""" 13 | if data_sign == "zh_onto": 14 | return ["O", "S-GPE", "B-GPE", "M-GPE", "E-GPE", 15 | "S-LOC", "B-LOC", "M-LOC", "E-LOC", 16 | "S-PER", "B-PER", "M-PER", "E-PER", 17 | "S-ORG", "B-ORG", "M-ORG", "E-ORG",] 18 | elif data_sign == "zh_msra": 19 | return ["O", "S-NS", "B-NS", "M-NS", "E-NS", 20 | "S-NR", "B-NR", "M-NR", "E-NR", 21 | "S-NT", "B-NT", "M-NT", "E-NT"] 22 | elif data_sign == "en_onto": 23 | return ["O", "S-LAW", "B-LAW", "M-LAW", "E-LAW", 24 | "S-EVENT", "B-EVENT", "M-EVENT", "E-EVENT", 25 | "S-CARDINAL", "B-CARDINAL", "M-CARDINAL", "E-CARDINAL", 26 | "S-FAC", "B-FAC", "M-FAC", "E-FAC", 27 | "S-TIME", "B-TIME", "M-TIME", "E-TIME", 28 | "S-DATE", "B-DATE", "M-DATE", "E-DATE", 29 | "S-ORDINAL", "B-ORDINAL", "M-ORDINAL", "E-ORDINAL", 30 | "S-ORG", "B-ORG", "M-ORG", "E-ORG", 31 | "S-QUANTITY", "B-QUANTITY", "M-QUANTITY", "E-QUANTITY", 32 | "S-PERCENT", "B-PERCENT", "M-PERCENT", "E-PERCENT", 33 | "S-WORK_OF_ART", "B-WORK_OF_ART", "M-WORK_OF_ART", "E-WORK_OF_ART", 34 | "S-LOC", "B-LOC", "M-LOC", "E-LOC", 35 | "S-LANGUAGE", "B-LANGUAGE", "M-LANGUAGE", "E-LANGUAGE", 36 | "S-NORP", "B-NORP", "M-NORP", "E-NORP", 37 | "S-MONEY", "B-MONEY", "M-MONEY", "E-MONEY", 38 | "S-PERSON", "B-PERSON", "M-PERSON", "E-PERSON", 39 | "S-GPE", "B-GPE", "M-GPE", "E-GPE", 40 | "S-PRODUCT", "B-PRODUCT", "M-PRODUCT", "E-PRODUCT"] 41 | elif data_sign == "en_conll03": 42 | return ["O", "S-ORG", "B-ORG", "M-ORG", "E-ORG", 43 | "S-PER", "B-PER", "M-PER", "E-PER", 44 | "S-LOC", "B-LOC", "M-LOC", "E-LOC", 45 | "S-MISC", "B-MISC", "M-MISC", "E-MISC"] 46 | return ["0", "1"] 47 | 48 | 49 | def load_data_in_conll(data_path): 50 | """ 51 | Desc: 52 | load data in conll format 53 | Returns: 54 | [([word1, word2, word3, word4], [label1, label2, label3, label4]), 55 | ([word5, word6, word7, wordd8], [label5, label6, label7, label8])] 56 | """ 57 | dataset = [] 58 | with open(data_path, "r", encoding="utf-8") as f: 59 | datalines = f.readlines() 60 | sentence, labels = [], [] 61 | 62 | for line in datalines: 63 | line = line.strip() 64 | if len(line) == 0: 65 | dataset.append((sentence, labels)) 66 | sentence, labels = [], [] 67 | else: 68 | word, tag = line.split(" ") 69 | sentence.append(word) 70 | labels.append(tag) 71 | return dataset 72 | 73 | 74 | class TaggerNERDataset(Dataset): 75 | """ 76 | MRC NER Dataset 77 | Args: 78 | data_path: path to Conll-style named entity dadta file. 79 | tokenizer: BertTokenizer 80 | max_length: int, max length of query+context 81 | is_chinese: is chinese dataset 82 | Note: 83 | https://github.com/huggingface/transformers/blob/143738214cb83e471f3a43652617c8881370342c/examples/pytorch/token-classification/run_ner.py#L362 84 | https://github.com/huggingface/transformers/blob/143738214cb83e471f3a43652617c8881370342c/src/transformers/models/bert/modeling_bert.py#L1739 85 | """ 86 | def __init__(self, data_path, tokenizer: AutoTokenizer, dataset_signature, max_length: int = 512, 87 | is_chinese=False, pad_to_maxlen=False, tagging_schema="BMESO", ): 88 | self.all_data = load_data_in_conll(data_path) 89 | self.tokenizer = tokenizer 90 | self.max_length = max_length 91 | self.is_chinese = is_chinese 92 | self.pad_to_maxlen = pad_to_maxlen 93 | self.pad_idx = 0 94 | self.cls_idx = 101 95 | self.sep_idx = 102 96 | self.label2idx = {label_item: label_idx for label_idx, label_item in enumerate(get_labels(dataset_signature))} 97 | 98 | def __len__(self): 99 | return len(self.all_data) 100 | 101 | def __getitem__(self, item): 102 | data = self.all_data[item] 103 | token_lst, label_lst = tuple(data) 104 | wordpiece_token_lst, wordpiece_label_lst = [], [] 105 | 106 | for token_item, label_item in zip(token_lst, label_lst): 107 | tmp_token_lst = self.tokenizer.encode(token_item, add_special_tokens=False, return_token_type_ids=None) 108 | if len(tmp_token_lst) == 1: 109 | wordpiece_token_lst.append(tmp_token_lst[0]) 110 | wordpiece_label_lst.append(label_item) 111 | else: 112 | len_wordpiece = len(tmp_token_lst) 113 | wordpiece_token_lst.extend(tmp_token_lst) 114 | tmp_label_lst = [label_item] + [-100 for idx in range((len_wordpiece - 1))] 115 | wordpiece_label_lst.extend(tmp_label_lst) 116 | 117 | if len(wordpiece_token_lst) > self.max_length - 2: 118 | wordpiece_token_lst = wordpiece_token_lst[: self.max_length-2] 119 | wordpiece_label_lst = wordpiece_label_lst[: self.max_length-2] 120 | 121 | wordpiece_token_lst = [self.cls_idx] + wordpiece_token_lst + [self.sep_idx] 122 | wordpiece_label_lst = [-100] + wordpiece_label_lst + [-100] 123 | # token_type_ids: segment token indices to indicate first and second portions of the inputs. 124 | # - 0 corresponds to a "sentence a" token 125 | # - 1 corresponds to a "sentence b" token 126 | token_type_ids = [0] * len(wordpiece_token_lst) 127 | # attention_mask: mask to avoid performing attention on padding token indices. 128 | # - 1 for tokens that are not masked. 129 | # - 0 for tokens that are masked. 130 | attention_mask = [1] * len(wordpiece_token_lst) 131 | is_wordpiece_mask = [1 if label_item != -100 else -100 for label_item in wordpiece_label_lst] 132 | wordpiece_label_idx_lst = [self.label2idx[label_item] if label_item != -100 else -100 for label_item in wordpiece_label_lst] 133 | 134 | return [torch.tensor(wordpiece_token_lst, dtype=torch.long), 135 | torch.tensor(token_type_ids, dtype=torch.long), 136 | torch.tensor(attention_mask, dtype=torch.long), 137 | torch.tensor(wordpiece_label_idx_lst, dtype=torch.long), 138 | torch.tensor(is_wordpiece_mask, dtype=torch.long)] 139 | 140 | 141 | -------------------------------------------------------------------------------- /datasets/truncate_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: truncate_dataset.py 5 | 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class TruncateDataset(Dataset): 10 | """Truncate dataset to certain num""" 11 | def __init__(self, dataset: Dataset, max_num: int = 100): 12 | self.dataset = dataset 13 | self.max_num = min(max_num, len(self.dataset)) 14 | 15 | def __len__(self): 16 | return self.max_num 17 | 18 | def __getitem__(self, item): 19 | return self.dataset[item] 20 | 21 | def __getattr__(self, item): 22 | """other dataset func""" 23 | return getattr(self.dataset, item) 24 | -------------------------------------------------------------------------------- /evaluate/mrc_ner_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: mrc_ner_evaluate.py 5 | # example command: 6 | # python3 mrc_ner_evaluate.py /data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/epoch=0.ckpt \ 7 | # /data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/lightning_logs/version_2/hparams.yaml 8 | 9 | import sys 10 | from pytorch_lightning import Trainer 11 | from train.mrc_ner_trainer import BertLabeling 12 | from utils.random_seed import set_random_seed 13 | 14 | set_random_seed(0) 15 | 16 | 17 | def evaluate(ckpt, hparams_file, gpus=[0, 1], max_length=300): 18 | trainer = Trainer(gpus=gpus, distributed_backend="dp") 19 | 20 | model = BertLabeling.load_from_checkpoint( 21 | checkpoint_path=ckpt, 22 | hparams_file=hparams_file, 23 | map_location=None, 24 | batch_size=1, 25 | max_length=max_length, 26 | workers=0 27 | ) 28 | trainer.test(model=model) 29 | 30 | 31 | if __name__ == '__main__': 32 | # example of running evaluate.py 33 | # CHECKPOINTS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/epoch=2_v1.ckpt" 34 | # HPARAMS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/lightning_logs/version_2/hparams.yaml" 35 | # GPUS="1,2,3" 36 | CHECKPOINTS = sys.argv[1] 37 | HPARAMS = sys.argv[2] 38 | try: 39 | GPUS = [int(gpu_item) for gpu_item in sys.argv[3].strip().split(",")] 40 | except: 41 | GPUS = [0] 42 | 43 | try: 44 | MAXLEN = int(sys.argv[4]) 45 | except: 46 | MAXLEN = 300 47 | 48 | evaluate(ckpt=CHECKPOINTS, hparams_file=HPARAMS, gpus=GPUS, max_length=MAXLEN) 49 | -------------------------------------------------------------------------------- /evaluate/tagger_ner_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tagger_ner_evaluate.py 5 | # example command: 6 | 7 | 8 | import sys 9 | from pytorch_lightning import Trainer 10 | from train.bert_tagger_trainer import BertSequenceLabeling 11 | from utils.random_seed import set_random_seed 12 | 13 | set_random_seed(0) 14 | 15 | 16 | def evaluate(ckpt, hparams_file, gpus=[0, 1], max_length=300): 17 | trainer = Trainer(gpus=gpus, distributed_backend="dp") 18 | 19 | model = BertSequenceLabeling.load_from_checkpoint( 20 | checkpoint_path=ckpt, 21 | hparams_file=hparams_file, 22 | map_location=None, 23 | batch_size=1, 24 | max_length=max_length, 25 | workers=0 26 | ) 27 | trainer.test(model=model) 28 | 29 | 30 | if __name__ == '__main__': 31 | # example of running evaluate.py 32 | # CHECKPOINTS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/epoch=2_v1.ckpt" 33 | # HPARAMS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/lightning_logs/version_2/hparams.yaml" 34 | # GPUS="1,2,3" 35 | CHECKPOINTS = sys.argv[1] 36 | HPARAMS = sys.argv[2] 37 | 38 | try: 39 | GPUS = [int(gpu_item) for gpu_item in sys.argv[3].strip().split(",")] 40 | except: 41 | GPUS = [0] 42 | 43 | try: 44 | MAXLEN = int(sys.argv[4]) 45 | except: 46 | MAXLEN = 300 47 | 48 | evaluate(ckpt=CHECKPOINTS, hparams_file=HPARAMS, gpus=GPUS, max_length=MAXLEN, ) 49 | -------------------------------------------------------------------------------- /inference/mrc_ner_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: mrc_ner_inference.py 5 | 6 | import os 7 | import torch 8 | import argparse 9 | from torch.utils.data import DataLoader 10 | from utils.random_seed import set_random_seed 11 | set_random_seed(0) 12 | from train.mrc_ner_trainer import BertLabeling 13 | from tokenizers import BertWordPieceTokenizer 14 | from datasets.mrc_ner_dataset import MRCNERDataset 15 | from metrics.functional.query_span_f1 import extract_flat_spans, extract_nested_spans 16 | 17 | def get_dataloader(config, data_prefix="test"): 18 | data_path = os.path.join(config.data_dir, f"mrc-ner.{data_prefix}") 19 | vocab_path = os.path.join(config.bert_dir, "vocab.txt") 20 | data_tokenizer = BertWordPieceTokenizer(vocab_path) 21 | 22 | dataset = MRCNERDataset(json_path=data_path, 23 | tokenizer=data_tokenizer, 24 | max_length=config.max_length, 25 | is_chinese=config.is_chinese, 26 | pad_to_maxlen=False) 27 | 28 | dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) 29 | 30 | return dataloader, data_tokenizer 31 | 32 | def get_query_index_to_label_cate(dataset_sign): 33 | # NOTICE: need change if you use other datasets. 34 | # please notice it should in line with the mrc-ner.test/train/dev json file 35 | if dataset_sign == "conll03": 36 | return {1: "ORG", 2: "PER", 3: "LOC", 4: "MISC"} 37 | elif dataset_sign == "ace04": 38 | return {1: "GPE", 2: "ORG", 3: "PER", 4: "FAC", 5: "VEH", 6: "LOC", 7: "WEA"} 39 | 40 | 41 | def get_parser() -> argparse.ArgumentParser: 42 | parser = argparse.ArgumentParser(description="inference the model output.") 43 | parser.add_argument("--data_dir", type=str, required=True) 44 | parser.add_argument("--bert_dir", type=str, required=True) 45 | parser.add_argument("--max_length", type=int, default=256) 46 | parser.add_argument("--is_chinese", action="store_true") 47 | parser.add_argument("--model_ckpt", type=str, default="") 48 | parser.add_argument("--hparams_file", type=str, default="") 49 | parser.add_argument("--flat_ner", action="store_true",) 50 | parser.add_argument("--dataset_sign", type=str, choices=["ontonotes4", "msra", "conll03", "ace04", "ace05"], default="conll03") 51 | 52 | return parser 53 | 54 | 55 | def main(): 56 | parser = get_parser() 57 | args = parser.parse_args() 58 | trained_mrc_ner_model = BertLabeling.load_from_checkpoint( 59 | checkpoint_path=args.model_ckpt, 60 | hparams_file=args.hparams_file, 61 | map_location=None, 62 | batch_size=1, 63 | max_length=args.max_length, 64 | workers=0) 65 | 66 | data_loader, data_tokenizer = get_dataloader(args,) 67 | # load token 68 | vocab_path = os.path.join(args.bert_dir, "vocab.txt") 69 | with open(vocab_path, "r") as f: 70 | subtokens = [token.strip() for token in f.readlines()] 71 | idx2tokens = {} 72 | for token_idx, token in enumerate(subtokens): 73 | idx2tokens[token_idx] = token 74 | 75 | query2label_dict = get_query_index_to_label_cate(args.dataset_sign) 76 | 77 | for batch in data_loader: 78 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch 79 | attention_mask = (tokens != 0).long() 80 | 81 | start_logits, end_logits, span_logits = trained_mrc_ner_model.model(tokens, attention_mask=attention_mask, token_type_ids=token_type_ids) 82 | start_preds, end_preds, span_preds = start_logits > 0, end_logits > 0, span_logits > 0 83 | 84 | subtokens_idx_lst = tokens.numpy().tolist()[0] 85 | subtokens_lst = [idx2tokens[item] for item in subtokens_idx_lst] 86 | label_cate = query2label_dict[label_idx.item()] 87 | readable_input_str = data_tokenizer.decode(subtokens_idx_lst, skip_special_tokens=True) 88 | 89 | if args.flat_ner: 90 | entities_info = extract_flat_spans(torch.squeeze(start_preds), torch.squeeze(end_preds), 91 | torch.squeeze(span_preds), torch.squeeze(attention_mask), pseudo_tag=label_cate) 92 | entity_lst = [] 93 | 94 | if len(entities_info) != 0: 95 | for entity_info in entities_info: 96 | start, end = entity_info[0], entity_info[1] 97 | entity_string = " ".join(subtokens_lst[start: end]) 98 | entity_string = entity_string.replace(" ##", "") 99 | entity_lst.append((start, end, entity_string, entity_info[2])) 100 | 101 | else: 102 | match_preds = span_logits > 0 103 | entities_info = extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask, pseudo_tag=label_cate) 104 | 105 | entity_lst = [] 106 | 107 | if len(entities_info) != 0: 108 | for entity_info in entities_info: 109 | start, end = entity_info[0], entity_info[1] 110 | entity_string = " ".join(subtokens_lst[start: end+1 ]) 111 | entity_string = entity_string.replace(" ##", "") 112 | entity_lst.append((start, end+1, entity_string, entity_info[2])) 113 | 114 | print("*="*10) 115 | print(f"Given input: {readable_input_str}") 116 | print(f"Model predict: {entity_lst}") 117 | # entity_lst is a list of (subtoken_start_pos, subtoken_end_pos, substring, entity_type) 118 | 119 | if __name__ == "__main__": 120 | main() -------------------------------------------------------------------------------- /inference/tagger_ner_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tagger_ner_inference.py 5 | 6 | import os 7 | import torch 8 | import argparse 9 | from torch.utils.data import DataLoader 10 | from utils.random_seed import set_random_seed 11 | set_random_seed(0) 12 | from train.bert_tagger_trainer import BertSequenceLabeling 13 | from transformers import AutoTokenizer 14 | from datasets.tagger_ner_dataset import get_labels 15 | from datasets.tagger_ner_dataset import TaggerNERDataset 16 | from metrics.functional.tagger_span_f1 import get_entity_from_bmes_lst, transform_predictions_to_labels 17 | 18 | 19 | def get_dataloader(config, data_prefix="test"): 20 | data_path = os.path.join(config.data_dir, f"{data_prefix}{config.data_file_suffix}") 21 | data_tokenizer = AutoTokenizer.from_pretrained(config.bert_dir, use_fast=False, do_lower_case=config.do_lowercase) 22 | 23 | dataset = TaggerNERDataset(data_path, data_tokenizer, config.dataset_sign, 24 | max_length=config.max_length, is_chinese=config.is_chinese, pad_to_maxlen=False) 25 | 26 | dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) 27 | 28 | return dataloader, data_tokenizer 29 | 30 | 31 | def get_parser() -> argparse.ArgumentParser: 32 | parser = argparse.ArgumentParser(description="inference the model output.") 33 | parser.add_argument("--data_dir", type=str, required=True) 34 | parser.add_argument("--bert_dir", type=str, required=True) 35 | parser.add_argument("--max_length", type=int, default=256) 36 | parser.add_argument("--is_chinese", action="store_true") 37 | parser.add_argument("--model_ckpt", type=str, default="") 38 | parser.add_argument("--hparams_file", type=str, default="") 39 | parser.add_argument("--do_lowercase", action="store_true") 40 | parser.add_argument("--data_file_suffix", type=str, default=".word.bmes") 41 | parser.add_argument("--dataset_sign", type=str, choices=["en_onto", "en_conll03", "zh_onto", "zh_msra" ], default="en_onto") 42 | 43 | return parser 44 | 45 | 46 | def main(): 47 | parser = get_parser() 48 | args = parser.parse_args() 49 | 50 | trained_tagger_ner_model = BertSequenceLabeling.load_from_checkpoint( 51 | checkpoint_path=args.model_ckpt, 52 | hparams_file=args.hparams_file, 53 | map_location=None, 54 | batch_size=1, 55 | max_length=args.max_length, 56 | workers=0) 57 | 58 | entity_label_lst = get_labels(args.dataset_sign) 59 | task_idx2label = {label_idx: label_item for label_idx, label_item in enumerate(entity_label_lst)} 60 | 61 | data_loader, data_tokenizer = get_dataloader(args) 62 | vocab_path = os.path.join(args.bert_dir, "vocab.txt") 63 | # load token 64 | vocab_path = os.path.join(args.bert_dir, "vocab.txt") 65 | with open(vocab_path, "r") as f: 66 | subtokens = [token.strip() for token in f.readlines()] 67 | idx2tokens = {} 68 | for token_idx, token in enumerate(subtokens): 69 | idx2tokens[token_idx] = token 70 | 71 | for batch in data_loader: 72 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch 73 | batch_size = token_input_ids.shape[0] 74 | logits = trained_tagger_ner_model.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 75 | 76 | sequence_pred_lst = transform_predictions_to_labels(logits.view(batch_size, -1, len(entity_label_lst)), 77 | is_wordpiece_mask, task_idx2label, input_type="logit") 78 | batch_subtokens_idx_lst = token_input_ids.numpy().tolist()[0] 79 | batch_subtokens_lst = [idx2tokens[item] for item in batch_subtokens_idx_lst] 80 | readable_input_str = data_tokenizer.decode(batch_subtokens_idx_lst, skip_special_tokens=True) 81 | 82 | batch_entity_lst = [get_entity_from_bmes_lst(label_lst_item) for label_lst_item in sequence_pred_lst] 83 | 84 | pred_entity_lst = [] 85 | 86 | for entity_lst, subtoken_lst in zip(batch_entity_lst, batch_subtokens_lst): 87 | if len(entity_lst) != 0: 88 | # example of entity_lst: 89 | # ['[0,3]PER', '[6,9]ORG', '[10]PER'] 90 | for entity_info in entity_lst: 91 | if "," in entity_info: 92 | inter_pos = entity_info.find(",") 93 | start_pos = 1 94 | end_pos = entity_info.find("]") 95 | start_idx = int(entity_info[start_pos: inter_pos]) 96 | end_idx = int(entity_info[inter_pos+1: end_pos]) 97 | else: 98 | start_pos = 1 99 | end_pos = entity_info.find("]") 100 | start_idx = int(entity_info[start_pos:end_pos]) 101 | end_idx = int(entity_info[start_pos:end_pos]) 102 | 103 | entity_tokens = subtoken_lst[start_idx: end_idx] 104 | entity_string = " ".join(entity_tokens) 105 | entity_string = entity_string.replace(" ##", "") 106 | # append start, end 107 | pred_entity_lst.append((entity_string, entity_info[end_pos+1:])) 108 | else: 109 | pred_entity_lst.append([]) 110 | 111 | print("*=" * 10) 112 | print(f"Given input: {readable_input_str}") 113 | print(f"Model predict: {pred_entity_lst}") 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/metrics/functional/__init__.py -------------------------------------------------------------------------------- /metrics/functional/query_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: query_span_f1.py 5 | 6 | import torch 7 | import numpy as np 8 | from utils.bmes_decode import bmes_decode 9 | 10 | 11 | def query_span_f1(start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels, flat=False): 12 | """ 13 | Compute span f1 according to query-based model output 14 | Args: 15 | start_preds: [bsz, seq_len] 16 | end_preds: [bsz, seq_len] 17 | match_logits: [bsz, seq_len, seq_len] 18 | start_label_mask: [bsz, seq_len] 19 | end_label_mask: [bsz, seq_len] 20 | match_labels: [bsz, seq_len, seq_len] 21 | flat: if True, decode as flat-ner 22 | Returns: 23 | span-f1 counts, tensor of shape [3]: tp, fp, fn 24 | """ 25 | start_label_mask = start_label_mask.bool() 26 | end_label_mask = end_label_mask.bool() 27 | match_labels = match_labels.bool() 28 | bsz, seq_len = start_label_mask.size() 29 | # [bsz, seq_len, seq_len] 30 | match_preds = match_logits > 0 31 | # [bsz, seq_len] 32 | start_preds = start_preds.bool() 33 | # [bsz, seq_len] 34 | end_preds = end_preds.bool() 35 | 36 | match_preds = (match_preds 37 | & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) 38 | & end_preds.unsqueeze(1).expand(-1, seq_len, -1)) 39 | match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len) 40 | & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1)) 41 | match_label_mask = torch.triu(match_label_mask, 0) # start should be less or equal to end 42 | match_preds = match_label_mask & match_preds 43 | 44 | tp = (match_labels & match_preds).long().sum() 45 | fp = (~match_labels & match_preds).long().sum() 46 | fn = (match_labels & ~match_preds).long().sum() 47 | return torch.stack([tp, fp, fn]) 48 | 49 | 50 | def extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask, pseudo_tag="TAG"): 51 | start_label_mask = start_label_mask.bool() 52 | end_label_mask = end_label_mask.bool() 53 | bsz, seq_len = start_label_mask.size() 54 | start_preds = start_preds.bool() 55 | end_preds = end_preds.bool() 56 | 57 | match_preds = (match_preds & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(1).expand(-1, seq_len, -1)) 58 | match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len) & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1)) 59 | match_label_mask = torch.triu(match_label_mask, 0) # start should be less or equal to end 60 | match_preds = match_label_mask & match_preds 61 | match_pos_pairs = np.transpose(np.nonzero(match_preds.numpy())).tolist() 62 | return [(pos[0], pos[1], pseudo_tag) for pos in match_pos_pairs] 63 | 64 | 65 | def extract_flat_spans(start_pred, end_pred, match_pred, label_mask, pseudo_tag = "TAG"): 66 | """ 67 | Extract flat-ner spans from start/end/match logits 68 | Args: 69 | start_pred: [seq_len], 1/True for start, 0/False for non-start 70 | end_pred: [seq_len, 2], 1/True for end, 0/False for non-end 71 | match_pred: [seq_len, seq_len], 1/True for match, 0/False for non-match 72 | label_mask: [seq_len], 1 for valid boundary. 73 | Returns: 74 | tags: list of tuple (start, end) 75 | Examples: 76 | >>> start_pred = [0, 1] 77 | >>> end_pred = [0, 1] 78 | >>> match_pred = [[0, 0], [0, 1]] 79 | >>> label_mask = [1, 1] 80 | >>> extract_flat_spans(start_pred, end_pred, match_pred, label_mask) 81 | [(1, 2)] 82 | """ 83 | pseudo_input = "a" 84 | 85 | bmes_labels = ["O"] * len(start_pred) 86 | start_positions = [idx for idx, tmp in enumerate(start_pred) if tmp and label_mask[idx]] 87 | end_positions = [idx for idx, tmp in enumerate(end_pred) if tmp and label_mask[idx]] 88 | 89 | for start_item in start_positions: 90 | bmes_labels[start_item] = f"B-{pseudo_tag}" 91 | for end_item in end_positions: 92 | bmes_labels[end_item] = f"E-{pseudo_tag}" 93 | 94 | for tmp_start in start_positions: 95 | tmp_end = [tmp for tmp in end_positions if tmp >= tmp_start] 96 | if len(tmp_end) == 0: 97 | continue 98 | else: 99 | tmp_end = min(tmp_end) 100 | if match_pred[tmp_start][tmp_end]: 101 | if tmp_start != tmp_end: 102 | for i in range(tmp_start+1, tmp_end): 103 | bmes_labels[i] = f"M-{pseudo_tag}" 104 | else: 105 | bmes_labels[tmp_end] = f"S-{pseudo_tag}" 106 | 107 | tags = bmes_decode([(pseudo_input, label) for label in bmes_labels]) 108 | 109 | return [(entity.begin, entity.end, entity.tag) for entity in tags] 110 | 111 | 112 | def remove_overlap(spans): 113 | """ 114 | remove overlapped spans greedily for flat-ner 115 | Args: 116 | spans: list of tuple (start, end), which means [start, end] is a ner-span 117 | Returns: 118 | spans without overlap 119 | """ 120 | output = [] 121 | occupied = set() 122 | for start, end in spans: 123 | if any(x for x in range(start, end+1)) in occupied: 124 | continue 125 | output.append((start, end)) 126 | for x in range(start, end + 1): 127 | occupied.add(x) 128 | return output 129 | -------------------------------------------------------------------------------- /metrics/functional/tagger_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tagger_span_f1.py 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def transform_predictions_to_labels(sequence_input_lst, wordpiece_mask, idx2label_map, input_type="logit"): 11 | """ 12 | shape: 13 | sequence_input_lst: [batch_size, seq_len, num_labels] 14 | wordpiece_mask: [batch_size, seq_len, ] 15 | """ 16 | wordpiece_mask = wordpiece_mask.detach().cpu().numpy().tolist() 17 | if input_type == "logit": 18 | label_sequence = torch.argmax(F.softmax(sequence_input_lst, dim=2), dim=2).detach().cpu().numpy().tolist() 19 | elif input_type == "prob": 20 | label_sequence = torch.argmax(sequence_input_lst, dim=2).detach().cpu().numpy().tolist() 21 | elif input_type == "label": 22 | label_sequence = sequence_input_lst.detach().cpu().numpy().tolist() 23 | else: 24 | raise ValueError 25 | output_label_sequence = [] 26 | for tmp_idx_lst, tmp_label_lst in enumerate(label_sequence): 27 | tmp_wordpiece_mask = wordpiece_mask[tmp_idx_lst] 28 | tmp_label_seq = [] 29 | for tmp_idx, tmp_label in enumerate(tmp_label_lst): 30 | if tmp_wordpiece_mask[tmp_idx] != -100: 31 | tmp_label_seq.append(idx2label_map[tmp_label]) 32 | else: 33 | tmp_label_seq.append(-100) 34 | output_label_sequence.append(tmp_label_seq) 35 | return output_label_sequence 36 | 37 | 38 | def compute_tagger_span_f1(sequence_pred_lst, sequence_gold_lst): 39 | sum_true_positive, sum_false_positive, sum_false_negative = 0, 0, 0 40 | 41 | for seq_pred_item, seq_gold_item in zip(sequence_pred_lst, sequence_gold_lst): 42 | gold_entity_lst = get_entity_from_bmes_lst(seq_gold_item) 43 | pred_entity_lst = get_entity_from_bmes_lst(seq_pred_item) 44 | 45 | true_positive_item, false_positive_item, false_negative_item = count_confusion_matrix(pred_entity_lst, gold_entity_lst) 46 | sum_true_positive += true_positive_item 47 | sum_false_negative += false_negative_item 48 | sum_false_positive += false_positive_item 49 | 50 | batch_confusion_matrix = torch.tensor([sum_true_positive, sum_false_positive, sum_false_negative], dtype=torch.long) 51 | return batch_confusion_matrix 52 | 53 | 54 | def count_confusion_matrix(pred_entities, gold_entities): 55 | true_positive, false_positive, false_negative = 0, 0, 0 56 | for span_item in pred_entities: 57 | if span_item in gold_entities: 58 | true_positive += 1 59 | gold_entities.remove(span_item) 60 | else: 61 | false_positive += 1 62 | # these entities are not predicted. 63 | for span_item in gold_entities: 64 | false_negative += 1 65 | return true_positive, false_positive, false_negative 66 | 67 | 68 | def get_entity_from_bmes_lst(label_list): 69 | """reuse the code block from 70 | https://github.com/jiesutd/NCRFpp/blob/105a53a321eca9c1280037c473967858e01aaa43/utils/metric.py#L73 71 | Many thanks to Jie Yang. 72 | """ 73 | list_len = len(label_list) 74 | begin_label = 'B-' 75 | end_label = 'E-' 76 | single_label = 'S-' 77 | whole_tag = '' 78 | index_tag = '' 79 | tag_list = [] 80 | stand_matrix = [] 81 | for i in range(0, list_len): 82 | if label_list[i] != -100: 83 | current_label = label_list[i].upper() 84 | else: 85 | continue 86 | 87 | if begin_label in current_label: 88 | if index_tag != '': 89 | tag_list.append(whole_tag + ',' + str(i-1)) 90 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i) 91 | index_tag = current_label.replace(begin_label,"",1) 92 | elif single_label in current_label: 93 | if index_tag != '': 94 | tag_list.append(whole_tag + ',' + str(i-1)) 95 | whole_tag = current_label.replace(single_label,"",1) +'[' +str(i) 96 | tag_list.append(whole_tag) 97 | whole_tag = "" 98 | index_tag = "" 99 | elif end_label in current_label: 100 | if index_tag != '': 101 | tag_list.append(whole_tag +',' + str(i)) 102 | whole_tag = '' 103 | index_tag = '' 104 | else: 105 | continue 106 | if (whole_tag != '')&(index_tag != ''): 107 | tag_list.append(whole_tag) 108 | tag_list_len = len(tag_list) 109 | 110 | for i in range(0, tag_list_len): 111 | if len(tag_list[i]) > 0: 112 | tag_list[i] = tag_list[i]+ ']' 113 | insert_list = reverse_style(tag_list[i]) 114 | stand_matrix.append(insert_list) 115 | return stand_matrix 116 | 117 | 118 | def reverse_style(input_string): 119 | target_position = input_string.index('[') 120 | input_len = len(input_string) 121 | output_string = input_string[target_position:input_len] + input_string[0:target_position] 122 | return output_string 123 | 124 | -------------------------------------------------------------------------------- /metrics/query_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: query_span_f1.py 5 | 6 | from pytorch_lightning.metrics.metric import TensorMetric 7 | from metrics.functional.query_span_f1 import query_span_f1 8 | 9 | 10 | class QuerySpanF1(TensorMetric): 11 | """ 12 | Query Span F1 13 | Args: 14 | flat: is flat-ner 15 | """ 16 | def __init__(self, reduce_group=None, reduce_op=None, flat=False): 17 | super(QuerySpanF1, self).__init__(name="query_span_f1", 18 | reduce_group=reduce_group, 19 | reduce_op=reduce_op) 20 | self.flat = flat 21 | 22 | def forward(self, start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels): 23 | return query_span_f1(start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels, 24 | flat=self.flat) 25 | -------------------------------------------------------------------------------- /metrics/tagger_span_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: tagger_span_f1.py 5 | 6 | from pytorch_lightning.metrics.metric import TensorMetric 7 | from metrics.functional.tagger_span_f1 import compute_tagger_span_f1 8 | 9 | 10 | class TaggerSpanF1(TensorMetric): 11 | def __init__(self, reduce_group=None, reduce_op=None): 12 | super(TaggerSpanF1, self).__init__(name="tagger_span_f1", reduce_group=reduce_group, reduce_op=reduce_op) 13 | 14 | def forward(self, sequence_pred_lst, sequence_gold_lst): 15 | return compute_tagger_span_f1(sequence_pred_lst, sequence_gold_lst) 16 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/models/__init__.py -------------------------------------------------------------------------------- /models/bert_query_ner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_query_ner.py 5 | 6 | import torch 7 | import torch.nn as nn 8 | from transformers import BertModel, BertPreTrainedModel 9 | 10 | from models.classifier import MultiNonLinearClassifier 11 | 12 | 13 | class BertQueryNER(BertPreTrainedModel): 14 | def __init__(self, config): 15 | super(BertQueryNER, self).__init__(config) 16 | self.bert = BertModel(config) 17 | 18 | self.start_outputs = nn.Linear(config.hidden_size, 1) 19 | self.end_outputs = nn.Linear(config.hidden_size, 1) 20 | self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.mrc_dropout, 21 | intermediate_hidden_size=config.classifier_intermediate_hidden_size) 22 | 23 | self.hidden_size = config.hidden_size 24 | 25 | self.init_weights() 26 | 27 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 28 | """ 29 | Args: 30 | input_ids: bert input tokens, tensor of shape [seq_len] 31 | token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len] 32 | attention_mask: attention mask, tensor of shape [seq_len] 33 | Returns: 34 | start_logits: start/non-start probs of shape [seq_len] 35 | end_logits: end/non-end probs of shape [seq_len] 36 | match_logits: start-end-match probs of shape [seq_len, 1] 37 | """ 38 | 39 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 40 | 41 | sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden] 42 | batch_size, seq_len, hid_size = sequence_heatmap.size() 43 | 44 | start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1] 45 | end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1] 46 | 47 | # for every position $i$ in sequence, should concate $j$ to 48 | # predict if $i$ and $j$ are start_pos and end_pos for an entity. 49 | # [batch, seq_len, seq_len, hidden] 50 | start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) 51 | # [batch, seq_len, seq_len, hidden] 52 | end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) 53 | # [batch, seq_len, seq_len, hidden*2] 54 | span_matrix = torch.cat([start_extend, end_extend], 3) 55 | # [batch, seq_len, seq_len] 56 | span_logits = self.span_embedding(span_matrix).squeeze(-1) 57 | 58 | return start_logits, end_logits, span_logits 59 | -------------------------------------------------------------------------------- /models/bert_tagger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_tagger.py 5 | # 6 | 7 | import torch.nn as nn 8 | from transformers import BertModel, BertPreTrainedModel 9 | from models.classifier import BERTTaggerClassifier 10 | 11 | 12 | class BertTagger(BertPreTrainedModel): 13 | def __init__(self, config): 14 | super(BertTagger, self).__init__(config) 15 | self.bert = BertModel(config) 16 | 17 | self.num_labels = config.num_labels 18 | self.hidden_size = config.hidden_size 19 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 20 | if config.classifier_sign == "multi_nonlinear": 21 | self.classifier = BERTTaggerClassifier(self.hidden_size, self.num_labels, 22 | config.classifier_dropout, 23 | act_func=config.classifier_act_func, 24 | intermediate_hidden_size=config.classifier_intermediate_hidden_size) 25 | else: 26 | self.classifier = nn.Linear(self.hidden_size, self.num_labels) 27 | 28 | self.init_weights() 29 | 30 | def forward(self, input_ids, token_type_ids=None, attention_mask=None,): 31 | last_bert_layer, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 32 | last_bert_layer = last_bert_layer.view(-1, self.hidden_size) 33 | last_bert_layer = self.dropout(last_bert_layer) 34 | logits = self.classifier(last_bert_layer) 35 | return logits -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: classifier.py 5 | 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | 10 | class SingleLinearClassifier(nn.Module): 11 | def __init__(self, hidden_size, num_label): 12 | super(SingleLinearClassifier, self).__init__() 13 | self.num_label = num_label 14 | self.classifier = nn.Linear(hidden_size, num_label) 15 | 16 | def forward(self, input_features): 17 | features_output = self.classifier(input_features) 18 | return features_output 19 | 20 | 21 | class MultiNonLinearClassifier(nn.Module): 22 | def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None): 23 | super(MultiNonLinearClassifier, self).__init__() 24 | self.num_label = num_label 25 | self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size 26 | self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size) 27 | self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label) 28 | self.dropout = nn.Dropout(dropout_rate) 29 | self.act_func = act_func 30 | 31 | def forward(self, input_features): 32 | features_output1 = self.classifier1(input_features) 33 | if self.act_func == "gelu": 34 | features_output1 = F.gelu(features_output1) 35 | elif self.act_func == "relu": 36 | features_output1 = F.relu(features_output1) 37 | elif self.act_func == "tanh": 38 | features_output1 = F.tanh(features_output1) 39 | else: 40 | raise ValueError 41 | features_output1 = self.dropout(features_output1) 42 | features_output2 = self.classifier2(features_output1) 43 | return features_output2 44 | 45 | 46 | class BERTTaggerClassifier(nn.Module): 47 | def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None): 48 | super(BERTTaggerClassifier, self).__init__() 49 | self.num_label = num_label 50 | self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size 51 | self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size) 52 | self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label) 53 | self.dropout = nn.Dropout(dropout_rate) 54 | self.act_func = act_func 55 | 56 | def forward(self, input_features): 57 | features_output1 = self.classifier1(input_features) 58 | if self.act_func == "gelu": 59 | features_output1 = F.gelu(features_output1) 60 | elif self.act_func == "relu": 61 | features_output1 = F.relu(features_output1) 62 | elif self.act_func == "tanh": 63 | features_output1 = F.tanh(features_output1) 64 | else: 65 | raise ValueError 66 | features_output1 = self.dropout(features_output1) 67 | features_output2 = self.classifier2(features_output1) 68 | return features_output2 69 | -------------------------------------------------------------------------------- /models/model_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: model_config.py 5 | 6 | from transformers import BertConfig 7 | 8 | 9 | class BertQueryNerConfig(BertConfig): 10 | def __init__(self, **kwargs): 11 | super(BertQueryNerConfig, self).__init__(**kwargs) 12 | self.mrc_dropout = kwargs.get("mrc_dropout", 0.1) 13 | self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024) 14 | self.classifier_act_func = kwargs.get("classifier_act_func", "gelu") 15 | 16 | class BertTaggerConfig(BertConfig): 17 | def __init__(self, **kwargs): 18 | super(BertTaggerConfig, self).__init__(**kwargs) 19 | self.num_labels = kwargs.get("num_labels", 6) 20 | self.classifier_dropout = kwargs.get("classifier_dropout", 0.1) 21 | self.classifier_sign = kwargs.get("classifier_sign", "multi_nonlinear") 22 | self.classifier_act_func = kwargs.get("classifier_act_func", "gelu") 23 | self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024) -------------------------------------------------------------------------------- /ner2mrc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/ner2mrc/__init__.py -------------------------------------------------------------------------------- /ner2mrc/download.md: -------------------------------------------------------------------------------- 1 | ## Download MRC-style NER Datasets 2 | ZH: 3 | - [MSRA](https://drive.google.com/file/d/1bAoSJfT1IBdpbQWSrZPjQPPbAsDGlN2D/view?usp=sharing) 4 | - [OntoNotes4](https://drive.google.com/file/d/1CRVgZJDDGuj0O1NLK5DgujQBTLKyMR-g/view?usp=sharing) 5 | 6 | EN: 7 | - [CoNLL03](https://drive.google.com/file/d/1mGO9CYkgXsV-Et-hSZpOmS0m9G8A5mau/view?usp=sharing) 8 | - [ACE2004](https://drive.google.com/file/d/1U-hGOgLmdqudsRdKIGles1-QrNJ7SSg6/view?usp=sharing) 9 | - [ACE2005](https://drive.google.com/file/d/1iodaJ92dTAjUWnkMyYm8aLEi5hj3cseY/view?usp=sharing) 10 | - [GENIA](https://drive.google.com/file/d/1oF1P8s-0MN9X1M1PlKB2c5aBtxhmoxXb/view?usp=sharing) 11 | 12 | ## Download Sequence Labeling NER Datasets (Based on BMES tagging schema) 13 | ZH: 14 | - [MSRA](https://drive.google.com/file/d/1ytrfNgh53la7bdSNyI1QURAyJ1PniPCe/view?usp=sharing) 15 | - [OntoNotes4](https://drive.google.com/file/d/1FVg3XcW1eaqlikU36df5xl7DC3UNBJDm/view?usp=sharing) 16 | 17 | EN: 18 | - [CoNLL03](https://drive.google.com/file/d/1PUH2uw6lkWrWGfl-9wOAG13lvPrvKO25/view?usp=sharing) 19 | -------------------------------------------------------------------------------- /ner2mrc/genia2mrc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: genia2mrc.py 5 | 6 | import os 7 | import json 8 | 9 | 10 | def convert_file(input_file, output_file, tag2query_file): 11 | """ 12 | Convert GENIA data to MRC format 13 | """ 14 | all_data = json.load(open(input_file)) 15 | tag2query = json.load(open(tag2query_file)) 16 | 17 | output = [] 18 | origin_count = 0 19 | new_count = 0 20 | 21 | for data in all_data: 22 | origin_count += 1 23 | context = data["context"] 24 | label2positions = data["label"] 25 | for tag_idx, (tag, query) in enumerate(tag2query.items()): 26 | positions = label2positions.get(tag, []) 27 | mrc_sample = { 28 | "context": context, 29 | "query": query, 30 | "start_position": [int(x.split(";")[0]) for x in positions], 31 | "end_position": [int(x.split(";")[1]) for x in positions], 32 | "qas_id": f"{origin_count}.{tag_idx}" 33 | } 34 | output.append(mrc_sample) 35 | new_count += 1 36 | 37 | json.dump(output, open(output_file, "w"), ensure_ascii=False, indent=2) 38 | print(f"Convert {origin_count} samples to {new_count} samples and save to {output_file}") 39 | 40 | 41 | def main(): 42 | genia_raw_dir = "/mnt/mrc/genia/genia_raw" 43 | genia_mrc_dir = "/mnt/mrc/genia/genia_raw/mrc_format" 44 | tag2query_file = "queries/genia.json" 45 | os.makedirs(genia_mrc_dir, exist_ok=True) 46 | for phase in ["train", "dev", "test"]: 47 | old_file = os.path.join(genia_raw_dir, f"{phase}.genia.json") 48 | new_file = os.path.join(genia_mrc_dir, f"mrc-ner.{phase}") 49 | convert_file(old_file, new_file, tag2query_file) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /ner2mrc/msra2mrc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: msra2mrc.py 5 | 6 | import os 7 | from utils.bmes_decode import bmes_decode 8 | import json 9 | 10 | 11 | def convert_file(input_file, output_file, tag2query_file): 12 | """ 13 | Convert MSRA raw data to MRC format 14 | """ 15 | origin_count = 0 16 | new_count = 0 17 | tag2query = json.load(open(tag2query_file)) 18 | mrc_samples = [] 19 | with open(input_file) as fin: 20 | for line in fin: 21 | line = line.strip() 22 | if not line: 23 | continue 24 | origin_count += 1 25 | src, labels = line.split("\t") 26 | tags = bmes_decode(char_label_list=[(char, label) for char, label in zip(src.split(), labels.split())]) 27 | for label, query in tag2query.items(): 28 | mrc_samples.append( 29 | { 30 | "context": src, 31 | "start_position": [tag.begin for tag in tags if tag.tag == label], 32 | "end_position": [tag.end-1 for tag in tags if tag.tag == label], 33 | "query": query 34 | } 35 | ) 36 | new_count += 1 37 | 38 | json.dump(mrc_samples, open(output_file, "w"), ensure_ascii=False, sort_keys=True, indent=2) 39 | print(f"Convert {origin_count} samples to {new_count} samples and save to {output_file}") 40 | 41 | 42 | def main(): 43 | msra_raw_dir = "/mnt/mrc/zh_msra_yuxian" 44 | msra_mrc_dir = "/mnt/mrc/zh_msra_yuxian/mrc_format" 45 | tag2query_file = "queries/zh_msra.json" 46 | os.makedirs(msra_mrc_dir, exist_ok=True) 47 | for phase in ["train", "dev", "test"]: 48 | old_file = os.path.join(msra_raw_dir, f"{phase}.tsv") 49 | new_file = os.path.join(msra_mrc_dir, f"mrc-ner.{phase}") 50 | convert_file(old_file, new_file, tag2query_file) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /ner2mrc/queries/genia.json: -------------------------------------------------------------------------------- 1 | { 2 | "DNA": "deoxyribonucleic acid", 3 | "RNA": "ribonucleic acid", 4 | "cell_line": "cell line", 5 | "cell_type": "cell type", 6 | "protein": "protein entities are limited to nitrogenous organic compounds and are parts of all living organisms, as structural components of body tissues such as muscle, hair, collagen and as enzymes and antibodies." 7 | } 8 | -------------------------------------------------------------------------------- /ner2mrc/queries/zh_msra.json: -------------------------------------------------------------------------------- 1 | { 2 | "NR": "人名和虚构的人物形象", 3 | "NS": "按照地理位置划分的国家,城市,乡镇,大洲", 4 | "NT": "组织包括公司,政府党派,学校,政府,新闻机构" 5 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==0.9.0 2 | tokenizers==0.9.3 3 | transformers==3.5.1 -------------------------------------------------------------------------------- /scripts/bert_tagger/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: evaluate.sh 5 | 6 | 7 | REPO_PATH=/home/lixiaoya/mrc-for-flat-nested-ner 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | 10 | OUTPUT_DIR=/data/lixiaoya/outputs/mrc_ner_baseline/0909/onto5_bert_tagger10_lr2e-5_drop_norm1.0_weight_warmup0.01_maxlen512 11 | # find best checkpoint on dev in ${OUTPUT_DIR}/train_log.txt 12 | BEST_CKPT_DEV=${OUTPUT_DIR}/epoch=25_v1.ckpt 13 | PYTORCHLIGHT_HPARAMS=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml 14 | GPU_ID=0,1 15 | MAX_LEN=220 16 | 17 | python3 ${REPO_PATH}/evaluate/tagger_ner_evaluate.py ${BEST_CKPT_DEV} ${PYTORCHLIGHT_HPARAMS} ${GPU_ID} ${MAX_LEN} -------------------------------------------------------------------------------- /scripts/bert_tagger/inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: inference.sh 5 | 6 | 7 | REPO_PATH=/home/lixiaoya/mrc-for-flat-nested-ner 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | 10 | DATA_SIGN=en_onto 11 | DATA_DIR=/data/lixiaoya/datasets/bmes_ner/en_ontonotes5 12 | BERT_DIR=/data/lixiaoya/models/bert_cased_large 13 | MAX_LEN=200 14 | OUTPUT_DIR=/data/lixiaoya/outputs/mrc_ner_baseline/0909/onto5_bert_tagger10_lr2e-5_drop_norm1.0_weight_warmup0.01_maxlen512 15 | MODEL_CKPT=${OUTPUT_DIR}/epoch=25_v1.ckpt 16 | HPARAMS_FILE=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml 17 | DATA_SUFFIX=.word.bmes 18 | 19 | python3 ${REPO_PATH}/inference/tagger_ner_inference.py \ 20 | --data_dir ${DATA_DIR} \ 21 | --bert_dir ${BERT_DIR} \ 22 | --max_length ${MAX_LEN} \ 23 | --model_ckpt ${MODEL_CKPT} \ 24 | --hparams_file ${HPARAMS_FILE} \ 25 | --dataset_sign ${DATA_SIGN} \ 26 | --data_file_suffix ${DATA_SUFFIX} -------------------------------------------------------------------------------- /scripts/bert_tagger/reproduce/conll03.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: conll03.sh 5 | # dev Span-F1: 96.61 6 | # test Span-F1: 92.74 7 | 8 | TIME=0824 9 | FILE=conll03_bert_tagger 10 | REPO_PATH=/home/lixiaoya/mrc-for-flat-nested-ner 11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 12 | DATA_DIR=/data/lixiaoya/datasets/ner/en_conll03 13 | BERT_DIR=/data/lixiaoya/models/bert_cased_large 14 | 15 | BERT_DROPOUT=0.2 16 | LR=2e-5 17 | LR_SCHEDULER=polydecay 18 | MAXLEN=256 19 | MAXNORM=1.0 20 | DATA_SUFFIX=.word.bmes 21 | TRAIN_BATCH_SIZE=8 22 | GRAD_ACC=4 23 | MAX_EPOCH=10 24 | WEIGHT_DECAY=0.02 25 | OPTIM=torch.adam 26 | DATA_SIGN=en_conll03 27 | WARMUP_PROPORTION=0.01 28 | INTER_HIDDEN=1024 29 | 30 | OUTPUT_DIR=/data/lixiaoya/outputs/mrc_ner_baseline/${TIME}/${FILE}_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP_PROPORTION}_maxlen${MAXLEN} 31 | mkdir -p ${OUTPUT_DIR} 32 | 33 | 34 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/train/bert_tagger_trainer.py \ 35 | --gpus="1" \ 36 | --progress_bar_refresh_rate 1 \ 37 | --data_dir ${DATA_DIR} \ 38 | --bert_config_dir ${BERT_DIR} \ 39 | --max_length ${MAXLEN} \ 40 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 41 | --precision=16 \ 42 | --lr_scheduler ${LR_SCHEDULER} \ 43 | --lr ${LR} \ 44 | --val_check_interval 0.25 \ 45 | --accumulate_grad_batches ${GRAD_ACC} \ 46 | --output_dir ${OUTPUT_DIR} \ 47 | --max_epochs ${MAX_EPOCH} \ 48 | --warmup_proportion ${WARMUP_PROPORTION} \ 49 | --max_length ${MAXLEN} \ 50 | --gradient_clip_val ${MAXNORM} \ 51 | --weight_decay ${WEIGHT_DECAY} \ 52 | --data_file_suffix ${DATA_SUFFIX} \ 53 | --optimizer ${OPTIM} \ 54 | --data_sign ${DATA_SIGN} \ 55 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} 56 | 57 | -------------------------------------------------------------------------------- /scripts/bert_tagger/reproduce/msra.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: msra.sh 5 | # dev Span-F1: 94.02 6 | # test Span-F1: 95.15 7 | 8 | TIME=0826 9 | FILE=msra_bert_tagger 10 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 12 | DATA_DIR=/userhome/xiaoya/dataset/tagger_ner_datasets/msra 13 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 14 | 15 | 16 | BERT_DROPOUT=0.2 17 | LR=2e-5 18 | LR_SCHEDULER=polydecay 19 | MAXLEN=256 20 | MAXNORM=1.0 21 | DATA_SUFFIX=.char.bmes 22 | TRAIN_BATCH_SIZE=8 23 | GRAD_ACC=4 24 | MAX_EPOCH=10 25 | WEIGHT_DECAY=0.02 26 | OPTIM=torch.adam 27 | DATA_SIGN=zh_msra 28 | WARMUP_PROPORTION=0.02 29 | INTER_HIDDEN=768 30 | 31 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner_baseline/${TIME}/${FILE}_chinese_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 32 | mkdir -p ${OUTPUT_DIR} 33 | 34 | 35 | CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/train/bert_tagger_trainer.py \ 36 | --gpus="1" \ 37 | --progress_bar_refresh_rate 1 \ 38 | --data_dir ${DATA_DIR} \ 39 | --bert_config_dir ${BERT_DIR} \ 40 | --max_length ${MAXLEN} \ 41 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 42 | --precision=16 \ 43 | --lr_scheduler ${LR_SCHEDULER} \ 44 | --lr ${LR} \ 45 | --val_check_interval 0.25 \ 46 | --accumulate_grad_batches ${GRAD_ACC} \ 47 | --output_dir ${OUTPUT_DIR} \ 48 | --max_epochs ${MAX_EPOCH} \ 49 | --warmup_proportion ${WARMUP_PROPORTION} \ 50 | --max_length ${MAXLEN} \ 51 | --gradient_clip_val ${MAXNORM} \ 52 | --weight_decay ${WEIGHT_DECAY} \ 53 | --data_file_suffix ${DATA_SUFFIX} \ 54 | --optimizer ${OPTIM} \ 55 | --data_sign ${DATA_SIGN} \ 56 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \ 57 | --chinese -------------------------------------------------------------------------------- /scripts/bert_tagger/reproduce/onto4.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: onto4.sh 5 | # dev Span-F1: 79.18 6 | # test Span-F1: 80.35 7 | 8 | TIME=0826 9 | FILE=onto_bert_tagger 10 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 12 | DATA_DIR=/userhome/xiaoya/dataset/tagger_ner_datasets/zhontonotes4 13 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert 14 | 15 | BERT_DROPOUT=0.2 16 | LR=2e-5 17 | LR_SCHEDULER=polydecay 18 | MAXLEN=256 19 | MAXNORM=1.0 20 | DATA_SUFFIX=.char.bmes 21 | TRAIN_BATCH_SIZE=8 22 | GRAD_ACC=4 23 | MAX_EPOCH=10 24 | WEIGHT_DECAY=0.02 25 | OPTIM=torch.adam 26 | DATA_SIGN=zh_onto 27 | WARMUP_PROPORTION=0.02 28 | INTER_HIDDEN=768 29 | 30 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner_baseline/${TIME}/${FILE}_chinese_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 31 | mkdir -p ${OUTPUT_DIR} 32 | 33 | 34 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/train/bert_tagger_trainer.py \ 35 | --gpus="1" \ 36 | --progress_bar_refresh_rate 1 \ 37 | --data_dir ${DATA_DIR} \ 38 | --bert_config_dir ${BERT_DIR} \ 39 | --max_length ${MAXLEN} \ 40 | --train_batch_size ${TRAIN_BATCH_SIZE} \ 41 | --precision=32 \ 42 | --lr_scheduler ${LR_SCHEDULER} \ 43 | --lr ${LR} \ 44 | --val_check_interval 0.25 \ 45 | --accumulate_grad_batches ${GRAD_ACC} \ 46 | --output_dir ${OUTPUT_DIR} \ 47 | --max_epochs ${MAX_EPOCH} \ 48 | --warmup_proportion ${WARMUP_PROPORTION} \ 49 | --max_length ${MAXLEN} \ 50 | --gradient_clip_val ${MAXNORM} \ 51 | --weight_decay ${WEIGHT_DECAY} \ 52 | --data_file_suffix ${DATA_SUFFIX} \ 53 | --optimizer ${OPTIM} \ 54 | --data_sign ${DATA_SIGN} \ 55 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \ 56 | --chinese -------------------------------------------------------------------------------- /scripts/mrc_ner/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: eval.sh 5 | 6 | REPO_PATH=/data/xiaoya/workspace/mrc-for-flat-nested-ner 7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 8 | 9 | OUTPUT_DIR=/data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100 10 | # find best checkpoint on dev in ${OUTPUT_DIR}/train_log.txt 11 | BEST_CKPT_DEV=${OUTPUT_DIR}/epoch=8.ckpt 12 | PYTORCHLIGHT_HPARAMS=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml 13 | GPU_ID=0,1 14 | 15 | python3 ${REPO_PATH}/evaluate/mrc_ner_evaluate.py ${BEST_CKPT_DEV} ${PYTORCHLIGHT_HPARAMS} ${GPU_ID} -------------------------------------------------------------------------------- /scripts/mrc_ner/flat_inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: flat_inference.sh 5 | # 6 | 7 | REPO_PATH=/data/xiaoya/workspace/mrc-for-flat-nested-ner-github 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | 10 | DATA_SIGN=conll03 11 | DATA_DIR=/data/xiaoya/datasets/mrc_ner_datasets/en_conll03_truecase_sent 12 | BERT_DIR=/data/xiaoya/models/uncased_L-12_H-768_A-12 13 | MAX_LEN=180 14 | MODEL_CKPT=/data/xiaoya/outputs/mrc_ner/conll03/large_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen180/epoch=1_v7.ckpt 15 | HPARAMS_FILE=/data/xiaoya/outputs/mrc_ner/conll03/large_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen180/lightning_logs/version_0/hparams.yaml 16 | 17 | 18 | python3 ${REPO_PATH}/inference/mrc_ner_inference.py \ 19 | --data_dir ${DATA_DIR} \ 20 | --bert_dir ${BERT_DIR} \ 21 | --max_length ${MAX_LEN} \ 22 | --model_ckpt ${MODEL_CKPT} \ 23 | --hparams_file ${HPARAMS_FILE} \ 24 | --flat_ner \ 25 | --dataset_sign ${DATA_SIGN} -------------------------------------------------------------------------------- /scripts/mrc_ner/nested_inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: nested_inference.sh 5 | # 6 | 7 | REPO_PATH=/data/xiaoya/workspace/mrc-for-flat-nested-ner-github 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | 10 | DATA_SIGN=ace04 11 | DATA_DIR=/data/xiaoya/datasets/mrc-for-flat-nested-ner/ace2004 12 | BERT_DIR=/data/xiaoya/models/uncased_L-12_H-768_A-12 13 | MAX_LEN=100 14 | MODEL_CKPT=/data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/epoch=0.ckpt 15 | HPARAMS_FILE=/data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/lightning_logs/version_3/hparams.yaml 16 | 17 | 18 | python3 ${REPO_PATH}/inference/mrc_ner_inference.py \ 19 | --data_dir ${DATA_DIR} \ 20 | --bert_dir ${BERT_DIR} \ 21 | --max_length ${MAX_LEN} \ 22 | --model_ckpt ${MODEL_CKPT} \ 23 | --hparams_file ${HPARAMS_FILE} \ 24 | --dataset_sign ${DATA_SIGN} -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/ace04.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: ace04.sh 5 | 6 | TIME=0910 7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | DATA_DIR=/userhome/xiaoya/dataset/ace2004 10 | BERT_DIR=/userhome/xiaoya/bert/bert_uncased_large 11 | 12 | BERT_DROPOUT=0.1 13 | MRC_DROPOUT=0.3 14 | LR=3e-5 15 | SPAN_WEIGHT=0.1 16 | WARMUP=0 17 | MAXLEN=128 18 | MAXNORM=1.0 19 | INTER_HIDDEN=2048 20 | 21 | BATCH_SIZE=4 22 | PREC=16 23 | VAL_CKPT=0.25 24 | ACC_GRAD=2 25 | MAX_EPOCH=20 26 | SPAN_CANDI=pred_and_gold 27 | PROGRESS_BAR=1 28 | 29 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/${TIME}/ace2004/large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 30 | mkdir -p ${OUTPUT_DIR} 31 | 32 | CUDA_VISIBLE_DEVICES=0,1 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 33 | --gpus="2" \ 34 | --distributed_backend=ddp \ 35 | --workers 0 \ 36 | --data_dir ${DATA_DIR} \ 37 | --bert_config_dir ${BERT_DIR} \ 38 | --max_length ${MAXLEN} \ 39 | --batch_size ${BATCH_SIZE} \ 40 | --precision=${PREC} \ 41 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 42 | --lr ${LR} \ 43 | --val_check_interval ${VAL_CKPT} \ 44 | --accumulate_grad_batches ${ACC_GRAD} \ 45 | --default_root_dir ${OUTPUT_DIR} \ 46 | --mrc_dropout ${MRC_DROPOUT}\ 47 | --bert_dropout ${BERT_DROPOUT} \ 48 | --max_epochs ${MAX_EPOCH} \ 49 | --span_loss_candidates ${SPAN_CANDI} \ 50 | --weight_span ${SPAN_WEIGHT} \ 51 | --warmup_steps ${WARMUP} \ 52 | --gradient_clip_val ${MAXNORM} \ 53 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} 54 | 55 | -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/ace05.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: ace05.sh 5 | 6 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 8 | 9 | DATA_DIR=/userhome/xiaoya/dataset/ace2005 10 | BERT_DIR=/userhome/xiaoya/bert/bert_uncased_large 11 | 12 | BERT_DROPOUT=0.1 13 | MRC_DROPOUT=0.3 14 | LR=2e-5 15 | SPAN_WEIGHT=0.1 16 | WARMUP=0 17 | MAXLEN=128 18 | MAXNORM=1.0 19 | INTER_HIDDEN=2048 20 | 21 | BATCH_SIZE=8 22 | PREC=16 23 | VAL_CKPT=0.25 24 | ACC_GRAD=1 25 | MAX_EPOCH=20 26 | SPAN_CANDI=pred_and_gold 27 | PROGRESS_BAR=1 28 | OPTIM=adamw 29 | 30 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/ace2005/warmup${WARMUP}lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 31 | mkdir -p ${OUTPUT_DIR} 32 | 33 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 34 | --gpus="4" \ 35 | --distributed_backend=ddp \ 36 | --workers 0 \ 37 | --data_dir ${DATA_DIR} \ 38 | --bert_config_dir ${BERT_DIR} \ 39 | --max_length ${MAXLEN} \ 40 | --batch_size ${BATCH_SIZE} \ 41 | --precision=${PREC} \ 42 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 43 | --lr ${LR} \ 44 | --val_check_interval ${VAL_CKPT} \ 45 | --accumulate_grad_batches ${ACC_GRAD} \ 46 | --default_root_dir ${OUTPUT_DIR} \ 47 | --mrc_dropout ${MRC_DROPOUT} \ 48 | --bert_dropout ${BERT_DROPOUT} \ 49 | --max_epochs ${MAX_EPOCH} \ 50 | --span_loss_candidates ${SPAN_CANDI} \ 51 | --weight_span ${SPAN_WEIGHT} \ 52 | --warmup_steps ${WARMUP} \ 53 | --gradient_clip_val ${MAXNORM} \ 54 | --optimizer ${OPTIM} \ 55 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/conll03.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | TIME=0901 6 | FILE=conll03_cased_large 7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | 10 | DATA_DIR=/userhome/xiaoya/dataset/en_conll03 11 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large 12 | OUTPUT_BASE=/userhome/xiaoya/outputs 13 | 14 | BATCH=10 15 | GRAD_ACC=4 16 | BERT_DROPOUT=0.1 17 | MRC_DROPOUT=0.3 18 | LR=3e-5 19 | LR_MINI=3e-7 20 | LR_SCHEDULER=polydecay 21 | SPAN_WEIGHT=0.1 22 | WARMUP=0 23 | MAX_LEN=200 24 | MAX_NORM=1.0 25 | MAX_EPOCH=20 26 | INTER_HIDDEN=2048 27 | WEIGHT_DECAY=0.01 28 | OPTIM=torch.adam 29 | VAL_CHECK=0.2 30 | PREC=16 31 | SPAN_CAND=pred_and_gold 32 | 33 | 34 | OUTPUT_DIR=${OUTPUT_BASE}/mrc_ner/${TIME}/${FILE}_cased_large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 35 | mkdir -p ${OUTPUT_DIR} 36 | 37 | 38 | CUDA_VISIBLE_DEVICES=0,1 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 39 | --data_dir ${DATA_DIR} \ 40 | --bert_config_dir ${BERT_DIR} \ 41 | --max_length ${MAX_LEN} \ 42 | --batch_size ${BATCH} \ 43 | --gpus="2" \ 44 | --precision=${PREC} \ 45 | --progress_bar_refresh_rate 1 \ 46 | --lr ${LR} \ 47 | --val_check_interval ${VAL_CHECK} \ 48 | --accumulate_grad_batches ${GRAD_ACC} \ 49 | --default_root_dir ${OUTPUT_DIR} \ 50 | --mrc_dropout ${MRC_DROPOUT} \ 51 | --bert_dropout ${BERT_DROPOUT} \ 52 | --max_epochs ${MAX_EPOCH} \ 53 | --span_loss_candidates ${SPAN_CAND} \ 54 | --weight_span ${SPAN_WEIGHT} \ 55 | --warmup_steps ${WARMUP} \ 56 | --distributed_backend=ddp \ 57 | --max_length ${MAX_LEN} \ 58 | --gradient_clip_val ${MAX_NORM} \ 59 | --weight_decay ${WEIGHT_DECAY} \ 60 | --optimizer ${OPTIM} \ 61 | --lr_scheduler ${LR_SCHEDULER} \ 62 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \ 63 | --flat \ 64 | --lr_mini ${LR_MINI} 65 | 66 | -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/genia.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: genia.sh 5 | 6 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 8 | 9 | DATA_DIR=/userhome/xiaoya/dataset/genia 10 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large 11 | BERT_DROPOUT=0.2 12 | MRC_DROPOUT=0.2 13 | LR=2e-5 14 | SPAN_WEIGHT=0.1 15 | WARMUP=0 16 | MAXLEN=180 17 | MAXNORM=1.0 18 | INTER_HIDDEN=2048 19 | 20 | BATCH_SIZE=8 21 | PREC=16 22 | VAL_CKPT=0.25 23 | ACC_GRAD=4 24 | MAX_EPOCH=20 25 | SPAN_CANDI=pred_and_gold 26 | PROGRESS_BAR=1 27 | WEIGHT_DECAY=0.002 28 | 29 | OUTPUT_DIR=/userhome/xiaoya/outputs/github_mrc/genia/large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_bsz32_hard_span_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 30 | mkdir -p ${OUTPUT_DIR} 31 | 32 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 33 | --gpus="4" \ 34 | --distributed_backend=ddp \ 35 | --workers 0 \ 36 | --data_dir ${DATA_DIR} \ 37 | --bert_config_dir ${BERT_DIR} \ 38 | --max_length ${MAXLEN} \ 39 | --batch_size ${BATCH_SIZE} \ 40 | --precision=${PREC} \ 41 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 42 | --lr ${LR} \ 43 | --val_check_interval ${VAL_CKPT} \ 44 | --accumulate_grad_batches ${ACC_GRAD} \ 45 | --default_root_dir ${OUTPUT_DIR} \ 46 | --mrc_dropout ${MRC_DROPOUT} \ 47 | --bert_dropout ${BERT_DROPOUT} \ 48 | --max_epochs ${MAX_EPOCH} \ 49 | --span_loss_candidates ${SPAN_CANDI} \ 50 | --weight_span ${SPAN_WEIGHT} \ 51 | --warmup_steps ${WARMUP} \ 52 | --gradient_clip_val ${MAXNORM} \ 53 | --weight_decay ${WEIGHT_DECAY} \ 54 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} 55 | 56 | -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/kbp17.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: genia.sh 5 | 6 | TIME=0901 7 | FILE=kbp17_bert_large 8 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 9 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 10 | 11 | DATA_DIR=/userhome/xiaoya/dataset/kbp17 12 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large 13 | BERT_DROPOUT=0.2 14 | MRC_DROPOUT=0.2 15 | LR=2e-5 16 | SPAN_WEIGHT=0.1 17 | WARMUP=0 18 | MAXLEN=180 19 | MAXNORM=1.0 20 | INTER_HIDDEN=2048 21 | 22 | BATCH_SIZE=4 23 | PREC=16 24 | VAL_CKPT=0.25 25 | ACC_GRAD=4 26 | MAX_EPOCH=6 27 | SPAN_CANDI=pred_and_gold 28 | PROGRESS_BAR=1 29 | WEIGHT_DECAY=0.002 30 | 31 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/${TIME}/${FILE}_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 32 | mkdir -p ${OUTPUT_DIR} 33 | 34 | 35 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 36 | --gpus="4" \ 37 | --distributed_backend=ddp \ 38 | --workers 0 \ 39 | --data_dir ${DATA_DIR} \ 40 | --bert_config_dir ${BERT_DIR} \ 41 | --max_length ${MAXLEN} \ 42 | --batch_size ${BATCH_SIZE} \ 43 | --precision=${PREC} \ 44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 45 | --lr ${LR} \ 46 | --val_check_interval ${VAL_CKPT} \ 47 | --accumulate_grad_batches ${ACC_GRAD} \ 48 | --default_root_dir ${OUTPUT_DIR} \ 49 | --mrc_dropout ${MRC_DROPOUT} \ 50 | --bert_dropout ${BERT_DROPOUT} \ 51 | --max_epochs ${MAX_EPOCH} \ 52 | --span_loss_candidates ${SPAN_CANDI} \ 53 | --weight_span ${SPAN_WEIGHT} \ 54 | --warmup_steps ${WARMUP} \ 55 | --gradient_clip_val ${MAXNORM} \ 56 | --weight_decay ${WEIGHT_DECAY} \ 57 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} 58 | 59 | -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/msra.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: msra.sh 5 | 6 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 8 | export TOKENIZERS_PARALLELISM=false 9 | 10 | DATA_DIR=/mnt/mrc/zh_msra 11 | BERT_DIR=/mnt/mrc/chinese_roberta_wwm_large_ext_pytorch 12 | SPAN_WEIGHT=0.1 13 | DROPOUT=0.2 14 | LR=8e-6 15 | MAXLEN=128 16 | INTER_HIDDEN=1536 17 | 18 | BATCH_SIZE=4 19 | PREC=16 20 | VAL_CKPT=0.25 21 | ACC_GRAD=1 22 | MAX_EPOCH=20 23 | SPAN_CANDI=pred_and_gold 24 | PROGRESS_BAR=1 25 | 26 | OUTPUT_DIR=/mnt/mrc/train_logs/zh_msra/zh_msra_bertlarge_lr${LR}20200913_dropout${DROPOUT}_maxlen${MAXLEN} 27 | 28 | mkdir -p ${OUTPUT_DIR} 29 | 30 | CUDA_VISIBLE_DEVICES=0,1 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 31 | --gpus="2" \ 32 | --distributed_backend=ddp \ 33 | --data_dir ${DATA_DIR} \ 34 | --bert_config_dir ${BERT_DIR} \ 35 | --max_length ${MAXLEN} \ 36 | --batch_size ${BATCH_SIZE} \ 37 | --precision=${PREC} \ 38 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 39 | --lr ${LR} \ 40 | --val_check_interval ${VAL_CKPT} \ 41 | --accumulate_grad_batches ${ACC_GRAD} \ 42 | --default_root_dir ${OUTPUT_DIR} \ 43 | --mrc_dropout ${DROPOUT} \ 44 | --max_epochs ${MAX_EPOCH} \ 45 | --weight_span ${SPAN_WEIGHT} \ 46 | --span_loss_candidates ${SPAN_CANDI} \ 47 | --chinese \ 48 | --workers 0 \ 49 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} 50 | 51 | -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/onto4.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: onto4.sh 5 | # desc: for chinese ontonotes 04 6 | 7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | export TOKENIZERS_PARALLELISM=false 10 | 11 | DATA_DIR=/userhome/yuxian/data/zh_onto4 12 | BERT_DIR=/userhome/yuxian/data/chinese_roberta_wwm_large_ext_pytorch 13 | 14 | MAXLENGTH=128 15 | WEIGHT_SPAN=0.1 16 | LR=1e-5 17 | OPTIMIZER=adamw 18 | INTER_HIDDEN=1536 19 | OUTPUT_DIR=/userhome/yuxian/train_logs/zh_onto/zh_onto_${OPTIMIZER}_lr${lr}_maxlen${MAXLENGTH}_spanw${WEIGHT_SPAN} 20 | mkdir -p ${OUTPUT_DIR} 21 | 22 | BATCH_SIZE=8 23 | PREC=16 24 | VAL_CKPT=0.25 25 | ACC_GRAD=1 26 | MAX_EPOCH=10 27 | SPAN_CANDI=pred_and_gold 28 | PROGRESS_BAR=1 29 | 30 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 31 | --data_dir ${DATA_DIR} \ 32 | --bert_config_dir ${BERT_DIR} \ 33 | --max_length ${MAXLENGTH} \ 34 | --batch_size ${BATCH_SIZE} \ 35 | --gpus="4" \ 36 | --precision=${PREC} \ 37 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 38 | --lr ${LR} \ 39 | --workers 0 \ 40 | --distributed_backend=ddp \ 41 | --val_check_interval ${VAL_CKPT} \ 42 | --accumulate_grad_batches ${ACC_GRAD} \ 43 | --default_root_dir ${OUTPUT_DIR} \ 44 | --max_epochs ${MAX_EPOCH} \ 45 | --span_loss_candidates ${SPAN_CANDI} \ 46 | --weight_span ${WEIGHT_SPAN} \ 47 | --mrc_dropout 0.3 \ 48 | --chinese \ 49 | --warmup_steps 5000 \ 50 | --gradient_clip_val 5.0 \ 51 | --final_div_factor 20 \ 52 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} 53 | 54 | -------------------------------------------------------------------------------- /scripts/mrc_ner/reproduce/onto5.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | TIME=0901 6 | FILE=onto5_mrc_cased_large 7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner 8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH" 9 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_en_onto5 10 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large 11 | 12 | BERT_DROPOUT=0.2 13 | MRC_DROPOUT=0.2 14 | LR=2e-5 15 | LR_MINI=3e-7 16 | LR_SCHEDULER=polydecay 17 | SPAN_WEIGHT=0.1 18 | WARMUP=200 19 | MAXLEN=210 20 | MAXNORM=1.0 21 | INTER_HIDDEN=2048 22 | 23 | BATCH_SIZE=4 24 | PREC=16 25 | VAL_CKPT=0.2 26 | ACC_GRAD=5 27 | MAX_EPOCH=10 28 | SPAN_CANDI=pred_and_gold 29 | PROGRESS_BAR=1 30 | WEIGHT_DECAY=0.01 31 | OPTIM=torch.adam 32 | 33 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/${TIME}/${FILE}_cased_large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN} 34 | mkdir -p ${OUTPUT_DIR} 35 | 36 | 37 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python ${REPO_PATH}/train/mrc_ner_trainer.py \ 38 | --data_dir ${DATA_DIR} \ 39 | --bert_config_dir ${BERT_DIR} \ 40 | --max_length ${MAXLEN} \ 41 | --batch_size ${BATCH_SIZE} \ 42 | --gpus="4" \ 43 | --precision=${PREC} \ 44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \ 45 | --lr ${LR} \ 46 | --distributed_backend=ddp \ 47 | --val_check_interval ${VAL_CKPT} \ 48 | --accumulate_grad_batches ${ACC_GRAD} \ 49 | --default_root_dir ${OUTPUT_DIR} \ 50 | --mrc_dropout ${MRC_DROPOUT} \ 51 | --bert_dropout ${BERT_DROPOUT} \ 52 | --max_epochs ${MAX_EPOCH} \ 53 | --span_loss_candidates ${SPAN_CANDI} \ 54 | --weight_span ${SPAN_WEIGHT} \ 55 | --warmup_steps ${WARMUP} \ 56 | --max_length ${MAXLEN} \ 57 | --gradient_clip_val ${MAXNORM} \ 58 | --weight_decay ${WEIGHT_DECAY} \ 59 | --flat \ 60 | --optimizer ${OPTIM} \ 61 | --lr_scheduler ${LR_SCHEDULER} \ 62 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \ 63 | --lr_mini ${LR_MINI} 64 | 65 | 66 | -------------------------------------------------------------------------------- /tests/assert_correct_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: assert_correct_dataset.py 5 | 6 | import os 7 | import sys 8 | 9 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2]) 10 | print(REPO_PATH) 11 | if REPO_PATH not in sys.path: 12 | sys.path.insert(0, REPO_PATH) 13 | 14 | import json 15 | from datasets.tagger_ner_dataset import load_data_in_conll 16 | from metrics.functional.tagger_span_f1 import get_entity_from_bmes_lst 17 | 18 | 19 | def count_entity_with_mrc_ner_format(data_path): 20 | """ 21 | mrc_data_example: 22 | { 23 | "context": "The chemiluminescent ( CL ) response of interferon - gamma - treated U937 ( IFN - U937 ) cells to sensitized target cells has been used to detect red cell , platelet and granulocyte antibodies .", 24 | "end_position": [ 25 | 16, 26 | 18 27 | ], 28 | "entity_label": "cell_line", 29 | "impossible": false, 30 | "qas_id": "532.1", 31 | "query": "cell line", 32 | "span_position": [ 33 | "14;16", 34 | "7;18" 35 | ], 36 | "start_position": [ 37 | 14, 38 | 7 39 | ] 40 | } 41 | """ 42 | entity_counter = {} 43 | with open(data_path, encoding="utf-8") as f: 44 | data_lst = json.load(f) 45 | 46 | for data_item in data_lst: 47 | tmp_entity_type = data_item["entity_label"] 48 | if len(data_item["end_position"]) != 0: 49 | if tmp_entity_type not in entity_counter.keys(): 50 | entity_counter[tmp_entity_type] = len(data_item["end_position"]) 51 | else: 52 | entity_counter[tmp_entity_type] += len(data_item["end_position"]) 53 | 54 | print(f"mrc: the number of sentences is : {len(data_lst)/len(entity_counter.keys())}") 55 | print("UNDER MRC-NER format -> ") 56 | print(entity_counter) 57 | 58 | 59 | def count_entity_with_sequence_ner_format(data_path, is_nested=False): 60 | entity_counter = {} 61 | if not is_nested: 62 | data_lst = load_data_in_conll(data_path) 63 | print(f"bmes: the number of sentences is : {len(data_lst)}") 64 | label_lst = [label_item[1] for label_item in data_lst] 65 | for label_item in label_lst: 66 | tmp_entity_lst = get_entity_from_bmes_lst(label_item) 67 | for tmp_entity in tmp_entity_lst: 68 | tmp_entity_type = tmp_entity[tmp_entity.index("]")+1:] 69 | if tmp_entity_type not in entity_counter.keys(): 70 | entity_counter[tmp_entity_type] = 1 71 | else: 72 | entity_counter[tmp_entity_type] += 1 73 | print("UNDER SEQ format ->") 74 | print(entity_counter) 75 | else: 76 | # genia, ace04, ace05 77 | pass 78 | 79 | 80 | def main(mrc_data_dir, seq_data_dir, seq_data_suffix="char.bmes", is_nested=False): 81 | for data_type in ["train", "dev", "test"]: 82 | mrc_data_path = os.path.join(mrc_data_dir, f"mrc-ner.{data_type}") 83 | seq_data_path = os.path.join(seq_data_dir, f"{data_type}.{seq_data_suffix}") 84 | 85 | print("$"*10) 86 | print(f"{data_type}") 87 | print("$"*10) 88 | count_entity_with_mrc_ner_format(mrc_data_path) 89 | count_entity_with_sequence_ner_format(seq_data_path, is_nested=is_nested) 90 | print("\n") 91 | 92 | 93 | if __name__ == "__main__": 94 | mrc_data_dir = "/data/lixiaoya/datasets/mrc_ner/en_conll03" 95 | seq_data_dir = "/data/lixiaoya/datasets/bmes_ner/en_conll03" 96 | seq_data_suffix = "word.bmes" 97 | is_nested = False 98 | main(mrc_data_dir, seq_data_dir, seq_data_suffix=seq_data_suffix, is_nested=is_nested) -------------------------------------------------------------------------------- /tests/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_tokenizer.py 5 | 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def tokenize_word(model_path, do_lower_case=False): 10 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, do_lower_case=do_lower_case) 11 | context = "EUROPEAN" 12 | cut_tokens = tokenizer.encode(context, add_special_tokens=False, return_token_type_ids=None) 13 | print(cut_tokens) 14 | print(type(cut_tokens)) 15 | print(type(cut_tokens[0])) 16 | 17 | 18 | if __name__ == "__main__": 19 | model_path = "/data/xiaoya/models/bert_cased_large" 20 | tokenize_word(model_path) 21 | 22 | -------------------------------------------------------------------------------- /tests/collect_entity_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: collect_entity_labels.py 5 | 6 | import os 7 | import sys 8 | 9 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2]) 10 | print(REPO_PATH) 11 | if REPO_PATH not in sys.path: 12 | sys.path.insert(0, REPO_PATH) 13 | 14 | from datasets.tagger_ner_dataset import load_data_in_conll, get_labels 15 | 16 | 17 | def main(data_dir, data_sign, datafile_suffix=".word.bmes"): 18 | label_collection_set = set() 19 | for data_type in ["train", "dev", "test"]: 20 | label_lst = [] 21 | data_path = os.path.join(data_dir, f"{data_type}{datafile_suffix}") 22 | datasets = load_data_in_conll(data_path) 23 | for data_item in datasets: 24 | label_lst.extend(data_item[1]) 25 | 26 | label_collection_set.update(set(label_lst)) 27 | 28 | print("sum the type of labels: ") 29 | print(len(label_collection_set)) 30 | print(label_collection_set) 31 | 32 | print("%"*10) 33 | set_labels = get_labels(data_sign) 34 | print(len(set_labels)) 35 | 36 | 37 | if __name__ == "__main__": 38 | data_dir = "/data/xiaoya/datasets/ner/ontonotes5" 39 | data_sign = "en_onto" 40 | main(data_dir, data_sign) 41 | -------------------------------------------------------------------------------- /tests/count_mrc_max_length.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: count_mrc_max_length.py 5 | 6 | import os 7 | import sys 8 | 9 | REPO_PATH="/".join(os.path.realpath(__file__).split("/")[:-2]) 10 | print(REPO_PATH) 11 | if REPO_PATH not in sys.path: 12 | sys.path.insert(0, REPO_PATH) 13 | 14 | import os 15 | from datasets.collate_functions import collate_to_max_length 16 | from torch.utils.data import DataLoader 17 | from tokenizers import BertWordPieceTokenizer 18 | from datasets.mrc_ner_dataset import MRCNERDataset 19 | 20 | 21 | 22 | def main(): 23 | # en datasets 24 | bert_path = "/data/xiaoya/models/bert_cased_large" 25 | json_path = "/data/xiaoya/datasets/mrc_ner_datasets/en_conll03_truecase_sent/mrc-ner.train" 26 | is_chinese = False 27 | # [test] 28 | # max length is 227 29 | # min length is 12 30 | # avg length is 40.45264986967854 31 | # [dev] 32 | # max length is 212 33 | # min length is 12 34 | # avg length is 43.42584615384615 35 | # [train] 36 | # max length is 201 37 | # min length is 12 38 | # avg length is 41.733423545331526 39 | 40 | vocab_file = os.path.join(bert_path, "vocab.txt") 41 | tokenizer = BertWordPieceTokenizer(vocab_file) 42 | dataset = MRCNERDataset(json_path=json_path, tokenizer=tokenizer, 43 | is_chinese=is_chinese, max_length=10000) 44 | 45 | dataloader = DataLoader(dataset, batch_size=1, 46 | collate_fn=collate_to_max_length) 47 | 48 | length_lst = [] 49 | for batch in dataloader: 50 | for tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx in zip(*batch): 51 | tokens = tokens.tolist() 52 | length_lst.append(len(tokens)) 53 | print(f"max length is {max(length_lst)}") 54 | print(f"min length is {min(length_lst)}") 55 | print(f"avg length is {sum(length_lst)/len(length_lst)}") 56 | 57 | 58 | if __name__ == '__main__': 59 | main() -------------------------------------------------------------------------------- /tests/count_sequence_max_length.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # author: xiaoya li 5 | # file: count_length_autotokenizer.py 6 | 7 | import os 8 | import sys 9 | 10 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2]) 11 | print(repo_path) 12 | if repo_path not in sys.path: 13 | sys.path.insert(0, repo_path) 14 | 15 | 16 | from transformers import AutoTokenizer 17 | from datasets.tagger_ner_dataset import TaggerNERDataset 18 | 19 | 20 | class OntoNotesDataConfig: 21 | def __init__(self): 22 | self.data_dir = "/data/xiaoya/datasets/ner/zhontonotes4" 23 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12" 24 | self.do_lower_case = False 25 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True) 26 | # BertWordPieceTokenizer(os.path.join(self.model_path, "vocab.txt"), lowercase=self.do_lower_case) 27 | self.max_length = 512 28 | self.is_chinese = True 29 | self.threshold = 275 30 | self.data_sign = "zh_onto" 31 | self.data_file_suffix = "char.bmes" 32 | 33 | class ChineseMSRADataConfig: 34 | def __init__(self): 35 | self.data_dir = "/data/xiaoya/datasets/ner/msra" 36 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12" 37 | self.do_lower_case = False 38 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True) 39 | self.max_length = 512 40 | self.is_chinese = True 41 | self.threshold = 275 42 | self.data_sign = "zh_msra" 43 | self.data_file_suffix = "char.bmes" 44 | 45 | 46 | class EnglishOntoDataConfig: 47 | def __init__(self): 48 | self.data_dir = "/data/xiaoya/datasets/ner/ontonotes5" 49 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12" 50 | self.do_lower_case = False 51 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) 52 | self.max_length = 512 53 | self.is_chinese = False 54 | self.threshold = 256 55 | self.data_sign = "en_onto" 56 | self.data_file_suffix = "word.bmes" 57 | 58 | 59 | class EnglishCoNLLDataConfig: 60 | def __init__(self): 61 | self.data_dir = "/data/xiaoya/datasets/ner/conll03_truecase_bmes" 62 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12" 63 | if "uncased" in self.model_path: 64 | self.do_lower_case = True 65 | else: 66 | self.do_lower_case = False 67 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, do_lower_case=self.do_lower_case) 68 | self.max_length = 512 69 | self.is_chinese = False 70 | self.threshold = 275 71 | self.data_sign = "en_conll03" 72 | self.data_file_suffix = "word.bmes" 73 | 74 | class EnglishCoNLL03DocDataConfig: 75 | def __init__(self): 76 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_conll03_doc" 77 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12" 78 | self.do_lower_case = False 79 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) 80 | self.max_length = 512 81 | self.is_chinese = False 82 | self.threshold = 384 83 | self.data_sign = "en_conll03" 84 | self.data_file_suffix = "word.bmes" 85 | 86 | def count_max_length(data_sign): 87 | if data_sign == "zh_onto": 88 | data_config = OntoNotesDataConfig() 89 | elif data_sign == "zh_msra": 90 | data_config = ChineseMSRADataConfig() 91 | elif data_sign == "en_onto": 92 | data_config = EnglishOntoDataConfig() 93 | elif data_sign == "en_conll03": 94 | data_config = EnglishCoNLLDataConfig() 95 | elif data_sign == "en_conll03_doc": 96 | data_config = EnglishCoNLL03DocDataConfig() 97 | else: 98 | raise ValueError 99 | for prefix in ["test", "train", "dev"]: 100 | print("=*"*15) 101 | print(f"INFO -> loading {prefix} data. ") 102 | data_file_path = os.path.join(data_config.data_dir, f"{prefix}.{data_config.data_file_suffix}") 103 | 104 | dataset = TaggerNERDataset(data_file_path, 105 | data_config.tokenizer, 106 | data_sign, 107 | max_length=data_config.max_length, 108 | is_chinese=data_config.is_chinese, 109 | pad_to_maxlen=False,) 110 | max_len = 0 111 | counter = 0 112 | for idx, data_item in enumerate(dataset): 113 | tokens = data_item[0] 114 | num_tokens = tokens.shape[0] 115 | if num_tokens >= max_len: 116 | max_len = num_tokens 117 | if num_tokens > data_config.threshold: 118 | print(num_tokens) 119 | counter += 1 120 | 121 | print(f"INFO -> Max LEN for {prefix} set is : {max_len}") 122 | print(f"INFO -> large than {data_config.threshold} is {counter}") 123 | 124 | 125 | 126 | if __name__ == '__main__': 127 | data_sign = "en_onto" 128 | # english ontonotes 5.0 129 | # test: 172 130 | # dev: 407 131 | # train: 306 132 | count_max_length(data_sign) 133 | 134 | 135 | -------------------------------------------------------------------------------- /tests/extract_entity_span.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: extract_entity_span.py 5 | 6 | 7 | def get_entity_from_bmes_lst(label_list): 8 | """reuse the code block from 9 | https://github.com/jiesutd/NCRFpp/blob/105a53a321eca9c1280037c473967858e01aaa43/utils/metric.py#L73 10 | Many thanks to Jie Yang. 11 | """ 12 | list_len = len(label_list) 13 | begin_label = 'B-' 14 | end_label = 'E-' 15 | single_label = 'S-' 16 | whole_tag = '' 17 | index_tag = '' 18 | tag_list = [] 19 | stand_matrix = [] 20 | for i in range(0, list_len): 21 | if label_list[i] != -100: 22 | current_label = label_list[i].upper() 23 | else: 24 | continue 25 | if begin_label in current_label: 26 | if index_tag != '': 27 | tag_list.append(whole_tag + ',' + str(i-1)) 28 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i) 29 | index_tag = current_label.replace(begin_label,"",1) 30 | elif single_label in current_label: 31 | if index_tag != '': 32 | tag_list.append(whole_tag + ',' + str(i-1)) 33 | whole_tag = current_label.replace(single_label,"",1) +'[' +str(i) 34 | tag_list.append(whole_tag) 35 | whole_tag = "" 36 | index_tag = "" 37 | elif end_label in current_label: 38 | if index_tag != '': 39 | tag_list.append(whole_tag +',' + str(i)) 40 | whole_tag = '' 41 | index_tag = '' 42 | else: 43 | continue 44 | if (whole_tag != '')&(index_tag != ''): 45 | tag_list.append(whole_tag) 46 | tag_list_len = len(tag_list) 47 | 48 | for i in range(0, tag_list_len): 49 | if len(tag_list[i]) > 0: 50 | tag_list[i] = tag_list[i]+ ']' 51 | insert_list = reverse_style(tag_list[i]) 52 | stand_matrix.append(insert_list) 53 | return stand_matrix 54 | 55 | 56 | def reverse_style(input_string): 57 | target_position = input_string.index('[') 58 | input_len = len(input_string) 59 | output_string = input_string[target_position:input_len] + input_string[0:target_position] 60 | return output_string 61 | 62 | 63 | if __name__ == "__main__": 64 | label_lst = ["B-PER", "M-PER", "M-PER", "E-PER", "O", "O", "B-ORG", "M-ORG", "M-ORG", "E-ORG", "B-PER", "M-PER", "M-PER", "M-PER"] 65 | span_results = get_entity_from_bmes_lst(label_lst) 66 | print(span_results) 67 | 68 | label_lst = ["B-PER", "M-PER", -100, -100, "M-PER", "E-PER", -100, "O", "O", -100, "B-ORG", -100, "M-ORG", "M-ORG", "E-ORG", "B-PER", "M-PER", 69 | "M-PER", "M-PER"] 70 | span_results = get_entity_from_bmes_lst(label_lst) 71 | print(span_results) -------------------------------------------------------------------------------- /tests/illegal_entity_boundary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: illegal_entity_boundary.py 5 | 6 | from transformers import AutoTokenizer 7 | 8 | def load_dataexamples(file_path, ): 9 | with open(file_path, "r") as f: 10 | datalines = f.readlines() 11 | 12 | sentence_collections = [] 13 | sentence_label_collections = [] 14 | word_collections = [] 15 | word_label_collections = [] 16 | 17 | for data_item in datalines: 18 | data_item = data_item.strip() 19 | if len(data_item) != 0: 20 | word, label = tuple(data_item.split(" ")) 21 | word_collections.append(word) 22 | word_label_collections.append(label) 23 | else: 24 | sentence_collections.append(word_collections) 25 | sentence_label_collections.append(word_label_collections) 26 | word_collections = [] 27 | word_label_collections = [] 28 | 29 | return sentence_collections, sentence_label_collections 30 | 31 | 32 | def find_data_instance(file_path, search_string): 33 | sentence_collections, sentence_label_collections = load_dataexamples(file_path) 34 | 35 | for sentence_lst, label_lst in zip(sentence_collections, sentence_label_collections): 36 | sentence_str = "".join(sentence_lst) 37 | if search_string in sentence_str: 38 | print(sentence_str) 39 | print("-"*10) 40 | print(sentence_lst) 41 | print(label_lst) 42 | print("=*"*10) 43 | 44 | 45 | def find_illegal_entity(query, context_tokens, labels, model_path, is_chinese=True, do_lower_case=True): 46 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, do_lower_case=do_lower_case) 47 | if is_chinese: 48 | context = "".join(context_tokens) 49 | else: 50 | context = " ".join(context_tokens) 51 | 52 | start_positions = [] 53 | end_positions = [] 54 | origin_tokens = context_tokens 55 | print("check labels in ") 56 | print(len(origin_tokens)) 57 | print(len(labels)) 58 | 59 | for label_idx, label_item in enumerate(labels): 60 | if "B-" in label_item: 61 | start_positions.append(label_idx) 62 | if "S-" in label_item: 63 | end_positions.append(label_idx) 64 | start_positions.append(label_idx) 65 | if "E-" in label_item: 66 | end_positions.append(label_idx) 67 | 68 | print("origin entity tokens") 69 | for start_item, end_item in zip(start_positions, end_positions): 70 | print(origin_tokens[start_item: end_item + 1]) 71 | 72 | query_context_tokens = tokenizer.encode_plus(query, context, 73 | add_special_tokens=True, 74 | max_length=500000, 75 | return_overflowing_tokens=True, 76 | return_token_type_ids=True) 77 | 78 | if tokenizer.pad_token_id in query_context_tokens["input_ids"]: 79 | non_padded_ids = query_context_tokens["input_ids"][ 80 | : query_context_tokens["input_ids"].index(tokenizer.pad_token_id)] 81 | else: 82 | non_padded_ids = query_context_tokens["input_ids"] 83 | 84 | non_pad_tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 85 | first_sep_token = non_pad_tokens.index("[SEP]") 86 | end_sep_token = len(non_pad_tokens) - 1 87 | new_start_positions = [] 88 | new_end_positions = [] 89 | if len(start_positions) != 0: 90 | for start_index, end_index in zip(start_positions, end_positions): 91 | if is_chinese: 92 | answer_text_span = " ".join(context[start_index: end_index + 1]) 93 | else: 94 | answer_text_span = " ".join(context.split(" ")[start_index: end_index + 1]) 95 | new_start, new_end = _improve_answer_span(query_context_tokens["input_ids"], first_sep_token, end_sep_token, 96 | tokenizer, answer_text_span) 97 | new_start_positions.append(new_start) 98 | new_end_positions.append(new_end) 99 | else: 100 | new_start_positions = start_positions 101 | new_end_positions = end_positions 102 | 103 | # clip out-of-boundary entity positions. 104 | new_start_positions = [start_pos for start_pos in new_start_positions if start_pos < 500000] 105 | new_end_positions = [end_pos for end_pos in new_end_positions if end_pos < 500000] 106 | 107 | print("print tokens :") 108 | for start_item, end_item in zip(new_start_positions, new_end_positions): 109 | print(tokenizer.convert_ids_to_tokens(query_context_tokens["input_ids"][start_item: end_item + 1])) 110 | 111 | 112 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text, return_subtoken_start=False): 113 | """Returns tokenized answer spans that better match the annotated answer.""" 114 | doc_tokens = [str(tmp) for tmp in doc_tokens] 115 | answer_tokens = tokenizer.encode(orig_answer_text, add_special_tokens=False) 116 | tok_answer_text = " ".join([str(tmp) for tmp in answer_tokens]) 117 | for new_start in range(input_start, input_end + 1): 118 | for new_end in range(input_end, new_start - 1, -1): 119 | text_span = " ".join(doc_tokens[new_start : (new_end+1)]) 120 | if text_span == tok_answer_text: 121 | if not return_subtoken_start: 122 | return (new_start, new_end) 123 | tokens = tokenizer.convert_ids_to_tokens(doc_tokens[new_start: (new_end + 1)]) 124 | if "##" not in tokens[-1]: 125 | return (new_start, new_end) 126 | else: 127 | for idx in range(len(tokens)-1, -1, -1): 128 | if "##" not in tokens[idx]: 129 | new_end = new_end - (len(tokens)-1 - idx) 130 | return (new_start, new_end) 131 | 132 | return (input_start, input_end) 133 | 134 | 135 | if __name__ == "__main__": 136 | # file_path = "/data/xiaoya/datasets/ner/msra/train.char.bmes" 137 | # search_string = "美亚股份" 138 | # find_data_instance(file_path, search_string) 139 | # 140 | # print("=%"*20) 141 | # print("check entity boundary") 142 | # print("=&"*20) 143 | 144 | print(">>> check for Chinese data example ... ...") 145 | context_tokens = ['1', '美', '亚', '股', '份', '3', '2', '.', '6', '6', '2', '民', '族', '集', '团', '2', '2', '.', '3', 146 | '8', '3', '鲁', '石', '化', 'A', '1', '9', '.', '1', '1', '4', '四', '川', '湖', '山', '1', '7', '.', 147 | '0', '9', '5', '太', '原', '刚', '玉', '1', '0', '.', '5', '8', '1', '咸', '阳', '偏', '转', '1', '6', 148 | '.', '1', '1', '2', '深', '华', '发', 'A', '1', '5', '.', '6', '6', '3', '渝', '开', '发', 'A', '1', 149 | '5', '.', '5', '2', '4', '深', '发', '展', 'A', '1', '3', '.', '8', '9', '5', '深', '纺', '织', 'A', 150 | '1', '3', '.', '2', '2', '1', '太', '极', '实', '业', '2', '3', '.', '2', '2', '2', '友', '好', '集', 151 | '团', '2', '2', '.', '1', '4', '3', '双', '虎', '涂', '料', '2', '0', '.', '2', '0', '4', '新', '潮', 152 | '实', '业', '1', '5', '.', '5', '8', '5', '信', '联', '股', '份', '1', '2', '.', '5', '7', '1', '氯', 153 | '碱', '化', '工', '2', '1', '.', '1', '7', '2', '百', '隆', '股', '份', '1', '5', '.', '6', '4', '3', 154 | '贵', '华', '旅', '业', '1', '5', '.', '1', '5', '4', '南', '洋', '实', '业', '1', '4', '.', '5', '0', 155 | '5', '福', '建', '福', '联', '1', '3', '.', '8', '0'] 156 | 157 | labels = ['O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 158 | 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 159 | 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 160 | 'O', 'B-NT', 'M-NT', 161 | 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 162 | 'O', 163 | 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 164 | 'O', 'O', 'O', 165 | 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 166 | 'O', 'O', 'O', 167 | 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 168 | 'O', 'O', 169 | 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 170 | 'E-NT', 171 | 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 172 | 'M-NT', 'M-NT', 173 | 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 174 | 'B-NT', 175 | 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 176 | 'O'] 177 | query = "组织机构" 178 | 179 | model_path = "/data/nfsdata/nlp/BERT_BASE_DIR/chinese_L-12_H-768_A-12" 180 | find_illegal_entity(query, context_tokens, labels, model_path, is_chinese=True, do_lower_case=True) 181 | 182 | print("$$$$$"*20) 183 | print(">>> check for English data example ... ...") 184 | query = "organization" 185 | context_tokens = ['RUGBY', 'LEAGUE', '-', 'EUROPEAN', 'SUPER', 'LEAGUE', 'RESULTS', '/', 'STANDINGS', '.', 'LONDON', 186 | '1996-08-24', 'Results', 'of', 'European', 'Super', 'League', 'rugby', 'league', 'matches', 'on', 187 | 'Saturday', ':', 'Paris', '14', 'Bradford', '27', 'Wigan', '78', 'Workington', '4', 'Standings', 188 | '(', 'tabulated', 'under', 'played', ',', 'won', ',', 'drawn', ',', 'lost', ',', 'points', 'for', 189 | ',', 'against', ',', 'total', 'points', ')', ':', 'Wigan', '22', '19', '1', '2', '902', '326', 190 | '39', 'St', 'Helens', '21', '19', '0', '2', '884', '441', '38', 'Bradford', '22', '17', '0', '5', 191 | '767', '409', '34', 'Warrington', '21', '12', '0', '9', '555', '499', '24', 'London', '21', '11', 192 | '1', '9', '555', '462', '23', 'Sheffield', '21', '10', '0', '11', '574', '696', '20', 'Halifax', 193 | '21', '9', '1', '11', '603', '552', '19', 'Castleford', '21', '9', '0', '12', '548', '543', '18', 194 | 'Oldham', '21', '8', '1', '12', '439', '656', '17', 'Leeds', '21', '6', '0', '15', '531', '681', 195 | '12', 'Paris', '22', '3', '1', '18', '398', '795', '7', 'Workington', '22', '2', '1', '19', '325', 196 | '1021', '5'] 197 | labels = ['B-MISC', 'E-MISC', 'O', 'B-MISC', 'I-MISC', 'E-MISC', 'O', 'O', 'O', 'O', 'S-LOC', 'O', 'O', 'O', 198 | 'B-MISC', 'I-MISC', 'E-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'S-ORG', 'O', 'S-ORG', 'O', 199 | 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 200 | 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'E-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 201 | 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 202 | 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 203 | 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 204 | 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 205 | 'O', 'O', 'O', 'O', 'O'] 206 | 207 | model_path = "/data/xiaoya/models/bert_cased_large" 208 | find_illegal_entity(query, context_tokens, labels, model_path, is_chinese=False, do_lower_case=False) 209 | 210 | -------------------------------------------------------------------------------- /train/bert_tagger_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bert_tagger_trainer.py 5 | 6 | import os 7 | import re 8 | import argparse 9 | import logging 10 | from typing import Dict 11 | from collections import namedtuple 12 | from utils.random_seed import set_random_seed 13 | set_random_seed(0) 14 | 15 | import torch 16 | import pytorch_lightning as pl 17 | from torch import Tensor 18 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 19 | from torch.nn.modules import CrossEntropyLoss 20 | from pytorch_lightning import Trainer 21 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 22 | from transformers import AutoTokenizer 23 | from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup 24 | 25 | from utils.get_parser import get_parser 26 | from datasets.tagger_ner_dataset import get_labels, TaggerNERDataset 27 | from datasets.truncate_dataset import TruncateDataset 28 | from datasets.collate_functions import tagger_collate_to_max_length 29 | from metrics.tagger_span_f1 import TaggerSpanF1 30 | from metrics.functional.tagger_span_f1 import transform_predictions_to_labels 31 | from models.bert_tagger import BertTagger 32 | from models.model_config import BertTaggerConfig 33 | 34 | 35 | class BertSequenceLabeling(pl.LightningModule): 36 | def __init__( 37 | self, 38 | args: argparse.Namespace 39 | ): 40 | """Initialize a model, tokenizer and config.""" 41 | super().__init__() 42 | format = '%(asctime)s - %(name)s - %(message)s' 43 | if isinstance(args, argparse.Namespace): 44 | self.save_hyperparameters(args) 45 | self.args = args 46 | logging.basicConfig(format=format, filename=os.path.join(self.args.output_dir, "eval_result_log.txt"), level=logging.INFO) 47 | else: 48 | # eval mode 49 | TmpArgs = namedtuple("tmp_args", field_names=list(args.keys())) 50 | self.args = args = TmpArgs(**args) 51 | logging.basicConfig(format=format, filename=os.path.join(self.args.output_dir, "eval_test.txt"), level=logging.INFO) 52 | 53 | self.bert_dir = args.bert_config_dir 54 | self.data_dir = self.args.data_dir 55 | self.task_labels = get_labels(self.args.data_sign) 56 | self.num_labels = len(self.task_labels) 57 | self.task_idx2label = {label_idx : label_item for label_idx, label_item in enumerate(get_labels(self.args.data_sign))} 58 | bert_config = BertTaggerConfig.from_pretrained(args.bert_config_dir, 59 | hidden_dropout_prob=args.bert_dropout, 60 | attention_probs_dropout_prob=args.bert_dropout, 61 | num_labels=self.num_labels, 62 | classifier_dropout=args.classifier_dropout, 63 | classifier_sign=args.classifier_sign, 64 | classifier_act_func=args.classifier_act_func, 65 | classifier_intermediate_hidden_size=args.classifier_intermediate_hidden_size) 66 | 67 | self.tokenizer = AutoTokenizer.from_pretrained(args.bert_config_dir, use_fast=False, do_lower_case=args.do_lowercase) 68 | self.model = BertTagger.from_pretrained(args.bert_config_dir, config=bert_config) 69 | logging.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args)) 70 | self.result_logger = logging.getLogger(__name__) 71 | self.result_logger.setLevel(logging.INFO) 72 | self.loss_func = CrossEntropyLoss() 73 | self.span_f1 = TaggerSpanF1() 74 | self.chinese = args.chinese 75 | self.optimizer = args.optimizer 76 | 77 | @staticmethod 78 | def add_model_specific_args(parent_parser): 79 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 80 | parser.add_argument("--train_batch_size", type=int, default=8, help="batch size") 81 | parser.add_argument("--eval_batch_size", type=int, default=8, help="batch size") 82 | parser.add_argument("--bert_dropout", type=float, default=0.1, help="bert dropout rate") 83 | parser.add_argument("--classifier_sign", type=str, default="multi_nonlinear") 84 | parser.add_argument("--classifier_dropout", type=float, default=0.1) 85 | parser.add_argument("--classifier_act_func", type=str, default="gelu") 86 | parser.add_argument("--classifier_intermediate_hidden_size", type=int, default=1024) 87 | parser.add_argument("--chinese", action="store_true", help="is chinese dataset") 88 | parser.add_argument("--optimizer", choices=["adamw", "torch.adam"], default="adamw", help="optimizer type") 89 | parser.add_argument("--final_div_factor", type=float, default=1e4, help="final div factor of linear decay scheduler") 90 | parser.add_argument("--output_dir", type=str, default="", help="the path for saving intermediate model checkpoints.") 91 | parser.add_argument("--lr_scheduler", type=str, default="linear_decay", help="lr scheduler") 92 | parser.add_argument("--data_sign", type=str, default="en_conll03", help="data signature for the dataset.") 93 | parser.add_argument("--polydecay_ratio", type=float, default=4, help="ratio for polydecay learing rate scheduler.") 94 | parser.add_argument("--do_lowercase", action="store_true", ) 95 | parser.add_argument("--data_file_suffix", type=str, default=".char.bmes") 96 | parser.add_argument("--lr_scheulder", type=str, default="polydecay") 97 | parser.add_argument("--lr_mini", type=float, default=-1) 98 | parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for.") 99 | 100 | return parser 101 | 102 | def configure_optimizers(self): 103 | """Prepare optimizer and schedule (linear warmup and decay)""" 104 | no_decay = ["bias", "LayerNorm.weight"] 105 | optimizer_grouped_parameters = [ 106 | { 107 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 108 | "weight_decay": self.args.weight_decay, 109 | }, 110 | { 111 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 112 | "weight_decay": 0.0, 113 | }, 114 | ] 115 | if self.optimizer == "adamw": 116 | optimizer = AdamW(optimizer_grouped_parameters, 117 | betas=(0.9, 0.98), # according to RoBERTa paper 118 | lr=self.args.lr, 119 | eps=self.args.adam_epsilon,) 120 | elif self.optimizer == "torch.adam": 121 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, 122 | lr=self.args.lr, 123 | eps=self.args.adam_epsilon, 124 | weight_decay=self.args.weight_decay) 125 | else: 126 | raise ValueError("Optimizer type does not exist.") 127 | num_gpus = len([x for x in str(self.args.gpus).split(",") if x.strip()]) 128 | t_total = (len(self.train_dataloader()) // (self.args.accumulate_grad_batches * num_gpus) + 1) * self.args.max_epochs 129 | warmup_steps = int(self.args.warmup_proportion * t_total) 130 | if self.args.lr_scheduler == "onecycle": 131 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 132 | optimizer, max_lr=self.args.lr, pct_start=float(warmup_steps/t_total), 133 | final_div_factor=self.args.final_div_factor, 134 | total_steps=t_total, anneal_strategy='linear') 135 | elif self.args.lr_scheduler == "linear": 136 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) 137 | elif self.args.lr_scheulder == "polydecay": 138 | if self.args.lr_mini == -1: 139 | lr_mini = self.args.lr / self.args.polydecay_ratio 140 | else: 141 | lr_mini = self.args.lr_mini 142 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, warmup_steps, t_total, lr_end=lr_mini) 143 | else: 144 | raise ValueError 145 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 146 | 147 | def forward(self, input_ids, token_type_ids, attention_mask): 148 | return self.model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 149 | 150 | def compute_loss(self, sequence_logits, sequence_labels, input_mask=None): 151 | if input_mask is not None: 152 | active_loss = input_mask.view(-1) == 1 153 | active_logits = sequence_logits.view(-1, self.num_labels) 154 | active_labels = torch.where( 155 | active_loss, sequence_labels.view(-1), torch.tensor(self.loss_func.ignore_index).type_as(sequence_labels) 156 | ) 157 | loss = self.loss_func(active_logits, active_labels) 158 | else: 159 | loss = self.loss_func(sequence_logits.view(-1, self.num_labels), sequence_labels.view(-1)) 160 | return loss 161 | 162 | def training_step(self, batch, batch_idx): 163 | tf_board_logs = {"lr": self.trainer.optimizers[0].param_groups[0]['lr']} 164 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch 165 | 166 | logits = self.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 167 | loss = self.compute_loss(logits, sequence_labels, input_mask=attention_mask) 168 | tf_board_logs[f"train_loss"] = loss 169 | 170 | return {'loss': loss, 'log': tf_board_logs} 171 | 172 | def validation_step(self, batch, batch_idx): 173 | output = {} 174 | 175 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch 176 | batch_size = token_input_ids.shape[0] 177 | logits = self.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 178 | loss = self.compute_loss(logits, sequence_labels, input_mask=attention_mask) 179 | output[f"val_loss"] = loss 180 | 181 | sequence_pred_lst = transform_predictions_to_labels(logits.view(batch_size, -1, len(self.task_labels)), is_wordpiece_mask, self.task_idx2label, input_type="logit") 182 | sequence_gold_lst = transform_predictions_to_labels(sequence_labels, is_wordpiece_mask, self.task_idx2label, input_type="label") 183 | span_f1_stats = self.span_f1(sequence_pred_lst, sequence_gold_lst) 184 | output["span_f1_stats"] = span_f1_stats 185 | 186 | return output 187 | 188 | def validation_epoch_end(self, outputs): 189 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 190 | tensorboard_logs = {'val_loss': avg_loss} 191 | 192 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0) 193 | span_tp, span_fp, span_fn = all_counts 194 | span_recall = span_tp / (span_tp + span_fn + 1e-10) 195 | span_precision = span_tp / (span_tp + span_fp + 1e-10) 196 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10) 197 | tensorboard_logs[f"span_precision"] = span_precision 198 | tensorboard_logs[f"span_recall"] = span_recall 199 | tensorboard_logs[f"span_f1"] = span_f1 200 | self.result_logger.info(f"EVAL INFO -> current_epoch is: {self.trainer.current_epoch}, current_global_step is: {self.trainer.global_step} ") 201 | self.result_logger.info(f"EVAL INFO -> valid_f1 is: {span_f1}") 202 | 203 | return {'val_loss': avg_loss, 'log': tensorboard_logs} 204 | 205 | def test_step(self, batch, batch_idx): 206 | output = {} 207 | 208 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch 209 | batch_size = token_input_ids.shape[0] 210 | logits = self.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 211 | loss = self.compute_loss(logits, sequence_labels, input_mask=attention_mask) 212 | output[f"test_loss"] = loss 213 | 214 | sequence_pred_lst = transform_predictions_to_labels(logits.view(batch_size, -1, len(self.task_labels)), 215 | is_wordpiece_mask, self.task_idx2label, input_type="logit") 216 | sequence_gold_lst = transform_predictions_to_labels(sequence_labels, is_wordpiece_mask, self.task_idx2label, 217 | input_type="label") 218 | span_f1_stats = self.span_f1(sequence_pred_lst, sequence_gold_lst) 219 | output["span_f1_stats"] = span_f1_stats 220 | return output 221 | 222 | def test_epoch_end(self, outputs) -> Dict[str, Dict[str, Tensor]]: 223 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 224 | tensorboard_logs = {'test_loss': avg_loss} 225 | 226 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0) 227 | span_tp, span_fp, span_fn = all_counts 228 | span_recall = span_tp / (span_tp + span_fn + 1e-10) 229 | span_precision = span_tp / (span_tp + span_fp + 1e-10) 230 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10) 231 | tensorboard_logs[f"span_precision"] = span_precision 232 | tensorboard_logs[f"span_recall"] = span_recall 233 | tensorboard_logs[f"span_f1"] = span_f1 234 | print(f"TEST INFO -> test_f1 is: {span_f1} precision: {span_precision}, recall: {span_recall}") 235 | self.result_logger.info(f"EVAL INFO -> test_f1 is: {span_f1}, test_precision is: {span_precision}, test_recall is: {span_recall}") 236 | 237 | return {'test_loss': avg_loss, 'log': tensorboard_logs} 238 | 239 | def train_dataloader(self) -> DataLoader: 240 | return self.get_dataloader("train") 241 | 242 | def val_dataloader(self) -> DataLoader: 243 | return self.get_dataloader("dev") 244 | 245 | def test_dataloader(self) -> DataLoader: 246 | return self.get_dataloader("test") 247 | 248 | def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader: 249 | """get train/dev/test dataloader""" 250 | data_path = os.path.join(self.data_dir, f"{prefix}{self.args.data_file_suffix}") 251 | dataset = TaggerNERDataset(data_path, self.tokenizer, self.args.data_sign, 252 | max_length=self.args.max_length, is_chinese=self.args.chinese, 253 | pad_to_maxlen=False) 254 | 255 | if limit is not None: 256 | dataset = TruncateDataset(dataset, limit) 257 | 258 | if prefix == "train": 259 | batch_size = self.args.train_batch_size 260 | # define data_generator will help experiment reproducibility. 261 | # cannot use random data sampler since the gradient may explode. 262 | data_generator = torch.Generator() 263 | data_generator.manual_seed(self.args.seed) 264 | data_sampler = RandomSampler(dataset, generator=data_generator) 265 | else: 266 | data_sampler = SequentialSampler(dataset) 267 | batch_size = self.args.eval_batch_size 268 | 269 | dataloader = DataLoader( 270 | dataset=dataset, sampler=data_sampler, 271 | batch_size=batch_size, num_workers=self.args.workers, 272 | collate_fn=tagger_collate_to_max_length 273 | ) 274 | 275 | return dataloader 276 | 277 | 278 | def find_best_checkpoint_on_dev(output_dir: str, log_file: str = "eval_result_log.txt", only_keep_the_best_ckpt: bool = False): 279 | with open(os.path.join(output_dir, log_file)) as f: 280 | log_lines = f.readlines() 281 | 282 | F1_PATTERN = re.compile(r"span_f1 reached \d+\.\d* \(best") 283 | # val_f1 reached 0.00000 (best 0.00000) 284 | CKPT_PATTERN = re.compile(r"saving model to \S+ as top") 285 | checkpoint_info_lines = [] 286 | for log_line in log_lines: 287 | if "saving model to" in log_line: 288 | checkpoint_info_lines.append(log_line) 289 | # example of log line 290 | # Epoch 00000: val_f1 reached 0.00000 (best 0.00000), saving model to /data/xiaoya/outputs/0117/debug_5_12_2e-5_0.001_0.001_275_0.1_1_0.25/checkpoint/epoch=0.ckpt as top 20 291 | best_f1_on_dev = 0 292 | best_checkpoint_on_dev = "" 293 | for checkpoint_info_line in checkpoint_info_lines: 294 | current_f1 = float( 295 | re.findall(F1_PATTERN, checkpoint_info_line)[0].replace("span_f1 reached ", "").replace(" (best", "")) 296 | current_ckpt = re.findall(CKPT_PATTERN, checkpoint_info_line)[0].replace("saving model to ", "").replace( 297 | " as top", "") 298 | 299 | if current_f1 >= best_f1_on_dev: 300 | if only_keep_the_best_ckpt and len(best_checkpoint_on_dev) != 0: 301 | os.remove(best_checkpoint_on_dev) 302 | best_f1_on_dev = current_f1 303 | best_checkpoint_on_dev = current_ckpt 304 | 305 | return best_f1_on_dev, best_checkpoint_on_dev 306 | 307 | 308 | def main(): 309 | """main""" 310 | parser = get_parser() 311 | 312 | # add model specific args 313 | parser = BertSequenceLabeling.add_model_specific_args(parser) 314 | # add all the available trainer options to argparse 315 | # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli 316 | parser = Trainer.add_argparse_args(parser) 317 | args = parser.parse_args() 318 | model = BertSequenceLabeling(args) 319 | if args.pretrained_checkpoint: 320 | model.load_state_dict(torch.load(args.pretrained_checkpoint, 321 | map_location=torch.device('cpu'))["state_dict"]) 322 | checkpoint_callback = ModelCheckpoint( 323 | filepath=args.output_dir, 324 | save_top_k=args.max_keep_ckpt, 325 | verbose=True, 326 | monitor="span_f1", 327 | period=-1, 328 | mode="max", 329 | ) 330 | trainer = Trainer.from_argparse_args( 331 | args, 332 | checkpoint_callback=checkpoint_callback, 333 | deterministic=True, 334 | default_root_dir=args.output_dir 335 | ) 336 | trainer.fit(model) 337 | 338 | # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set. 339 | best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(args.output_dir,) 340 | model.result_logger.info("=&" * 20) 341 | model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}") 342 | model.result_logger.info(f"Best checkpoint on DEV set is {path_to_best_checkpoint}") 343 | checkpoint = torch.load(path_to_best_checkpoint) 344 | model.load_state_dict(checkpoint['state_dict']) 345 | trainer.test(model) 346 | model.result_logger.info("=&" * 20) 347 | 348 | 349 | if __name__ == '__main__': 350 | main() 351 | -------------------------------------------------------------------------------- /train/mrc_ner_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: mrc_ner_trainer.py 5 | 6 | import os 7 | import re 8 | import argparse 9 | import logging 10 | from collections import namedtuple 11 | from typing import Dict 12 | 13 | import torch 14 | import pytorch_lightning as pl 15 | from pytorch_lightning import Trainer 16 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 17 | from tokenizers import BertWordPieceTokenizer 18 | from torch import Tensor 19 | from torch.nn.modules import CrossEntropyLoss, BCEWithLogitsLoss 20 | from torch.utils.data import DataLoader 21 | from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup 22 | from torch.optim import SGD 23 | 24 | from datasets.mrc_ner_dataset import MRCNERDataset 25 | from datasets.truncate_dataset import TruncateDataset 26 | from datasets.collate_functions import collate_to_max_length 27 | from metrics.query_span_f1 import QuerySpanF1 28 | from models.bert_query_ner import BertQueryNER 29 | from models.model_config import BertQueryNerConfig 30 | from utils.get_parser import get_parser 31 | from utils.random_seed import set_random_seed 32 | 33 | set_random_seed(0) 34 | 35 | 36 | class BertLabeling(pl.LightningModule): 37 | def __init__( 38 | self, 39 | args: argparse.Namespace 40 | ): 41 | """Initialize a model, tokenizer and config.""" 42 | super().__init__() 43 | format = '%(asctime)s - %(name)s - %(message)s' 44 | if isinstance(args, argparse.Namespace): 45 | self.save_hyperparameters(args) 46 | self.args = args 47 | logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_result_log.txt"), level=logging.INFO) 48 | else: 49 | # eval mode 50 | TmpArgs = namedtuple("tmp_args", field_names=list(args.keys())) 51 | self.args = args = TmpArgs(**args) 52 | logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_test.txt"), level=logging.INFO) 53 | 54 | self.bert_dir = args.bert_config_dir 55 | self.data_dir = self.args.data_dir 56 | 57 | bert_config = BertQueryNerConfig.from_pretrained(args.bert_config_dir, 58 | hidden_dropout_prob=args.bert_dropout, 59 | attention_probs_dropout_prob=args.bert_dropout, 60 | mrc_dropout=args.mrc_dropout, 61 | classifier_act_func = args.classifier_act_func, 62 | classifier_intermediate_hidden_size=args.classifier_intermediate_hidden_size) 63 | 64 | self.model = BertQueryNER.from_pretrained(args.bert_config_dir, 65 | config=bert_config) 66 | logging.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args)) 67 | self.result_logger = logging.getLogger(__name__) 68 | self.result_logger.setLevel(logging.INFO) 69 | self.result_logger.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args)) 70 | self.bce_loss = BCEWithLogitsLoss(reduction="none") 71 | 72 | weight_sum = args.weight_start + args.weight_end + args.weight_span 73 | self.weight_start = args.weight_start / weight_sum 74 | self.weight_end = args.weight_end / weight_sum 75 | self.weight_span = args.weight_span / weight_sum 76 | self.flat_ner = args.flat 77 | self.span_f1 = QuerySpanF1(flat=self.flat_ner) 78 | self.chinese = args.chinese 79 | self.optimizer = args.optimizer 80 | self.span_loss_candidates = args.span_loss_candidates 81 | 82 | @staticmethod 83 | def add_model_specific_args(parent_parser): 84 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 85 | parser.add_argument("--mrc_dropout", type=float, default=0.1, 86 | help="mrc dropout rate") 87 | parser.add_argument("--bert_dropout", type=float, default=0.1, 88 | help="bert dropout rate") 89 | parser.add_argument("--classifier_act_func", type=str, default="gelu") 90 | parser.add_argument("--classifier_intermediate_hidden_size", type=int, default=1024) 91 | parser.add_argument("--weight_start", type=float, default=1.0) 92 | parser.add_argument("--weight_end", type=float, default=1.0) 93 | parser.add_argument("--weight_span", type=float, default=1.0) 94 | parser.add_argument("--flat", action="store_true", help="is flat ner") 95 | parser.add_argument("--span_loss_candidates", choices=["all", "pred_and_gold", "pred_gold_random", "gold"], 96 | default="all", help="Candidates used to compute span loss") 97 | parser.add_argument("--chinese", action="store_true", 98 | help="is chinese dataset") 99 | parser.add_argument("--optimizer", choices=["adamw", "sgd", "torch.adam"], default="adamw", 100 | help="loss type") 101 | parser.add_argument("--final_div_factor", type=float, default=1e4, 102 | help="final div factor of linear decay scheduler") 103 | parser.add_argument("--lr_scheduler", type=str, default="onecycle", ) 104 | parser.add_argument("--lr_mini", type=float, default=-1) 105 | return parser 106 | 107 | def configure_optimizers(self): 108 | """Prepare optimizer and schedule (linear warmup and decay)""" 109 | no_decay = ["bias", "LayerNorm.weight"] 110 | optimizer_grouped_parameters = [ 111 | { 112 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 113 | "weight_decay": self.args.weight_decay, 114 | }, 115 | { 116 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 117 | "weight_decay": 0.0, 118 | }, 119 | ] 120 | if self.optimizer == "adamw": 121 | optimizer = AdamW(optimizer_grouped_parameters, 122 | betas=(0.9, 0.98), # according to RoBERTa paper 123 | lr=self.args.lr, 124 | eps=self.args.adam_epsilon,) 125 | elif self.optimizer == "torch.adam": 126 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, 127 | lr=self.args.lr, 128 | eps=self.args.adam_epsilon, 129 | weight_decay=self.args.weight_decay) 130 | else: 131 | optimizer = SGD(optimizer_grouped_parameters, lr=self.args.lr, momentum=0.9) 132 | num_gpus = len([x for x in str(self.args.gpus).split(",") if x.strip()]) 133 | t_total = (len(self.train_dataloader()) // (self.args.accumulate_grad_batches * num_gpus) + 1) * self.args.max_epochs 134 | if self.args.lr_scheduler == "onecycle": 135 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 136 | optimizer, max_lr=self.args.lr, pct_start=float(self.args.warmup_steps/t_total), 137 | final_div_factor=self.args.final_div_factor, 138 | total_steps=t_total, anneal_strategy='linear' 139 | ) 140 | elif self.args.lr_scheduler == "linear": 141 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total) 142 | elif self.args.lr_scheduler == "polydecay": 143 | if self.args.lr_mini == -1: 144 | lr_mini = self.args.lr / 5 145 | else: 146 | lr_mini = self.args.lr_mini 147 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, self.args.warmup_steps, t_total, lr_end=lr_mini) 148 | else: 149 | raise ValueError 150 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 151 | 152 | def forward(self, input_ids, attention_mask, token_type_ids): 153 | return self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 154 | 155 | def compute_loss(self, start_logits, end_logits, span_logits, 156 | start_labels, end_labels, match_labels, start_label_mask, end_label_mask): 157 | batch_size, seq_len = start_logits.size() 158 | 159 | start_float_label_mask = start_label_mask.view(-1).float() 160 | end_float_label_mask = end_label_mask.view(-1).float() 161 | match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len) 162 | match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1) 163 | match_label_mask = match_label_row_mask & match_label_col_mask 164 | match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end 165 | 166 | if self.span_loss_candidates == "all": 167 | # naive mask 168 | float_match_label_mask = match_label_mask.view(batch_size, -1).float() 169 | else: 170 | # use only pred or golden start/end to compute match loss 171 | start_preds = start_logits > 0 172 | end_preds = end_logits > 0 173 | if self.span_loss_candidates == "gold": 174 | match_candidates = ((start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0) 175 | & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0)) 176 | elif self.span_loss_candidates == "pred_gold_random": 177 | gold_and_pred = torch.logical_or( 178 | (start_preds.unsqueeze(-1).expand(-1, -1, seq_len) 179 | & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)), 180 | (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) 181 | & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)) 182 | ) 183 | data_generator = torch.Generator() 184 | data_generator.manual_seed(0) 185 | random_matrix = torch.empty(batch_size, seq_len, seq_len).uniform_(0, 1) 186 | random_matrix = torch.bernoulli(random_matrix, generator=data_generator).long() 187 | random_matrix = random_matrix.cuda() 188 | match_candidates = torch.logical_or( 189 | gold_and_pred, random_matrix 190 | ) 191 | else: 192 | match_candidates = torch.logical_or( 193 | (start_preds.unsqueeze(-1).expand(-1, -1, seq_len) 194 | & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)), 195 | (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) 196 | & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)) 197 | ) 198 | match_label_mask = match_label_mask & match_candidates 199 | float_match_label_mask = match_label_mask.view(batch_size, -1).float() 200 | 201 | start_loss = self.bce_loss(start_logits.view(-1), start_labels.view(-1).float()) 202 | start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum() 203 | end_loss = self.bce_loss(end_logits.view(-1), end_labels.view(-1).float()) 204 | end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum() 205 | match_loss = self.bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float()) 206 | match_loss = match_loss * float_match_label_mask 207 | match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10) 208 | 209 | return start_loss, end_loss, match_loss 210 | 211 | def training_step(self, batch, batch_idx): 212 | tf_board_logs = { 213 | "lr": self.trainer.optimizers[0].param_groups[0]['lr'] 214 | } 215 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch 216 | 217 | # num_tasks * [bsz, length, num_labels] 218 | attention_mask = (tokens != 0).long() 219 | start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids) 220 | 221 | start_loss, end_loss, match_loss = self.compute_loss(start_logits=start_logits, 222 | end_logits=end_logits, 223 | span_logits=span_logits, 224 | start_labels=start_labels, 225 | end_labels=end_labels, 226 | match_labels=match_labels, 227 | start_label_mask=start_label_mask, 228 | end_label_mask=end_label_mask 229 | ) 230 | 231 | total_loss = self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss 232 | 233 | tf_board_logs[f"train_loss"] = total_loss 234 | tf_board_logs[f"start_loss"] = start_loss 235 | tf_board_logs[f"end_loss"] = end_loss 236 | tf_board_logs[f"match_loss"] = match_loss 237 | 238 | return {'loss': total_loss, 'log': tf_board_logs} 239 | 240 | def validation_step(self, batch, batch_idx): 241 | output = {} 242 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch 243 | 244 | attention_mask = (tokens != 0).long() 245 | start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids) 246 | 247 | start_loss, end_loss, match_loss = self.compute_loss(start_logits=start_logits, 248 | end_logits=end_logits, 249 | span_logits=span_logits, 250 | start_labels=start_labels, 251 | end_labels=end_labels, 252 | match_labels=match_labels, 253 | start_label_mask=start_label_mask, 254 | end_label_mask=end_label_mask 255 | ) 256 | 257 | total_loss = self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss 258 | 259 | output[f"val_loss"] = total_loss 260 | output[f"start_loss"] = start_loss 261 | output[f"end_loss"] = end_loss 262 | output[f"match_loss"] = match_loss 263 | 264 | start_preds, end_preds = start_logits > 0, end_logits > 0 265 | span_f1_stats = self.span_f1(start_preds=start_preds, end_preds=end_preds, match_logits=span_logits, 266 | start_label_mask=start_label_mask, end_label_mask=end_label_mask, 267 | match_labels=match_labels) 268 | output["span_f1_stats"] = span_f1_stats 269 | 270 | return output 271 | 272 | def validation_epoch_end(self, outputs): 273 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 274 | tensorboard_logs = {'val_loss': avg_loss} 275 | 276 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0) 277 | span_tp, span_fp, span_fn = all_counts 278 | span_recall = span_tp / (span_tp + span_fn + 1e-10) 279 | span_precision = span_tp / (span_tp + span_fp + 1e-10) 280 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10) 281 | tensorboard_logs[f"span_precision"] = span_precision 282 | tensorboard_logs[f"span_recall"] = span_recall 283 | tensorboard_logs[f"span_f1"] = span_f1 284 | self.result_logger.info(f"EVAL INFO -> current_epoch is: {self.trainer.current_epoch}, current_global_step is: {self.trainer.global_step} ") 285 | self.result_logger.info(f"EVAL INFO -> valid_f1 is: {span_f1}; precision: {span_precision}, recall: {span_recall}.") 286 | 287 | return {'val_loss': avg_loss, 'log': tensorboard_logs} 288 | 289 | def test_step(self, batch, batch_idx): 290 | """""" 291 | output = {} 292 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch 293 | 294 | attention_mask = (tokens != 0).long() 295 | start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids) 296 | 297 | start_preds, end_preds = start_logits > 0, end_logits > 0 298 | span_f1_stats = self.span_f1(start_preds=start_preds, end_preds=end_preds, match_logits=span_logits, 299 | start_label_mask=start_label_mask, end_label_mask=end_label_mask, 300 | match_labels=match_labels) 301 | output["span_f1_stats"] = span_f1_stats 302 | 303 | return output 304 | 305 | def test_epoch_end(self, outputs) -> Dict[str, Dict[str, Tensor]]: 306 | tensorboard_logs = {} 307 | 308 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0) 309 | span_tp, span_fp, span_fn = all_counts 310 | span_recall = span_tp / (span_tp + span_fn + 1e-10) 311 | span_precision = span_tp / (span_tp + span_fp + 1e-10) 312 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10) 313 | print(f"TEST INFO -> test_f1 is: {span_f1} precision: {span_precision}, recall: {span_recall}") 314 | self.result_logger.info(f"TEST INFO -> test_f1 is: {span_f1} precision: {span_precision}, recall: {span_recall}") 315 | return {'log': tensorboard_logs} 316 | 317 | def train_dataloader(self) -> DataLoader: 318 | return self.get_dataloader("train") 319 | 320 | def val_dataloader(self) -> DataLoader: 321 | return self.get_dataloader("dev") 322 | 323 | def test_dataloader(self) -> DataLoader: 324 | return self.get_dataloader("test") 325 | 326 | def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader: 327 | """get training dataloader""" 328 | """ 329 | load_mmap_dataset 330 | """ 331 | json_path = os.path.join(self.data_dir, f"mrc-ner.{prefix}") 332 | vocab_path = os.path.join(self.bert_dir, "vocab.txt") 333 | dataset = MRCNERDataset(json_path=json_path, 334 | tokenizer=BertWordPieceTokenizer(vocab_path), 335 | max_length=self.args.max_length, 336 | is_chinese=self.chinese, 337 | pad_to_maxlen=False 338 | ) 339 | 340 | if limit is not None: 341 | dataset = TruncateDataset(dataset, limit) 342 | 343 | dataloader = DataLoader( 344 | dataset=dataset, 345 | batch_size=self.args.batch_size, 346 | num_workers=self.args.workers, 347 | shuffle=True if prefix == "train" else False, 348 | collate_fn=collate_to_max_length 349 | ) 350 | 351 | return dataloader 352 | 353 | def find_best_checkpoint_on_dev(output_dir: str, log_file: str = "eval_result_log.txt", only_keep_the_best_ckpt: bool = False): 354 | with open(os.path.join(output_dir, log_file)) as f: 355 | log_lines = f.readlines() 356 | 357 | F1_PATTERN = re.compile(r"span_f1 reached \d+\.\d* \(best") 358 | # val_f1 reached 0.00000 (best 0.00000) 359 | CKPT_PATTERN = re.compile(r"saving model to \S+ as top") 360 | checkpoint_info_lines = [] 361 | for log_line in log_lines: 362 | if "saving model to" in log_line: 363 | checkpoint_info_lines.append(log_line) 364 | # example of log line 365 | # Epoch 00000: val_f1 reached 0.00000 (best 0.00000), saving model to /data/xiaoya/outputs/0117/debug_5_12_2e-5_0.001_0.001_275_0.1_1_0.25/checkpoint/epoch=0.ckpt as top 20 366 | best_f1_on_dev = 0 367 | best_checkpoint_on_dev = "" 368 | for checkpoint_info_line in checkpoint_info_lines: 369 | current_f1 = float( 370 | re.findall(F1_PATTERN, checkpoint_info_line)[0].replace("span_f1 reached ", "").replace(" (best", "")) 371 | current_ckpt = re.findall(CKPT_PATTERN, checkpoint_info_line)[0].replace("saving model to ", "").replace( 372 | " as top", "") 373 | 374 | if current_f1 >= best_f1_on_dev: 375 | if only_keep_the_best_ckpt and len(best_checkpoint_on_dev) != 0: 376 | os.remove(best_checkpoint_on_dev) 377 | best_f1_on_dev = current_f1 378 | best_checkpoint_on_dev = current_ckpt 379 | 380 | return best_f1_on_dev, best_checkpoint_on_dev 381 | 382 | 383 | def main(): 384 | """main""" 385 | parser = get_parser() 386 | 387 | # add model specific args 388 | parser = BertLabeling.add_model_specific_args(parser) 389 | 390 | # add all the available trainer options to argparse 391 | # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli 392 | parser = Trainer.add_argparse_args(parser) 393 | 394 | args = parser.parse_args() 395 | 396 | model = BertLabeling(args) 397 | if args.pretrained_checkpoint: 398 | model.load_state_dict(torch.load(args.pretrained_checkpoint, 399 | map_location=torch.device('cpu'))["state_dict"]) 400 | 401 | checkpoint_callback = ModelCheckpoint( 402 | filepath=args.default_root_dir, 403 | save_top_k=args.max_keep_ckpt, 404 | verbose=True, 405 | monitor="span_f1", 406 | period=-1, 407 | mode="max", 408 | ) 409 | trainer = Trainer.from_argparse_args( 410 | args, 411 | checkpoint_callback=checkpoint_callback, 412 | deterministic=True, 413 | default_root_dir=args.default_root_dir 414 | ) 415 | 416 | trainer.fit(model) 417 | 418 | # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set. 419 | best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(args.default_root_dir, ) 420 | model.result_logger.info("=&" * 20) 421 | model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}") 422 | model.result_logger.info(f"Best checkpoint on DEV set is {path_to_best_checkpoint}") 423 | checkpoint = torch.load(path_to_best_checkpoint) 424 | model.load_state_dict(checkpoint['state_dict']) 425 | model.result_logger.info("=&" * 20) 426 | 427 | 428 | if __name__ == '__main__': 429 | main() 430 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/utils/__init__.py -------------------------------------------------------------------------------- /utils/bmes_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: bmes_decode.py 5 | 6 | from typing import Tuple, List 7 | 8 | 9 | class Tag(object): 10 | def __init__(self, term, tag, begin, end): 11 | self.term = term 12 | self.tag = tag 13 | self.begin = begin 14 | self.end = end 15 | 16 | def to_tuple(self): 17 | return tuple([self.term, self.begin, self.end]) 18 | 19 | def __str__(self): 20 | return str({key: value for key, value in self.__dict__.items()}) 21 | 22 | def __repr__(self): 23 | return str({key: value for key, value in self.__dict__.items()}) 24 | 25 | 26 | def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]: 27 | """ 28 | decode inputs to tags 29 | Args: 30 | char_label_list: list of tuple (word, bmes-tag) 31 | Returns: 32 | tags 33 | Examples: 34 | >>> x = [("Hi", "O"), ("Beijing", "S-LOC")] 35 | >>> bmes_decode(x) 36 | [{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}] 37 | """ 38 | idx = 0 39 | length = len(char_label_list) 40 | tags = [] 41 | while idx < length: 42 | term, label = char_label_list[idx] 43 | current_label = label[0] 44 | 45 | # correct labels 46 | if idx + 1 == length and current_label == "B": 47 | current_label = "S" 48 | 49 | # merge chars 50 | if current_label == "O": 51 | idx += 1 52 | continue 53 | if current_label == "S": 54 | tags.append(Tag(term, label[2:], idx, idx + 1)) 55 | idx += 1 56 | continue 57 | if current_label == "B": 58 | end = idx + 1 59 | while end + 1 < length and char_label_list[end][1][0] == "M": 60 | end += 1 61 | if char_label_list[end][1][0] == "E": # end with E 62 | entity = "".join(char_label_list[i][0] for i in range(idx, end + 1)) 63 | tags.append(Tag(entity, label[2:], idx, end + 1)) 64 | idx = end + 1 65 | else: # end with M/B 66 | entity = "".join(char_label_list[i][0] for i in range(idx, end)) 67 | tags.append(Tag(entity, label[2:], idx, end)) 68 | idx = end 69 | continue 70 | else: 71 | raise Exception("Invalid Inputs") 72 | return tags 73 | -------------------------------------------------------------------------------- /utils/convert_tf2torch.sh: -------------------------------------------------------------------------------- 1 | # convert tf model to pytorch format 2 | 3 | export BERT_BASE_DIR=/mnt/mrc/wwm_uncased_L-24_H-1024_A-16 4 | 5 | transformers-cli convert --model_type bert \ 6 | --tf_checkpoint $BERT_BASE_DIR/model.ckpt \ 7 | --config $BERT_BASE_DIR/config.json \ 8 | --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin 9 | -------------------------------------------------------------------------------- /utils/get_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # file: get_parser.py 5 | 6 | import argparse 7 | 8 | 9 | def get_parser() -> argparse.ArgumentParser: 10 | """ 11 | return basic arg parser 12 | """ 13 | parser = argparse.ArgumentParser(description="Training") 14 | 15 | parser.add_argument("--data_dir", type=str, required=True, help="data dir") 16 | parser.add_argument("--max_keep_ckpt", default=3, type=int, help="the number of keeping ckpt max.") 17 | parser.add_argument("--bert_config_dir", type=str, required=True, help="bert config dir") 18 | parser.add_argument("--pretrained_checkpoint", default="", type=str, help="pretrained checkpoint path") 19 | parser.add_argument("--max_length", type=int, default=128, help="max length of dataset") 20 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 21 | parser.add_argument("--lr", type=float, default=2e-5, help="learning rate") 22 | parser.add_argument("--workers", type=int, default=0, help="num workers for dataloader") 23 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.") 24 | parser.add_argument("--warmup_steps", default=0, type=int, help="warmup steps used for scheduler.") 25 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 26 | parser.add_argument("--seed", default=0, type=int, help="set random seed for reproducing results.") 27 | return parser 28 | -------------------------------------------------------------------------------- /utils/random_seed.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #!/usr/bin/env python3 3 | # last update: xiaoya li 4 | # issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/1868 5 | # set for trainer: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html 6 | # from pytorch_lightning import Trainer, seed_everything 7 | # seed_everything(42) 8 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. 9 | # model = Model() 10 | # trainer = Trainer(deterministic=True) 11 | 12 | import random 13 | import torch 14 | import numpy as np 15 | from pytorch_lightning import seed_everything 16 | 17 | def set_random_seed(seed: int): 18 | """set seeds for reproducibility""" 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | seed_everything(seed=seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | 28 | if __name__ == '__main__': 29 | # without this line, x would be different in every execution. 30 | set_random_seed(0) 31 | 32 | x = np.random.random() 33 | print(x) 34 | --------------------------------------------------------------------------------