├── .gitignore
├── README.md
├── datasets
├── __init__.py
├── collate_functions.py
├── mrc_ner_dataset.py
├── tagger_ner_dataset.py
└── truncate_dataset.py
├── evaluate
├── mrc_ner_evaluate.py
└── tagger_ner_evaluate.py
├── inference
├── mrc_ner_inference.py
└── tagger_ner_inference.py
├── metrics
├── __init__.py
├── functional
│ ├── __init__.py
│ ├── query_span_f1.py
│ └── tagger_span_f1.py
├── query_span_f1.py
└── tagger_span_f1.py
├── models
├── __init__.py
├── bert_query_ner.py
├── bert_tagger.py
├── classifier.py
└── model_config.py
├── ner2mrc
├── __init__.py
├── download.md
├── genia2mrc.py
├── msra2mrc.py
└── queries
│ ├── genia.json
│ └── zh_msra.json
├── requirements.txt
├── scripts
├── bert_tagger
│ ├── evaluate.sh
│ ├── inference.sh
│ └── reproduce
│ │ ├── conll03.sh
│ │ ├── msra.sh
│ │ └── onto4.sh
└── mrc_ner
│ ├── evaluate.sh
│ ├── flat_inference.sh
│ ├── nested_inference.sh
│ └── reproduce
│ ├── ace04.sh
│ ├── ace05.sh
│ ├── conll03.sh
│ ├── genia.sh
│ ├── kbp17.sh
│ ├── msra.sh
│ ├── onto4.sh
│ └── onto5.sh
├── tests
├── assert_correct_dataset.py
├── bert_tokenizer.py
├── collect_entity_labels.py
├── count_mrc_max_length.py
├── count_sequence_max_length.py
├── extract_entity_span.py
└── illegal_entity_boundary.py
├── train
├── bert_tagger_trainer.py
└── mrc_ner_trainer.py
└── utils
├── __init__.py
├── bmes_decode.py
├── convert_tf2torch.sh
├── get_parser.py
└── random_seed.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Xcode
2 | *.DS_Store
3 |
4 | # Logs
5 | logs/*
6 |
7 | #
8 | experiments/*
9 | log/*
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # backup files
20 | bk
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 |
119 | # Do Not push origin intermediate logging files
120 | *.log
121 | *.out
122 |
123 | # mac book
124 | .DS_Store
125 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Unified MRC Framework for Named Entity Recognition
2 | The repository contains the code of the recent research advances in [Shannon.AI](http://www.shannonai.com).
3 |
4 | **A Unified MRC Framework for Named Entity Recognition**
5 | Xiaoya Li, Jingrong Feng, Yuxian Meng, Qinghong Han, Fei Wu and Jiwei Li
6 | In ACL 2020. [paper](https://arxiv.org/abs/1910.11476)
7 | If you find this repo helpful, please cite the following:
8 | ```latex
9 | @article{li2019unified,
10 | title={A Unified MRC Framework for Named Entity Recognition},
11 | author={Li, Xiaoya and Feng, Jingrong and Meng, Yuxian and Han, Qinghong and Wu, Fei and Li, Jiwei},
12 | journal={arXiv preprint arXiv:1910.11476},
13 | year={2019}
14 | }
15 | ```
16 | For any question, please feel free to post Github issues.
17 |
18 | ## Install Requirements
19 |
20 | * The code requires Python 3.6+.
21 |
22 | * If you are working on a GPU machine with CUDA 10.1, please run `pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html` to install PyTorch. If not, please see the [PyTorch Official Website](https://pytorch.org/) for instructions.
23 |
24 | * Then run the following script to install the remaining dependenices: `pip install -r requirements.txt`
25 |
26 | We build our project on [pytorch-lightning.](https://github.com/PyTorchLightning/pytorch-lightning)
27 | If you want to know more about the arguments used in our training scripts, please
28 | refer to [pytorch-lightning documentation.](https://pytorch-lightning.readthedocs.io/en/latest/)
29 |
30 | ### Baseline: BERT-Tagger
31 |
32 | We release code, [scripts](./scripts/bert_tagger/reproduce) and [datafiles](./ner2mrc/download.md) for fine-tuning BERT and treating NER as a sequence labeling task.
33 |
34 | ### MRC-NER: Prepare Datasets
35 |
36 | You can [download](./ner2mrc/download.md) the preprocessed MRC-NER datasets used in our paper.
37 | For flat NER datasets, please use `ner2mrc/mrsa2mrc.py` to transform your BMES NER annotations to MRC-format.
38 | For nested NER datasets, please use `ner2mrc/genia2mrc.py` to transform your start-end NER annotations to MRC-format.
39 |
40 | ### MRC-NER: Training
41 |
42 | The main training procedure is in `train/mrc_ner_trainer.py`
43 |
44 | Scripts for reproducing our experimental results can be found in the `./scripts/mrc_ner/reproduce/` folder.
45 | Note that you need to change `DATA_DIR`, `BERT_DIR`, `OUTPUT_DIR` to your own dataset path, bert model path and log path, respectively.
46 | For example, run `./scripts/mrc_ner/reproduce/ace04.sh` will start training MRC-NER models and save intermediate log to `$OUTPUT_DIR/train_log.txt`.
47 | During training, the model trainer will automatically evaluate on the dev set every `val_check_interval` epochs,
48 | and save the topk checkpoints to `$OUTPUT_DIR`.
49 |
50 | ### MRC-NER: Evaluation
51 |
52 | After training, you can find the best checkpoint on the dev set according to the evaluation results in `$OUTPUT_DIR/train_log.txt`.
53 | Then run `python3 evaluate/mrc_ner_evaluate.py $OUTPUT_DIR/.ckpt $OUTPUT_DIR/lightning_logs/` to evaluate on the test set with the best checkpoint chosen on dev.
54 |
55 | ### MRC-NER: Inference
56 |
57 | Code for inference using the trained MRC-NER model can be found in `inference/mrc_ner_inference.py` file.
58 | For flat NER, we provide the inference script in [flat_inference.sh](./scripts/mrc_ner/flat_inference.sh)
59 | For nested NER, we provide the inference script in [nested_inference.sh](./scripts/mrc_ner/nested_inference.sh)
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/collate_functions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: collate_functions.py
5 |
6 | import torch
7 | from typing import List
8 |
9 |
10 | def tagger_collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]:
11 | """
12 | pad to maximum length of this batch
13 | Args:
14 | batch: a batch of samples, each contains a list of field data(Tensor):
15 | tokens, token_type_ids, attention_mask, wordpiece_label_idx_lst
16 | Returns:
17 | output: list of field batched data, which shape is [batch, max_length]
18 | """
19 | batch_size = len(batch)
20 | max_length = max(x[0].shape[0] for x in batch)
21 | output = []
22 |
23 | for field_idx in range(3):
24 | # 0 -> tokens
25 | # 1 -> token_type_ids
26 | # 2 -> attention_mask
27 | pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype)
28 | for sample_idx in range(batch_size):
29 | data = batch[sample_idx][field_idx]
30 | pad_output[sample_idx][: data.shape[0]] = data
31 | output.append(pad_output)
32 |
33 | # 3 -> sequence_label
34 | # -100 is ignore_index in the cross-entropy loss function.
35 | pad_output = torch.full([batch_size, max_length], -100, dtype=batch[0][3].dtype)
36 | for sample_idx in range(batch_size):
37 | data = batch[sample_idx][3]
38 | pad_output[sample_idx][: data.shape[0]] = data
39 | output.append(pad_output)
40 |
41 | # 4 -> is word_piece_label
42 | pad_output = torch.full([batch_size, max_length], -100, dtype=batch[0][4].dtype)
43 | for sample_idx in range(batch_size):
44 | data = batch[sample_idx][4]
45 | pad_output[sample_idx][: data.shape[0]] = data
46 | output.append(pad_output)
47 |
48 | return output
49 |
50 |
51 | def collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]:
52 | """
53 | pad to maximum length of this batch
54 | Args:
55 | batch: a batch of samples, each contains a list of field data(Tensor):
56 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx
57 | Returns:
58 | output: list of field batched data, which shape is [batch, max_length]
59 | """
60 | batch_size = len(batch)
61 | max_length = max(x[0].shape[0] for x in batch)
62 | output = []
63 |
64 | for field_idx in range(6):
65 | pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype)
66 | for sample_idx in range(batch_size):
67 | data = batch[sample_idx][field_idx]
68 | pad_output[sample_idx][: data.shape[0]] = data
69 | output.append(pad_output)
70 |
71 | pad_match_labels = torch.zeros([batch_size, max_length, max_length], dtype=torch.long)
72 | for sample_idx in range(batch_size):
73 | data = batch[sample_idx][6]
74 | pad_match_labels[sample_idx, : data.shape[1], : data.shape[1]] = data
75 | output.append(pad_match_labels)
76 |
77 | output.append(torch.stack([x[-2] for x in batch]))
78 | output.append(torch.stack([x[-1] for x in batch]))
79 |
80 | return output
81 |
--------------------------------------------------------------------------------
/datasets/mrc_ner_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: mrc_ner_dataset.py
5 |
6 | import json
7 | import torch
8 | from tokenizers import BertWordPieceTokenizer
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class MRCNERDataset(Dataset):
13 | """
14 | MRC NER Dataset
15 | Args:
16 | json_path: path to mrc-ner style json
17 | tokenizer: BertTokenizer
18 | max_length: int, max length of query+context
19 | possible_only: if True, only use possible samples that contain answer for the query/context
20 | is_chinese: is chinese dataset
21 | """
22 | def __init__(self, json_path, tokenizer: BertWordPieceTokenizer, max_length: int = 512, possible_only=False,
23 | is_chinese=False, pad_to_maxlen=False):
24 | self.all_data = json.load(open(json_path, encoding="utf-8"))
25 | self.tokenizer = tokenizer
26 | self.max_length = max_length
27 | self.possible_only = possible_only
28 | if self.possible_only:
29 | self.all_data = [
30 | x for x in self.all_data if x["start_position"]
31 | ]
32 | self.is_chinese = is_chinese
33 | self.pad_to_maxlen = pad_to_maxlen
34 |
35 | def __len__(self):
36 | return len(self.all_data)
37 |
38 | def __getitem__(self, item):
39 | """
40 | Args:
41 | item: int, idx
42 | Returns:
43 | tokens: tokens of query + context, [seq_len]
44 | token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
45 | start_labels: start labels of NER in tokens, [seq_len]
46 | end_labels: end labelsof NER in tokens, [seq_len]
47 | label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
48 | match_labels: match labels, [seq_len, seq_len]
49 | sample_idx: sample id
50 | label_idx: label id
51 | """
52 | data = self.all_data[item]
53 | tokenizer = self.tokenizer
54 |
55 | qas_id = data.get("qas_id", "0.0")
56 | sample_idx, label_idx = qas_id.split(".")
57 | sample_idx = torch.LongTensor([int(sample_idx)])
58 | label_idx = torch.LongTensor([int(label_idx)])
59 |
60 | query = data["query"]
61 | context = data["context"]
62 | start_positions = data["start_position"]
63 | end_positions = data["end_position"]
64 |
65 | if self.is_chinese:
66 | context = "".join(context.split())
67 | end_positions = [x+1 for x in end_positions]
68 | else:
69 | # add space offsets
70 | words = context.split()
71 | start_positions = [x + sum([len(w) for w in words[:x]]) for x in start_positions]
72 | end_positions = [x + sum([len(w) for w in words[:x + 1]]) for x in end_positions]
73 |
74 | query_context_tokens = tokenizer.encode(query, context, add_special_tokens=True)
75 | tokens = query_context_tokens.ids
76 | type_ids = query_context_tokens.type_ids
77 | offsets = query_context_tokens.offsets
78 |
79 | # find new start_positions/end_positions, considering
80 | # 1. we add query tokens at the beginning
81 | # 2. word-piece tokenize
82 | origin_offset2token_idx_start = {}
83 | origin_offset2token_idx_end = {}
84 | for token_idx in range(len(tokens)):
85 | # skip query tokens
86 | if type_ids[token_idx] == 0:
87 | continue
88 | token_start, token_end = offsets[token_idx]
89 | # skip [CLS] or [SEP]
90 | if token_start == token_end == 0:
91 | continue
92 | origin_offset2token_idx_start[token_start] = token_idx
93 | origin_offset2token_idx_end[token_end] = token_idx
94 |
95 | new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]
96 | new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]
97 |
98 | label_mask = [
99 | (0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1)
100 | for token_idx in range(len(tokens))
101 | ]
102 | start_label_mask = label_mask.copy()
103 | end_label_mask = label_mask.copy()
104 |
105 | # the start/end position must be whole word
106 | if not self.is_chinese:
107 | for token_idx in range(len(tokens)):
108 | current_word_idx = query_context_tokens.words[token_idx]
109 | next_word_idx = query_context_tokens.words[token_idx+1] if token_idx+1 < len(tokens) else None
110 | prev_word_idx = query_context_tokens.words[token_idx-1] if token_idx-1 > 0 else None
111 | if prev_word_idx is not None and current_word_idx == prev_word_idx:
112 | start_label_mask[token_idx] = 0
113 | if next_word_idx is not None and current_word_idx == next_word_idx:
114 | end_label_mask[token_idx] = 0
115 |
116 | assert all(start_label_mask[p] != 0 for p in new_start_positions)
117 | assert all(end_label_mask[p] != 0 for p in new_end_positions)
118 |
119 | assert len(new_start_positions) == len(new_end_positions) == len(start_positions)
120 | assert len(label_mask) == len(tokens)
121 | start_labels = [(1 if idx in new_start_positions else 0)
122 | for idx in range(len(tokens))]
123 | end_labels = [(1 if idx in new_end_positions else 0)
124 | for idx in range(len(tokens))]
125 |
126 | # truncate
127 | tokens = tokens[: self.max_length]
128 | type_ids = type_ids[: self.max_length]
129 | start_labels = start_labels[: self.max_length]
130 | end_labels = end_labels[: self.max_length]
131 | start_label_mask = start_label_mask[: self.max_length]
132 | end_label_mask = end_label_mask[: self.max_length]
133 |
134 | # make sure last token is [SEP]
135 | sep_token = tokenizer.token_to_id("[SEP]")
136 | if tokens[-1] != sep_token:
137 | assert len(tokens) == self.max_length
138 | tokens = tokens[: -1] + [sep_token]
139 | start_labels[-1] = 0
140 | end_labels[-1] = 0
141 | start_label_mask[-1] = 0
142 | end_label_mask[-1] = 0
143 |
144 | if self.pad_to_maxlen:
145 | tokens = self.pad(tokens, 0)
146 | type_ids = self.pad(type_ids, 1)
147 | start_labels = self.pad(start_labels)
148 | end_labels = self.pad(end_labels)
149 | start_label_mask = self.pad(start_label_mask)
150 | end_label_mask = self.pad(end_label_mask)
151 |
152 | seq_len = len(tokens)
153 | match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
154 | for start, end in zip(new_start_positions, new_end_positions):
155 | if start >= seq_len or end >= seq_len:
156 | continue
157 | match_labels[start, end] = 1
158 |
159 | return [
160 | torch.LongTensor(tokens),
161 | torch.LongTensor(type_ids),
162 | torch.LongTensor(start_labels),
163 | torch.LongTensor(end_labels),
164 | torch.LongTensor(start_label_mask),
165 | torch.LongTensor(end_label_mask),
166 | match_labels,
167 | sample_idx,
168 | label_idx
169 | ]
170 |
171 | def pad(self, lst, value=0, max_length=None):
172 | max_length = max_length or self.max_length
173 | while len(lst) < max_length:
174 | lst.append(value)
175 | return lst
176 |
177 |
178 | def run_dataset():
179 | """test dataset"""
180 | import os
181 | from datasets.collate_functions import collate_to_max_length
182 | from torch.utils.data import DataLoader
183 | # zh datasets
184 | bert_path = "/data/nfsdata/nlp/BERT_BASE_DIR/chinese_L-12_H-768_A-12"
185 | vocab_file = os.path.join(bert_path, "vocab.txt")
186 | # json_path = "/mnt/mrc/zh_msra/mrc-ner.test"
187 | json_path = "/data/xiaoya/datasets/mrc_ner/zh_msra/mrc-ner.train"
188 | is_chinese = True
189 |
190 | # en datasets
191 | # bert_path = "/mnt/mrc/bert-base-uncased"
192 | # json_path = "/mnt/mrc/ace2004/mrc-ner.train"
193 | # json_path = "/mnt/mrc/genia/mrc-ner.train"
194 | # is_chinese = False
195 |
196 | vocab_file = os.path.join(bert_path, "vocab.txt")
197 | tokenizer = BertWordPieceTokenizer(vocab_file)
198 | dataset = MRCNERDataset(json_path=json_path, tokenizer=tokenizer,
199 | is_chinese=is_chinese)
200 |
201 | dataloader = DataLoader(dataset, batch_size=1,
202 | collate_fn=collate_to_max_length)
203 |
204 | for batch in dataloader:
205 | for tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx in zip(*batch):
206 | tokens = tokens.tolist()
207 | start_positions, end_positions = torch.where(match_labels > 0)
208 | start_positions = start_positions.tolist()
209 | end_positions = end_positions.tolist()
210 | print(start_labels.numpy().tolist())
211 |
212 | tmp_start_position = []
213 | for tmp_idx, tmp_label in enumerate(start_labels.numpy().tolist()):
214 | if tmp_label != 0:
215 | tmp_start_position.append(tmp_idx)
216 |
217 | tmp_end_position = []
218 | for tmp_idx, tmp_label in enumerate(end_labels.numpy().tolist()):
219 | if tmp_label != 0:
220 | tmp_end_position.append(tmp_idx)
221 |
222 | if not start_positions:
223 | continue
224 | print("="*20)
225 | print(f"len: {len(tokens)}", tokenizer.decode(tokens, skip_special_tokens=False))
226 | for start, end in zip(start_positions, end_positions):
227 | print(str(sample_idx.item()), str(label_idx.item()) + "\t" + tokenizer.decode(tokens[start: end+1]))
228 |
229 | print("!!!"*20)
230 | for start, end in zip(tmp_start_position, tmp_end_position):
231 | print(str(sample_idx.item()), str(label_idx.item()) + "\t" + tokenizer.decode(tokens[start: end + 1]))
232 |
233 |
234 | if __name__ == '__main__':
235 | run_dataset()
236 |
--------------------------------------------------------------------------------
/datasets/tagger_ner_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tagger_ner_dataset.py
5 |
6 | import torch
7 | from transformers import AutoTokenizer
8 | from torch.utils.data import Dataset
9 |
10 |
11 | def get_labels(data_sign):
12 | """gets the list of labels for this data set."""
13 | if data_sign == "zh_onto":
14 | return ["O", "S-GPE", "B-GPE", "M-GPE", "E-GPE",
15 | "S-LOC", "B-LOC", "M-LOC", "E-LOC",
16 | "S-PER", "B-PER", "M-PER", "E-PER",
17 | "S-ORG", "B-ORG", "M-ORG", "E-ORG",]
18 | elif data_sign == "zh_msra":
19 | return ["O", "S-NS", "B-NS", "M-NS", "E-NS",
20 | "S-NR", "B-NR", "M-NR", "E-NR",
21 | "S-NT", "B-NT", "M-NT", "E-NT"]
22 | elif data_sign == "en_onto":
23 | return ["O", "S-LAW", "B-LAW", "M-LAW", "E-LAW",
24 | "S-EVENT", "B-EVENT", "M-EVENT", "E-EVENT",
25 | "S-CARDINAL", "B-CARDINAL", "M-CARDINAL", "E-CARDINAL",
26 | "S-FAC", "B-FAC", "M-FAC", "E-FAC",
27 | "S-TIME", "B-TIME", "M-TIME", "E-TIME",
28 | "S-DATE", "B-DATE", "M-DATE", "E-DATE",
29 | "S-ORDINAL", "B-ORDINAL", "M-ORDINAL", "E-ORDINAL",
30 | "S-ORG", "B-ORG", "M-ORG", "E-ORG",
31 | "S-QUANTITY", "B-QUANTITY", "M-QUANTITY", "E-QUANTITY",
32 | "S-PERCENT", "B-PERCENT", "M-PERCENT", "E-PERCENT",
33 | "S-WORK_OF_ART", "B-WORK_OF_ART", "M-WORK_OF_ART", "E-WORK_OF_ART",
34 | "S-LOC", "B-LOC", "M-LOC", "E-LOC",
35 | "S-LANGUAGE", "B-LANGUAGE", "M-LANGUAGE", "E-LANGUAGE",
36 | "S-NORP", "B-NORP", "M-NORP", "E-NORP",
37 | "S-MONEY", "B-MONEY", "M-MONEY", "E-MONEY",
38 | "S-PERSON", "B-PERSON", "M-PERSON", "E-PERSON",
39 | "S-GPE", "B-GPE", "M-GPE", "E-GPE",
40 | "S-PRODUCT", "B-PRODUCT", "M-PRODUCT", "E-PRODUCT"]
41 | elif data_sign == "en_conll03":
42 | return ["O", "S-ORG", "B-ORG", "M-ORG", "E-ORG",
43 | "S-PER", "B-PER", "M-PER", "E-PER",
44 | "S-LOC", "B-LOC", "M-LOC", "E-LOC",
45 | "S-MISC", "B-MISC", "M-MISC", "E-MISC"]
46 | return ["0", "1"]
47 |
48 |
49 | def load_data_in_conll(data_path):
50 | """
51 | Desc:
52 | load data in conll format
53 | Returns:
54 | [([word1, word2, word3, word4], [label1, label2, label3, label4]),
55 | ([word5, word6, word7, wordd8], [label5, label6, label7, label8])]
56 | """
57 | dataset = []
58 | with open(data_path, "r", encoding="utf-8") as f:
59 | datalines = f.readlines()
60 | sentence, labels = [], []
61 |
62 | for line in datalines:
63 | line = line.strip()
64 | if len(line) == 0:
65 | dataset.append((sentence, labels))
66 | sentence, labels = [], []
67 | else:
68 | word, tag = line.split(" ")
69 | sentence.append(word)
70 | labels.append(tag)
71 | return dataset
72 |
73 |
74 | class TaggerNERDataset(Dataset):
75 | """
76 | MRC NER Dataset
77 | Args:
78 | data_path: path to Conll-style named entity dadta file.
79 | tokenizer: BertTokenizer
80 | max_length: int, max length of query+context
81 | is_chinese: is chinese dataset
82 | Note:
83 | https://github.com/huggingface/transformers/blob/143738214cb83e471f3a43652617c8881370342c/examples/pytorch/token-classification/run_ner.py#L362
84 | https://github.com/huggingface/transformers/blob/143738214cb83e471f3a43652617c8881370342c/src/transformers/models/bert/modeling_bert.py#L1739
85 | """
86 | def __init__(self, data_path, tokenizer: AutoTokenizer, dataset_signature, max_length: int = 512,
87 | is_chinese=False, pad_to_maxlen=False, tagging_schema="BMESO", ):
88 | self.all_data = load_data_in_conll(data_path)
89 | self.tokenizer = tokenizer
90 | self.max_length = max_length
91 | self.is_chinese = is_chinese
92 | self.pad_to_maxlen = pad_to_maxlen
93 | self.pad_idx = 0
94 | self.cls_idx = 101
95 | self.sep_idx = 102
96 | self.label2idx = {label_item: label_idx for label_idx, label_item in enumerate(get_labels(dataset_signature))}
97 |
98 | def __len__(self):
99 | return len(self.all_data)
100 |
101 | def __getitem__(self, item):
102 | data = self.all_data[item]
103 | token_lst, label_lst = tuple(data)
104 | wordpiece_token_lst, wordpiece_label_lst = [], []
105 |
106 | for token_item, label_item in zip(token_lst, label_lst):
107 | tmp_token_lst = self.tokenizer.encode(token_item, add_special_tokens=False, return_token_type_ids=None)
108 | if len(tmp_token_lst) == 1:
109 | wordpiece_token_lst.append(tmp_token_lst[0])
110 | wordpiece_label_lst.append(label_item)
111 | else:
112 | len_wordpiece = len(tmp_token_lst)
113 | wordpiece_token_lst.extend(tmp_token_lst)
114 | tmp_label_lst = [label_item] + [-100 for idx in range((len_wordpiece - 1))]
115 | wordpiece_label_lst.extend(tmp_label_lst)
116 |
117 | if len(wordpiece_token_lst) > self.max_length - 2:
118 | wordpiece_token_lst = wordpiece_token_lst[: self.max_length-2]
119 | wordpiece_label_lst = wordpiece_label_lst[: self.max_length-2]
120 |
121 | wordpiece_token_lst = [self.cls_idx] + wordpiece_token_lst + [self.sep_idx]
122 | wordpiece_label_lst = [-100] + wordpiece_label_lst + [-100]
123 | # token_type_ids: segment token indices to indicate first and second portions of the inputs.
124 | # - 0 corresponds to a "sentence a" token
125 | # - 1 corresponds to a "sentence b" token
126 | token_type_ids = [0] * len(wordpiece_token_lst)
127 | # attention_mask: mask to avoid performing attention on padding token indices.
128 | # - 1 for tokens that are not masked.
129 | # - 0 for tokens that are masked.
130 | attention_mask = [1] * len(wordpiece_token_lst)
131 | is_wordpiece_mask = [1 if label_item != -100 else -100 for label_item in wordpiece_label_lst]
132 | wordpiece_label_idx_lst = [self.label2idx[label_item] if label_item != -100 else -100 for label_item in wordpiece_label_lst]
133 |
134 | return [torch.tensor(wordpiece_token_lst, dtype=torch.long),
135 | torch.tensor(token_type_ids, dtype=torch.long),
136 | torch.tensor(attention_mask, dtype=torch.long),
137 | torch.tensor(wordpiece_label_idx_lst, dtype=torch.long),
138 | torch.tensor(is_wordpiece_mask, dtype=torch.long)]
139 |
140 |
141 |
--------------------------------------------------------------------------------
/datasets/truncate_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: truncate_dataset.py
5 |
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class TruncateDataset(Dataset):
10 | """Truncate dataset to certain num"""
11 | def __init__(self, dataset: Dataset, max_num: int = 100):
12 | self.dataset = dataset
13 | self.max_num = min(max_num, len(self.dataset))
14 |
15 | def __len__(self):
16 | return self.max_num
17 |
18 | def __getitem__(self, item):
19 | return self.dataset[item]
20 |
21 | def __getattr__(self, item):
22 | """other dataset func"""
23 | return getattr(self.dataset, item)
24 |
--------------------------------------------------------------------------------
/evaluate/mrc_ner_evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: mrc_ner_evaluate.py
5 | # example command:
6 | # python3 mrc_ner_evaluate.py /data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/epoch=0.ckpt \
7 | # /data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/lightning_logs/version_2/hparams.yaml
8 |
9 | import sys
10 | from pytorch_lightning import Trainer
11 | from train.mrc_ner_trainer import BertLabeling
12 | from utils.random_seed import set_random_seed
13 |
14 | set_random_seed(0)
15 |
16 |
17 | def evaluate(ckpt, hparams_file, gpus=[0, 1], max_length=300):
18 | trainer = Trainer(gpus=gpus, distributed_backend="dp")
19 |
20 | model = BertLabeling.load_from_checkpoint(
21 | checkpoint_path=ckpt,
22 | hparams_file=hparams_file,
23 | map_location=None,
24 | batch_size=1,
25 | max_length=max_length,
26 | workers=0
27 | )
28 | trainer.test(model=model)
29 |
30 |
31 | if __name__ == '__main__':
32 | # example of running evaluate.py
33 | # CHECKPOINTS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/epoch=2_v1.ckpt"
34 | # HPARAMS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/lightning_logs/version_2/hparams.yaml"
35 | # GPUS="1,2,3"
36 | CHECKPOINTS = sys.argv[1]
37 | HPARAMS = sys.argv[2]
38 | try:
39 | GPUS = [int(gpu_item) for gpu_item in sys.argv[3].strip().split(",")]
40 | except:
41 | GPUS = [0]
42 |
43 | try:
44 | MAXLEN = int(sys.argv[4])
45 | except:
46 | MAXLEN = 300
47 |
48 | evaluate(ckpt=CHECKPOINTS, hparams_file=HPARAMS, gpus=GPUS, max_length=MAXLEN)
49 |
--------------------------------------------------------------------------------
/evaluate/tagger_ner_evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tagger_ner_evaluate.py
5 | # example command:
6 |
7 |
8 | import sys
9 | from pytorch_lightning import Trainer
10 | from train.bert_tagger_trainer import BertSequenceLabeling
11 | from utils.random_seed import set_random_seed
12 |
13 | set_random_seed(0)
14 |
15 |
16 | def evaluate(ckpt, hparams_file, gpus=[0, 1], max_length=300):
17 | trainer = Trainer(gpus=gpus, distributed_backend="dp")
18 |
19 | model = BertSequenceLabeling.load_from_checkpoint(
20 | checkpoint_path=ckpt,
21 | hparams_file=hparams_file,
22 | map_location=None,
23 | batch_size=1,
24 | max_length=max_length,
25 | workers=0
26 | )
27 | trainer.test(model=model)
28 |
29 |
30 | if __name__ == '__main__':
31 | # example of running evaluate.py
32 | # CHECKPOINTS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/epoch=2_v1.ckpt"
33 | # HPARAMS = "/mnt/mrc/train_logs/zh_msra/zh_msra_20200911_for_flat_debug/lightning_logs/version_2/hparams.yaml"
34 | # GPUS="1,2,3"
35 | CHECKPOINTS = sys.argv[1]
36 | HPARAMS = sys.argv[2]
37 |
38 | try:
39 | GPUS = [int(gpu_item) for gpu_item in sys.argv[3].strip().split(",")]
40 | except:
41 | GPUS = [0]
42 |
43 | try:
44 | MAXLEN = int(sys.argv[4])
45 | except:
46 | MAXLEN = 300
47 |
48 | evaluate(ckpt=CHECKPOINTS, hparams_file=HPARAMS, gpus=GPUS, max_length=MAXLEN, )
49 |
--------------------------------------------------------------------------------
/inference/mrc_ner_inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: mrc_ner_inference.py
5 |
6 | import os
7 | import torch
8 | import argparse
9 | from torch.utils.data import DataLoader
10 | from utils.random_seed import set_random_seed
11 | set_random_seed(0)
12 | from train.mrc_ner_trainer import BertLabeling
13 | from tokenizers import BertWordPieceTokenizer
14 | from datasets.mrc_ner_dataset import MRCNERDataset
15 | from metrics.functional.query_span_f1 import extract_flat_spans, extract_nested_spans
16 |
17 | def get_dataloader(config, data_prefix="test"):
18 | data_path = os.path.join(config.data_dir, f"mrc-ner.{data_prefix}")
19 | vocab_path = os.path.join(config.bert_dir, "vocab.txt")
20 | data_tokenizer = BertWordPieceTokenizer(vocab_path)
21 |
22 | dataset = MRCNERDataset(json_path=data_path,
23 | tokenizer=data_tokenizer,
24 | max_length=config.max_length,
25 | is_chinese=config.is_chinese,
26 | pad_to_maxlen=False)
27 |
28 | dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
29 |
30 | return dataloader, data_tokenizer
31 |
32 | def get_query_index_to_label_cate(dataset_sign):
33 | # NOTICE: need change if you use other datasets.
34 | # please notice it should in line with the mrc-ner.test/train/dev json file
35 | if dataset_sign == "conll03":
36 | return {1: "ORG", 2: "PER", 3: "LOC", 4: "MISC"}
37 | elif dataset_sign == "ace04":
38 | return {1: "GPE", 2: "ORG", 3: "PER", 4: "FAC", 5: "VEH", 6: "LOC", 7: "WEA"}
39 |
40 |
41 | def get_parser() -> argparse.ArgumentParser:
42 | parser = argparse.ArgumentParser(description="inference the model output.")
43 | parser.add_argument("--data_dir", type=str, required=True)
44 | parser.add_argument("--bert_dir", type=str, required=True)
45 | parser.add_argument("--max_length", type=int, default=256)
46 | parser.add_argument("--is_chinese", action="store_true")
47 | parser.add_argument("--model_ckpt", type=str, default="")
48 | parser.add_argument("--hparams_file", type=str, default="")
49 | parser.add_argument("--flat_ner", action="store_true",)
50 | parser.add_argument("--dataset_sign", type=str, choices=["ontonotes4", "msra", "conll03", "ace04", "ace05"], default="conll03")
51 |
52 | return parser
53 |
54 |
55 | def main():
56 | parser = get_parser()
57 | args = parser.parse_args()
58 | trained_mrc_ner_model = BertLabeling.load_from_checkpoint(
59 | checkpoint_path=args.model_ckpt,
60 | hparams_file=args.hparams_file,
61 | map_location=None,
62 | batch_size=1,
63 | max_length=args.max_length,
64 | workers=0)
65 |
66 | data_loader, data_tokenizer = get_dataloader(args,)
67 | # load token
68 | vocab_path = os.path.join(args.bert_dir, "vocab.txt")
69 | with open(vocab_path, "r") as f:
70 | subtokens = [token.strip() for token in f.readlines()]
71 | idx2tokens = {}
72 | for token_idx, token in enumerate(subtokens):
73 | idx2tokens[token_idx] = token
74 |
75 | query2label_dict = get_query_index_to_label_cate(args.dataset_sign)
76 |
77 | for batch in data_loader:
78 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
79 | attention_mask = (tokens != 0).long()
80 |
81 | start_logits, end_logits, span_logits = trained_mrc_ner_model.model(tokens, attention_mask=attention_mask, token_type_ids=token_type_ids)
82 | start_preds, end_preds, span_preds = start_logits > 0, end_logits > 0, span_logits > 0
83 |
84 | subtokens_idx_lst = tokens.numpy().tolist()[0]
85 | subtokens_lst = [idx2tokens[item] for item in subtokens_idx_lst]
86 | label_cate = query2label_dict[label_idx.item()]
87 | readable_input_str = data_tokenizer.decode(subtokens_idx_lst, skip_special_tokens=True)
88 |
89 | if args.flat_ner:
90 | entities_info = extract_flat_spans(torch.squeeze(start_preds), torch.squeeze(end_preds),
91 | torch.squeeze(span_preds), torch.squeeze(attention_mask), pseudo_tag=label_cate)
92 | entity_lst = []
93 |
94 | if len(entities_info) != 0:
95 | for entity_info in entities_info:
96 | start, end = entity_info[0], entity_info[1]
97 | entity_string = " ".join(subtokens_lst[start: end])
98 | entity_string = entity_string.replace(" ##", "")
99 | entity_lst.append((start, end, entity_string, entity_info[2]))
100 |
101 | else:
102 | match_preds = span_logits > 0
103 | entities_info = extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask, pseudo_tag=label_cate)
104 |
105 | entity_lst = []
106 |
107 | if len(entities_info) != 0:
108 | for entity_info in entities_info:
109 | start, end = entity_info[0], entity_info[1]
110 | entity_string = " ".join(subtokens_lst[start: end+1 ])
111 | entity_string = entity_string.replace(" ##", "")
112 | entity_lst.append((start, end+1, entity_string, entity_info[2]))
113 |
114 | print("*="*10)
115 | print(f"Given input: {readable_input_str}")
116 | print(f"Model predict: {entity_lst}")
117 | # entity_lst is a list of (subtoken_start_pos, subtoken_end_pos, substring, entity_type)
118 |
119 | if __name__ == "__main__":
120 | main()
--------------------------------------------------------------------------------
/inference/tagger_ner_inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tagger_ner_inference.py
5 |
6 | import os
7 | import torch
8 | import argparse
9 | from torch.utils.data import DataLoader
10 | from utils.random_seed import set_random_seed
11 | set_random_seed(0)
12 | from train.bert_tagger_trainer import BertSequenceLabeling
13 | from transformers import AutoTokenizer
14 | from datasets.tagger_ner_dataset import get_labels
15 | from datasets.tagger_ner_dataset import TaggerNERDataset
16 | from metrics.functional.tagger_span_f1 import get_entity_from_bmes_lst, transform_predictions_to_labels
17 |
18 |
19 | def get_dataloader(config, data_prefix="test"):
20 | data_path = os.path.join(config.data_dir, f"{data_prefix}{config.data_file_suffix}")
21 | data_tokenizer = AutoTokenizer.from_pretrained(config.bert_dir, use_fast=False, do_lower_case=config.do_lowercase)
22 |
23 | dataset = TaggerNERDataset(data_path, data_tokenizer, config.dataset_sign,
24 | max_length=config.max_length, is_chinese=config.is_chinese, pad_to_maxlen=False)
25 |
26 | dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
27 |
28 | return dataloader, data_tokenizer
29 |
30 |
31 | def get_parser() -> argparse.ArgumentParser:
32 | parser = argparse.ArgumentParser(description="inference the model output.")
33 | parser.add_argument("--data_dir", type=str, required=True)
34 | parser.add_argument("--bert_dir", type=str, required=True)
35 | parser.add_argument("--max_length", type=int, default=256)
36 | parser.add_argument("--is_chinese", action="store_true")
37 | parser.add_argument("--model_ckpt", type=str, default="")
38 | parser.add_argument("--hparams_file", type=str, default="")
39 | parser.add_argument("--do_lowercase", action="store_true")
40 | parser.add_argument("--data_file_suffix", type=str, default=".word.bmes")
41 | parser.add_argument("--dataset_sign", type=str, choices=["en_onto", "en_conll03", "zh_onto", "zh_msra" ], default="en_onto")
42 |
43 | return parser
44 |
45 |
46 | def main():
47 | parser = get_parser()
48 | args = parser.parse_args()
49 |
50 | trained_tagger_ner_model = BertSequenceLabeling.load_from_checkpoint(
51 | checkpoint_path=args.model_ckpt,
52 | hparams_file=args.hparams_file,
53 | map_location=None,
54 | batch_size=1,
55 | max_length=args.max_length,
56 | workers=0)
57 |
58 | entity_label_lst = get_labels(args.dataset_sign)
59 | task_idx2label = {label_idx: label_item for label_idx, label_item in enumerate(entity_label_lst)}
60 |
61 | data_loader, data_tokenizer = get_dataloader(args)
62 | vocab_path = os.path.join(args.bert_dir, "vocab.txt")
63 | # load token
64 | vocab_path = os.path.join(args.bert_dir, "vocab.txt")
65 | with open(vocab_path, "r") as f:
66 | subtokens = [token.strip() for token in f.readlines()]
67 | idx2tokens = {}
68 | for token_idx, token in enumerate(subtokens):
69 | idx2tokens[token_idx] = token
70 |
71 | for batch in data_loader:
72 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch
73 | batch_size = token_input_ids.shape[0]
74 | logits = trained_tagger_ner_model.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
75 |
76 | sequence_pred_lst = transform_predictions_to_labels(logits.view(batch_size, -1, len(entity_label_lst)),
77 | is_wordpiece_mask, task_idx2label, input_type="logit")
78 | batch_subtokens_idx_lst = token_input_ids.numpy().tolist()[0]
79 | batch_subtokens_lst = [idx2tokens[item] for item in batch_subtokens_idx_lst]
80 | readable_input_str = data_tokenizer.decode(batch_subtokens_idx_lst, skip_special_tokens=True)
81 |
82 | batch_entity_lst = [get_entity_from_bmes_lst(label_lst_item) for label_lst_item in sequence_pred_lst]
83 |
84 | pred_entity_lst = []
85 |
86 | for entity_lst, subtoken_lst in zip(batch_entity_lst, batch_subtokens_lst):
87 | if len(entity_lst) != 0:
88 | # example of entity_lst:
89 | # ['[0,3]PER', '[6,9]ORG', '[10]PER']
90 | for entity_info in entity_lst:
91 | if "," in entity_info:
92 | inter_pos = entity_info.find(",")
93 | start_pos = 1
94 | end_pos = entity_info.find("]")
95 | start_idx = int(entity_info[start_pos: inter_pos])
96 | end_idx = int(entity_info[inter_pos+1: end_pos])
97 | else:
98 | start_pos = 1
99 | end_pos = entity_info.find("]")
100 | start_idx = int(entity_info[start_pos:end_pos])
101 | end_idx = int(entity_info[start_pos:end_pos])
102 |
103 | entity_tokens = subtoken_lst[start_idx: end_idx]
104 | entity_string = " ".join(entity_tokens)
105 | entity_string = entity_string.replace(" ##", "")
106 | # append start, end
107 | pred_entity_lst.append((entity_string, entity_info[end_pos+1:]))
108 | else:
109 | pred_entity_lst.append([])
110 |
111 | print("*=" * 10)
112 | print(f"Given input: {readable_input_str}")
113 | print(f"Model predict: {pred_entity_lst}")
114 |
115 |
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/functional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/metrics/functional/__init__.py
--------------------------------------------------------------------------------
/metrics/functional/query_span_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: query_span_f1.py
5 |
6 | import torch
7 | import numpy as np
8 | from utils.bmes_decode import bmes_decode
9 |
10 |
11 | def query_span_f1(start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels, flat=False):
12 | """
13 | Compute span f1 according to query-based model output
14 | Args:
15 | start_preds: [bsz, seq_len]
16 | end_preds: [bsz, seq_len]
17 | match_logits: [bsz, seq_len, seq_len]
18 | start_label_mask: [bsz, seq_len]
19 | end_label_mask: [bsz, seq_len]
20 | match_labels: [bsz, seq_len, seq_len]
21 | flat: if True, decode as flat-ner
22 | Returns:
23 | span-f1 counts, tensor of shape [3]: tp, fp, fn
24 | """
25 | start_label_mask = start_label_mask.bool()
26 | end_label_mask = end_label_mask.bool()
27 | match_labels = match_labels.bool()
28 | bsz, seq_len = start_label_mask.size()
29 | # [bsz, seq_len, seq_len]
30 | match_preds = match_logits > 0
31 | # [bsz, seq_len]
32 | start_preds = start_preds.bool()
33 | # [bsz, seq_len]
34 | end_preds = end_preds.bool()
35 |
36 | match_preds = (match_preds
37 | & start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
38 | & end_preds.unsqueeze(1).expand(-1, seq_len, -1))
39 | match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len)
40 | & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1))
41 | match_label_mask = torch.triu(match_label_mask, 0) # start should be less or equal to end
42 | match_preds = match_label_mask & match_preds
43 |
44 | tp = (match_labels & match_preds).long().sum()
45 | fp = (~match_labels & match_preds).long().sum()
46 | fn = (match_labels & ~match_preds).long().sum()
47 | return torch.stack([tp, fp, fn])
48 |
49 |
50 | def extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask, pseudo_tag="TAG"):
51 | start_label_mask = start_label_mask.bool()
52 | end_label_mask = end_label_mask.bool()
53 | bsz, seq_len = start_label_mask.size()
54 | start_preds = start_preds.bool()
55 | end_preds = end_preds.bool()
56 |
57 | match_preds = (match_preds & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(1).expand(-1, seq_len, -1))
58 | match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len) & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1))
59 | match_label_mask = torch.triu(match_label_mask, 0) # start should be less or equal to end
60 | match_preds = match_label_mask & match_preds
61 | match_pos_pairs = np.transpose(np.nonzero(match_preds.numpy())).tolist()
62 | return [(pos[0], pos[1], pseudo_tag) for pos in match_pos_pairs]
63 |
64 |
65 | def extract_flat_spans(start_pred, end_pred, match_pred, label_mask, pseudo_tag = "TAG"):
66 | """
67 | Extract flat-ner spans from start/end/match logits
68 | Args:
69 | start_pred: [seq_len], 1/True for start, 0/False for non-start
70 | end_pred: [seq_len, 2], 1/True for end, 0/False for non-end
71 | match_pred: [seq_len, seq_len], 1/True for match, 0/False for non-match
72 | label_mask: [seq_len], 1 for valid boundary.
73 | Returns:
74 | tags: list of tuple (start, end)
75 | Examples:
76 | >>> start_pred = [0, 1]
77 | >>> end_pred = [0, 1]
78 | >>> match_pred = [[0, 0], [0, 1]]
79 | >>> label_mask = [1, 1]
80 | >>> extract_flat_spans(start_pred, end_pred, match_pred, label_mask)
81 | [(1, 2)]
82 | """
83 | pseudo_input = "a"
84 |
85 | bmes_labels = ["O"] * len(start_pred)
86 | start_positions = [idx for idx, tmp in enumerate(start_pred) if tmp and label_mask[idx]]
87 | end_positions = [idx for idx, tmp in enumerate(end_pred) if tmp and label_mask[idx]]
88 |
89 | for start_item in start_positions:
90 | bmes_labels[start_item] = f"B-{pseudo_tag}"
91 | for end_item in end_positions:
92 | bmes_labels[end_item] = f"E-{pseudo_tag}"
93 |
94 | for tmp_start in start_positions:
95 | tmp_end = [tmp for tmp in end_positions if tmp >= tmp_start]
96 | if len(tmp_end) == 0:
97 | continue
98 | else:
99 | tmp_end = min(tmp_end)
100 | if match_pred[tmp_start][tmp_end]:
101 | if tmp_start != tmp_end:
102 | for i in range(tmp_start+1, tmp_end):
103 | bmes_labels[i] = f"M-{pseudo_tag}"
104 | else:
105 | bmes_labels[tmp_end] = f"S-{pseudo_tag}"
106 |
107 | tags = bmes_decode([(pseudo_input, label) for label in bmes_labels])
108 |
109 | return [(entity.begin, entity.end, entity.tag) for entity in tags]
110 |
111 |
112 | def remove_overlap(spans):
113 | """
114 | remove overlapped spans greedily for flat-ner
115 | Args:
116 | spans: list of tuple (start, end), which means [start, end] is a ner-span
117 | Returns:
118 | spans without overlap
119 | """
120 | output = []
121 | occupied = set()
122 | for start, end in spans:
123 | if any(x for x in range(start, end+1)) in occupied:
124 | continue
125 | output.append((start, end))
126 | for x in range(start, end + 1):
127 | occupied.add(x)
128 | return output
129 |
--------------------------------------------------------------------------------
/metrics/functional/tagger_span_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tagger_span_f1.py
5 |
6 | import torch
7 | import torch.nn.functional as F
8 |
9 |
10 | def transform_predictions_to_labels(sequence_input_lst, wordpiece_mask, idx2label_map, input_type="logit"):
11 | """
12 | shape:
13 | sequence_input_lst: [batch_size, seq_len, num_labels]
14 | wordpiece_mask: [batch_size, seq_len, ]
15 | """
16 | wordpiece_mask = wordpiece_mask.detach().cpu().numpy().tolist()
17 | if input_type == "logit":
18 | label_sequence = torch.argmax(F.softmax(sequence_input_lst, dim=2), dim=2).detach().cpu().numpy().tolist()
19 | elif input_type == "prob":
20 | label_sequence = torch.argmax(sequence_input_lst, dim=2).detach().cpu().numpy().tolist()
21 | elif input_type == "label":
22 | label_sequence = sequence_input_lst.detach().cpu().numpy().tolist()
23 | else:
24 | raise ValueError
25 | output_label_sequence = []
26 | for tmp_idx_lst, tmp_label_lst in enumerate(label_sequence):
27 | tmp_wordpiece_mask = wordpiece_mask[tmp_idx_lst]
28 | tmp_label_seq = []
29 | for tmp_idx, tmp_label in enumerate(tmp_label_lst):
30 | if tmp_wordpiece_mask[tmp_idx] != -100:
31 | tmp_label_seq.append(idx2label_map[tmp_label])
32 | else:
33 | tmp_label_seq.append(-100)
34 | output_label_sequence.append(tmp_label_seq)
35 | return output_label_sequence
36 |
37 |
38 | def compute_tagger_span_f1(sequence_pred_lst, sequence_gold_lst):
39 | sum_true_positive, sum_false_positive, sum_false_negative = 0, 0, 0
40 |
41 | for seq_pred_item, seq_gold_item in zip(sequence_pred_lst, sequence_gold_lst):
42 | gold_entity_lst = get_entity_from_bmes_lst(seq_gold_item)
43 | pred_entity_lst = get_entity_from_bmes_lst(seq_pred_item)
44 |
45 | true_positive_item, false_positive_item, false_negative_item = count_confusion_matrix(pred_entity_lst, gold_entity_lst)
46 | sum_true_positive += true_positive_item
47 | sum_false_negative += false_negative_item
48 | sum_false_positive += false_positive_item
49 |
50 | batch_confusion_matrix = torch.tensor([sum_true_positive, sum_false_positive, sum_false_negative], dtype=torch.long)
51 | return batch_confusion_matrix
52 |
53 |
54 | def count_confusion_matrix(pred_entities, gold_entities):
55 | true_positive, false_positive, false_negative = 0, 0, 0
56 | for span_item in pred_entities:
57 | if span_item in gold_entities:
58 | true_positive += 1
59 | gold_entities.remove(span_item)
60 | else:
61 | false_positive += 1
62 | # these entities are not predicted.
63 | for span_item in gold_entities:
64 | false_negative += 1
65 | return true_positive, false_positive, false_negative
66 |
67 |
68 | def get_entity_from_bmes_lst(label_list):
69 | """reuse the code block from
70 | https://github.com/jiesutd/NCRFpp/blob/105a53a321eca9c1280037c473967858e01aaa43/utils/metric.py#L73
71 | Many thanks to Jie Yang.
72 | """
73 | list_len = len(label_list)
74 | begin_label = 'B-'
75 | end_label = 'E-'
76 | single_label = 'S-'
77 | whole_tag = ''
78 | index_tag = ''
79 | tag_list = []
80 | stand_matrix = []
81 | for i in range(0, list_len):
82 | if label_list[i] != -100:
83 | current_label = label_list[i].upper()
84 | else:
85 | continue
86 |
87 | if begin_label in current_label:
88 | if index_tag != '':
89 | tag_list.append(whole_tag + ',' + str(i-1))
90 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i)
91 | index_tag = current_label.replace(begin_label,"",1)
92 | elif single_label in current_label:
93 | if index_tag != '':
94 | tag_list.append(whole_tag + ',' + str(i-1))
95 | whole_tag = current_label.replace(single_label,"",1) +'[' +str(i)
96 | tag_list.append(whole_tag)
97 | whole_tag = ""
98 | index_tag = ""
99 | elif end_label in current_label:
100 | if index_tag != '':
101 | tag_list.append(whole_tag +',' + str(i))
102 | whole_tag = ''
103 | index_tag = ''
104 | else:
105 | continue
106 | if (whole_tag != '')&(index_tag != ''):
107 | tag_list.append(whole_tag)
108 | tag_list_len = len(tag_list)
109 |
110 | for i in range(0, tag_list_len):
111 | if len(tag_list[i]) > 0:
112 | tag_list[i] = tag_list[i]+ ']'
113 | insert_list = reverse_style(tag_list[i])
114 | stand_matrix.append(insert_list)
115 | return stand_matrix
116 |
117 |
118 | def reverse_style(input_string):
119 | target_position = input_string.index('[')
120 | input_len = len(input_string)
121 | output_string = input_string[target_position:input_len] + input_string[0:target_position]
122 | return output_string
123 |
124 |
--------------------------------------------------------------------------------
/metrics/query_span_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: query_span_f1.py
5 |
6 | from pytorch_lightning.metrics.metric import TensorMetric
7 | from metrics.functional.query_span_f1 import query_span_f1
8 |
9 |
10 | class QuerySpanF1(TensorMetric):
11 | """
12 | Query Span F1
13 | Args:
14 | flat: is flat-ner
15 | """
16 | def __init__(self, reduce_group=None, reduce_op=None, flat=False):
17 | super(QuerySpanF1, self).__init__(name="query_span_f1",
18 | reduce_group=reduce_group,
19 | reduce_op=reduce_op)
20 | self.flat = flat
21 |
22 | def forward(self, start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels):
23 | return query_span_f1(start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels,
24 | flat=self.flat)
25 |
--------------------------------------------------------------------------------
/metrics/tagger_span_f1.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: tagger_span_f1.py
5 |
6 | from pytorch_lightning.metrics.metric import TensorMetric
7 | from metrics.functional.tagger_span_f1 import compute_tagger_span_f1
8 |
9 |
10 | class TaggerSpanF1(TensorMetric):
11 | def __init__(self, reduce_group=None, reduce_op=None):
12 | super(TaggerSpanF1, self).__init__(name="tagger_span_f1", reduce_group=reduce_group, reduce_op=reduce_op)
13 |
14 | def forward(self, sequence_pred_lst, sequence_gold_lst):
15 | return compute_tagger_span_f1(sequence_pred_lst, sequence_gold_lst)
16 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/models/__init__.py
--------------------------------------------------------------------------------
/models/bert_query_ner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_query_ner.py
5 |
6 | import torch
7 | import torch.nn as nn
8 | from transformers import BertModel, BertPreTrainedModel
9 |
10 | from models.classifier import MultiNonLinearClassifier
11 |
12 |
13 | class BertQueryNER(BertPreTrainedModel):
14 | def __init__(self, config):
15 | super(BertQueryNER, self).__init__(config)
16 | self.bert = BertModel(config)
17 |
18 | self.start_outputs = nn.Linear(config.hidden_size, 1)
19 | self.end_outputs = nn.Linear(config.hidden_size, 1)
20 | self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, config.mrc_dropout,
21 | intermediate_hidden_size=config.classifier_intermediate_hidden_size)
22 |
23 | self.hidden_size = config.hidden_size
24 |
25 | self.init_weights()
26 |
27 | def forward(self, input_ids, token_type_ids=None, attention_mask=None):
28 | """
29 | Args:
30 | input_ids: bert input tokens, tensor of shape [seq_len]
31 | token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len]
32 | attention_mask: attention mask, tensor of shape [seq_len]
33 | Returns:
34 | start_logits: start/non-start probs of shape [seq_len]
35 | end_logits: end/non-end probs of shape [seq_len]
36 | match_logits: start-end-match probs of shape [seq_len, 1]
37 | """
38 |
39 | bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
40 |
41 | sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden]
42 | batch_size, seq_len, hid_size = sequence_heatmap.size()
43 |
44 | start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]
45 | end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]
46 |
47 | # for every position $i$ in sequence, should concate $j$ to
48 | # predict if $i$ and $j$ are start_pos and end_pos for an entity.
49 | # [batch, seq_len, seq_len, hidden]
50 | start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
51 | # [batch, seq_len, seq_len, hidden]
52 | end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
53 | # [batch, seq_len, seq_len, hidden*2]
54 | span_matrix = torch.cat([start_extend, end_extend], 3)
55 | # [batch, seq_len, seq_len]
56 | span_logits = self.span_embedding(span_matrix).squeeze(-1)
57 |
58 | return start_logits, end_logits, span_logits
59 |
--------------------------------------------------------------------------------
/models/bert_tagger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_tagger.py
5 | #
6 |
7 | import torch.nn as nn
8 | from transformers import BertModel, BertPreTrainedModel
9 | from models.classifier import BERTTaggerClassifier
10 |
11 |
12 | class BertTagger(BertPreTrainedModel):
13 | def __init__(self, config):
14 | super(BertTagger, self).__init__(config)
15 | self.bert = BertModel(config)
16 |
17 | self.num_labels = config.num_labels
18 | self.hidden_size = config.hidden_size
19 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
20 | if config.classifier_sign == "multi_nonlinear":
21 | self.classifier = BERTTaggerClassifier(self.hidden_size, self.num_labels,
22 | config.classifier_dropout,
23 | act_func=config.classifier_act_func,
24 | intermediate_hidden_size=config.classifier_intermediate_hidden_size)
25 | else:
26 | self.classifier = nn.Linear(self.hidden_size, self.num_labels)
27 |
28 | self.init_weights()
29 |
30 | def forward(self, input_ids, token_type_ids=None, attention_mask=None,):
31 | last_bert_layer, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
32 | last_bert_layer = last_bert_layer.view(-1, self.hidden_size)
33 | last_bert_layer = self.dropout(last_bert_layer)
34 | logits = self.classifier(last_bert_layer)
35 | return logits
--------------------------------------------------------------------------------
/models/classifier.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: classifier.py
5 |
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 |
9 |
10 | class SingleLinearClassifier(nn.Module):
11 | def __init__(self, hidden_size, num_label):
12 | super(SingleLinearClassifier, self).__init__()
13 | self.num_label = num_label
14 | self.classifier = nn.Linear(hidden_size, num_label)
15 |
16 | def forward(self, input_features):
17 | features_output = self.classifier(input_features)
18 | return features_output
19 |
20 |
21 | class MultiNonLinearClassifier(nn.Module):
22 | def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
23 | super(MultiNonLinearClassifier, self).__init__()
24 | self.num_label = num_label
25 | self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
26 | self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
27 | self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
28 | self.dropout = nn.Dropout(dropout_rate)
29 | self.act_func = act_func
30 |
31 | def forward(self, input_features):
32 | features_output1 = self.classifier1(input_features)
33 | if self.act_func == "gelu":
34 | features_output1 = F.gelu(features_output1)
35 | elif self.act_func == "relu":
36 | features_output1 = F.relu(features_output1)
37 | elif self.act_func == "tanh":
38 | features_output1 = F.tanh(features_output1)
39 | else:
40 | raise ValueError
41 | features_output1 = self.dropout(features_output1)
42 | features_output2 = self.classifier2(features_output1)
43 | return features_output2
44 |
45 |
46 | class BERTTaggerClassifier(nn.Module):
47 | def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
48 | super(BERTTaggerClassifier, self).__init__()
49 | self.num_label = num_label
50 | self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
51 | self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
52 | self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
53 | self.dropout = nn.Dropout(dropout_rate)
54 | self.act_func = act_func
55 |
56 | def forward(self, input_features):
57 | features_output1 = self.classifier1(input_features)
58 | if self.act_func == "gelu":
59 | features_output1 = F.gelu(features_output1)
60 | elif self.act_func == "relu":
61 | features_output1 = F.relu(features_output1)
62 | elif self.act_func == "tanh":
63 | features_output1 = F.tanh(features_output1)
64 | else:
65 | raise ValueError
66 | features_output1 = self.dropout(features_output1)
67 | features_output2 = self.classifier2(features_output1)
68 | return features_output2
69 |
--------------------------------------------------------------------------------
/models/model_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: model_config.py
5 |
6 | from transformers import BertConfig
7 |
8 |
9 | class BertQueryNerConfig(BertConfig):
10 | def __init__(self, **kwargs):
11 | super(BertQueryNerConfig, self).__init__(**kwargs)
12 | self.mrc_dropout = kwargs.get("mrc_dropout", 0.1)
13 | self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024)
14 | self.classifier_act_func = kwargs.get("classifier_act_func", "gelu")
15 |
16 | class BertTaggerConfig(BertConfig):
17 | def __init__(self, **kwargs):
18 | super(BertTaggerConfig, self).__init__(**kwargs)
19 | self.num_labels = kwargs.get("num_labels", 6)
20 | self.classifier_dropout = kwargs.get("classifier_dropout", 0.1)
21 | self.classifier_sign = kwargs.get("classifier_sign", "multi_nonlinear")
22 | self.classifier_act_func = kwargs.get("classifier_act_func", "gelu")
23 | self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024)
--------------------------------------------------------------------------------
/ner2mrc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/ner2mrc/__init__.py
--------------------------------------------------------------------------------
/ner2mrc/download.md:
--------------------------------------------------------------------------------
1 | ## Download MRC-style NER Datasets
2 | ZH:
3 | - [MSRA](https://drive.google.com/file/d/1bAoSJfT1IBdpbQWSrZPjQPPbAsDGlN2D/view?usp=sharing)
4 | - [OntoNotes4](https://drive.google.com/file/d/1CRVgZJDDGuj0O1NLK5DgujQBTLKyMR-g/view?usp=sharing)
5 |
6 | EN:
7 | - [CoNLL03](https://drive.google.com/file/d/1mGO9CYkgXsV-Et-hSZpOmS0m9G8A5mau/view?usp=sharing)
8 | - [ACE2004](https://drive.google.com/file/d/1U-hGOgLmdqudsRdKIGles1-QrNJ7SSg6/view?usp=sharing)
9 | - [ACE2005](https://drive.google.com/file/d/1iodaJ92dTAjUWnkMyYm8aLEi5hj3cseY/view?usp=sharing)
10 | - [GENIA](https://drive.google.com/file/d/1oF1P8s-0MN9X1M1PlKB2c5aBtxhmoxXb/view?usp=sharing)
11 |
12 | ## Download Sequence Labeling NER Datasets (Based on BMES tagging schema)
13 | ZH:
14 | - [MSRA](https://drive.google.com/file/d/1ytrfNgh53la7bdSNyI1QURAyJ1PniPCe/view?usp=sharing)
15 | - [OntoNotes4](https://drive.google.com/file/d/1FVg3XcW1eaqlikU36df5xl7DC3UNBJDm/view?usp=sharing)
16 |
17 | EN:
18 | - [CoNLL03](https://drive.google.com/file/d/1PUH2uw6lkWrWGfl-9wOAG13lvPrvKO25/view?usp=sharing)
19 |
--------------------------------------------------------------------------------
/ner2mrc/genia2mrc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: genia2mrc.py
5 |
6 | import os
7 | import json
8 |
9 |
10 | def convert_file(input_file, output_file, tag2query_file):
11 | """
12 | Convert GENIA data to MRC format
13 | """
14 | all_data = json.load(open(input_file))
15 | tag2query = json.load(open(tag2query_file))
16 |
17 | output = []
18 | origin_count = 0
19 | new_count = 0
20 |
21 | for data in all_data:
22 | origin_count += 1
23 | context = data["context"]
24 | label2positions = data["label"]
25 | for tag_idx, (tag, query) in enumerate(tag2query.items()):
26 | positions = label2positions.get(tag, [])
27 | mrc_sample = {
28 | "context": context,
29 | "query": query,
30 | "start_position": [int(x.split(";")[0]) for x in positions],
31 | "end_position": [int(x.split(";")[1]) for x in positions],
32 | "qas_id": f"{origin_count}.{tag_idx}"
33 | }
34 | output.append(mrc_sample)
35 | new_count += 1
36 |
37 | json.dump(output, open(output_file, "w"), ensure_ascii=False, indent=2)
38 | print(f"Convert {origin_count} samples to {new_count} samples and save to {output_file}")
39 |
40 |
41 | def main():
42 | genia_raw_dir = "/mnt/mrc/genia/genia_raw"
43 | genia_mrc_dir = "/mnt/mrc/genia/genia_raw/mrc_format"
44 | tag2query_file = "queries/genia.json"
45 | os.makedirs(genia_mrc_dir, exist_ok=True)
46 | for phase in ["train", "dev", "test"]:
47 | old_file = os.path.join(genia_raw_dir, f"{phase}.genia.json")
48 | new_file = os.path.join(genia_mrc_dir, f"mrc-ner.{phase}")
49 | convert_file(old_file, new_file, tag2query_file)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/ner2mrc/msra2mrc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: msra2mrc.py
5 |
6 | import os
7 | from utils.bmes_decode import bmes_decode
8 | import json
9 |
10 |
11 | def convert_file(input_file, output_file, tag2query_file):
12 | """
13 | Convert MSRA raw data to MRC format
14 | """
15 | origin_count = 0
16 | new_count = 0
17 | tag2query = json.load(open(tag2query_file))
18 | mrc_samples = []
19 | with open(input_file) as fin:
20 | for line in fin:
21 | line = line.strip()
22 | if not line:
23 | continue
24 | origin_count += 1
25 | src, labels = line.split("\t")
26 | tags = bmes_decode(char_label_list=[(char, label) for char, label in zip(src.split(), labels.split())])
27 | for label, query in tag2query.items():
28 | mrc_samples.append(
29 | {
30 | "context": src,
31 | "start_position": [tag.begin for tag in tags if tag.tag == label],
32 | "end_position": [tag.end-1 for tag in tags if tag.tag == label],
33 | "query": query
34 | }
35 | )
36 | new_count += 1
37 |
38 | json.dump(mrc_samples, open(output_file, "w"), ensure_ascii=False, sort_keys=True, indent=2)
39 | print(f"Convert {origin_count} samples to {new_count} samples and save to {output_file}")
40 |
41 |
42 | def main():
43 | msra_raw_dir = "/mnt/mrc/zh_msra_yuxian"
44 | msra_mrc_dir = "/mnt/mrc/zh_msra_yuxian/mrc_format"
45 | tag2query_file = "queries/zh_msra.json"
46 | os.makedirs(msra_mrc_dir, exist_ok=True)
47 | for phase in ["train", "dev", "test"]:
48 | old_file = os.path.join(msra_raw_dir, f"{phase}.tsv")
49 | new_file = os.path.join(msra_mrc_dir, f"mrc-ner.{phase}")
50 | convert_file(old_file, new_file, tag2query_file)
51 |
52 |
53 | if __name__ == '__main__':
54 | main()
55 |
--------------------------------------------------------------------------------
/ner2mrc/queries/genia.json:
--------------------------------------------------------------------------------
1 | {
2 | "DNA": "deoxyribonucleic acid",
3 | "RNA": "ribonucleic acid",
4 | "cell_line": "cell line",
5 | "cell_type": "cell type",
6 | "protein": "protein entities are limited to nitrogenous organic compounds and are parts of all living organisms, as structural components of body tissues such as muscle, hair, collagen and as enzymes and antibodies."
7 | }
8 |
--------------------------------------------------------------------------------
/ner2mrc/queries/zh_msra.json:
--------------------------------------------------------------------------------
1 | {
2 | "NR": "人名和虚构的人物形象",
3 | "NS": "按照地理位置划分的国家,城市,乡镇,大洲",
4 | "NT": "组织包括公司,政府党派,学校,政府,新闻机构"
5 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==0.9.0
2 | tokenizers==0.9.3
3 | transformers==3.5.1
--------------------------------------------------------------------------------
/scripts/bert_tagger/evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: evaluate.sh
5 |
6 |
7 | REPO_PATH=/home/lixiaoya/mrc-for-flat-nested-ner
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 |
10 | OUTPUT_DIR=/data/lixiaoya/outputs/mrc_ner_baseline/0909/onto5_bert_tagger10_lr2e-5_drop_norm1.0_weight_warmup0.01_maxlen512
11 | # find best checkpoint on dev in ${OUTPUT_DIR}/train_log.txt
12 | BEST_CKPT_DEV=${OUTPUT_DIR}/epoch=25_v1.ckpt
13 | PYTORCHLIGHT_HPARAMS=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml
14 | GPU_ID=0,1
15 | MAX_LEN=220
16 |
17 | python3 ${REPO_PATH}/evaluate/tagger_ner_evaluate.py ${BEST_CKPT_DEV} ${PYTORCHLIGHT_HPARAMS} ${GPU_ID} ${MAX_LEN}
--------------------------------------------------------------------------------
/scripts/bert_tagger/inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: inference.sh
5 |
6 |
7 | REPO_PATH=/home/lixiaoya/mrc-for-flat-nested-ner
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 |
10 | DATA_SIGN=en_onto
11 | DATA_DIR=/data/lixiaoya/datasets/bmes_ner/en_ontonotes5
12 | BERT_DIR=/data/lixiaoya/models/bert_cased_large
13 | MAX_LEN=200
14 | OUTPUT_DIR=/data/lixiaoya/outputs/mrc_ner_baseline/0909/onto5_bert_tagger10_lr2e-5_drop_norm1.0_weight_warmup0.01_maxlen512
15 | MODEL_CKPT=${OUTPUT_DIR}/epoch=25_v1.ckpt
16 | HPARAMS_FILE=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml
17 | DATA_SUFFIX=.word.bmes
18 |
19 | python3 ${REPO_PATH}/inference/tagger_ner_inference.py \
20 | --data_dir ${DATA_DIR} \
21 | --bert_dir ${BERT_DIR} \
22 | --max_length ${MAX_LEN} \
23 | --model_ckpt ${MODEL_CKPT} \
24 | --hparams_file ${HPARAMS_FILE} \
25 | --dataset_sign ${DATA_SIGN} \
26 | --data_file_suffix ${DATA_SUFFIX}
--------------------------------------------------------------------------------
/scripts/bert_tagger/reproduce/conll03.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: conll03.sh
5 | # dev Span-F1: 96.61
6 | # test Span-F1: 92.74
7 |
8 | TIME=0824
9 | FILE=conll03_bert_tagger
10 | REPO_PATH=/home/lixiaoya/mrc-for-flat-nested-ner
11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
12 | DATA_DIR=/data/lixiaoya/datasets/ner/en_conll03
13 | BERT_DIR=/data/lixiaoya/models/bert_cased_large
14 |
15 | BERT_DROPOUT=0.2
16 | LR=2e-5
17 | LR_SCHEDULER=polydecay
18 | MAXLEN=256
19 | MAXNORM=1.0
20 | DATA_SUFFIX=.word.bmes
21 | TRAIN_BATCH_SIZE=8
22 | GRAD_ACC=4
23 | MAX_EPOCH=10
24 | WEIGHT_DECAY=0.02
25 | OPTIM=torch.adam
26 | DATA_SIGN=en_conll03
27 | WARMUP_PROPORTION=0.01
28 | INTER_HIDDEN=1024
29 |
30 | OUTPUT_DIR=/data/lixiaoya/outputs/mrc_ner_baseline/${TIME}/${FILE}_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP_PROPORTION}_maxlen${MAXLEN}
31 | mkdir -p ${OUTPUT_DIR}
32 |
33 |
34 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/train/bert_tagger_trainer.py \
35 | --gpus="1" \
36 | --progress_bar_refresh_rate 1 \
37 | --data_dir ${DATA_DIR} \
38 | --bert_config_dir ${BERT_DIR} \
39 | --max_length ${MAXLEN} \
40 | --train_batch_size ${TRAIN_BATCH_SIZE} \
41 | --precision=16 \
42 | --lr_scheduler ${LR_SCHEDULER} \
43 | --lr ${LR} \
44 | --val_check_interval 0.25 \
45 | --accumulate_grad_batches ${GRAD_ACC} \
46 | --output_dir ${OUTPUT_DIR} \
47 | --max_epochs ${MAX_EPOCH} \
48 | --warmup_proportion ${WARMUP_PROPORTION} \
49 | --max_length ${MAXLEN} \
50 | --gradient_clip_val ${MAXNORM} \
51 | --weight_decay ${WEIGHT_DECAY} \
52 | --data_file_suffix ${DATA_SUFFIX} \
53 | --optimizer ${OPTIM} \
54 | --data_sign ${DATA_SIGN} \
55 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
56 |
57 |
--------------------------------------------------------------------------------
/scripts/bert_tagger/reproduce/msra.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: msra.sh
5 | # dev Span-F1: 94.02
6 | # test Span-F1: 95.15
7 |
8 | TIME=0826
9 | FILE=msra_bert_tagger
10 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
12 | DATA_DIR=/userhome/xiaoya/dataset/tagger_ner_datasets/msra
13 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
14 |
15 |
16 | BERT_DROPOUT=0.2
17 | LR=2e-5
18 | LR_SCHEDULER=polydecay
19 | MAXLEN=256
20 | MAXNORM=1.0
21 | DATA_SUFFIX=.char.bmes
22 | TRAIN_BATCH_SIZE=8
23 | GRAD_ACC=4
24 | MAX_EPOCH=10
25 | WEIGHT_DECAY=0.02
26 | OPTIM=torch.adam
27 | DATA_SIGN=zh_msra
28 | WARMUP_PROPORTION=0.02
29 | INTER_HIDDEN=768
30 |
31 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner_baseline/${TIME}/${FILE}_chinese_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
32 | mkdir -p ${OUTPUT_DIR}
33 |
34 |
35 | CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/train/bert_tagger_trainer.py \
36 | --gpus="1" \
37 | --progress_bar_refresh_rate 1 \
38 | --data_dir ${DATA_DIR} \
39 | --bert_config_dir ${BERT_DIR} \
40 | --max_length ${MAXLEN} \
41 | --train_batch_size ${TRAIN_BATCH_SIZE} \
42 | --precision=16 \
43 | --lr_scheduler ${LR_SCHEDULER} \
44 | --lr ${LR} \
45 | --val_check_interval 0.25 \
46 | --accumulate_grad_batches ${GRAD_ACC} \
47 | --output_dir ${OUTPUT_DIR} \
48 | --max_epochs ${MAX_EPOCH} \
49 | --warmup_proportion ${WARMUP_PROPORTION} \
50 | --max_length ${MAXLEN} \
51 | --gradient_clip_val ${MAXNORM} \
52 | --weight_decay ${WEIGHT_DECAY} \
53 | --data_file_suffix ${DATA_SUFFIX} \
54 | --optimizer ${OPTIM} \
55 | --data_sign ${DATA_SIGN} \
56 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \
57 | --chinese
--------------------------------------------------------------------------------
/scripts/bert_tagger/reproduce/onto4.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: onto4.sh
5 | # dev Span-F1: 79.18
6 | # test Span-F1: 80.35
7 |
8 | TIME=0826
9 | FILE=onto_bert_tagger
10 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
11 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
12 | DATA_DIR=/userhome/xiaoya/dataset/tagger_ner_datasets/zhontonotes4
13 | BERT_DIR=/userhome/xiaoya/bert/chinese_bert
14 |
15 | BERT_DROPOUT=0.2
16 | LR=2e-5
17 | LR_SCHEDULER=polydecay
18 | MAXLEN=256
19 | MAXNORM=1.0
20 | DATA_SUFFIX=.char.bmes
21 | TRAIN_BATCH_SIZE=8
22 | GRAD_ACC=4
23 | MAX_EPOCH=10
24 | WEIGHT_DECAY=0.02
25 | OPTIM=torch.adam
26 | DATA_SIGN=zh_onto
27 | WARMUP_PROPORTION=0.02
28 | INTER_HIDDEN=768
29 |
30 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner_baseline/${TIME}/${FILE}_chinese_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
31 | mkdir -p ${OUTPUT_DIR}
32 |
33 |
34 | CUDA_VISIBLE_DEVICES=1 python3 ${REPO_PATH}/train/bert_tagger_trainer.py \
35 | --gpus="1" \
36 | --progress_bar_refresh_rate 1 \
37 | --data_dir ${DATA_DIR} \
38 | --bert_config_dir ${BERT_DIR} \
39 | --max_length ${MAXLEN} \
40 | --train_batch_size ${TRAIN_BATCH_SIZE} \
41 | --precision=32 \
42 | --lr_scheduler ${LR_SCHEDULER} \
43 | --lr ${LR} \
44 | --val_check_interval 0.25 \
45 | --accumulate_grad_batches ${GRAD_ACC} \
46 | --output_dir ${OUTPUT_DIR} \
47 | --max_epochs ${MAX_EPOCH} \
48 | --warmup_proportion ${WARMUP_PROPORTION} \
49 | --max_length ${MAXLEN} \
50 | --gradient_clip_val ${MAXNORM} \
51 | --weight_decay ${WEIGHT_DECAY} \
52 | --data_file_suffix ${DATA_SUFFIX} \
53 | --optimizer ${OPTIM} \
54 | --data_sign ${DATA_SIGN} \
55 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \
56 | --chinese
--------------------------------------------------------------------------------
/scripts/mrc_ner/evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: eval.sh
5 |
6 | REPO_PATH=/data/xiaoya/workspace/mrc-for-flat-nested-ner
7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
8 |
9 | OUTPUT_DIR=/data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100
10 | # find best checkpoint on dev in ${OUTPUT_DIR}/train_log.txt
11 | BEST_CKPT_DEV=${OUTPUT_DIR}/epoch=8.ckpt
12 | PYTORCHLIGHT_HPARAMS=${OUTPUT_DIR}/lightning_logs/version_0/hparams.yaml
13 | GPU_ID=0,1
14 |
15 | python3 ${REPO_PATH}/evaluate/mrc_ner_evaluate.py ${BEST_CKPT_DEV} ${PYTORCHLIGHT_HPARAMS} ${GPU_ID}
--------------------------------------------------------------------------------
/scripts/mrc_ner/flat_inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: flat_inference.sh
5 | #
6 |
7 | REPO_PATH=/data/xiaoya/workspace/mrc-for-flat-nested-ner-github
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 |
10 | DATA_SIGN=conll03
11 | DATA_DIR=/data/xiaoya/datasets/mrc_ner_datasets/en_conll03_truecase_sent
12 | BERT_DIR=/data/xiaoya/models/uncased_L-12_H-768_A-12
13 | MAX_LEN=180
14 | MODEL_CKPT=/data/xiaoya/outputs/mrc_ner/conll03/large_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen180/epoch=1_v7.ckpt
15 | HPARAMS_FILE=/data/xiaoya/outputs/mrc_ner/conll03/large_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen180/lightning_logs/version_0/hparams.yaml
16 |
17 |
18 | python3 ${REPO_PATH}/inference/mrc_ner_inference.py \
19 | --data_dir ${DATA_DIR} \
20 | --bert_dir ${BERT_DIR} \
21 | --max_length ${MAX_LEN} \
22 | --model_ckpt ${MODEL_CKPT} \
23 | --hparams_file ${HPARAMS_FILE} \
24 | --flat_ner \
25 | --dataset_sign ${DATA_SIGN}
--------------------------------------------------------------------------------
/scripts/mrc_ner/nested_inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: nested_inference.sh
5 | #
6 |
7 | REPO_PATH=/data/xiaoya/workspace/mrc-for-flat-nested-ner-github
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 |
10 | DATA_SIGN=ace04
11 | DATA_DIR=/data/xiaoya/datasets/mrc-for-flat-nested-ner/ace2004
12 | BERT_DIR=/data/xiaoya/models/uncased_L-12_H-768_A-12
13 | MAX_LEN=100
14 | MODEL_CKPT=/data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/epoch=0.ckpt
15 | HPARAMS_FILE=/data/xiaoya/outputs/mrc_ner/ace2004/debug_lr3e-5_drop0.3_norm1.0_weight0.1_warmup0_maxlen100/lightning_logs/version_3/hparams.yaml
16 |
17 |
18 | python3 ${REPO_PATH}/inference/mrc_ner_inference.py \
19 | --data_dir ${DATA_DIR} \
20 | --bert_dir ${BERT_DIR} \
21 | --max_length ${MAX_LEN} \
22 | --model_ckpt ${MODEL_CKPT} \
23 | --hparams_file ${HPARAMS_FILE} \
24 | --dataset_sign ${DATA_SIGN}
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/ace04.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: ace04.sh
5 |
6 | TIME=0910
7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 | DATA_DIR=/userhome/xiaoya/dataset/ace2004
10 | BERT_DIR=/userhome/xiaoya/bert/bert_uncased_large
11 |
12 | BERT_DROPOUT=0.1
13 | MRC_DROPOUT=0.3
14 | LR=3e-5
15 | SPAN_WEIGHT=0.1
16 | WARMUP=0
17 | MAXLEN=128
18 | MAXNORM=1.0
19 | INTER_HIDDEN=2048
20 |
21 | BATCH_SIZE=4
22 | PREC=16
23 | VAL_CKPT=0.25
24 | ACC_GRAD=2
25 | MAX_EPOCH=20
26 | SPAN_CANDI=pred_and_gold
27 | PROGRESS_BAR=1
28 |
29 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/${TIME}/ace2004/large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
30 | mkdir -p ${OUTPUT_DIR}
31 |
32 | CUDA_VISIBLE_DEVICES=0,1 python ${REPO_PATH}/train/mrc_ner_trainer.py \
33 | --gpus="2" \
34 | --distributed_backend=ddp \
35 | --workers 0 \
36 | --data_dir ${DATA_DIR} \
37 | --bert_config_dir ${BERT_DIR} \
38 | --max_length ${MAXLEN} \
39 | --batch_size ${BATCH_SIZE} \
40 | --precision=${PREC} \
41 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
42 | --lr ${LR} \
43 | --val_check_interval ${VAL_CKPT} \
44 | --accumulate_grad_batches ${ACC_GRAD} \
45 | --default_root_dir ${OUTPUT_DIR} \
46 | --mrc_dropout ${MRC_DROPOUT}\
47 | --bert_dropout ${BERT_DROPOUT} \
48 | --max_epochs ${MAX_EPOCH} \
49 | --span_loss_candidates ${SPAN_CANDI} \
50 | --weight_span ${SPAN_WEIGHT} \
51 | --warmup_steps ${WARMUP} \
52 | --gradient_clip_val ${MAXNORM} \
53 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
54 |
55 |
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/ace05.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: ace05.sh
5 |
6 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
8 |
9 | DATA_DIR=/userhome/xiaoya/dataset/ace2005
10 | BERT_DIR=/userhome/xiaoya/bert/bert_uncased_large
11 |
12 | BERT_DROPOUT=0.1
13 | MRC_DROPOUT=0.3
14 | LR=2e-5
15 | SPAN_WEIGHT=0.1
16 | WARMUP=0
17 | MAXLEN=128
18 | MAXNORM=1.0
19 | INTER_HIDDEN=2048
20 |
21 | BATCH_SIZE=8
22 | PREC=16
23 | VAL_CKPT=0.25
24 | ACC_GRAD=1
25 | MAX_EPOCH=20
26 | SPAN_CANDI=pred_and_gold
27 | PROGRESS_BAR=1
28 | OPTIM=adamw
29 |
30 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/ace2005/warmup${WARMUP}lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
31 | mkdir -p ${OUTPUT_DIR}
32 |
33 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \
34 | --gpus="4" \
35 | --distributed_backend=ddp \
36 | --workers 0 \
37 | --data_dir ${DATA_DIR} \
38 | --bert_config_dir ${BERT_DIR} \
39 | --max_length ${MAXLEN} \
40 | --batch_size ${BATCH_SIZE} \
41 | --precision=${PREC} \
42 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
43 | --lr ${LR} \
44 | --val_check_interval ${VAL_CKPT} \
45 | --accumulate_grad_batches ${ACC_GRAD} \
46 | --default_root_dir ${OUTPUT_DIR} \
47 | --mrc_dropout ${MRC_DROPOUT} \
48 | --bert_dropout ${BERT_DROPOUT} \
49 | --max_epochs ${MAX_EPOCH} \
50 | --span_loss_candidates ${SPAN_CANDI} \
51 | --weight_span ${SPAN_WEIGHT} \
52 | --warmup_steps ${WARMUP} \
53 | --gradient_clip_val ${MAXNORM} \
54 | --optimizer ${OPTIM} \
55 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/conll03.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | TIME=0901
6 | FILE=conll03_cased_large
7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 |
10 | DATA_DIR=/userhome/xiaoya/dataset/en_conll03
11 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large
12 | OUTPUT_BASE=/userhome/xiaoya/outputs
13 |
14 | BATCH=10
15 | GRAD_ACC=4
16 | BERT_DROPOUT=0.1
17 | MRC_DROPOUT=0.3
18 | LR=3e-5
19 | LR_MINI=3e-7
20 | LR_SCHEDULER=polydecay
21 | SPAN_WEIGHT=0.1
22 | WARMUP=0
23 | MAX_LEN=200
24 | MAX_NORM=1.0
25 | MAX_EPOCH=20
26 | INTER_HIDDEN=2048
27 | WEIGHT_DECAY=0.01
28 | OPTIM=torch.adam
29 | VAL_CHECK=0.2
30 | PREC=16
31 | SPAN_CAND=pred_and_gold
32 |
33 |
34 | OUTPUT_DIR=${OUTPUT_BASE}/mrc_ner/${TIME}/${FILE}_cased_large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
35 | mkdir -p ${OUTPUT_DIR}
36 |
37 |
38 | CUDA_VISIBLE_DEVICES=0,1 python ${REPO_PATH}/train/mrc_ner_trainer.py \
39 | --data_dir ${DATA_DIR} \
40 | --bert_config_dir ${BERT_DIR} \
41 | --max_length ${MAX_LEN} \
42 | --batch_size ${BATCH} \
43 | --gpus="2" \
44 | --precision=${PREC} \
45 | --progress_bar_refresh_rate 1 \
46 | --lr ${LR} \
47 | --val_check_interval ${VAL_CHECK} \
48 | --accumulate_grad_batches ${GRAD_ACC} \
49 | --default_root_dir ${OUTPUT_DIR} \
50 | --mrc_dropout ${MRC_DROPOUT} \
51 | --bert_dropout ${BERT_DROPOUT} \
52 | --max_epochs ${MAX_EPOCH} \
53 | --span_loss_candidates ${SPAN_CAND} \
54 | --weight_span ${SPAN_WEIGHT} \
55 | --warmup_steps ${WARMUP} \
56 | --distributed_backend=ddp \
57 | --max_length ${MAX_LEN} \
58 | --gradient_clip_val ${MAX_NORM} \
59 | --weight_decay ${WEIGHT_DECAY} \
60 | --optimizer ${OPTIM} \
61 | --lr_scheduler ${LR_SCHEDULER} \
62 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \
63 | --flat \
64 | --lr_mini ${LR_MINI}
65 |
66 |
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/genia.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: genia.sh
5 |
6 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
8 |
9 | DATA_DIR=/userhome/xiaoya/dataset/genia
10 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large
11 | BERT_DROPOUT=0.2
12 | MRC_DROPOUT=0.2
13 | LR=2e-5
14 | SPAN_WEIGHT=0.1
15 | WARMUP=0
16 | MAXLEN=180
17 | MAXNORM=1.0
18 | INTER_HIDDEN=2048
19 |
20 | BATCH_SIZE=8
21 | PREC=16
22 | VAL_CKPT=0.25
23 | ACC_GRAD=4
24 | MAX_EPOCH=20
25 | SPAN_CANDI=pred_and_gold
26 | PROGRESS_BAR=1
27 | WEIGHT_DECAY=0.002
28 |
29 | OUTPUT_DIR=/userhome/xiaoya/outputs/github_mrc/genia/large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_bsz32_hard_span_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
30 | mkdir -p ${OUTPUT_DIR}
31 |
32 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \
33 | --gpus="4" \
34 | --distributed_backend=ddp \
35 | --workers 0 \
36 | --data_dir ${DATA_DIR} \
37 | --bert_config_dir ${BERT_DIR} \
38 | --max_length ${MAXLEN} \
39 | --batch_size ${BATCH_SIZE} \
40 | --precision=${PREC} \
41 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
42 | --lr ${LR} \
43 | --val_check_interval ${VAL_CKPT} \
44 | --accumulate_grad_batches ${ACC_GRAD} \
45 | --default_root_dir ${OUTPUT_DIR} \
46 | --mrc_dropout ${MRC_DROPOUT} \
47 | --bert_dropout ${BERT_DROPOUT} \
48 | --max_epochs ${MAX_EPOCH} \
49 | --span_loss_candidates ${SPAN_CANDI} \
50 | --weight_span ${SPAN_WEIGHT} \
51 | --warmup_steps ${WARMUP} \
52 | --gradient_clip_val ${MAXNORM} \
53 | --weight_decay ${WEIGHT_DECAY} \
54 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
55 |
56 |
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/kbp17.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: genia.sh
5 |
6 | TIME=0901
7 | FILE=kbp17_bert_large
8 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
9 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
10 |
11 | DATA_DIR=/userhome/xiaoya/dataset/kbp17
12 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large
13 | BERT_DROPOUT=0.2
14 | MRC_DROPOUT=0.2
15 | LR=2e-5
16 | SPAN_WEIGHT=0.1
17 | WARMUP=0
18 | MAXLEN=180
19 | MAXNORM=1.0
20 | INTER_HIDDEN=2048
21 |
22 | BATCH_SIZE=4
23 | PREC=16
24 | VAL_CKPT=0.25
25 | ACC_GRAD=4
26 | MAX_EPOCH=6
27 | SPAN_CANDI=pred_and_gold
28 | PROGRESS_BAR=1
29 | WEIGHT_DECAY=0.002
30 |
31 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/${TIME}/${FILE}_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
32 | mkdir -p ${OUTPUT_DIR}
33 |
34 |
35 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \
36 | --gpus="4" \
37 | --distributed_backend=ddp \
38 | --workers 0 \
39 | --data_dir ${DATA_DIR} \
40 | --bert_config_dir ${BERT_DIR} \
41 | --max_length ${MAXLEN} \
42 | --batch_size ${BATCH_SIZE} \
43 | --precision=${PREC} \
44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
45 | --lr ${LR} \
46 | --val_check_interval ${VAL_CKPT} \
47 | --accumulate_grad_batches ${ACC_GRAD} \
48 | --default_root_dir ${OUTPUT_DIR} \
49 | --mrc_dropout ${MRC_DROPOUT} \
50 | --bert_dropout ${BERT_DROPOUT} \
51 | --max_epochs ${MAX_EPOCH} \
52 | --span_loss_candidates ${SPAN_CANDI} \
53 | --weight_span ${SPAN_WEIGHT} \
54 | --warmup_steps ${WARMUP} \
55 | --gradient_clip_val ${MAXNORM} \
56 | --weight_decay ${WEIGHT_DECAY} \
57 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
58 |
59 |
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/msra.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: msra.sh
5 |
6 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
7 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
8 | export TOKENIZERS_PARALLELISM=false
9 |
10 | DATA_DIR=/mnt/mrc/zh_msra
11 | BERT_DIR=/mnt/mrc/chinese_roberta_wwm_large_ext_pytorch
12 | SPAN_WEIGHT=0.1
13 | DROPOUT=0.2
14 | LR=8e-6
15 | MAXLEN=128
16 | INTER_HIDDEN=1536
17 |
18 | BATCH_SIZE=4
19 | PREC=16
20 | VAL_CKPT=0.25
21 | ACC_GRAD=1
22 | MAX_EPOCH=20
23 | SPAN_CANDI=pred_and_gold
24 | PROGRESS_BAR=1
25 |
26 | OUTPUT_DIR=/mnt/mrc/train_logs/zh_msra/zh_msra_bertlarge_lr${LR}20200913_dropout${DROPOUT}_maxlen${MAXLEN}
27 |
28 | mkdir -p ${OUTPUT_DIR}
29 |
30 | CUDA_VISIBLE_DEVICES=0,1 python ${REPO_PATH}/train/mrc_ner_trainer.py \
31 | --gpus="2" \
32 | --distributed_backend=ddp \
33 | --data_dir ${DATA_DIR} \
34 | --bert_config_dir ${BERT_DIR} \
35 | --max_length ${MAXLEN} \
36 | --batch_size ${BATCH_SIZE} \
37 | --precision=${PREC} \
38 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
39 | --lr ${LR} \
40 | --val_check_interval ${VAL_CKPT} \
41 | --accumulate_grad_batches ${ACC_GRAD} \
42 | --default_root_dir ${OUTPUT_DIR} \
43 | --mrc_dropout ${DROPOUT} \
44 | --max_epochs ${MAX_EPOCH} \
45 | --weight_span ${SPAN_WEIGHT} \
46 | --span_loss_candidates ${SPAN_CANDI} \
47 | --chinese \
48 | --workers 0 \
49 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
50 |
51 |
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/onto4.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: onto4.sh
5 | # desc: for chinese ontonotes 04
6 |
7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 | export TOKENIZERS_PARALLELISM=false
10 |
11 | DATA_DIR=/userhome/yuxian/data/zh_onto4
12 | BERT_DIR=/userhome/yuxian/data/chinese_roberta_wwm_large_ext_pytorch
13 |
14 | MAXLENGTH=128
15 | WEIGHT_SPAN=0.1
16 | LR=1e-5
17 | OPTIMIZER=adamw
18 | INTER_HIDDEN=1536
19 | OUTPUT_DIR=/userhome/yuxian/train_logs/zh_onto/zh_onto_${OPTIMIZER}_lr${lr}_maxlen${MAXLENGTH}_spanw${WEIGHT_SPAN}
20 | mkdir -p ${OUTPUT_DIR}
21 |
22 | BATCH_SIZE=8
23 | PREC=16
24 | VAL_CKPT=0.25
25 | ACC_GRAD=1
26 | MAX_EPOCH=10
27 | SPAN_CANDI=pred_and_gold
28 | PROGRESS_BAR=1
29 |
30 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ${REPO_PATH}/train/mrc_ner_trainer.py \
31 | --data_dir ${DATA_DIR} \
32 | --bert_config_dir ${BERT_DIR} \
33 | --max_length ${MAXLENGTH} \
34 | --batch_size ${BATCH_SIZE} \
35 | --gpus="4" \
36 | --precision=${PREC} \
37 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
38 | --lr ${LR} \
39 | --workers 0 \
40 | --distributed_backend=ddp \
41 | --val_check_interval ${VAL_CKPT} \
42 | --accumulate_grad_batches ${ACC_GRAD} \
43 | --default_root_dir ${OUTPUT_DIR} \
44 | --max_epochs ${MAX_EPOCH} \
45 | --span_loss_candidates ${SPAN_CANDI} \
46 | --weight_span ${WEIGHT_SPAN} \
47 | --mrc_dropout 0.3 \
48 | --chinese \
49 | --warmup_steps 5000 \
50 | --gradient_clip_val 5.0 \
51 | --final_div_factor 20 \
52 | --classifier_intermediate_hidden_size ${INTER_HIDDEN}
53 |
54 |
--------------------------------------------------------------------------------
/scripts/mrc_ner/reproduce/onto5.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | TIME=0901
6 | FILE=onto5_mrc_cased_large
7 | REPO_PATH=/userhome/xiaoya/mrc-for-flat-nested-ner
8 | export PYTHONPATH="$PYTHONPATH:$REPO_PATH"
9 | DATA_DIR=/userhome/xiaoya/dataset/new_mrc_ner/new_en_onto5
10 | BERT_DIR=/userhome/xiaoya/bert/bert_cased_large
11 |
12 | BERT_DROPOUT=0.2
13 | MRC_DROPOUT=0.2
14 | LR=2e-5
15 | LR_MINI=3e-7
16 | LR_SCHEDULER=polydecay
17 | SPAN_WEIGHT=0.1
18 | WARMUP=200
19 | MAXLEN=210
20 | MAXNORM=1.0
21 | INTER_HIDDEN=2048
22 |
23 | BATCH_SIZE=4
24 | PREC=16
25 | VAL_CKPT=0.2
26 | ACC_GRAD=5
27 | MAX_EPOCH=10
28 | SPAN_CANDI=pred_and_gold
29 | PROGRESS_BAR=1
30 | WEIGHT_DECAY=0.01
31 | OPTIM=torch.adam
32 |
33 | OUTPUT_DIR=/userhome/xiaoya/outputs/mrc_ner/${TIME}/${FILE}_cased_large_lr${LR}_drop${MRC_DROPOUT}_norm${MAXNORM}_weight${SPAN_WEIGHT}_warmup${WARMUP}_maxlen${MAXLEN}
34 | mkdir -p ${OUTPUT_DIR}
35 |
36 |
37 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python ${REPO_PATH}/train/mrc_ner_trainer.py \
38 | --data_dir ${DATA_DIR} \
39 | --bert_config_dir ${BERT_DIR} \
40 | --max_length ${MAXLEN} \
41 | --batch_size ${BATCH_SIZE} \
42 | --gpus="4" \
43 | --precision=${PREC} \
44 | --progress_bar_refresh_rate ${PROGRESS_BAR} \
45 | --lr ${LR} \
46 | --distributed_backend=ddp \
47 | --val_check_interval ${VAL_CKPT} \
48 | --accumulate_grad_batches ${ACC_GRAD} \
49 | --default_root_dir ${OUTPUT_DIR} \
50 | --mrc_dropout ${MRC_DROPOUT} \
51 | --bert_dropout ${BERT_DROPOUT} \
52 | --max_epochs ${MAX_EPOCH} \
53 | --span_loss_candidates ${SPAN_CANDI} \
54 | --weight_span ${SPAN_WEIGHT} \
55 | --warmup_steps ${WARMUP} \
56 | --max_length ${MAXLEN} \
57 | --gradient_clip_val ${MAXNORM} \
58 | --weight_decay ${WEIGHT_DECAY} \
59 | --flat \
60 | --optimizer ${OPTIM} \
61 | --lr_scheduler ${LR_SCHEDULER} \
62 | --classifier_intermediate_hidden_size ${INTER_HIDDEN} \
63 | --lr_mini ${LR_MINI}
64 |
65 |
66 |
--------------------------------------------------------------------------------
/tests/assert_correct_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: assert_correct_dataset.py
5 |
6 | import os
7 | import sys
8 |
9 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2])
10 | print(REPO_PATH)
11 | if REPO_PATH not in sys.path:
12 | sys.path.insert(0, REPO_PATH)
13 |
14 | import json
15 | from datasets.tagger_ner_dataset import load_data_in_conll
16 | from metrics.functional.tagger_span_f1 import get_entity_from_bmes_lst
17 |
18 |
19 | def count_entity_with_mrc_ner_format(data_path):
20 | """
21 | mrc_data_example:
22 | {
23 | "context": "The chemiluminescent ( CL ) response of interferon - gamma - treated U937 ( IFN - U937 ) cells to sensitized target cells has been used to detect red cell , platelet and granulocyte antibodies .",
24 | "end_position": [
25 | 16,
26 | 18
27 | ],
28 | "entity_label": "cell_line",
29 | "impossible": false,
30 | "qas_id": "532.1",
31 | "query": "cell line",
32 | "span_position": [
33 | "14;16",
34 | "7;18"
35 | ],
36 | "start_position": [
37 | 14,
38 | 7
39 | ]
40 | }
41 | """
42 | entity_counter = {}
43 | with open(data_path, encoding="utf-8") as f:
44 | data_lst = json.load(f)
45 |
46 | for data_item in data_lst:
47 | tmp_entity_type = data_item["entity_label"]
48 | if len(data_item["end_position"]) != 0:
49 | if tmp_entity_type not in entity_counter.keys():
50 | entity_counter[tmp_entity_type] = len(data_item["end_position"])
51 | else:
52 | entity_counter[tmp_entity_type] += len(data_item["end_position"])
53 |
54 | print(f"mrc: the number of sentences is : {len(data_lst)/len(entity_counter.keys())}")
55 | print("UNDER MRC-NER format -> ")
56 | print(entity_counter)
57 |
58 |
59 | def count_entity_with_sequence_ner_format(data_path, is_nested=False):
60 | entity_counter = {}
61 | if not is_nested:
62 | data_lst = load_data_in_conll(data_path)
63 | print(f"bmes: the number of sentences is : {len(data_lst)}")
64 | label_lst = [label_item[1] for label_item in data_lst]
65 | for label_item in label_lst:
66 | tmp_entity_lst = get_entity_from_bmes_lst(label_item)
67 | for tmp_entity in tmp_entity_lst:
68 | tmp_entity_type = tmp_entity[tmp_entity.index("]")+1:]
69 | if tmp_entity_type not in entity_counter.keys():
70 | entity_counter[tmp_entity_type] = 1
71 | else:
72 | entity_counter[tmp_entity_type] += 1
73 | print("UNDER SEQ format ->")
74 | print(entity_counter)
75 | else:
76 | # genia, ace04, ace05
77 | pass
78 |
79 |
80 | def main(mrc_data_dir, seq_data_dir, seq_data_suffix="char.bmes", is_nested=False):
81 | for data_type in ["train", "dev", "test"]:
82 | mrc_data_path = os.path.join(mrc_data_dir, f"mrc-ner.{data_type}")
83 | seq_data_path = os.path.join(seq_data_dir, f"{data_type}.{seq_data_suffix}")
84 |
85 | print("$"*10)
86 | print(f"{data_type}")
87 | print("$"*10)
88 | count_entity_with_mrc_ner_format(mrc_data_path)
89 | count_entity_with_sequence_ner_format(seq_data_path, is_nested=is_nested)
90 | print("\n")
91 |
92 |
93 | if __name__ == "__main__":
94 | mrc_data_dir = "/data/lixiaoya/datasets/mrc_ner/en_conll03"
95 | seq_data_dir = "/data/lixiaoya/datasets/bmes_ner/en_conll03"
96 | seq_data_suffix = "word.bmes"
97 | is_nested = False
98 | main(mrc_data_dir, seq_data_dir, seq_data_suffix=seq_data_suffix, is_nested=is_nested)
--------------------------------------------------------------------------------
/tests/bert_tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_tokenizer.py
5 |
6 | from transformers import AutoTokenizer
7 |
8 |
9 | def tokenize_word(model_path, do_lower_case=False):
10 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, do_lower_case=do_lower_case)
11 | context = "EUROPEAN"
12 | cut_tokens = tokenizer.encode(context, add_special_tokens=False, return_token_type_ids=None)
13 | print(cut_tokens)
14 | print(type(cut_tokens))
15 | print(type(cut_tokens[0]))
16 |
17 |
18 | if __name__ == "__main__":
19 | model_path = "/data/xiaoya/models/bert_cased_large"
20 | tokenize_word(model_path)
21 |
22 |
--------------------------------------------------------------------------------
/tests/collect_entity_labels.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: collect_entity_labels.py
5 |
6 | import os
7 | import sys
8 |
9 | REPO_PATH = "/".join(os.path.realpath(__file__).split("/")[:-2])
10 | print(REPO_PATH)
11 | if REPO_PATH not in sys.path:
12 | sys.path.insert(0, REPO_PATH)
13 |
14 | from datasets.tagger_ner_dataset import load_data_in_conll, get_labels
15 |
16 |
17 | def main(data_dir, data_sign, datafile_suffix=".word.bmes"):
18 | label_collection_set = set()
19 | for data_type in ["train", "dev", "test"]:
20 | label_lst = []
21 | data_path = os.path.join(data_dir, f"{data_type}{datafile_suffix}")
22 | datasets = load_data_in_conll(data_path)
23 | for data_item in datasets:
24 | label_lst.extend(data_item[1])
25 |
26 | label_collection_set.update(set(label_lst))
27 |
28 | print("sum the type of labels: ")
29 | print(len(label_collection_set))
30 | print(label_collection_set)
31 |
32 | print("%"*10)
33 | set_labels = get_labels(data_sign)
34 | print(len(set_labels))
35 |
36 |
37 | if __name__ == "__main__":
38 | data_dir = "/data/xiaoya/datasets/ner/ontonotes5"
39 | data_sign = "en_onto"
40 | main(data_dir, data_sign)
41 |
--------------------------------------------------------------------------------
/tests/count_mrc_max_length.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: count_mrc_max_length.py
5 |
6 | import os
7 | import sys
8 |
9 | REPO_PATH="/".join(os.path.realpath(__file__).split("/")[:-2])
10 | print(REPO_PATH)
11 | if REPO_PATH not in sys.path:
12 | sys.path.insert(0, REPO_PATH)
13 |
14 | import os
15 | from datasets.collate_functions import collate_to_max_length
16 | from torch.utils.data import DataLoader
17 | from tokenizers import BertWordPieceTokenizer
18 | from datasets.mrc_ner_dataset import MRCNERDataset
19 |
20 |
21 |
22 | def main():
23 | # en datasets
24 | bert_path = "/data/xiaoya/models/bert_cased_large"
25 | json_path = "/data/xiaoya/datasets/mrc_ner_datasets/en_conll03_truecase_sent/mrc-ner.train"
26 | is_chinese = False
27 | # [test]
28 | # max length is 227
29 | # min length is 12
30 | # avg length is 40.45264986967854
31 | # [dev]
32 | # max length is 212
33 | # min length is 12
34 | # avg length is 43.42584615384615
35 | # [train]
36 | # max length is 201
37 | # min length is 12
38 | # avg length is 41.733423545331526
39 |
40 | vocab_file = os.path.join(bert_path, "vocab.txt")
41 | tokenizer = BertWordPieceTokenizer(vocab_file)
42 | dataset = MRCNERDataset(json_path=json_path, tokenizer=tokenizer,
43 | is_chinese=is_chinese, max_length=10000)
44 |
45 | dataloader = DataLoader(dataset, batch_size=1,
46 | collate_fn=collate_to_max_length)
47 |
48 | length_lst = []
49 | for batch in dataloader:
50 | for tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx in zip(*batch):
51 | tokens = tokens.tolist()
52 | length_lst.append(len(tokens))
53 | print(f"max length is {max(length_lst)}")
54 | print(f"min length is {min(length_lst)}")
55 | print(f"avg length is {sum(length_lst)/len(length_lst)}")
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
--------------------------------------------------------------------------------
/tests/count_sequence_max_length.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # author: xiaoya li
5 | # file: count_length_autotokenizer.py
6 |
7 | import os
8 | import sys
9 |
10 | repo_path = "/".join(os.path.realpath(__file__).split("/")[:-2])
11 | print(repo_path)
12 | if repo_path not in sys.path:
13 | sys.path.insert(0, repo_path)
14 |
15 |
16 | from transformers import AutoTokenizer
17 | from datasets.tagger_ner_dataset import TaggerNERDataset
18 |
19 |
20 | class OntoNotesDataConfig:
21 | def __init__(self):
22 | self.data_dir = "/data/xiaoya/datasets/ner/zhontonotes4"
23 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12"
24 | self.do_lower_case = False
25 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True)
26 | # BertWordPieceTokenizer(os.path.join(self.model_path, "vocab.txt"), lowercase=self.do_lower_case)
27 | self.max_length = 512
28 | self.is_chinese = True
29 | self.threshold = 275
30 | self.data_sign = "zh_onto"
31 | self.data_file_suffix = "char.bmes"
32 |
33 | class ChineseMSRADataConfig:
34 | def __init__(self):
35 | self.data_dir = "/data/xiaoya/datasets/ner/msra"
36 | self.model_path = "/data/xiaoya/pretrain_lm/chinese_L-12_H-768_A-12"
37 | self.do_lower_case = False
38 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, tokenize_chinese_chars=True)
39 | self.max_length = 512
40 | self.is_chinese = True
41 | self.threshold = 275
42 | self.data_sign = "zh_msra"
43 | self.data_file_suffix = "char.bmes"
44 |
45 |
46 | class EnglishOntoDataConfig:
47 | def __init__(self):
48 | self.data_dir = "/data/xiaoya/datasets/ner/ontonotes5"
49 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12"
50 | self.do_lower_case = False
51 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
52 | self.max_length = 512
53 | self.is_chinese = False
54 | self.threshold = 256
55 | self.data_sign = "en_onto"
56 | self.data_file_suffix = "word.bmes"
57 |
58 |
59 | class EnglishCoNLLDataConfig:
60 | def __init__(self):
61 | self.data_dir = "/data/xiaoya/datasets/ner/conll03_truecase_bmes"
62 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12"
63 | if "uncased" in self.model_path:
64 | self.do_lower_case = True
65 | else:
66 | self.do_lower_case = False
67 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False, do_lower_case=self.do_lower_case)
68 | self.max_length = 512
69 | self.is_chinese = False
70 | self.threshold = 275
71 | self.data_sign = "en_conll03"
72 | self.data_file_suffix = "word.bmes"
73 |
74 | class EnglishCoNLL03DocDataConfig:
75 | def __init__(self):
76 | self.data_dir = "/data/xiaoya/datasets/mrc_ner/en_conll03_doc"
77 | self.model_path = "/data/xiaoya/pretrain_lm/cased_L-12_H-768_A-12"
78 | self.do_lower_case = False
79 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
80 | self.max_length = 512
81 | self.is_chinese = False
82 | self.threshold = 384
83 | self.data_sign = "en_conll03"
84 | self.data_file_suffix = "word.bmes"
85 |
86 | def count_max_length(data_sign):
87 | if data_sign == "zh_onto":
88 | data_config = OntoNotesDataConfig()
89 | elif data_sign == "zh_msra":
90 | data_config = ChineseMSRADataConfig()
91 | elif data_sign == "en_onto":
92 | data_config = EnglishOntoDataConfig()
93 | elif data_sign == "en_conll03":
94 | data_config = EnglishCoNLLDataConfig()
95 | elif data_sign == "en_conll03_doc":
96 | data_config = EnglishCoNLL03DocDataConfig()
97 | else:
98 | raise ValueError
99 | for prefix in ["test", "train", "dev"]:
100 | print("=*"*15)
101 | print(f"INFO -> loading {prefix} data. ")
102 | data_file_path = os.path.join(data_config.data_dir, f"{prefix}.{data_config.data_file_suffix}")
103 |
104 | dataset = TaggerNERDataset(data_file_path,
105 | data_config.tokenizer,
106 | data_sign,
107 | max_length=data_config.max_length,
108 | is_chinese=data_config.is_chinese,
109 | pad_to_maxlen=False,)
110 | max_len = 0
111 | counter = 0
112 | for idx, data_item in enumerate(dataset):
113 | tokens = data_item[0]
114 | num_tokens = tokens.shape[0]
115 | if num_tokens >= max_len:
116 | max_len = num_tokens
117 | if num_tokens > data_config.threshold:
118 | print(num_tokens)
119 | counter += 1
120 |
121 | print(f"INFO -> Max LEN for {prefix} set is : {max_len}")
122 | print(f"INFO -> large than {data_config.threshold} is {counter}")
123 |
124 |
125 |
126 | if __name__ == '__main__':
127 | data_sign = "en_onto"
128 | # english ontonotes 5.0
129 | # test: 172
130 | # dev: 407
131 | # train: 306
132 | count_max_length(data_sign)
133 |
134 |
135 |
--------------------------------------------------------------------------------
/tests/extract_entity_span.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: extract_entity_span.py
5 |
6 |
7 | def get_entity_from_bmes_lst(label_list):
8 | """reuse the code block from
9 | https://github.com/jiesutd/NCRFpp/blob/105a53a321eca9c1280037c473967858e01aaa43/utils/metric.py#L73
10 | Many thanks to Jie Yang.
11 | """
12 | list_len = len(label_list)
13 | begin_label = 'B-'
14 | end_label = 'E-'
15 | single_label = 'S-'
16 | whole_tag = ''
17 | index_tag = ''
18 | tag_list = []
19 | stand_matrix = []
20 | for i in range(0, list_len):
21 | if label_list[i] != -100:
22 | current_label = label_list[i].upper()
23 | else:
24 | continue
25 | if begin_label in current_label:
26 | if index_tag != '':
27 | tag_list.append(whole_tag + ',' + str(i-1))
28 | whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i)
29 | index_tag = current_label.replace(begin_label,"",1)
30 | elif single_label in current_label:
31 | if index_tag != '':
32 | tag_list.append(whole_tag + ',' + str(i-1))
33 | whole_tag = current_label.replace(single_label,"",1) +'[' +str(i)
34 | tag_list.append(whole_tag)
35 | whole_tag = ""
36 | index_tag = ""
37 | elif end_label in current_label:
38 | if index_tag != '':
39 | tag_list.append(whole_tag +',' + str(i))
40 | whole_tag = ''
41 | index_tag = ''
42 | else:
43 | continue
44 | if (whole_tag != '')&(index_tag != ''):
45 | tag_list.append(whole_tag)
46 | tag_list_len = len(tag_list)
47 |
48 | for i in range(0, tag_list_len):
49 | if len(tag_list[i]) > 0:
50 | tag_list[i] = tag_list[i]+ ']'
51 | insert_list = reverse_style(tag_list[i])
52 | stand_matrix.append(insert_list)
53 | return stand_matrix
54 |
55 |
56 | def reverse_style(input_string):
57 | target_position = input_string.index('[')
58 | input_len = len(input_string)
59 | output_string = input_string[target_position:input_len] + input_string[0:target_position]
60 | return output_string
61 |
62 |
63 | if __name__ == "__main__":
64 | label_lst = ["B-PER", "M-PER", "M-PER", "E-PER", "O", "O", "B-ORG", "M-ORG", "M-ORG", "E-ORG", "B-PER", "M-PER", "M-PER", "M-PER"]
65 | span_results = get_entity_from_bmes_lst(label_lst)
66 | print(span_results)
67 |
68 | label_lst = ["B-PER", "M-PER", -100, -100, "M-PER", "E-PER", -100, "O", "O", -100, "B-ORG", -100, "M-ORG", "M-ORG", "E-ORG", "B-PER", "M-PER",
69 | "M-PER", "M-PER"]
70 | span_results = get_entity_from_bmes_lst(label_lst)
71 | print(span_results)
--------------------------------------------------------------------------------
/tests/illegal_entity_boundary.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: illegal_entity_boundary.py
5 |
6 | from transformers import AutoTokenizer
7 |
8 | def load_dataexamples(file_path, ):
9 | with open(file_path, "r") as f:
10 | datalines = f.readlines()
11 |
12 | sentence_collections = []
13 | sentence_label_collections = []
14 | word_collections = []
15 | word_label_collections = []
16 |
17 | for data_item in datalines:
18 | data_item = data_item.strip()
19 | if len(data_item) != 0:
20 | word, label = tuple(data_item.split(" "))
21 | word_collections.append(word)
22 | word_label_collections.append(label)
23 | else:
24 | sentence_collections.append(word_collections)
25 | sentence_label_collections.append(word_label_collections)
26 | word_collections = []
27 | word_label_collections = []
28 |
29 | return sentence_collections, sentence_label_collections
30 |
31 |
32 | def find_data_instance(file_path, search_string):
33 | sentence_collections, sentence_label_collections = load_dataexamples(file_path)
34 |
35 | for sentence_lst, label_lst in zip(sentence_collections, sentence_label_collections):
36 | sentence_str = "".join(sentence_lst)
37 | if search_string in sentence_str:
38 | print(sentence_str)
39 | print("-"*10)
40 | print(sentence_lst)
41 | print(label_lst)
42 | print("=*"*10)
43 |
44 |
45 | def find_illegal_entity(query, context_tokens, labels, model_path, is_chinese=True, do_lower_case=True):
46 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, do_lower_case=do_lower_case)
47 | if is_chinese:
48 | context = "".join(context_tokens)
49 | else:
50 | context = " ".join(context_tokens)
51 |
52 | start_positions = []
53 | end_positions = []
54 | origin_tokens = context_tokens
55 | print("check labels in ")
56 | print(len(origin_tokens))
57 | print(len(labels))
58 |
59 | for label_idx, label_item in enumerate(labels):
60 | if "B-" in label_item:
61 | start_positions.append(label_idx)
62 | if "S-" in label_item:
63 | end_positions.append(label_idx)
64 | start_positions.append(label_idx)
65 | if "E-" in label_item:
66 | end_positions.append(label_idx)
67 |
68 | print("origin entity tokens")
69 | for start_item, end_item in zip(start_positions, end_positions):
70 | print(origin_tokens[start_item: end_item + 1])
71 |
72 | query_context_tokens = tokenizer.encode_plus(query, context,
73 | add_special_tokens=True,
74 | max_length=500000,
75 | return_overflowing_tokens=True,
76 | return_token_type_ids=True)
77 |
78 | if tokenizer.pad_token_id in query_context_tokens["input_ids"]:
79 | non_padded_ids = query_context_tokens["input_ids"][
80 | : query_context_tokens["input_ids"].index(tokenizer.pad_token_id)]
81 | else:
82 | non_padded_ids = query_context_tokens["input_ids"]
83 |
84 | non_pad_tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
85 | first_sep_token = non_pad_tokens.index("[SEP]")
86 | end_sep_token = len(non_pad_tokens) - 1
87 | new_start_positions = []
88 | new_end_positions = []
89 | if len(start_positions) != 0:
90 | for start_index, end_index in zip(start_positions, end_positions):
91 | if is_chinese:
92 | answer_text_span = " ".join(context[start_index: end_index + 1])
93 | else:
94 | answer_text_span = " ".join(context.split(" ")[start_index: end_index + 1])
95 | new_start, new_end = _improve_answer_span(query_context_tokens["input_ids"], first_sep_token, end_sep_token,
96 | tokenizer, answer_text_span)
97 | new_start_positions.append(new_start)
98 | new_end_positions.append(new_end)
99 | else:
100 | new_start_positions = start_positions
101 | new_end_positions = end_positions
102 |
103 | # clip out-of-boundary entity positions.
104 | new_start_positions = [start_pos for start_pos in new_start_positions if start_pos < 500000]
105 | new_end_positions = [end_pos for end_pos in new_end_positions if end_pos < 500000]
106 |
107 | print("print tokens :")
108 | for start_item, end_item in zip(new_start_positions, new_end_positions):
109 | print(tokenizer.convert_ids_to_tokens(query_context_tokens["input_ids"][start_item: end_item + 1]))
110 |
111 |
112 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text, return_subtoken_start=False):
113 | """Returns tokenized answer spans that better match the annotated answer."""
114 | doc_tokens = [str(tmp) for tmp in doc_tokens]
115 | answer_tokens = tokenizer.encode(orig_answer_text, add_special_tokens=False)
116 | tok_answer_text = " ".join([str(tmp) for tmp in answer_tokens])
117 | for new_start in range(input_start, input_end + 1):
118 | for new_end in range(input_end, new_start - 1, -1):
119 | text_span = " ".join(doc_tokens[new_start : (new_end+1)])
120 | if text_span == tok_answer_text:
121 | if not return_subtoken_start:
122 | return (new_start, new_end)
123 | tokens = tokenizer.convert_ids_to_tokens(doc_tokens[new_start: (new_end + 1)])
124 | if "##" not in tokens[-1]:
125 | return (new_start, new_end)
126 | else:
127 | for idx in range(len(tokens)-1, -1, -1):
128 | if "##" not in tokens[idx]:
129 | new_end = new_end - (len(tokens)-1 - idx)
130 | return (new_start, new_end)
131 |
132 | return (input_start, input_end)
133 |
134 |
135 | if __name__ == "__main__":
136 | # file_path = "/data/xiaoya/datasets/ner/msra/train.char.bmes"
137 | # search_string = "美亚股份"
138 | # find_data_instance(file_path, search_string)
139 | #
140 | # print("=%"*20)
141 | # print("check entity boundary")
142 | # print("=&"*20)
143 |
144 | print(">>> check for Chinese data example ... ...")
145 | context_tokens = ['1', '美', '亚', '股', '份', '3', '2', '.', '6', '6', '2', '民', '族', '集', '团', '2', '2', '.', '3',
146 | '8', '3', '鲁', '石', '化', 'A', '1', '9', '.', '1', '1', '4', '四', '川', '湖', '山', '1', '7', '.',
147 | '0', '9', '5', '太', '原', '刚', '玉', '1', '0', '.', '5', '8', '1', '咸', '阳', '偏', '转', '1', '6',
148 | '.', '1', '1', '2', '深', '华', '发', 'A', '1', '5', '.', '6', '6', '3', '渝', '开', '发', 'A', '1',
149 | '5', '.', '5', '2', '4', '深', '发', '展', 'A', '1', '3', '.', '8', '9', '5', '深', '纺', '织', 'A',
150 | '1', '3', '.', '2', '2', '1', '太', '极', '实', '业', '2', '3', '.', '2', '2', '2', '友', '好', '集',
151 | '团', '2', '2', '.', '1', '4', '3', '双', '虎', '涂', '料', '2', '0', '.', '2', '0', '4', '新', '潮',
152 | '实', '业', '1', '5', '.', '5', '8', '5', '信', '联', '股', '份', '1', '2', '.', '5', '7', '1', '氯',
153 | '碱', '化', '工', '2', '1', '.', '1', '7', '2', '百', '隆', '股', '份', '1', '5', '.', '6', '4', '3',
154 | '贵', '华', '旅', '业', '1', '5', '.', '1', '5', '4', '南', '洋', '实', '业', '1', '4', '.', '5', '0',
155 | '5', '福', '建', '福', '联', '1', '3', '.', '8', '0']
156 |
157 | labels = ['O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O',
158 | 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT',
159 | 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O',
160 | 'O', 'B-NT', 'M-NT',
161 | 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O',
162 | 'O',
163 | 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O',
164 | 'O', 'O', 'O',
165 | 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O',
166 | 'O', 'O', 'O',
167 | 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT',
168 | 'O', 'O',
169 | 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT',
170 | 'E-NT',
171 | 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT',
172 | 'M-NT', 'M-NT',
173 | 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O',
174 | 'B-NT',
175 | 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NT', 'M-NT', 'M-NT', 'E-NT', 'O', 'O', 'O', 'O',
176 | 'O']
177 | query = "组织机构"
178 |
179 | model_path = "/data/nfsdata/nlp/BERT_BASE_DIR/chinese_L-12_H-768_A-12"
180 | find_illegal_entity(query, context_tokens, labels, model_path, is_chinese=True, do_lower_case=True)
181 |
182 | print("$$$$$"*20)
183 | print(">>> check for English data example ... ...")
184 | query = "organization"
185 | context_tokens = ['RUGBY', 'LEAGUE', '-', 'EUROPEAN', 'SUPER', 'LEAGUE', 'RESULTS', '/', 'STANDINGS', '.', 'LONDON',
186 | '1996-08-24', 'Results', 'of', 'European', 'Super', 'League', 'rugby', 'league', 'matches', 'on',
187 | 'Saturday', ':', 'Paris', '14', 'Bradford', '27', 'Wigan', '78', 'Workington', '4', 'Standings',
188 | '(', 'tabulated', 'under', 'played', ',', 'won', ',', 'drawn', ',', 'lost', ',', 'points', 'for',
189 | ',', 'against', ',', 'total', 'points', ')', ':', 'Wigan', '22', '19', '1', '2', '902', '326',
190 | '39', 'St', 'Helens', '21', '19', '0', '2', '884', '441', '38', 'Bradford', '22', '17', '0', '5',
191 | '767', '409', '34', 'Warrington', '21', '12', '0', '9', '555', '499', '24', 'London', '21', '11',
192 | '1', '9', '555', '462', '23', 'Sheffield', '21', '10', '0', '11', '574', '696', '20', 'Halifax',
193 | '21', '9', '1', '11', '603', '552', '19', 'Castleford', '21', '9', '0', '12', '548', '543', '18',
194 | 'Oldham', '21', '8', '1', '12', '439', '656', '17', 'Leeds', '21', '6', '0', '15', '531', '681',
195 | '12', 'Paris', '22', '3', '1', '18', '398', '795', '7', 'Workington', '22', '2', '1', '19', '325',
196 | '1021', '5']
197 | labels = ['B-MISC', 'E-MISC', 'O', 'B-MISC', 'I-MISC', 'E-MISC', 'O', 'O', 'O', 'O', 'S-LOC', 'O', 'O', 'O',
198 | 'B-MISC', 'I-MISC', 'E-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'S-ORG', 'O', 'S-ORG', 'O',
199 | 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',
200 | 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'E-ORG', 'O', 'O', 'O', 'O', 'O', 'O',
201 | 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O',
202 | 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O',
203 | 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O',
204 | 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'S-ORG', 'O', 'O',
205 | 'O', 'O', 'O', 'O', 'O']
206 |
207 | model_path = "/data/xiaoya/models/bert_cased_large"
208 | find_illegal_entity(query, context_tokens, labels, model_path, is_chinese=False, do_lower_case=False)
209 |
210 |
--------------------------------------------------------------------------------
/train/bert_tagger_trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bert_tagger_trainer.py
5 |
6 | import os
7 | import re
8 | import argparse
9 | import logging
10 | from typing import Dict
11 | from collections import namedtuple
12 | from utils.random_seed import set_random_seed
13 | set_random_seed(0)
14 |
15 | import torch
16 | import pytorch_lightning as pl
17 | from torch import Tensor
18 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
19 | from torch.nn.modules import CrossEntropyLoss
20 | from pytorch_lightning import Trainer
21 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
22 | from transformers import AutoTokenizer
23 | from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
24 |
25 | from utils.get_parser import get_parser
26 | from datasets.tagger_ner_dataset import get_labels, TaggerNERDataset
27 | from datasets.truncate_dataset import TruncateDataset
28 | from datasets.collate_functions import tagger_collate_to_max_length
29 | from metrics.tagger_span_f1 import TaggerSpanF1
30 | from metrics.functional.tagger_span_f1 import transform_predictions_to_labels
31 | from models.bert_tagger import BertTagger
32 | from models.model_config import BertTaggerConfig
33 |
34 |
35 | class BertSequenceLabeling(pl.LightningModule):
36 | def __init__(
37 | self,
38 | args: argparse.Namespace
39 | ):
40 | """Initialize a model, tokenizer and config."""
41 | super().__init__()
42 | format = '%(asctime)s - %(name)s - %(message)s'
43 | if isinstance(args, argparse.Namespace):
44 | self.save_hyperparameters(args)
45 | self.args = args
46 | logging.basicConfig(format=format, filename=os.path.join(self.args.output_dir, "eval_result_log.txt"), level=logging.INFO)
47 | else:
48 | # eval mode
49 | TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
50 | self.args = args = TmpArgs(**args)
51 | logging.basicConfig(format=format, filename=os.path.join(self.args.output_dir, "eval_test.txt"), level=logging.INFO)
52 |
53 | self.bert_dir = args.bert_config_dir
54 | self.data_dir = self.args.data_dir
55 | self.task_labels = get_labels(self.args.data_sign)
56 | self.num_labels = len(self.task_labels)
57 | self.task_idx2label = {label_idx : label_item for label_idx, label_item in enumerate(get_labels(self.args.data_sign))}
58 | bert_config = BertTaggerConfig.from_pretrained(args.bert_config_dir,
59 | hidden_dropout_prob=args.bert_dropout,
60 | attention_probs_dropout_prob=args.bert_dropout,
61 | num_labels=self.num_labels,
62 | classifier_dropout=args.classifier_dropout,
63 | classifier_sign=args.classifier_sign,
64 | classifier_act_func=args.classifier_act_func,
65 | classifier_intermediate_hidden_size=args.classifier_intermediate_hidden_size)
66 |
67 | self.tokenizer = AutoTokenizer.from_pretrained(args.bert_config_dir, use_fast=False, do_lower_case=args.do_lowercase)
68 | self.model = BertTagger.from_pretrained(args.bert_config_dir, config=bert_config)
69 | logging.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args))
70 | self.result_logger = logging.getLogger(__name__)
71 | self.result_logger.setLevel(logging.INFO)
72 | self.loss_func = CrossEntropyLoss()
73 | self.span_f1 = TaggerSpanF1()
74 | self.chinese = args.chinese
75 | self.optimizer = args.optimizer
76 |
77 | @staticmethod
78 | def add_model_specific_args(parent_parser):
79 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
80 | parser.add_argument("--train_batch_size", type=int, default=8, help="batch size")
81 | parser.add_argument("--eval_batch_size", type=int, default=8, help="batch size")
82 | parser.add_argument("--bert_dropout", type=float, default=0.1, help="bert dropout rate")
83 | parser.add_argument("--classifier_sign", type=str, default="multi_nonlinear")
84 | parser.add_argument("--classifier_dropout", type=float, default=0.1)
85 | parser.add_argument("--classifier_act_func", type=str, default="gelu")
86 | parser.add_argument("--classifier_intermediate_hidden_size", type=int, default=1024)
87 | parser.add_argument("--chinese", action="store_true", help="is chinese dataset")
88 | parser.add_argument("--optimizer", choices=["adamw", "torch.adam"], default="adamw", help="optimizer type")
89 | parser.add_argument("--final_div_factor", type=float, default=1e4, help="final div factor of linear decay scheduler")
90 | parser.add_argument("--output_dir", type=str, default="", help="the path for saving intermediate model checkpoints.")
91 | parser.add_argument("--lr_scheduler", type=str, default="linear_decay", help="lr scheduler")
92 | parser.add_argument("--data_sign", type=str, default="en_conll03", help="data signature for the dataset.")
93 | parser.add_argument("--polydecay_ratio", type=float, default=4, help="ratio for polydecay learing rate scheduler.")
94 | parser.add_argument("--do_lowercase", action="store_true", )
95 | parser.add_argument("--data_file_suffix", type=str, default=".char.bmes")
96 | parser.add_argument("--lr_scheulder", type=str, default="polydecay")
97 | parser.add_argument("--lr_mini", type=float, default=-1)
98 | parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for.")
99 |
100 | return parser
101 |
102 | def configure_optimizers(self):
103 | """Prepare optimizer and schedule (linear warmup and decay)"""
104 | no_decay = ["bias", "LayerNorm.weight"]
105 | optimizer_grouped_parameters = [
106 | {
107 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
108 | "weight_decay": self.args.weight_decay,
109 | },
110 | {
111 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
112 | "weight_decay": 0.0,
113 | },
114 | ]
115 | if self.optimizer == "adamw":
116 | optimizer = AdamW(optimizer_grouped_parameters,
117 | betas=(0.9, 0.98), # according to RoBERTa paper
118 | lr=self.args.lr,
119 | eps=self.args.adam_epsilon,)
120 | elif self.optimizer == "torch.adam":
121 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
122 | lr=self.args.lr,
123 | eps=self.args.adam_epsilon,
124 | weight_decay=self.args.weight_decay)
125 | else:
126 | raise ValueError("Optimizer type does not exist.")
127 | num_gpus = len([x for x in str(self.args.gpus).split(",") if x.strip()])
128 | t_total = (len(self.train_dataloader()) // (self.args.accumulate_grad_batches * num_gpus) + 1) * self.args.max_epochs
129 | warmup_steps = int(self.args.warmup_proportion * t_total)
130 | if self.args.lr_scheduler == "onecycle":
131 | scheduler = torch.optim.lr_scheduler.OneCycleLR(
132 | optimizer, max_lr=self.args.lr, pct_start=float(warmup_steps/t_total),
133 | final_div_factor=self.args.final_div_factor,
134 | total_steps=t_total, anneal_strategy='linear')
135 | elif self.args.lr_scheduler == "linear":
136 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
137 | elif self.args.lr_scheulder == "polydecay":
138 | if self.args.lr_mini == -1:
139 | lr_mini = self.args.lr / self.args.polydecay_ratio
140 | else:
141 | lr_mini = self.args.lr_mini
142 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, warmup_steps, t_total, lr_end=lr_mini)
143 | else:
144 | raise ValueError
145 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
146 |
147 | def forward(self, input_ids, token_type_ids, attention_mask):
148 | return self.model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
149 |
150 | def compute_loss(self, sequence_logits, sequence_labels, input_mask=None):
151 | if input_mask is not None:
152 | active_loss = input_mask.view(-1) == 1
153 | active_logits = sequence_logits.view(-1, self.num_labels)
154 | active_labels = torch.where(
155 | active_loss, sequence_labels.view(-1), torch.tensor(self.loss_func.ignore_index).type_as(sequence_labels)
156 | )
157 | loss = self.loss_func(active_logits, active_labels)
158 | else:
159 | loss = self.loss_func(sequence_logits.view(-1, self.num_labels), sequence_labels.view(-1))
160 | return loss
161 |
162 | def training_step(self, batch, batch_idx):
163 | tf_board_logs = {"lr": self.trainer.optimizers[0].param_groups[0]['lr']}
164 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch
165 |
166 | logits = self.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
167 | loss = self.compute_loss(logits, sequence_labels, input_mask=attention_mask)
168 | tf_board_logs[f"train_loss"] = loss
169 |
170 | return {'loss': loss, 'log': tf_board_logs}
171 |
172 | def validation_step(self, batch, batch_idx):
173 | output = {}
174 |
175 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch
176 | batch_size = token_input_ids.shape[0]
177 | logits = self.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
178 | loss = self.compute_loss(logits, sequence_labels, input_mask=attention_mask)
179 | output[f"val_loss"] = loss
180 |
181 | sequence_pred_lst = transform_predictions_to_labels(logits.view(batch_size, -1, len(self.task_labels)), is_wordpiece_mask, self.task_idx2label, input_type="logit")
182 | sequence_gold_lst = transform_predictions_to_labels(sequence_labels, is_wordpiece_mask, self.task_idx2label, input_type="label")
183 | span_f1_stats = self.span_f1(sequence_pred_lst, sequence_gold_lst)
184 | output["span_f1_stats"] = span_f1_stats
185 |
186 | return output
187 |
188 | def validation_epoch_end(self, outputs):
189 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
190 | tensorboard_logs = {'val_loss': avg_loss}
191 |
192 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0)
193 | span_tp, span_fp, span_fn = all_counts
194 | span_recall = span_tp / (span_tp + span_fn + 1e-10)
195 | span_precision = span_tp / (span_tp + span_fp + 1e-10)
196 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)
197 | tensorboard_logs[f"span_precision"] = span_precision
198 | tensorboard_logs[f"span_recall"] = span_recall
199 | tensorboard_logs[f"span_f1"] = span_f1
200 | self.result_logger.info(f"EVAL INFO -> current_epoch is: {self.trainer.current_epoch}, current_global_step is: {self.trainer.global_step} ")
201 | self.result_logger.info(f"EVAL INFO -> valid_f1 is: {span_f1}")
202 |
203 | return {'val_loss': avg_loss, 'log': tensorboard_logs}
204 |
205 | def test_step(self, batch, batch_idx):
206 | output = {}
207 |
208 | token_input_ids, token_type_ids, attention_mask, sequence_labels, is_wordpiece_mask = batch
209 | batch_size = token_input_ids.shape[0]
210 | logits = self.model(token_input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
211 | loss = self.compute_loss(logits, sequence_labels, input_mask=attention_mask)
212 | output[f"test_loss"] = loss
213 |
214 | sequence_pred_lst = transform_predictions_to_labels(logits.view(batch_size, -1, len(self.task_labels)),
215 | is_wordpiece_mask, self.task_idx2label, input_type="logit")
216 | sequence_gold_lst = transform_predictions_to_labels(sequence_labels, is_wordpiece_mask, self.task_idx2label,
217 | input_type="label")
218 | span_f1_stats = self.span_f1(sequence_pred_lst, sequence_gold_lst)
219 | output["span_f1_stats"] = span_f1_stats
220 | return output
221 |
222 | def test_epoch_end(self, outputs) -> Dict[str, Dict[str, Tensor]]:
223 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
224 | tensorboard_logs = {'test_loss': avg_loss}
225 |
226 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0)
227 | span_tp, span_fp, span_fn = all_counts
228 | span_recall = span_tp / (span_tp + span_fn + 1e-10)
229 | span_precision = span_tp / (span_tp + span_fp + 1e-10)
230 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)
231 | tensorboard_logs[f"span_precision"] = span_precision
232 | tensorboard_logs[f"span_recall"] = span_recall
233 | tensorboard_logs[f"span_f1"] = span_f1
234 | print(f"TEST INFO -> test_f1 is: {span_f1} precision: {span_precision}, recall: {span_recall}")
235 | self.result_logger.info(f"EVAL INFO -> test_f1 is: {span_f1}, test_precision is: {span_precision}, test_recall is: {span_recall}")
236 |
237 | return {'test_loss': avg_loss, 'log': tensorboard_logs}
238 |
239 | def train_dataloader(self) -> DataLoader:
240 | return self.get_dataloader("train")
241 |
242 | def val_dataloader(self) -> DataLoader:
243 | return self.get_dataloader("dev")
244 |
245 | def test_dataloader(self) -> DataLoader:
246 | return self.get_dataloader("test")
247 |
248 | def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader:
249 | """get train/dev/test dataloader"""
250 | data_path = os.path.join(self.data_dir, f"{prefix}{self.args.data_file_suffix}")
251 | dataset = TaggerNERDataset(data_path, self.tokenizer, self.args.data_sign,
252 | max_length=self.args.max_length, is_chinese=self.args.chinese,
253 | pad_to_maxlen=False)
254 |
255 | if limit is not None:
256 | dataset = TruncateDataset(dataset, limit)
257 |
258 | if prefix == "train":
259 | batch_size = self.args.train_batch_size
260 | # define data_generator will help experiment reproducibility.
261 | # cannot use random data sampler since the gradient may explode.
262 | data_generator = torch.Generator()
263 | data_generator.manual_seed(self.args.seed)
264 | data_sampler = RandomSampler(dataset, generator=data_generator)
265 | else:
266 | data_sampler = SequentialSampler(dataset)
267 | batch_size = self.args.eval_batch_size
268 |
269 | dataloader = DataLoader(
270 | dataset=dataset, sampler=data_sampler,
271 | batch_size=batch_size, num_workers=self.args.workers,
272 | collate_fn=tagger_collate_to_max_length
273 | )
274 |
275 | return dataloader
276 |
277 |
278 | def find_best_checkpoint_on_dev(output_dir: str, log_file: str = "eval_result_log.txt", only_keep_the_best_ckpt: bool = False):
279 | with open(os.path.join(output_dir, log_file)) as f:
280 | log_lines = f.readlines()
281 |
282 | F1_PATTERN = re.compile(r"span_f1 reached \d+\.\d* \(best")
283 | # val_f1 reached 0.00000 (best 0.00000)
284 | CKPT_PATTERN = re.compile(r"saving model to \S+ as top")
285 | checkpoint_info_lines = []
286 | for log_line in log_lines:
287 | if "saving model to" in log_line:
288 | checkpoint_info_lines.append(log_line)
289 | # example of log line
290 | # Epoch 00000: val_f1 reached 0.00000 (best 0.00000), saving model to /data/xiaoya/outputs/0117/debug_5_12_2e-5_0.001_0.001_275_0.1_1_0.25/checkpoint/epoch=0.ckpt as top 20
291 | best_f1_on_dev = 0
292 | best_checkpoint_on_dev = ""
293 | for checkpoint_info_line in checkpoint_info_lines:
294 | current_f1 = float(
295 | re.findall(F1_PATTERN, checkpoint_info_line)[0].replace("span_f1 reached ", "").replace(" (best", ""))
296 | current_ckpt = re.findall(CKPT_PATTERN, checkpoint_info_line)[0].replace("saving model to ", "").replace(
297 | " as top", "")
298 |
299 | if current_f1 >= best_f1_on_dev:
300 | if only_keep_the_best_ckpt and len(best_checkpoint_on_dev) != 0:
301 | os.remove(best_checkpoint_on_dev)
302 | best_f1_on_dev = current_f1
303 | best_checkpoint_on_dev = current_ckpt
304 |
305 | return best_f1_on_dev, best_checkpoint_on_dev
306 |
307 |
308 | def main():
309 | """main"""
310 | parser = get_parser()
311 |
312 | # add model specific args
313 | parser = BertSequenceLabeling.add_model_specific_args(parser)
314 | # add all the available trainer options to argparse
315 | # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
316 | parser = Trainer.add_argparse_args(parser)
317 | args = parser.parse_args()
318 | model = BertSequenceLabeling(args)
319 | if args.pretrained_checkpoint:
320 | model.load_state_dict(torch.load(args.pretrained_checkpoint,
321 | map_location=torch.device('cpu'))["state_dict"])
322 | checkpoint_callback = ModelCheckpoint(
323 | filepath=args.output_dir,
324 | save_top_k=args.max_keep_ckpt,
325 | verbose=True,
326 | monitor="span_f1",
327 | period=-1,
328 | mode="max",
329 | )
330 | trainer = Trainer.from_argparse_args(
331 | args,
332 | checkpoint_callback=checkpoint_callback,
333 | deterministic=True,
334 | default_root_dir=args.output_dir
335 | )
336 | trainer.fit(model)
337 |
338 | # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set.
339 | best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(args.output_dir,)
340 | model.result_logger.info("=&" * 20)
341 | model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}")
342 | model.result_logger.info(f"Best checkpoint on DEV set is {path_to_best_checkpoint}")
343 | checkpoint = torch.load(path_to_best_checkpoint)
344 | model.load_state_dict(checkpoint['state_dict'])
345 | trainer.test(model)
346 | model.result_logger.info("=&" * 20)
347 |
348 |
349 | if __name__ == '__main__':
350 | main()
351 |
--------------------------------------------------------------------------------
/train/mrc_ner_trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: mrc_ner_trainer.py
5 |
6 | import os
7 | import re
8 | import argparse
9 | import logging
10 | from collections import namedtuple
11 | from typing import Dict
12 |
13 | import torch
14 | import pytorch_lightning as pl
15 | from pytorch_lightning import Trainer
16 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
17 | from tokenizers import BertWordPieceTokenizer
18 | from torch import Tensor
19 | from torch.nn.modules import CrossEntropyLoss, BCEWithLogitsLoss
20 | from torch.utils.data import DataLoader
21 | from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
22 | from torch.optim import SGD
23 |
24 | from datasets.mrc_ner_dataset import MRCNERDataset
25 | from datasets.truncate_dataset import TruncateDataset
26 | from datasets.collate_functions import collate_to_max_length
27 | from metrics.query_span_f1 import QuerySpanF1
28 | from models.bert_query_ner import BertQueryNER
29 | from models.model_config import BertQueryNerConfig
30 | from utils.get_parser import get_parser
31 | from utils.random_seed import set_random_seed
32 |
33 | set_random_seed(0)
34 |
35 |
36 | class BertLabeling(pl.LightningModule):
37 | def __init__(
38 | self,
39 | args: argparse.Namespace
40 | ):
41 | """Initialize a model, tokenizer and config."""
42 | super().__init__()
43 | format = '%(asctime)s - %(name)s - %(message)s'
44 | if isinstance(args, argparse.Namespace):
45 | self.save_hyperparameters(args)
46 | self.args = args
47 | logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_result_log.txt"), level=logging.INFO)
48 | else:
49 | # eval mode
50 | TmpArgs = namedtuple("tmp_args", field_names=list(args.keys()))
51 | self.args = args = TmpArgs(**args)
52 | logging.basicConfig(format=format, filename=os.path.join(self.args.default_root_dir, "eval_test.txt"), level=logging.INFO)
53 |
54 | self.bert_dir = args.bert_config_dir
55 | self.data_dir = self.args.data_dir
56 |
57 | bert_config = BertQueryNerConfig.from_pretrained(args.bert_config_dir,
58 | hidden_dropout_prob=args.bert_dropout,
59 | attention_probs_dropout_prob=args.bert_dropout,
60 | mrc_dropout=args.mrc_dropout,
61 | classifier_act_func = args.classifier_act_func,
62 | classifier_intermediate_hidden_size=args.classifier_intermediate_hidden_size)
63 |
64 | self.model = BertQueryNER.from_pretrained(args.bert_config_dir,
65 | config=bert_config)
66 | logging.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args))
67 | self.result_logger = logging.getLogger(__name__)
68 | self.result_logger.setLevel(logging.INFO)
69 | self.result_logger.info(str(args.__dict__ if isinstance(args, argparse.ArgumentParser) else args))
70 | self.bce_loss = BCEWithLogitsLoss(reduction="none")
71 |
72 | weight_sum = args.weight_start + args.weight_end + args.weight_span
73 | self.weight_start = args.weight_start / weight_sum
74 | self.weight_end = args.weight_end / weight_sum
75 | self.weight_span = args.weight_span / weight_sum
76 | self.flat_ner = args.flat
77 | self.span_f1 = QuerySpanF1(flat=self.flat_ner)
78 | self.chinese = args.chinese
79 | self.optimizer = args.optimizer
80 | self.span_loss_candidates = args.span_loss_candidates
81 |
82 | @staticmethod
83 | def add_model_specific_args(parent_parser):
84 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
85 | parser.add_argument("--mrc_dropout", type=float, default=0.1,
86 | help="mrc dropout rate")
87 | parser.add_argument("--bert_dropout", type=float, default=0.1,
88 | help="bert dropout rate")
89 | parser.add_argument("--classifier_act_func", type=str, default="gelu")
90 | parser.add_argument("--classifier_intermediate_hidden_size", type=int, default=1024)
91 | parser.add_argument("--weight_start", type=float, default=1.0)
92 | parser.add_argument("--weight_end", type=float, default=1.0)
93 | parser.add_argument("--weight_span", type=float, default=1.0)
94 | parser.add_argument("--flat", action="store_true", help="is flat ner")
95 | parser.add_argument("--span_loss_candidates", choices=["all", "pred_and_gold", "pred_gold_random", "gold"],
96 | default="all", help="Candidates used to compute span loss")
97 | parser.add_argument("--chinese", action="store_true",
98 | help="is chinese dataset")
99 | parser.add_argument("--optimizer", choices=["adamw", "sgd", "torch.adam"], default="adamw",
100 | help="loss type")
101 | parser.add_argument("--final_div_factor", type=float, default=1e4,
102 | help="final div factor of linear decay scheduler")
103 | parser.add_argument("--lr_scheduler", type=str, default="onecycle", )
104 | parser.add_argument("--lr_mini", type=float, default=-1)
105 | return parser
106 |
107 | def configure_optimizers(self):
108 | """Prepare optimizer and schedule (linear warmup and decay)"""
109 | no_decay = ["bias", "LayerNorm.weight"]
110 | optimizer_grouped_parameters = [
111 | {
112 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
113 | "weight_decay": self.args.weight_decay,
114 | },
115 | {
116 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
117 | "weight_decay": 0.0,
118 | },
119 | ]
120 | if self.optimizer == "adamw":
121 | optimizer = AdamW(optimizer_grouped_parameters,
122 | betas=(0.9, 0.98), # according to RoBERTa paper
123 | lr=self.args.lr,
124 | eps=self.args.adam_epsilon,)
125 | elif self.optimizer == "torch.adam":
126 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
127 | lr=self.args.lr,
128 | eps=self.args.adam_epsilon,
129 | weight_decay=self.args.weight_decay)
130 | else:
131 | optimizer = SGD(optimizer_grouped_parameters, lr=self.args.lr, momentum=0.9)
132 | num_gpus = len([x for x in str(self.args.gpus).split(",") if x.strip()])
133 | t_total = (len(self.train_dataloader()) // (self.args.accumulate_grad_batches * num_gpus) + 1) * self.args.max_epochs
134 | if self.args.lr_scheduler == "onecycle":
135 | scheduler = torch.optim.lr_scheduler.OneCycleLR(
136 | optimizer, max_lr=self.args.lr, pct_start=float(self.args.warmup_steps/t_total),
137 | final_div_factor=self.args.final_div_factor,
138 | total_steps=t_total, anneal_strategy='linear'
139 | )
140 | elif self.args.lr_scheduler == "linear":
141 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)
142 | elif self.args.lr_scheduler == "polydecay":
143 | if self.args.lr_mini == -1:
144 | lr_mini = self.args.lr / 5
145 | else:
146 | lr_mini = self.args.lr_mini
147 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer, self.args.warmup_steps, t_total, lr_end=lr_mini)
148 | else:
149 | raise ValueError
150 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
151 |
152 | def forward(self, input_ids, attention_mask, token_type_ids):
153 | return self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
154 |
155 | def compute_loss(self, start_logits, end_logits, span_logits,
156 | start_labels, end_labels, match_labels, start_label_mask, end_label_mask):
157 | batch_size, seq_len = start_logits.size()
158 |
159 | start_float_label_mask = start_label_mask.view(-1).float()
160 | end_float_label_mask = end_label_mask.view(-1).float()
161 | match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len)
162 | match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1)
163 | match_label_mask = match_label_row_mask & match_label_col_mask
164 | match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end
165 |
166 | if self.span_loss_candidates == "all":
167 | # naive mask
168 | float_match_label_mask = match_label_mask.view(batch_size, -1).float()
169 | else:
170 | # use only pred or golden start/end to compute match loss
171 | start_preds = start_logits > 0
172 | end_preds = end_logits > 0
173 | if self.span_loss_candidates == "gold":
174 | match_candidates = ((start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
175 | & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
176 | elif self.span_loss_candidates == "pred_gold_random":
177 | gold_and_pred = torch.logical_or(
178 | (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
179 | & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
180 | (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
181 | & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
182 | )
183 | data_generator = torch.Generator()
184 | data_generator.manual_seed(0)
185 | random_matrix = torch.empty(batch_size, seq_len, seq_len).uniform_(0, 1)
186 | random_matrix = torch.bernoulli(random_matrix, generator=data_generator).long()
187 | random_matrix = random_matrix.cuda()
188 | match_candidates = torch.logical_or(
189 | gold_and_pred, random_matrix
190 | )
191 | else:
192 | match_candidates = torch.logical_or(
193 | (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
194 | & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
195 | (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
196 | & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))
197 | )
198 | match_label_mask = match_label_mask & match_candidates
199 | float_match_label_mask = match_label_mask.view(batch_size, -1).float()
200 |
201 | start_loss = self.bce_loss(start_logits.view(-1), start_labels.view(-1).float())
202 | start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum()
203 | end_loss = self.bce_loss(end_logits.view(-1), end_labels.view(-1).float())
204 | end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum()
205 | match_loss = self.bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float())
206 | match_loss = match_loss * float_match_label_mask
207 | match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)
208 |
209 | return start_loss, end_loss, match_loss
210 |
211 | def training_step(self, batch, batch_idx):
212 | tf_board_logs = {
213 | "lr": self.trainer.optimizers[0].param_groups[0]['lr']
214 | }
215 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
216 |
217 | # num_tasks * [bsz, length, num_labels]
218 | attention_mask = (tokens != 0).long()
219 | start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids)
220 |
221 | start_loss, end_loss, match_loss = self.compute_loss(start_logits=start_logits,
222 | end_logits=end_logits,
223 | span_logits=span_logits,
224 | start_labels=start_labels,
225 | end_labels=end_labels,
226 | match_labels=match_labels,
227 | start_label_mask=start_label_mask,
228 | end_label_mask=end_label_mask
229 | )
230 |
231 | total_loss = self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss
232 |
233 | tf_board_logs[f"train_loss"] = total_loss
234 | tf_board_logs[f"start_loss"] = start_loss
235 | tf_board_logs[f"end_loss"] = end_loss
236 | tf_board_logs[f"match_loss"] = match_loss
237 |
238 | return {'loss': total_loss, 'log': tf_board_logs}
239 |
240 | def validation_step(self, batch, batch_idx):
241 | output = {}
242 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
243 |
244 | attention_mask = (tokens != 0).long()
245 | start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids)
246 |
247 | start_loss, end_loss, match_loss = self.compute_loss(start_logits=start_logits,
248 | end_logits=end_logits,
249 | span_logits=span_logits,
250 | start_labels=start_labels,
251 | end_labels=end_labels,
252 | match_labels=match_labels,
253 | start_label_mask=start_label_mask,
254 | end_label_mask=end_label_mask
255 | )
256 |
257 | total_loss = self.weight_start * start_loss + self.weight_end * end_loss + self.weight_span * match_loss
258 |
259 | output[f"val_loss"] = total_loss
260 | output[f"start_loss"] = start_loss
261 | output[f"end_loss"] = end_loss
262 | output[f"match_loss"] = match_loss
263 |
264 | start_preds, end_preds = start_logits > 0, end_logits > 0
265 | span_f1_stats = self.span_f1(start_preds=start_preds, end_preds=end_preds, match_logits=span_logits,
266 | start_label_mask=start_label_mask, end_label_mask=end_label_mask,
267 | match_labels=match_labels)
268 | output["span_f1_stats"] = span_f1_stats
269 |
270 | return output
271 |
272 | def validation_epoch_end(self, outputs):
273 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
274 | tensorboard_logs = {'val_loss': avg_loss}
275 |
276 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0)
277 | span_tp, span_fp, span_fn = all_counts
278 | span_recall = span_tp / (span_tp + span_fn + 1e-10)
279 | span_precision = span_tp / (span_tp + span_fp + 1e-10)
280 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)
281 | tensorboard_logs[f"span_precision"] = span_precision
282 | tensorboard_logs[f"span_recall"] = span_recall
283 | tensorboard_logs[f"span_f1"] = span_f1
284 | self.result_logger.info(f"EVAL INFO -> current_epoch is: {self.trainer.current_epoch}, current_global_step is: {self.trainer.global_step} ")
285 | self.result_logger.info(f"EVAL INFO -> valid_f1 is: {span_f1}; precision: {span_precision}, recall: {span_recall}.")
286 |
287 | return {'val_loss': avg_loss, 'log': tensorboard_logs}
288 |
289 | def test_step(self, batch, batch_idx):
290 | """"""
291 | output = {}
292 | tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
293 |
294 | attention_mask = (tokens != 0).long()
295 | start_logits, end_logits, span_logits = self(tokens, attention_mask, token_type_ids)
296 |
297 | start_preds, end_preds = start_logits > 0, end_logits > 0
298 | span_f1_stats = self.span_f1(start_preds=start_preds, end_preds=end_preds, match_logits=span_logits,
299 | start_label_mask=start_label_mask, end_label_mask=end_label_mask,
300 | match_labels=match_labels)
301 | output["span_f1_stats"] = span_f1_stats
302 |
303 | return output
304 |
305 | def test_epoch_end(self, outputs) -> Dict[str, Dict[str, Tensor]]:
306 | tensorboard_logs = {}
307 |
308 | all_counts = torch.stack([x[f'span_f1_stats'] for x in outputs]).view(-1, 3).sum(0)
309 | span_tp, span_fp, span_fn = all_counts
310 | span_recall = span_tp / (span_tp + span_fn + 1e-10)
311 | span_precision = span_tp / (span_tp + span_fp + 1e-10)
312 | span_f1 = span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)
313 | print(f"TEST INFO -> test_f1 is: {span_f1} precision: {span_precision}, recall: {span_recall}")
314 | self.result_logger.info(f"TEST INFO -> test_f1 is: {span_f1} precision: {span_precision}, recall: {span_recall}")
315 | return {'log': tensorboard_logs}
316 |
317 | def train_dataloader(self) -> DataLoader:
318 | return self.get_dataloader("train")
319 |
320 | def val_dataloader(self) -> DataLoader:
321 | return self.get_dataloader("dev")
322 |
323 | def test_dataloader(self) -> DataLoader:
324 | return self.get_dataloader("test")
325 |
326 | def get_dataloader(self, prefix="train", limit: int = None) -> DataLoader:
327 | """get training dataloader"""
328 | """
329 | load_mmap_dataset
330 | """
331 | json_path = os.path.join(self.data_dir, f"mrc-ner.{prefix}")
332 | vocab_path = os.path.join(self.bert_dir, "vocab.txt")
333 | dataset = MRCNERDataset(json_path=json_path,
334 | tokenizer=BertWordPieceTokenizer(vocab_path),
335 | max_length=self.args.max_length,
336 | is_chinese=self.chinese,
337 | pad_to_maxlen=False
338 | )
339 |
340 | if limit is not None:
341 | dataset = TruncateDataset(dataset, limit)
342 |
343 | dataloader = DataLoader(
344 | dataset=dataset,
345 | batch_size=self.args.batch_size,
346 | num_workers=self.args.workers,
347 | shuffle=True if prefix == "train" else False,
348 | collate_fn=collate_to_max_length
349 | )
350 |
351 | return dataloader
352 |
353 | def find_best_checkpoint_on_dev(output_dir: str, log_file: str = "eval_result_log.txt", only_keep_the_best_ckpt: bool = False):
354 | with open(os.path.join(output_dir, log_file)) as f:
355 | log_lines = f.readlines()
356 |
357 | F1_PATTERN = re.compile(r"span_f1 reached \d+\.\d* \(best")
358 | # val_f1 reached 0.00000 (best 0.00000)
359 | CKPT_PATTERN = re.compile(r"saving model to \S+ as top")
360 | checkpoint_info_lines = []
361 | for log_line in log_lines:
362 | if "saving model to" in log_line:
363 | checkpoint_info_lines.append(log_line)
364 | # example of log line
365 | # Epoch 00000: val_f1 reached 0.00000 (best 0.00000), saving model to /data/xiaoya/outputs/0117/debug_5_12_2e-5_0.001_0.001_275_0.1_1_0.25/checkpoint/epoch=0.ckpt as top 20
366 | best_f1_on_dev = 0
367 | best_checkpoint_on_dev = ""
368 | for checkpoint_info_line in checkpoint_info_lines:
369 | current_f1 = float(
370 | re.findall(F1_PATTERN, checkpoint_info_line)[0].replace("span_f1 reached ", "").replace(" (best", ""))
371 | current_ckpt = re.findall(CKPT_PATTERN, checkpoint_info_line)[0].replace("saving model to ", "").replace(
372 | " as top", "")
373 |
374 | if current_f1 >= best_f1_on_dev:
375 | if only_keep_the_best_ckpt and len(best_checkpoint_on_dev) != 0:
376 | os.remove(best_checkpoint_on_dev)
377 | best_f1_on_dev = current_f1
378 | best_checkpoint_on_dev = current_ckpt
379 |
380 | return best_f1_on_dev, best_checkpoint_on_dev
381 |
382 |
383 | def main():
384 | """main"""
385 | parser = get_parser()
386 |
387 | # add model specific args
388 | parser = BertLabeling.add_model_specific_args(parser)
389 |
390 | # add all the available trainer options to argparse
391 | # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
392 | parser = Trainer.add_argparse_args(parser)
393 |
394 | args = parser.parse_args()
395 |
396 | model = BertLabeling(args)
397 | if args.pretrained_checkpoint:
398 | model.load_state_dict(torch.load(args.pretrained_checkpoint,
399 | map_location=torch.device('cpu'))["state_dict"])
400 |
401 | checkpoint_callback = ModelCheckpoint(
402 | filepath=args.default_root_dir,
403 | save_top_k=args.max_keep_ckpt,
404 | verbose=True,
405 | monitor="span_f1",
406 | period=-1,
407 | mode="max",
408 | )
409 | trainer = Trainer.from_argparse_args(
410 | args,
411 | checkpoint_callback=checkpoint_callback,
412 | deterministic=True,
413 | default_root_dir=args.default_root_dir
414 | )
415 |
416 | trainer.fit(model)
417 |
418 | # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set.
419 | best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(args.default_root_dir, )
420 | model.result_logger.info("=&" * 20)
421 | model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}")
422 | model.result_logger.info(f"Best checkpoint on DEV set is {path_to_best_checkpoint}")
423 | checkpoint = torch.load(path_to_best_checkpoint)
424 | model.load_state_dict(checkpoint['state_dict'])
425 | model.result_logger.info("=&" * 20)
426 |
427 |
428 | if __name__ == '__main__':
429 | main()
430 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShannonAI/mrc-for-flat-nested-ner/457b0759f7fd462d0abd0a23441726352716fff9/utils/__init__.py
--------------------------------------------------------------------------------
/utils/bmes_decode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: bmes_decode.py
5 |
6 | from typing import Tuple, List
7 |
8 |
9 | class Tag(object):
10 | def __init__(self, term, tag, begin, end):
11 | self.term = term
12 | self.tag = tag
13 | self.begin = begin
14 | self.end = end
15 |
16 | def to_tuple(self):
17 | return tuple([self.term, self.begin, self.end])
18 |
19 | def __str__(self):
20 | return str({key: value for key, value in self.__dict__.items()})
21 |
22 | def __repr__(self):
23 | return str({key: value for key, value in self.__dict__.items()})
24 |
25 |
26 | def bmes_decode(char_label_list: List[Tuple[str, str]]) -> List[Tag]:
27 | """
28 | decode inputs to tags
29 | Args:
30 | char_label_list: list of tuple (word, bmes-tag)
31 | Returns:
32 | tags
33 | Examples:
34 | >>> x = [("Hi", "O"), ("Beijing", "S-LOC")]
35 | >>> bmes_decode(x)
36 | [{'term': 'Beijing', 'tag': 'LOC', 'begin': 1, 'end': 2}]
37 | """
38 | idx = 0
39 | length = len(char_label_list)
40 | tags = []
41 | while idx < length:
42 | term, label = char_label_list[idx]
43 | current_label = label[0]
44 |
45 | # correct labels
46 | if idx + 1 == length and current_label == "B":
47 | current_label = "S"
48 |
49 | # merge chars
50 | if current_label == "O":
51 | idx += 1
52 | continue
53 | if current_label == "S":
54 | tags.append(Tag(term, label[2:], idx, idx + 1))
55 | idx += 1
56 | continue
57 | if current_label == "B":
58 | end = idx + 1
59 | while end + 1 < length and char_label_list[end][1][0] == "M":
60 | end += 1
61 | if char_label_list[end][1][0] == "E": # end with E
62 | entity = "".join(char_label_list[i][0] for i in range(idx, end + 1))
63 | tags.append(Tag(entity, label[2:], idx, end + 1))
64 | idx = end + 1
65 | else: # end with M/B
66 | entity = "".join(char_label_list[i][0] for i in range(idx, end))
67 | tags.append(Tag(entity, label[2:], idx, end))
68 | idx = end
69 | continue
70 | else:
71 | raise Exception("Invalid Inputs")
72 | return tags
73 |
--------------------------------------------------------------------------------
/utils/convert_tf2torch.sh:
--------------------------------------------------------------------------------
1 | # convert tf model to pytorch format
2 |
3 | export BERT_BASE_DIR=/mnt/mrc/wwm_uncased_L-24_H-1024_A-16
4 |
5 | transformers-cli convert --model_type bert \
6 | --tf_checkpoint $BERT_BASE_DIR/model.ckpt \
7 | --config $BERT_BASE_DIR/config.json \
8 | --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin
9 |
--------------------------------------------------------------------------------
/utils/get_parser.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # file: get_parser.py
5 |
6 | import argparse
7 |
8 |
9 | def get_parser() -> argparse.ArgumentParser:
10 | """
11 | return basic arg parser
12 | """
13 | parser = argparse.ArgumentParser(description="Training")
14 |
15 | parser.add_argument("--data_dir", type=str, required=True, help="data dir")
16 | parser.add_argument("--max_keep_ckpt", default=3, type=int, help="the number of keeping ckpt max.")
17 | parser.add_argument("--bert_config_dir", type=str, required=True, help="bert config dir")
18 | parser.add_argument("--pretrained_checkpoint", default="", type=str, help="pretrained checkpoint path")
19 | parser.add_argument("--max_length", type=int, default=128, help="max length of dataset")
20 | parser.add_argument("--batch_size", type=int, default=32, help="batch size")
21 | parser.add_argument("--lr", type=float, default=2e-5, help="learning rate")
22 | parser.add_argument("--workers", type=int, default=0, help="num workers for dataloader")
23 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.")
24 | parser.add_argument("--warmup_steps", default=0, type=int, help="warmup steps used for scheduler.")
25 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
26 | parser.add_argument("--seed", default=0, type=int, help="set random seed for reproducing results.")
27 | return parser
28 |
--------------------------------------------------------------------------------
/utils/random_seed.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python3
3 | # last update: xiaoya li
4 | # issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/1868
5 | # set for trainer: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
6 | # from pytorch_lightning import Trainer, seed_everything
7 | # seed_everything(42)
8 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
9 | # model = Model()
10 | # trainer = Trainer(deterministic=True)
11 |
12 | import random
13 | import torch
14 | import numpy as np
15 | from pytorch_lightning import seed_everything
16 |
17 | def set_random_seed(seed: int):
18 | """set seeds for reproducibility"""
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | seed_everything(seed=seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 |
28 | if __name__ == '__main__':
29 | # without this line, x would be different in every execution.
30 | set_random_seed(0)
31 |
32 | x = np.random.random()
33 | print(x)
34 |
--------------------------------------------------------------------------------