├── Data.zip
├── README.md
├── dee
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── base_task.cpython-36.pyc
│ ├── dee_helper.cpython-36.pyc
│ ├── dee_metric.cpython-36.pyc
│ ├── dee_model.cpython-36.pyc
│ ├── dee_task.cpython-36.pyc
│ ├── event_type.cpython-36.pyc
│ ├── ner_model.cpython-36.pyc
│ ├── ner_task.cpython-36.pyc
│ ├── transformer.cpython-36.pyc
│ ├── utils.cpython-36.pyc
│ └── utils.cpython-37.pyc
├── base_task.py
├── dee_helper.py
├── dee_metric.py
├── dee_model.py
├── dee_task.py
├── event_type.py
├── ner_model.py
├── ner_task.py
├── transformer.py
└── utils.py
├── figs
├── model.png
└── result.png
├── run_dee_task.py
├── run_eval.sh
├── run_train.sh
└── train_multi.sh
/Data.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/Data.zip
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker
2 |
3 | Source code for ACL-IJCNLP 2021 Long paper: [Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker](https://arxiv.org/abs/2105.14924).
4 |
5 | Our code is based on [Doc2EDAG](https://github.com/dolphin-zs/Doc2EDAG).
6 |
7 |
8 | ## 0. Introduction
9 |
10 | > Document-level event extraction aims to extract events within a document. Different from sentence-level event extraction, the arguments of an event record may scatter across sentences, which requires a comprehensive understanding of the cross-sentence context. Besides, a document may express several correlated events simultaneously, and recognizing the interdependency among them is fundamental to successful extraction. To tackle the aforementioned two challenges, We propose a novel heterogeneous Graph-based Interaction Model with a Tracker (GIT). A graph-based interaction network is introduced to capture the global context for the scattered event arguments across sentences with different heterogeneous edges. We also decode event records with a Tracker module, which tracks the extracted event records, so that the interdependency among events is taken into consideration. Our approach delivers better results over the state-of-the-art methods, especially in cross-sentence events and multiple events scenarios.
11 |
12 |
13 | + Architecture
14 | 
15 |
16 | + Overall Results
17 |
18 |

19 |
20 | ## 1. Package Description
21 | ```
22 | GIT/
23 | ├─ dee/
24 | ├── __init__.py
25 | ├── base_task.py
26 | ├── dee_task.py
27 | ├── ner_task.py
28 | ├── dee_helper.py: data features constrcution and evaluation utils
29 | ├── dee_metric.py: data evaluation utils
30 | ├── config.py: process command arguments
31 | ├── dee_model.py: GIT model
32 | ├── ner_model.py
33 | ├── transformer.py: transformer module
34 | ├── utils.py: utils
35 | ├─ run_dee_task.py: the main entry
36 | ├─ train_multi.sh
37 | ├─ run_train.sh: script for training (including evaluation)
38 | ├─ run_eval.sh: script for evaluation
39 | ├─ Exps/: experiment outputs
40 | ├─ Data.zip
41 | ├─ Data: unzip Data.zip
42 | ├─ LICENSE
43 | ├─ README.md
44 | ```
45 |
46 | ## 2. Environments
47 |
48 | - python (3.6.9)
49 | - cuda (11.1)
50 | - Ubuntu-18.0.4 (5.4.0-73-generic)
51 |
52 | ## 3. Dependencies
53 |
54 | - numpy (1.19.5)
55 | - torch (1.8.1+cu111)
56 | - pytorch-pretrained-bert (0.4.0)
57 | - dgl-cu111 (0.6.1)
58 | - tensorboardX (2.2)
59 |
60 | PS: The environments and dependencies listed here is different from what we use in our paper, so the results may be a bit different.
61 |
62 | ## 4. Preparation
63 |
64 | - Unzip Data.zip and you can get an Data folder, where the training/dev/test data locate.
65 |
66 | ## 5. Training
67 |
68 | ```bash
69 | >> bash run_train.sh
70 | ```
71 |
72 | ## 6. Evaluation
73 |
74 | ```bash
75 | >> bash run_eval.sh
76 | ```
77 |
78 | (The evaluation is also conducted after the training)
79 |
80 | ## 7. License
81 |
82 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
83 |
84 | ## 8. Citation
85 |
86 | If you use this work or code, please kindly cite the following paper:
87 |
88 | ```bib
89 | @inproceedings{xu-etal-2021-git,
90 | title = "Document-level Event Extraction via Heterogeneous Graph-based Interaction Model with a Tracker",
91 | author = "Runxin Xu and
92 | Tianyu Liu and
93 | Lei Li and
94 | Baobao Chang",
95 | booktitle = "The Joint Conference of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (ACL-IJCNLP 2021)",
96 | year = "2021",
97 | publisher = "Association for Computational Linguistics",
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/dee/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dee/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/base_task.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/base_task.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/dee_helper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_helper.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/dee_metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_metric.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/dee_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_model.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/dee_task.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/dee_task.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/event_type.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/event_type.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/ner_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/ner_model.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/ner_task.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/ner_task.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/transformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/transformer.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/dee/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/dee/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/dee/base_task.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import logging
4 | import random
5 | import os
6 | import json
7 | import sys
8 | import numpy as np
9 | from datetime import datetime
10 | import torch
11 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
12 | from torch.utils.data.distributed import DistributedSampler
13 | import torch.distributed as dist
14 | import torch.nn.parallel as para
15 | from pytorch_pretrained_bert.optimization import BertAdam
16 | from tqdm import trange, tqdm
17 | from tensorboardX import SummaryWriter
18 |
19 | from .utils import default_dump_pkl, default_dump_json
20 |
21 | PY2 = sys.version_info[0] == 2
22 | PY3 = sys.version_info[0] == 3
23 | if PY2:
24 | import collections
25 | container_abcs = collections
26 | elif PY3:
27 | import collections.abc
28 | container_abcs = collections.abc
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 | class TaskSetting(object):
33 | """Base task setting that can be initialized with a dictionary"""
34 | base_key_attrs = ['data_dir', 'model_dir', 'output_dir']
35 | base_attr_default_pairs = [
36 | ('bert_model', 'bert-base-chinese'),
37 | ('train_file_name', 'train.json'),
38 | ('dev_file_name', 'dev.json'),
39 | ('test_file_name', 'test.json'),
40 | ('max_seq_len', 128),
41 | ('train_batch_size', 32),
42 | ('eval_batch_size', 256),
43 | ('learning_rate', 1e-4),
44 | ('num_train_epochs', 3.0),
45 | ('warmup_proportion', 0.1),
46 | ('no_cuda', False),
47 | ('local_rank', -1),
48 | ('seed', 99),
49 | ('gradient_accumulation_steps', 1),
50 | ('optimize_on_cpu', False),
51 | ('fp16', False),
52 | ('loss_scale', 128),
53 | ('cpt_file_name', 'task.cpt'),
54 | ('summary_dir_name', '/root/summary'),
55 | ]
56 |
57 | def __init__(self, key_attrs, attr_default_pairs, **kwargs):
58 | for key_attr in TaskSetting.base_key_attrs:
59 | setattr(self, key_attr, kwargs[key_attr])
60 |
61 | for attr, val in TaskSetting.base_attr_default_pairs:
62 | setattr(self, attr, val)
63 |
64 | for key_attr in key_attrs:
65 | setattr(self, key_attr, kwargs[key_attr])
66 |
67 | for attr, val in attr_default_pairs:
68 | if attr in kwargs:
69 | setattr(self, attr, kwargs[attr])
70 | else:
71 | setattr(self, attr, val)
72 |
73 | def update_by_dict(self, config_dict):
74 | for key, val in config_dict.items():
75 | setattr(self, key, val)
76 |
77 | def dump_to(self, dir_path, file_name='task_setting.json'):
78 | dump_fp = os.path.join(dir_path, file_name)
79 | default_dump_json(self.__dict__, dump_fp)
80 |
81 | def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False):
82 | """
83 | Utility function for optimize_on_cpu and 16-bits training.
84 | Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model
85 | """
86 | is_nan = False
87 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
88 | if name_opti != name_model:
89 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
90 | raise ValueError
91 | if param_model.grad is not None:
92 | if test_nan and torch.isnan(param_model.grad).sum() > 0:
93 | is_nan = True
94 | if param_opti.grad is None:
95 | param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
96 | param_opti.grad.data.copy_(param_model.grad.data)
97 | else:
98 | param_opti.grad = None
99 | return is_nan
100 |
101 |
102 | def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
103 | """
104 | Utility function for optimize_on_cpu and 16-bits training.
105 | Copy the parameters optimized on CPU/RAM back to the model on GPU
106 | """
107 | for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model):
108 | if name_opti != name_model:
109 | logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
110 | raise ValueError
111 | param_model.data.copy_(param_opti.data)
112 |
113 | class BasePytorchTask(object):
114 | """Basic task to support deep learning models on Pytorch"""
115 |
116 | def __init__(self, setting, only_master_logging=False):
117 | self.setting = setting
118 | self.logger = logging.getLogger(self.__class__.__name__)
119 | self.only_master_logging = only_master_logging
120 |
121 | if self.in_distributed_mode() and not dist.is_initialized():
122 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
123 | dist.init_process_group(backend='nccl')
124 | # dist.init_process_group(backend='gloo') # 3 times slower than nccl for gpu training
125 | torch.cuda.set_device(self.setting.local_rank)
126 | self.logging('World Size {} Rank {}, Local Rank {}, Device Num {}, Device {}'.format(
127 | dist.get_world_size(), dist.get_rank(), self.setting.local_rank,
128 | torch.cuda.device_count(), torch.cuda.current_device()
129 | ))
130 | dist.barrier()
131 |
132 | self._check_setting_validity()
133 | self._init_device()
134 | self.reset_random_seed()
135 | self.summary_writer = None
136 |
137 | # ==> task-specific initialization
138 | # The following functions should be called specifically in inherited classes
139 |
140 | self.custom_collate_fn = None
141 | self.train_examples = None
142 | self.train_features = None
143 | self.train_dataset = None
144 | self.dev_examples = None
145 | self.dev_features = None
146 | self.dev_dataset = None
147 | self.test_examples = None
148 | self.test_features = None
149 | self.test_dataset = None
150 | # self._load_data()
151 |
152 | self.model = None
153 | # self._decorate_model()
154 |
155 | self.optimizer = None
156 | self.num_train_steps = None
157 | self.model_named_parameters = None
158 | # self._init_bert_optimizer()
159 | # (option) self.resume_checkpoint()
160 |
161 | def logging(self, msg, level=logging.INFO):
162 | if self.in_distributed_mode():
163 | msg = 'Rank {} {}'.format(dist.get_rank(), msg)
164 | if self.only_master_logging:
165 | if self.is_master_node():
166 | self.logger.log(level, msg)
167 | else:
168 | self.logger.log(level, msg)
169 |
170 | def _check_setting_validity(self):
171 | self.logging('='*20 + 'Check Setting Validity' + '='*20)
172 | self.logging('Setting: {}'.format(
173 | json.dumps(self.setting.__dict__, ensure_ascii=False, indent=2)
174 | ))
175 |
176 | # check valid grad accumulate step
177 | if self.setting.gradient_accumulation_steps < 1:
178 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
179 | self.setting.gradient_accumulation_steps))
180 | # reset train batch size
181 | self.setting.train_batch_size = int(self.setting.train_batch_size
182 | / self.setting.gradient_accumulation_steps)
183 |
184 | # check output dir
185 | if os.path.exists(self.setting.output_dir) and os.listdir(self.setting.output_dir):
186 | self.logging("Output directory ({}) already exists and is not empty.".format(self.setting.output_dir),
187 | level=logging.WARNING)
188 | os.makedirs(self.setting.output_dir, exist_ok=True)
189 |
190 | # check model dir
191 | if os.path.exists(self.setting.model_dir) and os.listdir(self.setting.model_dir):
192 | self.logging("Model directory ({}) already exists and is not empty.".format(self.setting.model_dir),
193 | level=logging.WARNING)
194 | os.makedirs(self.setting.model_dir, exist_ok=True)
195 |
196 | def _init_device(self):
197 | self.logging('='*20 + 'Init Device' + '='*20)
198 |
199 | # set device
200 | if self.setting.local_rank == -1 or self.setting.no_cuda:
201 | self.device = torch.device("cuda" if torch.cuda.is_available() and not self.setting.no_cuda else "cpu")
202 | self.n_gpu = torch.cuda.device_count()
203 | else:
204 | self.device = torch.device("cuda", self.setting.local_rank)
205 | self.n_gpu = 1
206 | if self.setting.fp16:
207 | self.logging("16-bits training currently not supported in distributed training")
208 | self.setting.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
209 | self.logging("device {} n_gpu {} distributed training {}".format(
210 | self.device, self.n_gpu,self.in_distributed_mode()
211 | ))
212 |
213 | def reset_random_seed(self, seed=None):
214 | if seed is None:
215 | seed = self.setting.seed
216 | self.logging('='*20 + 'Reset Random Seed to {}'.format(seed) + '='*20)
217 |
218 | # set random seeds
219 | random.seed(seed)
220 | np.random.seed(seed)
221 | torch.manual_seed(seed)
222 | if self.n_gpu > 0:
223 | torch.cuda.manual_seed_all(seed)
224 |
225 | def is_master_node(self):
226 | if self.in_distributed_mode():
227 | if dist.get_rank() == 0:
228 | return True
229 | else:
230 | return False
231 | else:
232 | return True
233 |
234 | def in_distributed_mode(self):
235 | return self.setting.local_rank >= 0
236 |
237 | def _init_summary_writer(self):
238 | if self.is_master_node():
239 | self.logging('Init Summary Writer')
240 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
241 | sum_dir = '{}-{}'.format(self.setting.summary_dir_name, current_time)
242 | self.summary_writer = SummaryWriter(sum_dir)
243 | self.logging('Writing summary into {}'.format(sum_dir))
244 |
245 | if self.in_distributed_mode():
246 | # TODO: maybe this can be removed
247 | dist.barrier()
248 |
249 | def load_example_feature_dataset(self, load_example_func, convert_to_feature_func, convert_to_dataset_func,
250 | file_name=None, file_path=None):
251 | if file_name is None and file_path is None:
252 | raise Exception('Either file name or file path should be provided')
253 |
254 | if file_path is None:
255 | file_path = os.path.join(self.setting.data_dir, file_name)
256 |
257 | if os.path.exists(file_path):
258 | self.logging('Load example feature dataset from {}'.format(file_path))
259 | examples = load_example_func(file_path)
260 | features = convert_to_feature_func(examples)
261 | dataset = convert_to_dataset_func(features)
262 | else:
263 | self.logging('Warning: file does not exists, {}'.format(file_path))
264 | examples = None
265 | features = None
266 | dataset = None
267 |
268 | return examples, features, dataset
269 |
270 | def _load_data(self, load_example_func, convert_to_feature_func, convert_to_dataset_func,
271 | load_train=True, load_dev=True, load_test=True):
272 | self.logging('='*20 + 'Load Task Data' + '='*20)
273 | # prepare data
274 | if load_train:
275 | self.logging('Load train portion')
276 | self.train_examples, self.train_features, self.train_dataset = self.load_example_feature_dataset(
277 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
278 | file_name=self.setting.train_file_name
279 | )
280 | else:
281 | self.logging('Do not load train portion')
282 |
283 | if load_dev:
284 | self.logging('Load dev portion')
285 | self.dev_examples, self.dev_features, self.dev_dataset = self.load_example_feature_dataset(
286 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
287 | file_name=self.setting.dev_file_name
288 | )
289 | else:
290 | self.logging('Do not load dev portion')
291 |
292 | if load_test:
293 | self.logging('Load test portion')
294 | self.test_examples, self.test_features, self.test_dataset = self.load_example_feature_dataset(
295 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
296 | file_name=self.setting.test_file_name
297 | )
298 | else:
299 | self.logging('Do not load test portion')
300 |
301 | def reload_data(self, load_example_func, convert_to_feature_func, convert_to_dataset_func,
302 | data_type='return', file_name=None, file_path=None):
303 | """Subclass should inherit this function to omit function arguments"""
304 | if data_type.lower() == 'train':
305 | self.train_examples, self.train_features, self.train_dataset = \
306 | self.load_example_feature_dataset(
307 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
308 | file_name=file_name, file_path=file_path
309 | )
310 | elif data_type.lower() == 'dev':
311 | self.dev_examples, self.dev_features, self.dev_dataset = \
312 | self.load_example_feature_dataset(
313 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
314 | file_name=file_name, file_path=file_path
315 | )
316 | elif data_type.lower() == 'test':
317 | self.test_examples, self.test_features, self.test_dataset = \
318 | self.load_example_feature_dataset(
319 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
320 | file_name=file_name, file_path=file_path
321 | )
322 | elif data_type.lower() == 'return':
323 | examples, features, dataset = self.load_example_feature_dataset(
324 | load_example_func, convert_to_feature_func, convert_to_dataset_func,
325 | file_name=file_name, file_path=file_path,
326 | )
327 |
328 | return examples, features, dataset
329 | else:
330 | raise Exception('Unexpected data type {}'.format(data_type))
331 |
332 | def _decorate_model(self, parallel_decorate=True):
333 | self.logging('='*20 + 'Decorate Model' + '='*20)
334 |
335 | if self.setting.fp16:
336 | self.model.half()
337 |
338 | self.model.to(self.device)
339 | self.logging('Set model device to {}'.format(str(self.device)))
340 |
341 | if parallel_decorate:
342 | if self.in_distributed_mode():
343 | self.model = para.DistributedDataParallel(self.model,
344 | device_ids=[self.setting.local_rank],
345 | output_device=self.setting.local_rank)
346 | self.logging('Wrap distributed data parallel')
347 | # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
348 | elif self.n_gpu > 1:
349 | self.model = para.DataParallel(self.model)
350 | self.logging('Wrap data parallel')
351 | else:
352 | self.logging('Do not wrap parallel layers')
353 |
354 | def _init_bert_optimizer(self):
355 | self.logging('='*20 + 'Init Bert Optimizer' + '='*20)
356 | self.optimizer, self.num_train_steps, self.model_named_parameters = \
357 | self.reset_bert_optimizer()
358 |
359 | def reset_bert_optimizer(self):
360 | # Prepare optimizer
361 | if self.setting.fp16:
362 | model_named_parameters = [(n, param.clone().detach().to('cpu').float().requires_grad_())
363 | for n, param in self.model.named_parameters()]
364 | elif self.setting.optimize_on_cpu:
365 | model_named_parameters = [(n, param.clone().detach().to('cpu').requires_grad_())
366 | for n, param in self.model.named_parameters()]
367 | else:
368 | model_named_parameters = list(self.model.named_parameters())
369 |
370 | no_decay = ['bias', 'gamma', 'beta']
371 | optimizer_grouped_parameters = [
372 | {
373 | 'params': [p for n, p in model_named_parameters if n not in no_decay],
374 | 'weight_decay_rate': 0.01
375 | },
376 | {
377 | 'params': [p for n, p in model_named_parameters if n in no_decay],
378 | 'weight_decay_rate': 0.0
379 | }
380 | ]
381 |
382 | num_train_steps = int(len(self.train_examples)
383 | / self.setting.train_batch_size
384 | / self.setting.gradient_accumulation_steps
385 | * self.setting.num_train_epochs)
386 |
387 | optimizer = BertAdam(optimizer_grouped_parameters,
388 | lr=self.setting.learning_rate,
389 | warmup=self.setting.warmup_proportion,
390 | t_total=num_train_steps)
391 |
392 | return optimizer, num_train_steps, model_named_parameters
393 |
394 | def prepare_data_loader(self, dataset, batch_size, rand_flag=True):
395 | # prepare data loader
396 | if rand_flag:
397 | data_sampler = RandomSampler(dataset)
398 | else:
399 | data_sampler = SequentialSampler(dataset)
400 |
401 | if self.custom_collate_fn is None:
402 | dataloader = DataLoader(dataset,
403 | batch_size=batch_size,
404 | sampler=data_sampler)
405 | else:
406 | dataloader = DataLoader(dataset,
407 | batch_size=batch_size,
408 | sampler=data_sampler,
409 | collate_fn=self.custom_collate_fn)
410 |
411 | return dataloader
412 |
413 | def prepare_dist_data_loader(self, dataset, batch_size, epoch=0):
414 | # prepare distributed data loader
415 | data_sampler = DistributedSampler(dataset)
416 | data_sampler.set_epoch(epoch)
417 |
418 | if self.custom_collate_fn is None:
419 | dataloader = DataLoader(dataset,
420 | batch_size=batch_size,
421 | sampler=data_sampler)
422 | else:
423 | dataloader = DataLoader(dataset,
424 | batch_size=batch_size,
425 | sampler=data_sampler,
426 | collate_fn=self.custom_collate_fn)
427 | return dataloader
428 |
429 | def get_current_train_batch_size(self):
430 | if self.in_distributed_mode():
431 | train_batch_size = max(self.setting.train_batch_size // dist.get_world_size(), 1)
432 | else:
433 | train_batch_size = self.setting.train_batch_size
434 |
435 | return train_batch_size
436 |
437 | def set_batch_to_device(self, batch):
438 | # move mini-batch data to the proper device
439 | if isinstance(batch, torch.Tensor):
440 | batch = batch.to(self.device)
441 |
442 | return batch
443 | elif isinstance(batch, dict):
444 | for key, value in batch.items():
445 | if isinstance(value, torch.Tensor):
446 | batch[key] = value.to(self.device)
447 | elif isinstance(value, dict) or isinstance(value, container_abcs.Sequence):
448 | batch[key] = self.set_batch_to_device(value)
449 |
450 | return batch
451 | elif isinstance(batch, container_abcs.Sequence):
452 | # batch = [
453 | # t.to(self.device) if isinstance(t, torch.Tensor) else t for t in batch
454 | # ]
455 | new_batch = []
456 | for value in batch:
457 | if isinstance(value, torch.Tensor):
458 | new_batch.append(value.to(self.device))
459 | elif isinstance(value, dict) or isinstance(value, container_abcs.Sequence):
460 | new_batch.append(self.set_batch_to_device(value))
461 | else:
462 | new_batch.append(value)
463 |
464 | return new_batch
465 | else:
466 | raise Exception('Unsupported batch type {}'.format(type(batch)))
467 |
468 | def base_train(self, get_loss_func, kwargs_dict1={},
469 | epoch_eval_func=None, kwargs_dict2={}, base_epoch_idx=0):
470 | assert self.model is not None
471 |
472 | if self.num_train_steps is None:
473 | self.num_train_steps = round(
474 | self.setting.num_train_epochs * len(self.train_examples) / self.setting.train_batch_size
475 | )
476 |
477 | train_batch_size = self.get_current_train_batch_size()
478 |
479 | self.logging('='*20 + 'Start Base Training' + '='*20)
480 | self.logging("\tTotal examples Num = {}".format(len(self.train_examples)))
481 | self.logging("\tBatch size = {}".format(self.setting.train_batch_size))
482 | self.logging("\tNum steps = {}".format(self.num_train_steps))
483 | if self.in_distributed_mode():
484 | self.logging("\tWorker Batch Size = {}".format(train_batch_size))
485 | self._init_summary_writer()
486 |
487 | # prepare data loader
488 | train_dataloader = self.prepare_data_loader(
489 | self.train_dataset, self.setting.train_batch_size, rand_flag=True
490 | )
491 |
492 | # enter train mode
493 | global_step = 0
494 | self.model.train()
495 |
496 | self.logging('Reach the epoch beginning')
497 | for epoch_idx in trange(base_epoch_idx, int(self.setting.num_train_epochs), desc="Epoch"):
498 | iter_desc = 'Iteration'
499 | self.model.train()
500 | if self.in_distributed_mode():
501 | train_dataloader = self.prepare_dist_data_loader(
502 | self.train_dataset, train_batch_size, epoch=epoch_idx
503 | )
504 | iter_desc = 'Rank {} {}'.format(dist.get_rank(), iter_desc)
505 |
506 | tr_loss = 0
507 | nb_tr_examples, nb_tr_steps = 0, 0
508 |
509 | if self.only_master_logging:
510 | if self.is_master_node():
511 | step_batch_iter = enumerate(tqdm(train_dataloader, desc=iter_desc))
512 | else:
513 | step_batch_iter = enumerate(train_dataloader)
514 | else:
515 | step_batch_iter = enumerate(tqdm(train_dataloader, desc=iter_desc))
516 |
517 | for step, batch in step_batch_iter:
518 | batch = self.set_batch_to_device(batch)
519 |
520 | # forward
521 | loss = get_loss_func(self, batch, **kwargs_dict1)
522 |
523 | if self.n_gpu > 1:
524 | loss = loss.mean() # mean() to average on multi-gpu.
525 | if self.setting.fp16 and self.setting.loss_scale != 1.0:
526 | # rescale loss for fp16 training
527 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
528 | loss = loss * self.setting.loss_scale
529 | if self.setting.gradient_accumulation_steps > 1:
530 | loss = loss / self.setting.gradient_accumulation_steps
531 |
532 | # backward
533 | loss.backward()
534 |
535 | loss_scalar = loss.item()
536 | tr_loss += loss_scalar
537 | if self.is_master_node():
538 | self.summary_writer.add_scalar('Loss', loss_scalar, global_step=global_step)
539 | nb_tr_examples += self.setting.train_batch_size # may not be very accurate due to incomplete batch
540 | nb_tr_steps += 1
541 | if (step + 1) % self.setting.gradient_accumulation_steps == 0:
542 | if self.setting.fp16 or self.setting.optimize_on_cpu:
543 | if self.setting.fp16 and self.setting.loss_scale != 1.0:
544 | # scale down gradients for fp16 training
545 | for param in self.model.parameters():
546 | param.grad.data = param.grad.data / self.setting.loss_scale
547 | is_nan = set_optimizer_params_grad(
548 | self.model_named_parameters, self.model.named_parameters(), test_nan=True
549 | )
550 | if is_nan:
551 | self.logging("FP16 TRAINING: Nan in gradients, reducing loss scaling")
552 | self.setting.loss_scale = self.setting.loss_scale / 2
553 | self.model.zero_grad()
554 | continue
555 | self.optimizer.step()
556 | copy_optimizer_params_to_model(
557 | self.model.named_parameters(), self.model_named_parameters
558 | )
559 | else:
560 | self.optimizer.step()
561 |
562 | self.model.zero_grad()
563 | global_step += 1
564 |
565 | if epoch_eval_func is not None:
566 | epoch_eval_func(self, epoch_idx + 1, **kwargs_dict2)
567 |
568 | def base_eval(self, eval_dataset, get_info_on_batch, reduce_info_type='mean', dump_pkl_path=None, **func_kwargs):
569 | self.logging('='*20 + 'Start Base Evaluation' + '='*20)
570 | self.logging("\tNum examples = {}".format(len(eval_dataset)))
571 | self.logging("\tBatch size = {}".format(self.setting.eval_batch_size))
572 | self.logging("\tReduce type = {}".format(reduce_info_type))
573 |
574 | # prepare data loader
575 | eval_dataloader = self.prepare_data_loader(
576 | eval_dataset, self.setting.eval_batch_size, rand_flag=False
577 | )
578 |
579 | # enter eval mode
580 | total_info = []
581 | if self.model is not None:
582 | self.model.eval()
583 |
584 | iter_desc = 'Iteration'
585 | if self.in_distributed_mode():
586 | iter_desc = 'Rank {} {}'.format(dist.get_rank(), iter_desc)
587 |
588 | for step, batch in enumerate(tqdm(eval_dataloader, desc=iter_desc)):
589 | batch = self.set_batch_to_device(batch)
590 |
591 | with torch.no_grad():
592 | # this func must run batch_info = model(batch_input)
593 | # and metrics is an instance of torch.Tensor with Size([batch_size, ...])
594 | # to fit the DataParallel and DistributedParallel functionality
595 | batch_info = get_info_on_batch(self, batch, **func_kwargs)
596 | # append metrics from this batch to event_info
597 | if isinstance(batch_info, torch.Tensor):
598 | total_info.append(
599 | batch_info.to(torch.device('cpu')) # collect results in cpu memory
600 | )
601 | else:
602 | # batch_info is a list of some info on each example
603 | total_info.extend(batch_info)
604 |
605 | if isinstance(total_info[0], torch.Tensor):
606 | # transform event_info to torch.Tensor
607 | total_info = torch.cat(total_info, dim=0)
608 |
609 | # [batch_size, ...] -> [...]
610 | if reduce_info_type.lower() == 'sum':
611 | reduced_info = total_info.sum(dim=0)
612 | elif reduce_info_type.lower() == 'mean':
613 | reduced_info = total_info.mean(dim=0)
614 | elif reduce_info_type.lower() == 'none':
615 | reduced_info = total_info
616 | else:
617 | raise Exception('Unsupported reduce metric type {}'.format(reduce_info_type))
618 |
619 | if dump_pkl_path is not None:
620 | default_dump_pkl(reduced_info, dump_pkl_path)
621 |
622 | return reduced_info
623 |
624 | def save_checkpoint(self, cpt_file_name=None, epoch=None):
625 | self.logging('='*20 + 'Dump Checkpoint' + '='*20)
626 | if cpt_file_name is None:
627 | cpt_file_name = self.setting.cpt_file_name
628 | cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
629 | self.logging('Dump checkpoint into {}'.format(cpt_file_path))
630 |
631 | store_dict = {
632 | 'setting': self.setting.__dict__,
633 | }
634 |
635 | if self.model:
636 | if isinstance(self.model, para.DataParallel) or \
637 | isinstance(self.model, para.DistributedDataParallel):
638 | model_state = self.model.module.state_dict()
639 | else:
640 | model_state = self.model.state_dict()
641 | store_dict['model_state'] = model_state
642 | else:
643 | self.logging('No model state is dumped', level=logging.WARNING)
644 |
645 | if self.optimizer:
646 | store_dict['optimizer_state'] = self.optimizer.state_dict()
647 | else:
648 | self.logging('No optimizer state is dumped', level=logging.WARNING)
649 |
650 | if epoch:
651 | store_dict['epoch'] = epoch
652 |
653 | torch.save(store_dict, cpt_file_path)
654 |
655 | def resume_checkpoint(self, cpt_file_path=None, cpt_file_name=None,
656 | resume_model=True, resume_optimizer=False, strict=False):
657 | self.logging('='*20 + 'Resume Checkpoint' + '='*20)
658 | # decide cpt_file_path to resume
659 | if cpt_file_path is None: # use provided path with highest priority
660 | if cpt_file_name is None: # no path and no name will resort to the default cpt name
661 | cpt_file_name = self.setting.cpt_file_name
662 | cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
663 | elif cpt_file_name is not None: # error when path and name are both provided
664 | raise Exception('Confused about path {} or file name {} to resume'.format(
665 | cpt_file_path, cpt_file_name
666 | ))
667 |
668 | if os.path.exists(cpt_file_path):
669 | self.logging('Resume checkpoint from {}'.format(cpt_file_path))
670 | elif strict:
671 | raise Exception('Checkpoint does not exist, {}'.format(cpt_file_path))
672 | else:
673 | self.logging('Checkpoint does not exist, {}'.format(cpt_file_path), level=logging.WARNING)
674 | return
675 |
676 | if torch.cuda.device_count() == 0:
677 | store_dict = torch.load(cpt_file_path, map_location='cpu')
678 | else:
679 | store_dict = torch.load(cpt_file_path, map_location=self.device)
680 |
681 | self.logging('Setting: {}'.format(
682 | json.dumps(store_dict['setting'], ensure_ascii=False, indent=2)
683 | ))
684 |
685 | if resume_model:
686 | if self.model and 'model_state' in store_dict:
687 | if isinstance(self.model, para.DataParallel) or \
688 | isinstance(self.model, para.DistributedDataParallel):
689 | self.model.module.load_state_dict(store_dict['model_state'])
690 | else:
691 | self.model.load_state_dict(store_dict['model_state'])
692 | self.logging('Resume model successfully')
693 | elif strict:
694 | raise Exception('Resume model failed, dict.keys = {}'.format(store_dict.keys()))
695 | else:
696 | self.logging('Do not resume model')
697 |
698 | if resume_optimizer:
699 | if self.optimizer and 'optimizer_state' in store_dict:
700 | self.optimizer.load_state_dict(store_dict['optimizer_state'])
701 | self.logging('Resume optimizer successfully')
702 | elif strict:
703 | raise Exception('Resume optimizer failed, dict.keys = {}'.format(store_dict.keys()))
704 | else:
705 | self.logging('Do not resume optimizer')
706 |
707 |
708 | def average_gradients(model):
709 | """ Gradient averaging. """
710 | size = float(dist.get_world_size())
711 | for name, param in model.named_parameters():
712 | try:
713 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
714 | param.grad.data /= size
715 | except Exception as e:
716 | logger.error('Error when all_reduce parameter {}, size={}, grad_type={}, error message {}'.format(
717 | name, param.size(), param.grad.data.dtype, repr(e)
718 | ))
719 |
--------------------------------------------------------------------------------
/dee/dee_helper.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import logging
4 | import os
5 | import re
6 | from collections import defaultdict, Counter
7 | import numpy as np
8 | import torch
9 |
10 | from .dee_metric import measure_event_table_filling
11 | from .event_type import event_type2event_class, BaseEvent, event_type_fields_list, common_fields
12 | from .ner_task import NERExample, NERFeatureConverter
13 | from .utils import default_load_json, default_dump_json, default_dump_pkl, default_load_pkl
14 |
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class DEEExample(object):
20 | def __init__(self, annguid, detail_align_dict, only_inference=False):
21 | self.guid = annguid
22 | # [sent_text, ...]
23 | self.sentences = detail_align_dict['sentences']
24 | self.num_sentences = len(self.sentences)
25 |
26 | if only_inference:
27 | # set empty entity/event information
28 | self.only_inference = True
29 | self.ann_valid_mspans = []
30 | self.ann_mspan2dranges = {}
31 | self.ann_mspan2guess_field = {}
32 | self.recguid_eventname_eventdict_list = []
33 | self.num_events = 0
34 | self.sent_idx2srange_mspan_mtype_tuples = {}
35 | self.event_type2event_objs = {}
36 | else:
37 | # set event information accordingly
38 | self.only_inference = False
39 |
40 | # [span_text, ...]
41 | self.ann_valid_mspans = detail_align_dict['ann_valid_mspans']
42 | # span_text -> [drange_tuple, ...]
43 | self.ann_mspan2dranges = detail_align_dict['ann_mspan2dranges']
44 | # span_text -> guessed_field_name
45 | self.ann_mspan2guess_field = detail_align_dict['ann_mspan2guess_field']
46 | # [(recguid, event_name, event_dict), ...]
47 | self.recguid_eventname_eventdict_list = detail_align_dict['recguid_eventname_eventdict_list']
48 | self.num_events = len(self.recguid_eventname_eventdict_list)
49 |
50 | # for create ner examples
51 | # sentence_index -> [(sent_match_range, match_span, match_type), ...]
52 | self.sent_idx2srange_mspan_mtype_tuples = {}
53 | for sent_idx in range(self.num_sentences):
54 | self.sent_idx2srange_mspan_mtype_tuples[sent_idx] = []
55 |
56 | for mspan in self.ann_valid_mspans:
57 | for drange in self.ann_mspan2dranges[mspan]:
58 | sent_idx, char_s, char_e = drange
59 | sent_mrange = (char_s, char_e)
60 |
61 | sent_text = self.sentences[sent_idx]
62 | if sent_text[char_s: char_e] != mspan:
63 | raise Exception('GUID: {} span range is not correct, span={}, range={}, sent={}'.format(
64 | annguid, mspan, str(sent_mrange), sent_text
65 | ))
66 |
67 | guess_field = self.ann_mspan2guess_field[mspan]
68 |
69 | self.sent_idx2srange_mspan_mtype_tuples[sent_idx].append(
70 | (sent_mrange, mspan, guess_field)
71 | )
72 |
73 | # for create event objects
74 | # the length of event_objs should >= 1
75 | self.event_type2event_objs = {}
76 | for mrecguid, event_name, event_dict in self.recguid_eventname_eventdict_list:
77 | event_class = event_type2event_class[event_name]
78 | event_obj = event_class()
79 | assert isinstance(event_obj, BaseEvent)
80 | event_obj.update_by_dict(event_dict, recguid=mrecguid)
81 |
82 | if event_obj.name in self.event_type2event_objs:
83 | self.event_type2event_objs[event_obj.name].append(event_obj)
84 | else:
85 | self.event_type2event_objs[event_name] = [event_obj]
86 |
87 | def __repr__(self):
88 | dee_str = 'DEEExample (\n'
89 | dee_str += ' guid: {},\n'.format(repr(self.guid))
90 |
91 | if not self.only_inference:
92 | dee_str += ' span info: (\n'
93 | for span_idx, span in enumerate(self.ann_valid_mspans):
94 | gfield = self.ann_mspan2guess_field[span]
95 | dranges = self.ann_mspan2dranges[span]
96 | dee_str += ' {:2} {:20} {:30} {}\n'.format(span_idx, span, gfield, str(dranges))
97 | dee_str += ' ),\n'
98 |
99 | dee_str += ' event info: (\n'
100 | event_str_list = repr(self.event_type2event_objs).split('\n')
101 | for event_str in event_str_list:
102 | dee_str += ' {}\n'.format(event_str)
103 | dee_str += ' ),\n'
104 |
105 | dee_str += ' sentences: (\n'
106 | for sent_idx, sent in enumerate(self.sentences):
107 | dee_str += ' {:2} {}\n'.format(sent_idx, sent)
108 | dee_str += ' ),\n'
109 |
110 | dee_str += ')\n'
111 |
112 | return dee_str
113 |
114 | @staticmethod
115 | def get_event_type_fields_pairs():
116 | return list(event_type_fields_list)
117 |
118 | @staticmethod
119 | def get_entity_label_list():
120 | visit_set = set()
121 | entity_label_list = [NERExample.basic_entity_label]
122 |
123 | for field in common_fields:
124 | if field not in visit_set:
125 | visit_set.add(field)
126 | entity_label_list.extend(['B-' + field, 'I-' + field])
127 |
128 | for event_name, fields in event_type_fields_list:
129 | for field in fields:
130 | if field not in visit_set:
131 | visit_set.add(field)
132 | entity_label_list.extend(['B-' + field, 'I-' + field])
133 |
134 | return entity_label_list
135 |
136 | class DEEExampleLoader(object):
137 | def __init__(self, rearrange_sent_flag, max_sent_len):
138 | self.rearrange_sent_flag = rearrange_sent_flag
139 | self.max_sent_len = max_sent_len
140 |
141 | def rearrange_sent_info(self, detail_align_info):
142 | if 'ann_valid_dranges' not in detail_align_info:
143 | detail_align_info['ann_valid_dranges'] = []
144 | if 'ann_mspan2dranges' not in detail_align_info:
145 | detail_align_info['ann_mspan2dranges'] = {}
146 |
147 | detail_align_info = dict(detail_align_info)
148 | split_rgx = re.compile('[,::;;))]')
149 |
150 | raw_sents = detail_align_info['sentences']
151 | doc_text = ''.join(raw_sents)
152 | raw_dranges = detail_align_info['ann_valid_dranges']
153 | raw_sid2span_char_set = defaultdict(lambda: set())
154 | for raw_sid, char_s, char_e in raw_dranges:
155 | span_char_set = raw_sid2span_char_set[raw_sid]
156 | span_char_set.update(range(char_s, char_e))
157 |
158 | # try to split long sentences into short ones by comma, colon, semi-colon, bracket 不能把mention切开!
159 | short_sents = []
160 | for raw_sid, sent in enumerate(raw_sents):
161 | span_char_set = raw_sid2span_char_set[raw_sid]
162 | if len(sent) > self.max_sent_len:
163 | cur_char_s = 0
164 | for mobj in split_rgx.finditer(sent):
165 | m_char_s, m_char_e = mobj.span()
166 | if m_char_s in span_char_set:
167 | continue
168 | short_sents.append(sent[cur_char_s:m_char_e])
169 | cur_char_s = m_char_e
170 | short_sents.append(sent[cur_char_s:])
171 | else:
172 | short_sents.append(sent)
173 |
174 | # merge adjacent short sentences to compact ones that match max_sent_len
175 | comp_sents = ['']
176 | for sent in short_sents:
177 | prev_sent = comp_sents[-1]
178 | if len(prev_sent + sent) <= self.max_sent_len:
179 | comp_sents[-1] = prev_sent + sent
180 | else:
181 | comp_sents.append(sent)
182 |
183 | # get global sentence character base indexes
184 | raw_char_bases = [0]
185 | for sent in raw_sents:
186 | raw_char_bases.append(raw_char_bases[-1] + len(sent))
187 | comp_char_bases = [0]
188 | for sent in comp_sents:
189 | comp_char_bases.append(comp_char_bases[-1] + len(sent))
190 |
191 | assert raw_char_bases[-1] == comp_char_bases[-1] == len(doc_text)
192 |
193 | # calculate compact doc ranges
194 | raw_dranges.sort()
195 | raw_drange2comp_drange = {}
196 | prev_comp_sid = 0
197 | for raw_drange in raw_dranges:
198 | raw_drange = tuple(raw_drange) # important when json dump change tuple to list
199 | raw_sid, raw_char_s, raw_char_e = raw_drange
200 | raw_char_base = raw_char_bases[raw_sid]
201 | doc_char_s = raw_char_base + raw_char_s
202 | doc_char_e = raw_char_base + raw_char_e
203 | assert doc_char_s >= comp_char_bases[prev_comp_sid]
204 |
205 | cur_comp_sid = prev_comp_sid
206 | for cur_comp_sid in range(prev_comp_sid, len(comp_sents)):
207 | if doc_char_e <= comp_char_bases[cur_comp_sid+1]:
208 | prev_comp_sid = cur_comp_sid
209 | break
210 | comp_char_base = comp_char_bases[cur_comp_sid]
211 | assert comp_char_base <= doc_char_s < doc_char_e <= comp_char_bases[cur_comp_sid+1]
212 | comp_char_s = doc_char_s - comp_char_base
213 | comp_char_e = doc_char_e - comp_char_base
214 | comp_drange = (cur_comp_sid, comp_char_s, comp_char_e)
215 |
216 | raw_drange2comp_drange[raw_drange] = comp_drange
217 | assert raw_sents[raw_drange[0]][raw_drange[1]:raw_drange[2]] == \
218 | comp_sents[comp_drange[0]][comp_drange[1]:comp_drange[2]]
219 |
220 | # update detailed align info with rearranged sentences
221 | detail_align_info['sentences'] = comp_sents
222 | detail_align_info['ann_valid_dranges'] = [
223 | raw_drange2comp_drange[tuple(raw_drange)] for raw_drange in detail_align_info['ann_valid_dranges']
224 | ]
225 | ann_mspan2comp_dranges = {}
226 | for ann_mspan, mspan_raw_dranges in detail_align_info['ann_mspan2dranges'].items():
227 | comp_dranges = [
228 | raw_drange2comp_drange[tuple(raw_drange)] for raw_drange in mspan_raw_dranges
229 | ]
230 | ann_mspan2comp_dranges[ann_mspan] = comp_dranges
231 | detail_align_info['ann_mspan2dranges'] = ann_mspan2comp_dranges
232 |
233 | return detail_align_info
234 |
235 | def convert_dict_to_example(self, annguid, detail_align_info, only_inference=False):
236 | if self.rearrange_sent_flag:
237 | detail_align_info = self.rearrange_sent_info(detail_align_info)
238 | dee_example = DEEExample(annguid, detail_align_info, only_inference=only_inference)
239 |
240 | return dee_example
241 |
242 | def __call__(self, dataset_json_path):
243 | total_dee_examples = []
244 | annguid_aligninfo_list = default_load_json(dataset_json_path)
245 | for annguid, detail_align_info in annguid_aligninfo_list:
246 | # if self.rearrange_sent_flag:
247 | # detail_align_info = self.rearrange_sent_info(detail_align_info)
248 | # dee_example = DEEExample(annguid, detail_align_info)
249 | dee_example = self.convert_dict_to_example(annguid, detail_align_info)
250 | total_dee_examples.append(dee_example)
251 |
252 | return total_dee_examples
253 |
254 | class DEEFeature(object):
255 | def __init__(self, guid, ex_idx, doc_token_id_mat, doc_token_mask_mat, doc_token_label_mat,
256 | span_token_ids_list, span_dranges_list, event_type_labels, event_arg_idxs_objs_list,
257 | valid_sent_num=None):
258 | self.guid = guid
259 | self.ex_idx = ex_idx # example row index, used for backtracking
260 | self.valid_sent_num = valid_sent_num
261 |
262 | # directly set tensor for dee feature to save memory
263 | # self.doc_token_id_mat = doc_token_id_mat
264 | # self.doc_token_mask_mat = doc_token_mask_mat
265 | # self.doc_token_label_mat = doc_token_label_mat
266 | self.doc_token_ids = torch.tensor(doc_token_id_mat, dtype=torch.long)
267 | self.doc_token_masks = torch.tensor(doc_token_mask_mat, dtype=torch.uint8) # uint8 for mask
268 | self.doc_token_labels = torch.tensor(doc_token_label_mat, dtype=torch.long)
269 |
270 | # sorted by the first drange tuple
271 | # [(token_id, ...), ...]
272 | # span_idx -> span_token_id tuple , list of token ids of span
273 | self.span_token_ids_list = span_token_ids_list
274 | # [[(sent_idx, char_s, char_e), ...], ...]
275 | # span_idx -> [drange tuple, ...]
276 | # self.span_dranges_list[i] contains all the mention spans of self.span_token_ids_list[i] entity
277 | self.span_dranges_list = span_dranges_list
278 |
279 | # [event_type_label, ...]
280 | # length = the total number of events to be considered
281 | # event_type_label \in {0, 1}, 0: no 1: yes
282 | self.event_type_labels = event_type_labels # length=5 1: has this event type 0: does not have this event type
283 | # event_type is denoted by the index of event_type_labels
284 | # event_type_idx -> event_obj_idx -> event_arg_idx -> span_idx
285 | # if no event objects, event_type_idx -> None
286 | self.event_arg_idxs_objs_list = event_arg_idxs_objs_list
287 |
288 | # event_type_idx -> event_field_idx -> pre_path -> {span_idx, ...}
289 | # pre_path is tuple of span_idx
290 | self.event_idx2field_idx2pre_path2cur_span_idx_set = self.build_dag_info(self.event_arg_idxs_objs_list)
291 |
292 | # event_type_idx -> key_sent_idx_set, used for key-event sentence detection
293 | self.event_idx2key_sent_idx_set, self.doc_sent_labels = self.build_key_event_sent_info()
294 |
295 | def generate_dag_info_for(self, pred_span_token_tup_list, return_miss=False):
296 | token_tup2pred_span_idx = {
297 | token_tup: pred_span_idx for pred_span_idx, token_tup in enumerate(pred_span_token_tup_list)
298 | }
299 | gold_span_idx2pred_span_idx = {}
300 | # pred_span_idx2gold_span_idx = {}
301 | missed_span_idx_list = [] # in terms of self
302 | missed_sent_idx_list = [] # in terms of self
303 | for gold_span_idx, token_tup in enumerate(self.span_token_ids_list):
304 | if token_tup in token_tup2pred_span_idx:
305 | pred_span_idx = token_tup2pred_span_idx[token_tup]
306 | gold_span_idx2pred_span_idx[gold_span_idx] = pred_span_idx
307 | # pred_span_idx2gold_span_idx[pred_span_idx] = gold_span_idx
308 | else:
309 | missed_span_idx_list.append(gold_span_idx)
310 | for gold_drange in self.span_dranges_list[gold_span_idx]:
311 | missed_sent_idx_list.append(gold_drange[0])
312 | missed_sent_idx_list = list(set(missed_sent_idx_list))
313 |
314 | pred_event_arg_idxs_objs_list = []
315 | for event_arg_idxs_objs in self.event_arg_idxs_objs_list:
316 | if event_arg_idxs_objs is None:
317 | pred_event_arg_idxs_objs_list.append(None)
318 | else:
319 | pred_event_arg_idxs_objs = []
320 | for event_arg_idxs in event_arg_idxs_objs:
321 | pred_event_arg_idxs = []
322 | for gold_span_idx in event_arg_idxs:
323 | if gold_span_idx in gold_span_idx2pred_span_idx:
324 | pred_event_arg_idxs.append(
325 | gold_span_idx2pred_span_idx[gold_span_idx]
326 | )
327 | else:
328 | pred_event_arg_idxs.append(None)
329 |
330 | pred_event_arg_idxs_objs.append(tuple(pred_event_arg_idxs))
331 | pred_event_arg_idxs_objs_list.append(pred_event_arg_idxs_objs)
332 |
333 | # event_idx -> field_idx -> pre_path -> cur_span_idx_set
334 | pred_dag_info = self.build_dag_info(pred_event_arg_idxs_objs_list)
335 |
336 | if return_miss:
337 | return pred_dag_info, missed_span_idx_list, missed_sent_idx_list
338 | else:
339 | return pred_dag_info
340 |
341 | def get_event_args_objs_list(self):
342 | event_args_objs_list = []
343 | for event_arg_idxs_objs in self.event_arg_idxs_objs_list:
344 | if event_arg_idxs_objs is None:
345 | event_args_objs_list.append(None)
346 | else:
347 | event_args_objs = []
348 | for event_arg_idxs in event_arg_idxs_objs:
349 | event_args = []
350 | for arg_idx in event_arg_idxs:
351 | if arg_idx is None:
352 | token_tup = None
353 | else:
354 | token_tup = self.span_token_ids_list[arg_idx]
355 | event_args.append(token_tup)
356 | event_args_objs.append(event_args)
357 | event_args_objs_list.append(event_args_objs)
358 |
359 | return event_args_objs_list
360 |
361 | def build_key_event_sent_info(self):
362 | assert len(self.event_type_labels) == len(self.event_arg_idxs_objs_list)
363 | # event_idx -> key_event_sent_index_set
364 | event_idx2key_sent_idx_set = [set() for _ in self.event_type_labels]
365 | for key_sent_idx_set, event_label, event_arg_idxs_objs in zip(
366 | event_idx2key_sent_idx_set, self.event_type_labels, self.event_arg_idxs_objs_list
367 | ):
368 | if event_label == 0:
369 | assert event_arg_idxs_objs is None
370 | else:
371 | for event_arg_idxs_obj in event_arg_idxs_objs:
372 | sent_idx_cands = []
373 | for span_idx in event_arg_idxs_obj:
374 | if span_idx is None:
375 | continue
376 | span_dranges = self.span_dranges_list[span_idx]
377 | for sent_idx, _, _ in span_dranges:
378 | sent_idx_cands.append(sent_idx)
379 | if len(sent_idx_cands) == 0:
380 | raise Exception('Event {} has no valid spans'.format(str(event_arg_idxs_obj)))
381 | sent_idx_cnter = Counter(sent_idx_cands)
382 | key_sent_idx = sent_idx_cnter.most_common()[0][0]
383 | key_sent_idx_set.add(key_sent_idx)
384 |
385 | doc_sent_labels = [] # 1: key event sentence, 0: otherwise
386 | for sent_idx in range(self.valid_sent_num): # masked sents will be truncated at the model part
387 | sent_labels = []
388 | for key_sent_idx_set in event_idx2key_sent_idx_set: # this mapping is a list
389 | if sent_idx in key_sent_idx_set:
390 | sent_labels.append(1)
391 | else:
392 | sent_labels.append(0)
393 | doc_sent_labels.append(sent_labels)
394 |
395 | return event_idx2key_sent_idx_set, doc_sent_labels
396 |
397 | @staticmethod
398 | def build_dag_info(event_arg_idxs_objs_list):
399 | # event_idx -> field_idx -> pre_path -> {span_idx, ...}
400 | # pre_path is tuple of span_idx
401 | event_idx2field_idx2pre_path2cur_span_idx_set = []
402 | for event_idx, event_arg_idxs_list in enumerate(event_arg_idxs_objs_list):
403 | if event_arg_idxs_list is None:
404 | event_idx2field_idx2pre_path2cur_span_idx_set.append(None)
405 | else:
406 | num_fields = len(event_arg_idxs_list[0])
407 | # field_idx -> pre_path -> {span_idx, ...}
408 | field_idx2pre_path2cur_span_idx_set = []
409 | for field_idx in range(num_fields):
410 | pre_path2cur_span_idx_set = {}
411 | for event_arg_idxs in event_arg_idxs_list:
412 | pre_path = event_arg_idxs[:field_idx]
413 | span_idx = event_arg_idxs[field_idx]
414 | if pre_path not in pre_path2cur_span_idx_set:
415 | pre_path2cur_span_idx_set[pre_path] = set()
416 | pre_path2cur_span_idx_set[pre_path].add(span_idx)
417 | field_idx2pre_path2cur_span_idx_set.append(pre_path2cur_span_idx_set)
418 | event_idx2field_idx2pre_path2cur_span_idx_set.append(field_idx2pre_path2cur_span_idx_set)
419 |
420 | return event_idx2field_idx2pre_path2cur_span_idx_set
421 |
422 | def is_multi_event(self):
423 | event_cnt = 0
424 | for event_objs in self.event_arg_idxs_objs_list:
425 | if event_objs is not None:
426 | event_cnt += len(event_objs)
427 | if event_cnt > 1:
428 | return True
429 |
430 | return False
431 |
432 | class DEEFeatureConverter(object):
433 | def __init__(self, entity_label_list, event_type_fields_pairs,
434 | max_sent_len, max_sent_num, tokenizer,
435 | ner_fea_converter=None, include_cls=True, include_sep=True):
436 | self.entity_label_list = entity_label_list
437 | self.event_type_fields_pairs = event_type_fields_pairs
438 | self.max_sent_len = max_sent_len
439 | self.max_sent_num = max_sent_num
440 | self.tokenizer = tokenizer
441 | self.truncate_doc_count = 0 # track how many docs have been truncated due to max_sent_num
442 | self.truncate_span_count = 0 # track how may spans have been truncated
443 |
444 | # label not in entity_label_list will be default 'O'
445 | # sent_len > max_sent_len will be truncated, and increase ner_fea_converter.truncate_freq
446 | if ner_fea_converter is None:
447 | self.ner_fea_converter = NERFeatureConverter(entity_label_list, self.max_sent_len, tokenizer,
448 | include_cls=include_cls, include_sep=include_sep)
449 | else:
450 | self.ner_fea_converter = ner_fea_converter
451 |
452 | self.include_cls = include_cls
453 | self.include_sep = include_sep
454 |
455 | # prepare entity_label -> entity_index mapping
456 | self.entity_label2index = {}
457 | for entity_idx, entity_label in enumerate(self.entity_label_list):
458 | self.entity_label2index[entity_label] = entity_idx
459 |
460 | # prepare event_type -> event_index and event_index -> event_fields mapping
461 | self.event_type2index = {}
462 | self.event_type_list = []
463 | self.event_fields_list = []
464 | for event_idx, (event_type, event_fields) in enumerate(self.event_type_fields_pairs):
465 | self.event_type2index[event_type] = event_idx
466 | self.event_type_list.append(event_type)
467 | self.event_fields_list.append(event_fields)
468 |
469 | def convert_example_to_feature(self, ex_idx, dee_example, log_flag=False):
470 | annguid = dee_example.guid
471 | assert isinstance(dee_example, DEEExample)
472 |
473 | # 1. prepare doc token-level feature
474 |
475 | # Size(num_sent_num, num_sent_len)
476 | doc_token_id_mat = [] # [[token_idx, ...], ...]
477 | doc_token_mask_mat = [] # [[token_mask, ...], ...]
478 | doc_token_label_mat = [] # [[token_label_id, ...], ...]
479 |
480 | for sent_idx, sent_text in enumerate(dee_example.sentences):
481 | if sent_idx >= self.max_sent_num:
482 | # truncate doc whose number of sentences is longer than self.max_sent_num
483 | self.truncate_doc_count += 1
484 | break
485 |
486 | if sent_idx in dee_example.sent_idx2srange_mspan_mtype_tuples:
487 | srange_mspan_mtype_tuples = dee_example.sent_idx2srange_mspan_mtype_tuples[sent_idx]
488 | else:
489 | srange_mspan_mtype_tuples = []
490 |
491 | # srange_mspan_mtype_tuples in this sentence (span-position,span-text,span-type)
492 |
493 | ner_example = NERExample(
494 | '{}-{}'.format(annguid, sent_idx), sent_text, srange_mspan_mtype_tuples
495 | )
496 | # sentence truncated count will be recorded incrementally
497 | ner_feature = self.ner_fea_converter.convert_example_to_feature(ner_example, log_flag=log_flag)
498 |
499 | doc_token_id_mat.append(ner_feature.input_ids)
500 | doc_token_mask_mat.append(ner_feature.input_masks)
501 | doc_token_label_mat.append(ner_feature.label_ids)
502 |
503 | # already pad to max_len=128
504 |
505 | assert len(doc_token_id_mat) == len(doc_token_mask_mat) == len(doc_token_label_mat) <= self.max_sent_num
506 | valid_sent_num = len(doc_token_id_mat)
507 |
508 | # 2. prepare span feature
509 | # spans are sorted by the first drange
510 | span_token_ids_list = []
511 | span_dranges_list = []
512 | mspan2span_idx = {}
513 | for mspan in dee_example.ann_valid_mspans:
514 | if mspan in mspan2span_idx:
515 | continue
516 |
517 | raw_dranges = dee_example.ann_mspan2dranges[mspan]
518 | char_base_s = 1 if self.include_cls else 0
519 | char_max_end = self.max_sent_len - 1 if self.include_sep else self.max_sent_len
520 | span_dranges = []
521 | for sent_idx, char_s, char_e in raw_dranges:
522 | if char_base_s + char_e <= char_max_end and sent_idx < self.max_sent_num:
523 | span_dranges.append((sent_idx, char_base_s + char_s, char_base_s + char_e))
524 | else:
525 | self.truncate_span_count += 1
526 | if len(span_dranges) == 0:
527 | # span does not have any valid location in truncated sequences
528 | continue
529 |
530 | span_tokens = self.tokenizer.char_tokenize(mspan)
531 | span_token_ids = tuple(self.tokenizer.convert_tokens_to_ids(span_tokens))
532 |
533 | mspan2span_idx[mspan] = len(span_token_ids_list)
534 | span_token_ids_list.append(span_token_ids)
535 | span_dranges_list.append(span_dranges)
536 | assert len(span_token_ids_list) == len(span_dranges_list) == len(mspan2span_idx)
537 |
538 | if len(span_token_ids_list) == 0 and not dee_example.only_inference:
539 | logger.warning('Neglect example {}'.format(ex_idx))
540 | return None
541 |
542 | # 3. prepare doc-level event feature
543 | # event_type_labels: event_type_index -> event_type_exist_sign (1: exist, 0: no)
544 | # event_arg_idxs_objs_list: event_type_index -> event_obj_index -> event_arg_index -> arg_span_token_ids
545 |
546 | event_type_labels = [] # event_type_idx -> event_type_exist_sign (1 or 0)
547 | event_arg_idxs_objs_list = [] # event_type_idx -> event_obj_idx -> event_arg_idx -> span_idx
548 | for event_idx, event_type in enumerate(self.event_type_list):
549 | event_fields = self.event_fields_list[event_idx]
550 |
551 | if event_type not in dee_example.event_type2event_objs:
552 | event_type_labels.append(0)
553 | event_arg_idxs_objs_list.append(None)
554 | else:
555 | event_objs = dee_example.event_type2event_objs[event_type]
556 |
557 | event_arg_idxs_objs = []
558 | for event_obj in event_objs:
559 | assert isinstance(event_obj, BaseEvent)
560 |
561 | event_arg_idxs = []
562 | any_valid_flag = False
563 | for field in event_fields:
564 | arg_span = event_obj.field2content[field]
565 |
566 | if arg_span is None or arg_span not in mspan2span_idx:
567 | # arg_span can be none or valid span is truncated
568 | arg_span_idx = None
569 | else:
570 | # when constructing data files,
571 | # must ensure event arg span is covered by the total span collections
572 | arg_span_idx = mspan2span_idx[arg_span]
573 | any_valid_flag = True
574 |
575 | event_arg_idxs.append(arg_span_idx)
576 |
577 | if any_valid_flag:
578 | event_arg_idxs_objs.append(tuple(event_arg_idxs))
579 |
580 | if event_arg_idxs_objs:
581 | event_type_labels.append(1)
582 | event_arg_idxs_objs_list.append(event_arg_idxs_objs)
583 | else:
584 | event_type_labels.append(0)
585 | event_arg_idxs_objs_list.append(None)
586 |
587 | dee_feature = DEEFeature(
588 | annguid, ex_idx, doc_token_id_mat, doc_token_mask_mat, doc_token_label_mat,
589 | span_token_ids_list, span_dranges_list, event_type_labels, event_arg_idxs_objs_list,
590 | valid_sent_num=valid_sent_num
591 | )
592 | return dee_feature
593 |
594 | def __call__(self, dee_examples, log_example_num=0):
595 | """Convert examples to features suitable for document-level event extraction"""
596 | dee_features = []
597 | self.truncate_doc_count = 0
598 | self.truncate_span_count = 0
599 | self.ner_fea_converter.truncate_count = 0
600 |
601 | remove_ex_cnt = 0
602 | for ex_idx, dee_example in enumerate(dee_examples):
603 | if ex_idx < log_example_num:
604 | dee_feature = self.convert_example_to_feature(ex_idx-remove_ex_cnt, dee_example, log_flag=True)
605 | else:
606 | dee_feature = self.convert_example_to_feature(ex_idx-remove_ex_cnt, dee_example, log_flag=False)
607 |
608 | if dee_feature is None:
609 | remove_ex_cnt += 1
610 | continue
611 | dee_features.append(dee_feature)
612 |
613 | logger.info('{} documents, ignore {} examples, truncate {} docs, {} sents, {} spans'.format(
614 | len(dee_examples), remove_ex_cnt,
615 | self.truncate_doc_count, self.ner_fea_converter.truncate_count, self.truncate_span_count
616 | ))
617 |
618 | return dee_features
619 |
620 | def convert_dee_features_to_dataset(dee_features):
621 | # just view a list of doc_fea as the dataset, that only requires __len__, __getitem__
622 | assert len(dee_features) > 0 and isinstance(dee_features[0], DEEFeature)
623 |
624 | return dee_features
625 |
626 | def prepare_doc_batch_dict(doc_fea_list):
627 | doc_batch_keys = ['ex_idx', 'doc_token_ids', 'doc_token_masks', 'doc_token_labels', 'valid_sent_num']
628 | doc_batch_dict = {}
629 | for key in doc_batch_keys:
630 | doc_batch_dict[key] = [getattr(doc_fea, key) for doc_fea in doc_fea_list]
631 |
632 | return doc_batch_dict
633 |
634 | def measure_dee_prediction(event_type_fields_pairs, features, event_decode_results,
635 | dump_json_path=None, writer=None, epoch=None):
636 | pred_record_mat_list = []
637 | gold_record_mat_list = []
638 | for term in event_decode_results:
639 | ex_idx, pred_event_type_labels, pred_record_mat, doc_span_info = term[:4]
640 | pred_record_mat = [
641 | [
642 | [
643 | tuple(arg_tup) if arg_tup is not None else None
644 | for arg_tup in pred_record
645 | ] for pred_record in pred_records
646 | ] if pred_records is not None else None
647 | for pred_records in pred_record_mat
648 | ]
649 | doc_fea = features[ex_idx]
650 | assert isinstance(doc_fea, DEEFeature)
651 | gold_record_mat = [
652 | [
653 | [
654 | tuple(doc_fea.span_token_ids_list[arg_idx]) if arg_idx is not None else None
655 | for arg_idx in event_arg_idxs
656 | ] for event_arg_idxs in event_arg_idxs_objs
657 | ] if event_arg_idxs_objs is not None else None
658 | for event_arg_idxs_objs in doc_fea.event_arg_idxs_objs_list
659 | ]
660 |
661 | pred_record_mat_list.append(pred_record_mat)
662 | gold_record_mat_list.append(gold_record_mat)
663 |
664 | g_eval_res = measure_event_table_filling(
665 | pred_record_mat_list, gold_record_mat_list, event_type_fields_pairs, dict_return=True
666 | )
667 |
668 | if writer is not None and dump_json_path is not None:
669 | if 'dev' in dump_json_path:
670 | prefix = 'Dev-Pred-' if 'pred' in dump_json_path else 'Dev-Gold-'
671 | else:
672 | prefix = 'Test-Pred-' if 'pred' in dump_json_path else 'Test-Gold-'
673 | writer.add_scalar(prefix+'MicroF1', g_eval_res[-1]['MicroF1'], global_step=epoch)
674 | writer.add_scalar(prefix+'MacroF1', g_eval_res[-1]['MacroF1'], global_step=epoch)
675 | writer.add_scalar(prefix+'MicroPrecision', g_eval_res[-1]['MicroPrecision'], global_step=epoch)
676 | writer.add_scalar(prefix+'MicroRecall', g_eval_res[-1]['MicroRecall'], global_step=epoch)
677 |
678 | event_triggering_tp = [0 for _ in range(5)]
679 | event_triggering_fp = [0 for _ in range(5)]
680 | event_triggering_fn = [0 for _ in range(5)]
681 | for term in event_decode_results:
682 | ex_idx, pred_event_type_labels, pred_record_mat, doc_span_info = term[:4]
683 | event_triggering_golden = features[ex_idx].event_type_labels
684 | for et_idx, et in enumerate(pred_event_type_labels):
685 | if pred_event_type_labels[et_idx] == 1:
686 | if event_triggering_golden[et_idx] == 1:
687 | event_triggering_tp[et_idx] += 1
688 | else:
689 | event_triggering_fp[et_idx] += 1
690 | else:
691 | if event_triggering_golden[et_idx] == 1:
692 | event_triggering_fn[et_idx] += 1
693 |
694 | for eidx in range(5):
695 | if event_triggering_tp[eidx]+event_triggering_fp[eidx] != 0:
696 | event_p = event_triggering_tp[eidx] / (event_triggering_tp[eidx]+event_triggering_fp[eidx])
697 | else:
698 | event_p = 0
699 | if event_triggering_tp[eidx]+event_triggering_fn[eidx] != 0:
700 | event_r = event_triggering_tp[eidx] / (event_triggering_tp[eidx]+event_triggering_fn[eidx])
701 | else:
702 | event_r = 0
703 | if event_p != 0 and event_r != 0:
704 | event_f1 = 2 * event_p * event_r / (event_p + event_r)
705 | else:
706 | event_f1 = 0
707 | g_eval_res[-1]['event_{}_p'.format(eidx+1)] = event_p
708 | g_eval_res[-1]['event_{}_r'.format(eidx+1)] = event_r
709 | g_eval_res[-1]['event_{}_f1'.format(eidx+1)] = event_f1
710 |
711 | if dump_json_path is not None:
712 | default_dump_json(g_eval_res, dump_json_path)
713 |
714 | return g_eval_res
715 |
716 | def aggregate_task_eval_info(eval_dir_path, target_file_pre='dee_eval', target_file_suffix='.json',
717 | dump_name='total_task_eval.pkl', dump_flag=False):
718 | """Enumerate the evaluation directory to collect all dumped evaluation results"""
719 | logger.info('Aggregate task evaluation info from {}'.format(eval_dir_path))
720 | data_span_type2model_str2epoch_res_list = {}
721 | for fn in os.listdir(eval_dir_path):
722 | fn_splits = fn.split('.')
723 | if fn.startswith(target_file_pre) and fn.endswith(target_file_suffix) and len(fn_splits) == 6:
724 | _, data_type, span_type, model_str, epoch, _ = fn_splits
725 |
726 | data_span_type = (data_type, span_type)
727 | if data_span_type not in data_span_type2model_str2epoch_res_list:
728 | data_span_type2model_str2epoch_res_list[data_span_type] = {}
729 | model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[data_span_type]
730 |
731 | if model_str not in model_str2epoch_res_list:
732 | model_str2epoch_res_list[model_str] = []
733 | epoch_res_list = model_str2epoch_res_list[model_str]
734 |
735 | epoch = int(epoch)
736 | fp = os.path.join(eval_dir_path, fn)
737 | eval_res = default_load_json(fp)
738 |
739 | epoch_res_list.append((epoch, eval_res))
740 |
741 | for data_span_type, model_str2epoch_res_list in data_span_type2model_str2epoch_res_list.items():
742 | for model_str, epoch_res_list in model_str2epoch_res_list.items():
743 | epoch_res_list.sort(key=lambda x: x[0])
744 |
745 | if dump_flag:
746 | dump_fp = os.path.join(eval_dir_path, dump_name)
747 | logger.info('Dumping {} into {}'.format(dump_name, eval_dir_path))
748 | default_dump_pkl(data_span_type2model_str2epoch_res_list, dump_fp)
749 |
750 | return data_span_type2model_str2epoch_res_list
751 |
752 | def print_total_eval_info(data_span_type2model_str2epoch_res_list,
753 | metric_type='micro',
754 | span_type='pred_span',
755 | model_str='GIT',
756 | target_set='test'):
757 | """Print the final performance by selecting the best epoch on dev set and emitting performance on test set"""
758 | dev_type = 'dev'
759 | test_type = 'test'
760 | avg_type2prf1_keys = {
761 | 'macro': ('MacroPrecision', 'MacroRecall', 'MacroF1'),
762 | 'micro': ('MicroPrecision', 'MicroRecall', 'MicroF1'),
763 | }
764 |
765 | name_key = 'EventType'
766 | p_key, r_key, f_key = avg_type2prf1_keys[metric_type]
767 |
768 | def get_avg_event_score(epoch_res):
769 | eval_res = epoch_res[1]
770 | avg_event_score = eval_res[-1][f_key]
771 |
772 | return avg_event_score
773 |
774 | dev_model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[(dev_type, span_type)]
775 | test_model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[(test_type, span_type)]
776 |
777 | has_header = False
778 | mstr_bepoch_list = []
779 | print('=' * 15, 'Final Performance (%) (avg_type={})'.format(metric_type), '=' * 15)
780 |
781 | if model_str not in dev_model_str2epoch_res_list or model_str not in test_model_str2epoch_res_list:
782 | pass
783 | else:
784 | # get the best epoch on dev set
785 | dev_epoch_res_list = dev_model_str2epoch_res_list[model_str]
786 | best_dev_epoch, best_dev_res = max(dev_epoch_res_list, key=get_avg_event_score)
787 |
788 | test_epoch_res_list = test_model_str2epoch_res_list[model_str]
789 | best_test_epoch = None
790 | best_test_res = None
791 | for test_epoch, test_res in test_epoch_res_list:
792 | if test_epoch == best_dev_epoch:
793 | best_test_epoch = test_epoch
794 | best_test_res = test_res
795 | assert best_test_epoch is not None
796 | mstr_bepoch_list.append((model_str, best_test_epoch))
797 |
798 | if target_set == 'test':
799 | target_eval_res = best_test_res
800 | else:
801 | target_eval_res = best_dev_res
802 |
803 | align_temp = '{:20}'
804 | head_str = align_temp.format('ModelType')
805 | eval_str = align_temp.format(model_str)
806 | head_temp = ' \t {}'
807 | eval_temp = ' \t & {:.1f} & {:.1f} & {:.1f}'
808 | ps = []
809 | rs = []
810 | fs = []
811 | for tgt_event_res in target_eval_res[:-1]:
812 | head_str += align_temp.format(head_temp.format(tgt_event_res[0][name_key]))
813 | p, r, f1 = (100 * tgt_event_res[0][key] for key in [p_key, r_key, f_key])
814 | eval_str += align_temp.format(eval_temp.format(p, r, f1))
815 | ps.append(p)
816 | rs.append(r)
817 | fs.append(f1)
818 |
819 | head_str += align_temp.format(head_temp.format('Average'))
820 | ap, ar, af1 = (x for x in [np.mean(ps), np.mean(rs), np.mean(fs)])
821 | eval_str += align_temp.format(eval_temp.format(ap, ar, af1))
822 |
823 | head_str += align_temp.format(head_temp.format('Total ({})'.format(metric_type)))
824 | g_avg_res = target_eval_res[-1]
825 | ap, ar, af1 = (100 * g_avg_res[key] for key in [p_key, r_key, f_key])
826 | eval_str += align_temp.format(eval_temp.format(ap, ar, af1))
827 |
828 | if not has_header:
829 | print(head_str)
830 | has_header = True
831 | print(eval_str)
832 |
833 | print(mstr_bepoch_list)
834 | return mstr_bepoch_list
835 |
836 | # evaluation dump file name template
837 | # dee_eval.[DataType].[SpanType].[ModelStr].[Epoch].(pkl|json)
838 | decode_dump_template = 'dee_eval.{}.{}.{}.{}.pkl'
839 | eval_dump_template = 'dee_eval.{}.{}.{}.{}.json'
840 |
841 | def resume_decode_results(base_dir, data_type, span_type, model_str, epoch):
842 | decode_fn = decode_dump_template.format(data_type, span_type, model_str, epoch)
843 | decode_fp = os.path.join(base_dir, decode_fn)
844 | logger.info('Resume decoded results from {}'.format(decode_fp))
845 | decode_results = default_load_pkl(decode_fp)
846 |
847 | return decode_results
848 |
849 | def resume_eval_results(base_dir, data_type, span_type, model_str, epoch):
850 | eval_fn = eval_dump_template.format(data_type, span_type, model_str, epoch)
851 | eval_fp = os.path.join(base_dir, eval_fn)
852 | logger.info('Resume eval results from {}'.format(eval_fp))
853 | eval_results = default_load_json(eval_fp)
854 |
855 | return eval_results
856 |
857 | def print_single_vs_multi_performance(mstr_bepoch_list, base_dir, features,
858 | metric_type='micro', data_type='test', span_type='pred_span'):
859 | model_str2decode_results = {}
860 | for model_str, best_epoch in mstr_bepoch_list:
861 | model_str2decode_results[model_str] = resume_decode_results(
862 | base_dir, data_type, span_type, model_str, best_epoch
863 | )
864 |
865 | single_eid_set = set([doc_fea.ex_idx for doc_fea in features if not doc_fea.is_multi_event()])
866 | multi_eid_set = set([doc_fea.ex_idx for doc_fea in features if doc_fea.is_multi_event()])
867 | event_type_fields_pairs = DEEExample.get_event_type_fields_pairs()
868 | event_type_list = [x for x, y in event_type_fields_pairs]
869 |
870 | name_key = 'EventType'
871 | avg_type2f1_key = {
872 | 'micro': 'MicroF1',
873 | 'macro': 'MacroF1',
874 | }
875 | f1_key = avg_type2f1_key[metric_type]
876 |
877 | model_str2etype_sf1_mf1_list = {}
878 | for model_str, _ in mstr_bepoch_list:
879 | total_decode_results = model_str2decode_results[model_str]
880 |
881 | single_decode_results = [dec_res for dec_res in total_decode_results if dec_res[0] in single_eid_set]
882 | assert len(single_decode_results) == len(single_eid_set)
883 | single_eval_res = measure_dee_prediction(
884 | event_type_fields_pairs, features, single_decode_results
885 | )
886 |
887 | multi_decode_results = [dec_res for dec_res in total_decode_results if dec_res[0] in multi_eid_set]
888 | assert len(multi_decode_results) == len(multi_eid_set)
889 | multi_eval_res = measure_dee_prediction(
890 | event_type_fields_pairs, features, multi_decode_results
891 | )
892 |
893 | etype_sf1_mf1_list = []
894 | for event_idx, (se_res, me_res) in enumerate(zip(single_eval_res[:-1], multi_eval_res[:-1])):
895 | assert se_res[0][name_key] == me_res[0][name_key] == event_type_list[event_idx]
896 | event_type = event_type_list[event_idx]
897 | single_f1 = se_res[0][f1_key]
898 | multi_f1 = me_res[0][f1_key]
899 |
900 | etype_sf1_mf1_list.append((event_type, single_f1, multi_f1))
901 | g_avg_se_res = single_eval_res[-1]
902 | g_avg_me_res = multi_eval_res[-1]
903 | etype_sf1_mf1_list.append(
904 | ('Total ({})'.format(metric_type), g_avg_se_res[f1_key], g_avg_me_res[f1_key])
905 | )
906 | model_str2etype_sf1_mf1_list[model_str] = etype_sf1_mf1_list
907 |
908 | print('=' * 15, 'Single vs. Multi (%) (avg_type={})'.format(metric_type), '=' * 15)
909 | align_temp = '{:20}'
910 | head_str = align_temp.format('ModelType')
911 | head_temp = ' \t {}'
912 | eval_temp = ' \t & {:.1f} & {:.1f} '
913 | for event_type in event_type_list:
914 | head_str += align_temp.format(head_temp.format(event_type))
915 | head_str += align_temp.format(head_temp.format('Total ({})'.format(metric_type)))
916 | head_str += align_temp.format(head_temp.format('Average'))
917 | print(head_str)
918 |
919 | for model_str, _ in mstr_bepoch_list:
920 | eval_str = align_temp.format(model_str)
921 | sf1s = []
922 | mf1s = []
923 | for _, single_f1, multi_f1 in model_str2etype_sf1_mf1_list[model_str]:
924 | eval_str += align_temp.format(eval_temp.format(single_f1*100, multi_f1*100))
925 | sf1s.append(single_f1)
926 | mf1s.append(multi_f1)
927 | avg_sf1 = np.mean(sf1s[:-1])
928 | avg_mf1 = np.mean(mf1s[:-1])
929 | eval_str += align_temp.format(eval_temp.format(avg_sf1*100, avg_mf1*100))
930 | print(eval_str)
--------------------------------------------------------------------------------
/dee/dee_metric.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import numpy as np
4 |
5 |
6 | def agg_event_role_tpfpfn_stats(pred_records, gold_records, role_num):
7 | """
8 | Aggregate TP,FP,FN statistics for a single event prediction of one instance.
9 | A pred_records should be formated as
10 | [(Record Index)
11 | ((Role Index)
12 | argument 1, ...
13 | ), ...
14 | ], where argument 1 should support the '=' operation and the empty argument is None.
15 | """
16 | role_tpfpfn_stats = [[0] * 3 for _ in range(role_num)]
17 |
18 | if gold_records is None:
19 | if pred_records is not None: # FP
20 | for pred_record in pred_records:
21 | assert len(pred_record) == role_num
22 | for role_idx, arg_tup in enumerate(pred_record):
23 | if arg_tup is not None:
24 | role_tpfpfn_stats[role_idx][1] += 1
25 | else: # ignore TN
26 | pass
27 | else:
28 | if pred_records is None: # FN
29 | for gold_record in gold_records:
30 | assert len(gold_record) == role_num
31 | for role_idx, arg_tup in enumerate(gold_record):
32 | if arg_tup is not None:
33 | role_tpfpfn_stats[role_idx][2] += 1
34 | else: # True Positive at the event level
35 | # sort predicted event records by the non-empty count
36 | # to remove the impact of the record order on evaluation
37 | pred_records = sorted(pred_records,
38 | key=lambda x: sum(1 for a in x if a is not None),
39 | reverse=True)
40 | gold_records = list(gold_records)
41 |
42 | while len(pred_records) > 0 and len(gold_records) > 0:
43 | pred_record = pred_records[0]
44 | assert len(pred_record) == role_num
45 |
46 | # pick the most similar gold record
47 | _tmp_key = lambda gr: sum([1 for pa, ga in zip(pred_record, gr) if pa == ga])
48 | best_gr_idx = gold_records.index(max(gold_records, key=_tmp_key))
49 | gold_record = gold_records[best_gr_idx]
50 |
51 | for role_idx, (pred_arg, gold_arg) in enumerate(zip(pred_record, gold_record)):
52 | if gold_arg is None:
53 | if pred_arg is not None: # FP at the role level
54 | role_tpfpfn_stats[role_idx][1] += 1
55 | else: # ignore TN
56 | pass
57 | else:
58 | if pred_arg is None: # FN
59 | role_tpfpfn_stats[role_idx][2] += 1
60 | else:
61 | if pred_arg == gold_arg: # TP
62 | role_tpfpfn_stats[role_idx][0] += 1
63 | else:
64 | role_tpfpfn_stats[role_idx][1] += 1
65 | role_tpfpfn_stats[role_idx][2] += 1
66 |
67 | del pred_records[0]
68 | del gold_records[best_gr_idx]
69 |
70 | # remaining FP
71 | for pred_record in pred_records:
72 | assert len(pred_record) == role_num
73 | for role_idx, arg_tup in enumerate(pred_record):
74 | if arg_tup is not None:
75 | role_tpfpfn_stats[role_idx][1] += 1
76 | # remaining FN
77 | for gold_record in gold_records:
78 | assert len(gold_record) == role_num
79 | for role_idx, arg_tup in enumerate(gold_record):
80 | if arg_tup is not None:
81 | role_tpfpfn_stats[role_idx][2] += 1
82 |
83 | return role_tpfpfn_stats
84 |
85 | def agg_event_level_tpfpfn_stats(pred_records, gold_records, role_num):
86 | """
87 | Get event-level TP,FP,FN
88 | """
89 | # add role-level statistics as the event-level ones
90 | role_tpfpfn_stats = agg_event_role_tpfpfn_stats(
91 | pred_records, gold_records, role_num
92 | )
93 |
94 | return list(np.sum(role_tpfpfn_stats, axis=0))
95 |
96 | def agg_ins_event_role_tpfpfn_stats(pred_record_mat, gold_record_mat, event_role_num_list):
97 | """
98 | Aggregate TP,FP,FN statistics for a single instance.
99 | A record_mat should be formated as
100 | [(Event Index)
101 | [(Record Index)
102 | ((Role Index)
103 | argument 1, ...
104 | ), ...
105 | ], ...
106 | ], where argument 1 should support the '=' operation and the empty argument is None.
107 | """
108 | assert len(pred_record_mat) == len(gold_record_mat)
109 | # tpfpfn_stat: TP, FP, FN
110 | event_role_tpfpfn_stats = []
111 | for event_idx, (pred_records, gold_records) in enumerate(zip(pred_record_mat, gold_record_mat)):
112 | role_num = event_role_num_list[event_idx]
113 | role_tpfpfn_stats = agg_event_role_tpfpfn_stats(pred_records, gold_records, role_num)
114 | event_role_tpfpfn_stats.append(role_tpfpfn_stats)
115 |
116 | return event_role_tpfpfn_stats
117 |
118 | def agg_ins_event_level_tpfpfn_stats(pred_record_mat, gold_record_mat, event_role_num_list):
119 | assert len(pred_record_mat) == len(gold_record_mat)
120 | # tpfpfn_stat: TP, FP, FN
121 | event_tpfpfn_stats = []
122 | for event_idx, (pred_records, gold_records, role_num) in enumerate(zip(
123 | pred_record_mat, gold_record_mat, event_role_num_list)):
124 | event_tpfpfn = agg_event_level_tpfpfn_stats(pred_records, gold_records, role_num)
125 | event_tpfpfn_stats.append(event_tpfpfn)
126 |
127 | return event_tpfpfn_stats
128 |
129 | def get_prec_recall_f1(tp, fp, fn):
130 | a = tp + fp
131 | prec = tp / a if a > 0 else 0
132 | b = tp + fn
133 | rec = tp / b if b > 0 else 0
134 | if prec > 0 and rec > 0:
135 | f1 = 2.0 / (1 / prec + 1 / rec)
136 | else:
137 | f1 = 0
138 | return prec, rec, f1
139 |
140 | def measure_event_table_filling(pred_record_mat_list, gold_record_mat_list, event_type_roles_list, avg_type='micro',
141 | dict_return=False):
142 | """
143 | The record_mat_list is formated as
144 | [(Document Index)
145 | [(Event Index)
146 | [(Record Index)
147 | ((Role Index)
148 | argument 1, ...
149 | ), ...
150 | ], ...
151 | ], ...
152 | ]
153 | The argument type should support the '==' operation.
154 | Empty arguments and records are set as None.
155 | """
156 | event_role_num_list = [len(roles) for _, roles in event_type_roles_list]
157 | # to store total statistics of TP, FP, FN
158 | total_event_role_stats = [
159 | [
160 | [0]*3 for _ in range(role_num)
161 | ] for event_idx, role_num in enumerate(event_role_num_list)
162 | ]
163 |
164 | assert len(pred_record_mat_list) == len(gold_record_mat_list)
165 | for pred_record_mat, gold_record_mat in zip(pred_record_mat_list, gold_record_mat_list):
166 | event_role_tpfpfn_stats = agg_ins_event_role_tpfpfn_stats(
167 | pred_record_mat, gold_record_mat, event_role_num_list
168 | )
169 | for event_idx, role_num in enumerate(event_role_num_list):
170 | for role_idx in range(role_num):
171 | for sid in range(3):
172 | total_event_role_stats[event_idx][role_idx][sid] += \
173 | event_role_tpfpfn_stats[event_idx][role_idx][sid]
174 |
175 | per_role_metric = []
176 | per_event_metric = []
177 |
178 | num_events = len(event_role_num_list)
179 | g_tpfpfn_stat = [0] * 3
180 | g_prf1_stat = [0] * 3
181 | event_role_eval_dicts = []
182 | for event_idx, role_num in enumerate(event_role_num_list):
183 | event_tpfpfn = [0] * 3 # tp, fp, fn
184 | event_prf1_stat = [0] * 3
185 | per_role_metric.append([])
186 | role_eval_dicts = []
187 | for role_idx in range(role_num):
188 | role_tpfpfn_stat = total_event_role_stats[event_idx][role_idx][:3]
189 | role_prf1_stat = get_prec_recall_f1(*role_tpfpfn_stat)
190 | per_role_metric[event_idx].append(role_prf1_stat)
191 | for mid in range(3):
192 | event_tpfpfn[mid] += role_tpfpfn_stat[mid]
193 | event_prf1_stat[mid] += role_prf1_stat[mid]
194 |
195 | role_eval_dict = {
196 | 'RoleType': event_type_roles_list[event_idx][1][role_idx],
197 | 'Precision': role_prf1_stat[0],
198 | 'Recall': role_prf1_stat[1],
199 | 'F1': role_prf1_stat[2],
200 | 'TP': role_tpfpfn_stat[0],
201 | 'FP': role_tpfpfn_stat[1],
202 | 'FN': role_tpfpfn_stat[2]
203 | }
204 | role_eval_dicts.append(role_eval_dict)
205 |
206 | for mid in range(3):
207 | event_prf1_stat[mid] /= role_num
208 | g_tpfpfn_stat[mid] += event_tpfpfn[mid]
209 | g_prf1_stat[mid] += event_prf1_stat[mid]
210 |
211 | micro_event_prf1 = get_prec_recall_f1(*event_tpfpfn)
212 | macro_event_prf1 = tuple(event_prf1_stat)
213 | if avg_type.lower() == 'micro':
214 | event_prf1_stat = micro_event_prf1
215 | elif avg_type.lower() == 'macro':
216 | event_prf1_stat = macro_event_prf1
217 | else:
218 | raise Exception('Unsupported average type {}'.format(avg_type))
219 |
220 | per_event_metric.append(event_prf1_stat)
221 |
222 | event_eval_dict = {
223 | 'EventType': event_type_roles_list[event_idx][0],
224 | 'MacroPrecision': macro_event_prf1[0],
225 | 'MacroRecall': macro_event_prf1[1],
226 | 'MacroF1': macro_event_prf1[2],
227 | 'MicroPrecision': micro_event_prf1[0],
228 | 'MicroRecall': micro_event_prf1[1],
229 | 'MicroF1': micro_event_prf1[2],
230 | 'TP': event_tpfpfn[0],
231 | 'FP': event_tpfpfn[1],
232 | 'FN': event_tpfpfn[2],
233 | }
234 | event_role_eval_dicts.append((event_eval_dict, role_eval_dicts))
235 |
236 | micro_g_prf1 = get_prec_recall_f1(*g_tpfpfn_stat)
237 | macro_g_prf1 = tuple(s / num_events for s in g_prf1_stat)
238 | if avg_type.lower() == 'micro':
239 | g_metric = micro_g_prf1
240 | else:
241 | g_metric = macro_g_prf1
242 |
243 | g_eval_dict = {
244 | 'MacroPrecision': macro_g_prf1[0],
245 | 'MacroRecall': macro_g_prf1[1],
246 | 'MacroF1': macro_g_prf1[2],
247 | 'MicroPrecision': micro_g_prf1[0],
248 | 'MicroRecall': micro_g_prf1[1],
249 | 'MicroF1': micro_g_prf1[2],
250 | 'TP': g_tpfpfn_stat[0],
251 | 'FP': g_tpfpfn_stat[1],
252 | 'FN': g_tpfpfn_stat[2],
253 | }
254 | event_role_eval_dicts.append(g_eval_dict)
255 |
256 | if not dict_return:
257 | return g_metric, per_event_metric, per_role_metric
258 | else:
259 | return event_role_eval_dicts
260 |
--------------------------------------------------------------------------------
/dee/dee_task.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import logging
4 | import os
5 | import torch.optim as optim
6 | import torch.distributed as dist
7 | from itertools import product
8 |
9 | from .dee_helper import logger, DEEExample, DEEExampleLoader, DEEFeatureConverter, \
10 | convert_dee_features_to_dataset, prepare_doc_batch_dict, measure_dee_prediction, \
11 | decode_dump_template, eval_dump_template
12 | from .utils import BERTChineseCharacterTokenizer, default_dump_json, default_load_pkl
13 | from .ner_model import BertForBasicNER
14 | from .base_task import TaskSetting, BasePytorchTask
15 | from .event_type import event_type_fields_list
16 | from .dee_model import GITModel
17 |
18 |
19 | class DEETaskSetting(TaskSetting):
20 | base_key_attrs = TaskSetting.base_key_attrs
21 | base_attr_default_pairs = [
22 | ('train_file_name', 'train.json'),
23 | ('dev_file_name', 'dev.json'),
24 | ('test_file_name', 'test.json'),
25 | ('summary_dir_name', './tmp/Summary'),
26 | ('max_sent_len', 128),
27 | ('max_sent_num', 64),
28 | ('train_batch_size', 64),
29 | ('gradient_accumulation_steps', 8),
30 | ('eval_batch_size', 2),
31 | ('learning_rate', 1e-4),
32 | ('num_train_epochs', 100),
33 | ('no_cuda', False),
34 | ('local_rank', -1),
35 | ('seed', 99),
36 | ('optimize_on_cpu', False),
37 | ('fp16', False),
38 | ('bert_model', 'bert-base-chinese'), # use which pretrained bert model
39 | ('only_master_logging', True), # whether to print logs from multiple processes
40 | ('resume_latest_cpt', True), # whether to resume latest checkpoints when training for fault tolerance
41 | ('cpt_file_name', 'GIT'), # decide the identity of checkpoints, evaluation results, etc.
42 | ('rearrange_sent', False), # whether to rearrange sentences
43 | ('use_crf_layer', True), # whether to use CRF Layer
44 | ('min_teacher_prob', 0.1), # the minimum prob to use gold spans
45 | ('schedule_epoch_start', 10), # from which epoch the scheduled sampling starts
46 | ('schedule_epoch_length', 10), # the number of epochs to linearly transit to the min_teacher_prob
47 | ('loss_lambda', 0.05), # the proportion of ner loss
48 | ('loss_gamma', 1.0), # the scaling proportion of missed span sentence ner loss
49 | ('seq_reduce_type', 'MaxPooling'), # use 'MaxPooling', 'MeanPooling' or 'AWA' to reduce a tensor sequence
50 | # network parameters (follow Bert Base)
51 | ('hidden_size', 768),
52 | ('dropout', 0.1),
53 | ('ff_size', 1024), # feed-forward mid layer size
54 | ('num_tf_layers', 4), # transformer layer number
55 | # ablation study parameters,
56 | ('use_path_mem', True), # whether to use the memory module when expanding paths
57 | ('use_scheduled_sampling', True), # whether to use the scheduled sampling
58 | ('neg_field_loss_scaling', 3.0), # prefer FNs over FPs
59 | ('gcn_layer', 3), # prefer FNs over FPs
60 | ('ner_num_tf_layers', 8)
61 | ]
62 |
63 | def __init__(self, **kwargs):
64 | super(DEETaskSetting, self).__init__(
65 | self.base_key_attrs, self.base_attr_default_pairs, **kwargs
66 | )
67 |
68 |
69 | class DEETask(BasePytorchTask):
70 | """Doc-level Event Extraction Task"""
71 |
72 | def __init__(self, dee_setting, load_train=True, load_dev=True, load_test=True,
73 | parallel_decorate=True):
74 | super(DEETask, self).__init__(dee_setting, only_master_logging=dee_setting.only_master_logging)
75 | self.logger = logging.getLogger(self.__class__.__name__)
76 | self.logging('Initializing {}'.format(self.__class__.__name__))
77 |
78 | self.tokenizer = BERTChineseCharacterTokenizer.from_pretrained(self.setting.bert_model)
79 | self.setting.vocab_size = len(self.tokenizer.vocab)
80 |
81 | # get entity and event label name
82 | self.entity_label_list = DEEExample.get_entity_label_list()
83 | self.event_type_fields_pairs = DEEExample.get_event_type_fields_pairs() # event -> list of entity types
84 | # build example loader
85 | self.example_loader_func = DEEExampleLoader(self.setting.rearrange_sent, self.setting.max_sent_len)
86 |
87 | # build feature converter
88 | self.feature_converter_func = DEEFeatureConverter(
89 | self.entity_label_list, self.event_type_fields_pairs,
90 | self.setting.max_sent_len, self.setting.max_sent_num, self.tokenizer,
91 | include_cls=False, include_sep=False,
92 | )
93 |
94 | # LOAD DATA
95 | # self.example_loader_func: raw data -> example
96 | # self.feature_converter_func: example -> feature
97 | self._load_data(
98 | self.example_loader_func, self.feature_converter_func, convert_dee_features_to_dataset,
99 | load_train=load_train, load_dev=load_dev, load_test=load_test,
100 | )
101 | # customized mini-batch producer
102 | self.custom_collate_fn = prepare_doc_batch_dict
103 |
104 | self.setting.num_entity_labels = len(self.entity_label_list)
105 |
106 | ner_model = None
107 |
108 | self.model = GITModel(
109 | self.setting, self.event_type_fields_pairs, ner_model=ner_model,
110 | )
111 |
112 | self._decorate_model(parallel_decorate=parallel_decorate)
113 |
114 | # prepare optimizer
115 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.setting.learning_rate)
116 |
117 | # # resume option
118 | # if resume_model or resume_optimizer:
119 | # self.resume_checkpoint(resume_model=resume_model, resume_optimizer=resume_optimizer)
120 |
121 | self.min_teacher_prob = None
122 | self.teacher_norm = None
123 | self.teacher_cnt = None
124 | self.teacher_base = None
125 | self.reset_teacher_prob()
126 |
127 | self.logging('Successfully initialize {}'.format(self.__class__.__name__))
128 |
129 | def reset_teacher_prob(self):
130 | self.min_teacher_prob = self.setting.min_teacher_prob
131 | if self.train_dataset is None:
132 | # avoid crashing when not loading training data
133 | num_step_per_epoch = 500
134 | else:
135 | num_step_per_epoch = int(len(self.train_dataset) / self.setting.train_batch_size)
136 | self.teacher_norm = num_step_per_epoch * self.setting.schedule_epoch_length
137 | self.teacher_base = num_step_per_epoch * self.setting.schedule_epoch_start
138 | self.teacher_cnt = 0
139 |
140 | def get_teacher_prob(self, batch_inc_flag=True):
141 | if self.teacher_cnt < self.teacher_base:
142 | prob = 1
143 | else:
144 | prob = max(
145 | self.min_teacher_prob, (self.teacher_norm - self.teacher_cnt + self.teacher_base) / self.teacher_norm
146 | )
147 |
148 | if batch_inc_flag:
149 | self.teacher_cnt += 1
150 |
151 | return prob
152 |
153 | def get_event_idx2entity_idx2field_idx(self):
154 | entity_idx2entity_type = {}
155 | for entity_idx, entity_label in enumerate(self.entity_label_list):
156 | if entity_label == 'O':
157 | entity_type = entity_label
158 | else:
159 | entity_type = entity_label[2:]
160 |
161 | entity_idx2entity_type[entity_idx] = entity_type
162 |
163 | event_idx2entity_idx2field_idx = {}
164 | for event_idx, (event_name, field_types) in enumerate(self.event_type_fields_pairs):
165 | field_type2field_idx = {}
166 | for field_idx, field_type in enumerate(field_types):
167 | field_type2field_idx[field_type] = field_idx
168 |
169 | entity_idx2field_idx = {}
170 | for entity_idx, entity_type in entity_idx2entity_type.items():
171 | if entity_type in field_type2field_idx:
172 | entity_idx2field_idx[entity_idx] = field_type2field_idx[entity_type]
173 | else:
174 | entity_idx2field_idx[entity_idx] = None
175 |
176 | event_idx2entity_idx2field_idx[event_idx] = entity_idx2field_idx
177 |
178 | return event_idx2entity_idx2field_idx
179 |
180 | def get_loss_on_batch(self, doc_batch_dict, features=None):
181 | if features is None:
182 | features = self.train_features
183 |
184 | # teacher_prob = 1
185 | # if use_gold_span, gold spans will be used every time
186 | # else, teacher_prob will ensure the proportion of using gold spans
187 | if self.setting.use_scheduled_sampling:
188 | use_gold_span = False
189 | teacher_prob = self.get_teacher_prob()
190 | else:
191 | use_gold_span = True
192 | teacher_prob = 1
193 |
194 | try:
195 | loss = self.model(
196 | doc_batch_dict, features, use_gold_span=use_gold_span, train_flag=True, teacher_prob=teacher_prob
197 | )
198 | except Exception as e:
199 | print('-'*30)
200 | print('Exception occurs when processing ' +
201 | ','.join([features[ex_idx].guid for ex_idx in doc_batch_dict['ex_idx']]))
202 | raise Exception('Cannot get the loss')
203 |
204 | return loss
205 |
206 | def get_event_decode_result_on_batch(self, doc_batch_dict, features=None, use_gold_span=False):
207 | if features is None:
208 | raise Exception('Features mush be provided')
209 |
210 | event_idx2entity_idx2field_idx = None
211 | batch_eval_results = self.model(
212 | doc_batch_dict, features, use_gold_span=use_gold_span, train_flag=False,
213 | event_idx2entity_idx2field_idx=event_idx2entity_idx2field_idx,
214 | )
215 |
216 | return batch_eval_results
217 |
218 | def train(self, save_cpt_flag=True, resume_base_epoch=None):
219 | self.logging('=' * 20 + 'Start Training' + '=' * 20)
220 | self.reset_teacher_prob()
221 |
222 | # resume_base_epoch arguments have higher priority over settings
223 | if resume_base_epoch is None:
224 | # whether to resume latest cpt when restarting, very useful for preemptive scheduling clusters
225 | if self.setting.resume_latest_cpt:
226 | resume_base_epoch = self.get_latest_cpt_epoch()
227 | else:
228 | resume_base_epoch = 0
229 |
230 | # resume cpt if possible
231 | if resume_base_epoch > 0:
232 | self.logging('Training starts from epoch {}'.format(resume_base_epoch))
233 | for _ in range(resume_base_epoch):
234 | self.get_teacher_prob()
235 | self.resume_cpt_at(resume_base_epoch, resume_model=True, resume_optimizer=True)
236 | else:
237 | self.logging('Training starts from scratch')
238 |
239 | self.base_train(
240 | DEETask.get_loss_on_batch,
241 | kwargs_dict1={},
242 | epoch_eval_func=DEETask.resume_save_eval_at,
243 | kwargs_dict2={
244 | 'save_cpt_flag': save_cpt_flag,
245 | 'resume_cpt_flag': False,
246 | },
247 | base_epoch_idx=resume_base_epoch,
248 | )
249 |
250 | def resume_save_eval_at(self, epoch, resume_cpt_flag=False, save_cpt_flag=True):
251 | if self.is_master_node():
252 | print('\nPROGRESS: {:.2f}%\n'.format(epoch / self.setting.num_train_epochs * 100))
253 | self.logging('Current teacher prob {}'.format(self.get_teacher_prob(batch_inc_flag=False)))
254 |
255 | if resume_cpt_flag:
256 | self.resume_cpt_at(epoch)
257 |
258 | if self.is_master_node() and save_cpt_flag:
259 | self.save_cpt_at(epoch)
260 |
261 | eval_tasks = product(['dev', 'test'], [False, True])
262 |
263 | for task_idx, (data_type, gold_span_flag) in enumerate(eval_tasks):
264 | if self.in_distributed_mode() and task_idx % dist.get_world_size() != dist.get_rank():
265 | continue
266 |
267 | if data_type == 'test':
268 | features = self.test_features
269 | dataset = self.test_dataset
270 | elif data_type == 'dev':
271 | features = self.dev_features
272 | dataset = self.dev_dataset
273 | else:
274 | raise Exception('Unsupported data type {}'.format(data_type))
275 |
276 | if gold_span_flag:
277 | span_str = 'gold_span'
278 | else:
279 | span_str = 'pred_span'
280 |
281 | model_str = self.setting.cpt_file_name.replace('.', '~')
282 |
283 | decode_dump_name = decode_dump_template.format(data_type, span_str, model_str, epoch)
284 | eval_dump_name = eval_dump_template.format(data_type, span_str, model_str, epoch)
285 | self.eval(features, dataset, use_gold_span=gold_span_flag,
286 | dump_decode_pkl_name=decode_dump_name, dump_eval_json_name=eval_dump_name, epoch=epoch)
287 |
288 | def save_cpt_at(self, epoch):
289 | self.save_checkpoint(cpt_file_name='{}.cpt.{}'.format(self.setting.cpt_file_name, epoch), epoch=epoch)
290 |
291 | def resume_cpt_at(self, epoch, resume_model=True, resume_optimizer=False):
292 | self.resume_checkpoint(cpt_file_name='{}.cpt.{}'.format(self.setting.cpt_file_name, epoch),
293 | resume_model=resume_model, resume_optimizer=resume_optimizer)
294 |
295 | def get_latest_cpt_epoch(self):
296 | prev_epochs = []
297 | for fn in os.listdir(self.setting.model_dir):
298 | if fn.startswith('{}.cpt'.format(self.setting.cpt_file_name)):
299 | try:
300 | epoch = int(fn.split('.')[-1])
301 | prev_epochs.append(epoch)
302 | except Exception as e:
303 | continue
304 | prev_epochs.sort()
305 |
306 | if len(prev_epochs) > 0:
307 | latest_epoch = prev_epochs[-1]
308 | self.logging('Pick latest epoch {} from {}'.format(latest_epoch, str(prev_epochs)))
309 | else:
310 | latest_epoch = 0
311 | self.logging('No previous epoch checkpoints, just start from scratch')
312 |
313 | return latest_epoch
314 |
315 | def eval(self, features, dataset, use_gold_span=False,
316 | dump_decode_pkl_name=None, dump_eval_json_name=None, epoch=None):
317 | self.logging('=' * 20 + 'Start Evaluation' + '=' * 20)
318 |
319 | if dump_decode_pkl_name is not None:
320 | dump_decode_pkl_path = os.path.join(self.setting.output_dir, dump_decode_pkl_name)
321 | self.logging('Dumping decode results into {}'.format(dump_decode_pkl_name))
322 | else:
323 | dump_decode_pkl_path = None
324 |
325 | total_event_decode_results = self.base_eval(
326 | dataset, DEETask.get_event_decode_result_on_batch,
327 | reduce_info_type='none', dump_pkl_path=dump_decode_pkl_path,
328 | features=features, use_gold_span=use_gold_span,
329 | )
330 |
331 | self.logging('Measure DEE Prediction')
332 |
333 | if dump_eval_json_name is not None:
334 | dump_eval_json_path = os.path.join(self.setting.output_dir, dump_eval_json_name)
335 | self.logging('Dumping eval results into {}'.format(dump_eval_json_name))
336 | else:
337 | dump_eval_json_path = None
338 |
339 | total_eval_res = measure_dee_prediction(
340 | self.event_type_fields_pairs, features, total_event_decode_results,
341 | dump_json_path=dump_eval_json_path, writer=self.summary_writer, epoch=epoch
342 | )
343 |
344 | return total_event_decode_results, total_eval_res
345 |
346 | def reevaluate_dee_prediction(self, target_file_pre='dee_eval', target_file_suffix='.pkl',
347 | dump_flag=False):
348 | """Enumerate the evaluation directory to collect all dumped evaluation results"""
349 | eval_dir_path = self.setting.output_dir
350 | logger.info('Re-evaluate dee predictions from {}'.format(eval_dir_path))
351 | data_span_type2model_str2epoch_res_list = {}
352 | for fn in os.listdir(eval_dir_path):
353 | fn_splits = fn.split('.')
354 | if fn.startswith(target_file_pre) and fn.endswith(target_file_suffix) and len(fn_splits) == 6:
355 | _, data_type, span_type, model_str, epoch, _ = fn_splits
356 |
357 | data_span_type = (data_type, span_type)
358 | if data_span_type not in data_span_type2model_str2epoch_res_list:
359 | data_span_type2model_str2epoch_res_list[data_span_type] = {}
360 | model_str2epoch_res_list = data_span_type2model_str2epoch_res_list[data_span_type]
361 |
362 | if model_str not in model_str2epoch_res_list:
363 | model_str2epoch_res_list[model_str] = []
364 | epoch_res_list = model_str2epoch_res_list[model_str]
365 |
366 | if data_type == 'dev':
367 | features = self.dev_features
368 | elif data_type == 'test':
369 | features = self.test_features
370 | else:
371 | raise Exception('Unsupported data type {}'.format(data_type))
372 |
373 | epoch = int(epoch)
374 | fp = os.path.join(eval_dir_path, fn)
375 | self.logging('Re-evaluating {}'.format(fp))
376 | event_decode_results = default_load_pkl(fp)
377 | total_eval_res = measure_dee_prediction(
378 | event_type_fields_list, features, event_decode_results
379 | )
380 |
381 | if dump_flag:
382 | fp = fp.rstrip('.pkl') + '.json'
383 | self.logging('Dumping {}'.format(fp))
384 | default_dump_json(total_eval_res, fp)
385 |
386 | epoch_res_list.append((epoch, total_eval_res))
387 |
388 | for data_span_type, model_str2epoch_res_list in data_span_type2model_str2epoch_res_list.items():
389 | for model_str, epoch_res_list in model_str2epoch_res_list.items():
390 | epoch_res_list.sort(key=lambda x: x[0])
391 |
392 | return data_span_type2model_str2epoch_res_list
393 |
394 |
395 |
--------------------------------------------------------------------------------
/dee/event_type.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 |
4 | class BaseEvent(object):
5 | def __init__(self, fields, event_name='Event', key_fields=(), recguid=None):
6 | self.recguid = recguid
7 | self.name = event_name
8 | self.fields = list(fields)
9 | self.field2content = {f: None for f in fields}
10 | self.nonempty_count = 0
11 | self.nonempty_ratio = self.nonempty_count / len(self.fields)
12 |
13 | self.key_fields = set(key_fields)
14 | for key_field in self.key_fields:
15 | assert key_field in self.field2content
16 |
17 | def __repr__(self):
18 | event_str = "\n{}[\n".format(self.name)
19 | event_str += " {}={}\n".format("recguid", self.recguid)
20 | event_str += " {}={}\n".format("nonempty_count", self.nonempty_count)
21 | event_str += " {}={:.3f}\n".format("nonempty_ratio", self.nonempty_ratio)
22 | event_str += "] (\n"
23 | for field in self.fields:
24 | if field in self.key_fields:
25 | key_str = " (key)"
26 | else:
27 | key_str = ""
28 | event_str += " " + field + "=" + str(self.field2content[field]) + ", {}\n".format(key_str)
29 | event_str += ")\n"
30 | return event_str
31 |
32 | def update_by_dict(self, field2text, recguid=None):
33 | self.nonempty_count = 0
34 | self.recguid = recguid
35 |
36 | for field in self.fields:
37 | if field in field2text and field2text[field] is not None:
38 | self.nonempty_count += 1
39 | self.field2content[field] = field2text[field]
40 | else:
41 | self.field2content[field] = None
42 |
43 | self.nonempty_ratio = self.nonempty_count / len(self.fields)
44 |
45 | def field_to_dict(self):
46 | return dict(self.field2content)
47 |
48 | def set_key_fields(self, key_fields):
49 | self.key_fields = set(key_fields)
50 |
51 | def is_key_complete(self):
52 | for key_field in self.key_fields:
53 | if self.field2content[key_field] is None:
54 | return False
55 |
56 | return True
57 |
58 | def is_good_candidate(self):
59 | raise NotImplementedError()
60 |
61 | def get_argument_tuple(self):
62 | args_tuple = tuple(self.field2content[field] for field in self.fields)
63 | return args_tuple
64 |
65 | class EquityFreezeEvent(BaseEvent):
66 | NAME = 'EquityFreeze'
67 | FIELDS = [
68 | 'EquityHolder',
69 | 'FrozeShares',
70 | 'LegalInstitution',
71 | 'TotalHoldingShares',
72 | 'TotalHoldingRatio',
73 | 'StartDate',
74 | 'EndDate',
75 | 'UnfrozeDate',
76 | ]
77 |
78 | def __init__(self, recguid=None):
79 | super().__init__(
80 | EquityFreezeEvent.FIELDS, event_name=EquityFreezeEvent.NAME, recguid=recguid
81 | )
82 | self.set_key_fields([
83 | 'EquityHolder',
84 | 'FrozeShares',
85 | 'LegalInstitution',
86 | ])
87 |
88 | def is_good_candidate(self, min_match_count=5):
89 | key_flag = self.is_key_complete()
90 | if key_flag:
91 | if self.nonempty_count >= min_match_count:
92 | return True
93 | return False
94 |
95 | class EquityRepurchaseEvent(BaseEvent):
96 | NAME = 'EquityRepurchase'
97 | FIELDS = [
98 | 'CompanyName',
99 | 'HighestTradingPrice',
100 | 'LowestTradingPrice',
101 | 'RepurchasedShares',
102 | 'ClosingDate',
103 | 'RepurchaseAmount',
104 | ]
105 |
106 | def __init__(self, recguid=None):
107 | super().__init__(
108 | EquityRepurchaseEvent.FIELDS, event_name=EquityRepurchaseEvent.NAME, recguid=recguid
109 | )
110 | self.set_key_fields([
111 | 'CompanyName',
112 | ])
113 |
114 | def is_good_candidate(self, min_match_count=4):
115 | key_flag = self.is_key_complete()
116 | if key_flag:
117 | if self.nonempty_count >= min_match_count:
118 | return True
119 | return False
120 |
121 | class EquityUnderweightEvent(BaseEvent):
122 | NAME = 'EquityUnderweight'
123 | FIELDS = [
124 | 'EquityHolder',
125 | 'TradedShares',
126 | 'StartDate',
127 | 'EndDate',
128 | 'LaterHoldingShares',
129 | 'AveragePrice',
130 | ]
131 |
132 | def __init__(self, recguid=None):
133 | super().__init__(
134 | EquityUnderweightEvent.FIELDS, event_name=EquityUnderweightEvent.NAME, recguid=recguid
135 | )
136 | self.set_key_fields([
137 | 'EquityHolder',
138 | 'TradedShares',
139 | ])
140 |
141 | def is_good_candidate(self, min_match_count=4):
142 | key_flag = self.is_key_complete()
143 | if key_flag:
144 | if self.nonempty_count >= min_match_count:
145 | return True
146 | return False
147 |
148 | class EquityOverweightEvent(BaseEvent):
149 | NAME = 'EquityOverweight'
150 | FIELDS = [
151 | 'EquityHolder',
152 | 'TradedShares',
153 | 'StartDate',
154 | 'EndDate',
155 | 'LaterHoldingShares',
156 | 'AveragePrice',
157 | ]
158 |
159 | def __init__(self, recguid=None):
160 | super().__init__(
161 | EquityOverweightEvent.FIELDS, event_name=EquityOverweightEvent.NAME, recguid=recguid
162 | )
163 | self.set_key_fields([
164 | 'EquityHolder',
165 | 'TradedShares',
166 | ])
167 |
168 | def is_good_candidate(self, min_match_count=4):
169 | key_flag = self.is_key_complete()
170 | if key_flag:
171 | if self.nonempty_count >= min_match_count:
172 | return True
173 | return False
174 |
175 | class EquityPledgeEvent(BaseEvent):
176 | NAME = 'EquityPledge'
177 | FIELDS = [
178 | 'Pledger',
179 | 'PledgedShares',
180 | 'Pledgee',
181 | 'TotalHoldingShares',
182 | 'TotalHoldingRatio',
183 | 'TotalPledgedShares',
184 | 'StartDate',
185 | 'EndDate',
186 | 'ReleasedDate',
187 | ]
188 |
189 | def __init__(self, recguid=None):
190 | # super(EquityPledgeEvent, self).__init__(
191 | super().__init__(
192 | EquityPledgeEvent.FIELDS, event_name=EquityPledgeEvent.NAME, recguid=recguid
193 | )
194 | self.set_key_fields([
195 | 'Pledger',
196 | 'PledgedShares',
197 | 'Pledgee',
198 | ])
199 |
200 | def is_good_candidate(self, min_match_count=5):
201 | key_flag = self.is_key_complete()
202 | if key_flag:
203 | if self.nonempty_count >= min_match_count:
204 | return True
205 | return False
206 |
207 | common_fields = ['StockCode', 'StockAbbr', 'CompanyName']
208 |
209 | event_type2event_class = {
210 | EquityFreezeEvent.NAME: EquityFreezeEvent,
211 | EquityRepurchaseEvent.NAME: EquityRepurchaseEvent,
212 | EquityUnderweightEvent.NAME: EquityUnderweightEvent,
213 | EquityOverweightEvent.NAME: EquityOverweightEvent,
214 | EquityPledgeEvent.NAME: EquityPledgeEvent,
215 | }
216 |
217 | event_type_fields_list = [
218 | (EquityFreezeEvent.NAME, EquityFreezeEvent.FIELDS),
219 | (EquityRepurchaseEvent.NAME, EquityRepurchaseEvent.FIELDS),
220 | (EquityUnderweightEvent.NAME, EquityUnderweightEvent.FIELDS),
221 | (EquityOverweightEvent.NAME, EquityOverweightEvent.FIELDS),
222 | (EquityPledgeEvent.NAME, EquityPledgeEvent.FIELDS),
223 | ]
224 |
--------------------------------------------------------------------------------
/dee/ner_model.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 | from pytorch_pretrained_bert.modeling import PreTrainedBertModel, BertModel
9 |
10 | from . import transformer
11 |
12 |
13 | class BertForBasicNER(PreTrainedBertModel):
14 | """BERT model for basic NER functionality.
15 | This module is composed of the BERT model with a linear layer on top of
16 | the output sequences.
17 |
18 | Params:
19 | `config`: a BertConfig class instance with the configuration to build a new model.
20 | `num_entity_labels`: the number of entity classes for the classifier.
21 |
22 | Inputs:
23 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
24 | with the word token indices in the vocabulary.
25 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
26 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
27 | a `sentence B` token (see BERT paper for more details).
28 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
29 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
30 | input sequence length in the current batch. It's the mask that we typically use for attention when
31 | a batch has varying length sentences.
32 | `label_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
33 | with label indices selected in [0, ..., num_labels-1].
34 |
35 | Outputs:
36 | if `labels` is not `None`:
37 | Outputs the CrossEntropy classification loss of the output with the labels.
38 | if `labels` is `None`:
39 | Outputs the classification logits sequence.
40 | """
41 |
42 | def __init__(self, config, num_entity_labels):
43 | super(BertForBasicNER, self).__init__(config)
44 | self.bert = BertModel(config)
45 |
46 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
47 | self.classifier = nn.Linear(config.hidden_size, num_entity_labels)
48 | self.apply(self.init_bert_weights)
49 |
50 | self.num_entity_labels = num_entity_labels
51 |
52 | def old_forward(self, input_ids, input_masks,
53 | token_type_ids=None, label_ids=None,
54 | eval_flag=False, eval_for_metric=True):
55 | """Assume input size [batch_size, seq_len]"""
56 | if input_masks.dtype != torch.uint8:
57 | input_masks = input_masks == 1
58 |
59 | enc_seq_out, _ = self.bert(input_ids,
60 | token_type_ids=token_type_ids,
61 | attention_mask=input_masks,
62 | output_all_encoded_layers=False)
63 | # [batch_size, seq_len, hidden_size]
64 | enc_seq_out = self.dropout(enc_seq_out)
65 | # [batch_size, seq_len, num_entity_labels]
66 | seq_logits = self.classifier(enc_seq_out)
67 |
68 | if eval_flag: # if for evaluation purpose
69 | if label_ids is None:
70 | raise Exception('Cannot do evaluation without label info')
71 | else:
72 | if eval_for_metric:
73 | batch_metrics = produce_ner_batch_metrics(seq_logits, label_ids, input_masks)
74 | return batch_metrics
75 | else:
76 | seq_logp = F.log_softmax(seq_logits, dim=-1)
77 | seq_pred = seq_logp.argmax(dim=-1, keepdim=True) # [batch_size, seq_len, 1]
78 | seq_gold = label_ids.unsqueeze(-1) # [batch_size, seq_len, 1]
79 | seq_mask = input_masks.unsqueeze(-1).long() # [batch_size, seq_len, 1]
80 | seq_pred_gold_mask = torch.cat([seq_pred, seq_gold, seq_mask], dim=-1) # [batch_size, seq_len, 3]
81 | return seq_pred_gold_mask
82 | elif label_ids is not None: # if has label_ids, calculate the loss
83 | # [num_valid_token, num_entity_labels]
84 | batch_logits = seq_logits[input_masks, :]
85 | # [num_valid_token], lid \in {0,..., num_entity_labels-1}
86 | batch_labels = label_ids[input_masks]
87 | loss = F.cross_entropy(batch_logits, batch_labels)
88 | return loss, enc_seq_out
89 | else: # just reture seq_pred_logps
90 | return F.log_softmax(seq_logits, dim=-1), enc_seq_out
91 |
92 | def forward(self, input_ids, input_masks,
93 | label_ids=None, train_flag=True, decode_flag=True):
94 | """Assume input size [batch_size, seq_len]"""
95 | if input_masks.dtype != torch.uint8:
96 | input_masks = input_masks == 1
97 |
98 | batch_seq_enc, _ = self.bert(input_ids,
99 | attention_mask=input_masks,
100 | output_all_encoded_layers=False)
101 | # [batch_size, seq_len, hidden_size]
102 | batch_seq_enc = self.dropout(batch_seq_enc)
103 | # [batch_size, seq_len, num_entity_labels]
104 | batch_seq_logits = self.classifier(batch_seq_enc)
105 |
106 | batch_seq_logp = F.log_softmax(batch_seq_logits, dim=-1)
107 |
108 | if train_flag:
109 | batch_logp = batch_seq_logp.view(-1, batch_seq_logp.size(-1))
110 | batch_label = label_ids.view(-1)
111 | # ner_loss = F.nll_loss(batch_logp, batch_label, reduction='sum')
112 | ner_loss = F.nll_loss(batch_logp, batch_label, reduction='none')
113 | ner_loss = ner_loss.view(label_ids.size()).sum(dim=-1) # [batch_size]
114 | else:
115 | ner_loss = None
116 |
117 | if decode_flag:
118 | batch_seq_preds = batch_seq_logp.argmax(dim=-1)
119 | else:
120 | batch_seq_preds = None
121 |
122 | return batch_seq_enc, ner_loss, batch_seq_preds
123 |
124 | class NERModel(nn.Module):
125 | def __init__(self, config):
126 | super(NERModel, self).__init__()
127 |
128 | self.config = config
129 | # Word Embedding, Word Local Position Embedding
130 | self.token_embedding = NERTokenEmbedding(
131 | config.vocab_size, config.hidden_size,
132 | max_sent_len=config.max_sent_len, dropout=config.dropout
133 | )
134 | # Multi-layer Transformer Layers to Incorporate Contextual Information
135 | self.token_encoder = transformer.make_transformer_encoder(
136 | config.ner_num_tf_layers, config.hidden_size, ff_size=config.ff_size, dropout=config.dropout
137 | )
138 | if self.config.use_crf_layer:
139 | self.crf_layer = CRFLayer(config.hidden_size, self.config.num_entity_labels)
140 | else:
141 | # Token Label Classification
142 | self.classifier = nn.Linear(config.hidden_size, self.config.num_entity_labels)
143 |
144 | def forward(self, input_ids, input_masks,
145 | label_ids=None, train_flag=True, decode_flag=True):
146 | """Assume input size [batch_size, seq_len]"""
147 | if input_masks.dtype != torch.uint8:
148 | input_masks = input_masks == 1
149 | if train_flag:
150 | assert label_ids is not None
151 |
152 | # get contextual info
153 | input_emb = self.token_embedding(input_ids)
154 | input_masks = input_masks.unsqueeze(-2) # to fit for the transformer code
155 | batch_seq_enc = self.token_encoder(input_emb, input_masks)
156 |
157 | if self.config.use_crf_layer:
158 | ner_loss, batch_seq_preds = self.crf_layer(
159 | batch_seq_enc, seq_token_label=label_ids, batch_first=True,
160 | train_flag=train_flag, decode_flag=decode_flag
161 | )
162 | else:
163 | # [batch_size, seq_len, num_entity_labels]
164 | batch_seq_logits = self.classifier(batch_seq_enc)
165 | batch_seq_logp = F.log_softmax(batch_seq_logits, dim=-1)
166 |
167 | if train_flag:
168 | batch_logp = batch_seq_logp.view(-1, batch_seq_logp.size(-1))
169 | batch_label = label_ids.view(-1)
170 | # ner_loss = F.nll_loss(batch_logp, batch_label, reduction='sum')
171 | ner_loss = F.nll_loss(batch_logp, batch_label, reduction='none')
172 | ner_loss = ner_loss.view(label_ids.size()).sum(dim=-1) # [batch_size]
173 | else:
174 | ner_loss = None
175 |
176 | if decode_flag:
177 | batch_seq_preds = batch_seq_logp.argmax(dim=-1)
178 | else:
179 | batch_seq_preds = None
180 |
181 | return batch_seq_enc, ner_loss, batch_seq_preds
182 |
183 | class NERTokenEmbedding(nn.Module):
184 | """Add token position information"""
185 | def __init__(self, vocab_size, hidden_size, max_sent_len=256, dropout=0.1):
186 | super(NERTokenEmbedding, self).__init__()
187 |
188 | self.token_embedding = nn.Embedding(vocab_size, hidden_size)
189 | self.pos_embedding = nn.Embedding(max_sent_len, hidden_size)
190 |
191 | self.layer_norm = transformer.LayerNorm(hidden_size)
192 | self.dropout = nn.Dropout(dropout)
193 |
194 | def forward(self, batch_token_ids):
195 | batch_size, sent_len = batch_token_ids.size()
196 | device = batch_token_ids.device
197 |
198 | batch_pos_ids = torch.arange(
199 | sent_len, dtype=torch.long, device=device, requires_grad=False
200 | )
201 | batch_pos_ids = batch_pos_ids.unsqueeze(0).expand_as(batch_token_ids)
202 |
203 | batch_token_emb = self.token_embedding(batch_token_ids)
204 | batch_pos_emb = self.pos_embedding(batch_pos_ids)
205 |
206 | batch_token_emb = batch_token_emb + batch_pos_emb
207 |
208 | batch_token_out = self.layer_norm(batch_token_emb)
209 | batch_token_out = self.dropout(batch_token_out)
210 |
211 | return batch_token_out
212 |
213 | class CRFLayer(nn.Module):
214 | NEG_LOGIT = -100000.
215 | """
216 | Conditional Random Field Layer
217 | Reference:
218 | https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html#sphx-glr-beginner-nlp-advanced-tutorial-py
219 | The original example codes operate on one sequence, while this version operates on one batch
220 | """
221 |
222 | def __init__(self, hidden_size, num_entity_labels):
223 | super(CRFLayer, self).__init__()
224 |
225 | self.tag_size = num_entity_labels + 2 # add start tag and end tag
226 | self.start_tag = self.tag_size - 2
227 | self.end_tag = self.tag_size - 1
228 |
229 | # Map token-level hidden state into tag scores
230 | self.hidden2tag = nn.Linear(hidden_size, self.tag_size)
231 | # Transition Matrix
232 | # [i, j] denotes transitioning from j to i
233 | self.trans_mat = nn.Parameter(torch.randn(self.tag_size, self.tag_size))
234 | self.reset_trans_mat()
235 |
236 | def reset_trans_mat(self):
237 | nn.init.kaiming_uniform_(self.trans_mat, a=math.sqrt(5)) # copy from Linear init
238 | # set parameters that will not be updated during training, but is important
239 | self.trans_mat.data[self.start_tag, :] = self.NEG_LOGIT
240 | self.trans_mat.data[:, self.end_tag] = self.NEG_LOGIT
241 |
242 | def get_log_parition(self, seq_emit_score):
243 | """
244 | Calculate the log of the partition function
245 | :param seq_emit_score: [seq_len, batch_size, tag_size]
246 | :return: Tensor with Size([batch_size])
247 | """
248 | seq_len, batch_size, tag_size = seq_emit_score.size()
249 | # dynamic programming table to store previously summarized tag logits
250 | dp_table = seq_emit_score.new_full(
251 | (batch_size, tag_size), self.NEG_LOGIT, requires_grad=False
252 | )
253 | dp_table[:, self.start_tag] = 0.
254 |
255 | batch_trans_mat = self.trans_mat.unsqueeze(0).expand(batch_size, tag_size, tag_size)
256 |
257 | for token_idx in range(seq_len):
258 | prev_logit = dp_table.unsqueeze(1) # [batch_size, 1, tag_size]
259 | batch_emit_score = seq_emit_score[token_idx].unsqueeze(-1) # [batch_size, tag_size, 1]
260 | cur_logit = batch_trans_mat + batch_emit_score + prev_logit # [batch_size, tag_size, tag_size]
261 | dp_table = log_sum_exp(cur_logit) # [batch_size, tag_size]
262 | batch_logit = dp_table + self.trans_mat[self.end_tag, :].unsqueeze(0)
263 | log_partition = log_sum_exp(batch_logit) # [batch_size]
264 |
265 | return log_partition
266 |
267 | def get_gold_score(self, seq_emit_score, seq_token_label):
268 | """
269 | Calculate the score of the given sequence label
270 | :param seq_emit_score: [seq_len, batch_size, tag_size]
271 | :param seq_token_label: [seq_len, batch_size]
272 | :return: Tensor with Size([batch_size])
273 | """
274 | seq_len, batch_size, tag_size = seq_emit_score.size()
275 |
276 | end_token_label = seq_token_label.new_full(
277 | (1, batch_size), self.end_tag, requires_grad=False
278 | )
279 | seq_cur_label = torch.cat(
280 | [seq_token_label, end_token_label], dim=0
281 | ).unsqueeze(-1).unsqueeze(-1).expand(seq_len+1, batch_size, 1, tag_size)
282 |
283 | start_token_label = seq_token_label.new_full(
284 | (1, batch_size), self.start_tag, requires_grad=False
285 | )
286 | seq_prev_label = torch.cat(
287 | [start_token_label, seq_token_label], dim=0
288 | ).unsqueeze(-1).unsqueeze(-1) # [seq_len+1, batch_size, 1, 1]
289 |
290 | seq_trans_score = self.trans_mat.unsqueeze(0).unsqueeze(0).expand(seq_len+1, batch_size, tag_size, tag_size)
291 | # gather according to token label at the current token
292 | gold_trans_score = torch.gather(seq_trans_score, 2, seq_cur_label) # [seq_len+1, batch_size, 1, tag_size]
293 | # gather according to token label at the previous token
294 | gold_trans_score = torch.gather(gold_trans_score, 3, seq_prev_label) # [seq_len+1, batch_size, 1, 1]
295 | batch_trans_score = gold_trans_score.sum(dim=0).squeeze(-1).squeeze(-1) # [batch_size]
296 |
297 | gold_emit_score = torch.gather(seq_emit_score, 2, seq_token_label.unsqueeze(-1)) # [seq_len, batch_size, 1]
298 | batch_emit_score = gold_emit_score.sum(dim=0).squeeze(-1) # [batch_size]
299 |
300 | gold_score = batch_trans_score + batch_emit_score # [batch_size]
301 |
302 | return gold_score
303 |
304 | def viterbi_decode(self, seq_emit_score):
305 | """
306 | Use viterbi decoding to get prediction
307 | :param seq_emit_score: [seq_len, batch_size, tag_size]
308 | :return:
309 | batch_best_path: [batch_size, seq_len], the best tag for each token
310 | batch_best_score: [batch_size], the corresponding score for each path
311 | """
312 | seq_len, batch_size, tag_size = seq_emit_score.size()
313 |
314 | dp_table = seq_emit_score.new_full((batch_size, tag_size), self.NEG_LOGIT, requires_grad=False)
315 | dp_table[:, self.start_tag] = 0
316 | backpointers = []
317 |
318 | for token_idx in range(seq_len):
319 | last_tag_score = dp_table.unsqueeze(-2) # [batch_size, 1, tag_size]
320 | batch_trans_mat = self.trans_mat.unsqueeze(0).expand(batch_size, tag_size, tag_size)
321 | cur_emit_score = seq_emit_score[token_idx].unsqueeze(-1) # [batch_size, tag_size, 1]
322 | cur_trans_score = batch_trans_mat + last_tag_score + cur_emit_score # [batch_size, tag_size, tag_size]
323 | dp_table, cur_tag_bp = cur_trans_score.max(dim=-1) # [batch_size, tag_size]
324 | backpointers.append(cur_tag_bp)
325 | # transition to the end tag
326 | last_trans_arr = self.trans_mat[self.end_tag].unsqueeze(0).expand(batch_size, tag_size)
327 | dp_table = dp_table + last_trans_arr
328 |
329 | # get the best path score and the best tag of the last token
330 | batch_best_score, best_tag = dp_table.max(dim=-1) # [batch_size]
331 | best_tag = best_tag.unsqueeze(-1) # [batch_size, 1]
332 | best_tag_list = [best_tag]
333 | # reversely traverse back pointers to recover the best path
334 | for last_tag_bp in reversed(backpointers):
335 | # best_tag Size([batch_size, 1]) records the current tag that can own the highest score
336 | # last_tag_bp Size([batch_size, tag_size]) records the last best tag that the current tag is based on
337 | best_tag = torch.gather(last_tag_bp, 1, best_tag) # [batch_size, 1]
338 | best_tag_list.append(best_tag)
339 | batch_start = best_tag_list.pop()
340 | assert (batch_start == self.start_tag).sum().item() == batch_size
341 | best_tag_list.reverse()
342 | batch_best_path = torch.cat(best_tag_list, dim=-1) # [batch_size, seq_len]
343 |
344 | return batch_best_path, batch_best_score
345 |
346 | def forward(self, seq_token_emb, seq_token_label=None, batch_first=False,
347 | train_flag=True, decode_flag=True):
348 | """
349 | Get loss and prediction with CRF support.
350 | :param seq_token_emb: assume size [seq_len, batch_size, hidden_size] if not batch_first
351 | :param seq_token_label: assume size [seq_len, batch_size] if not batch_first
352 | :param batch_first: Flag to denote the meaning of the first dimension
353 | :param train_flag: whether to calculate the loss
354 | :param decode_flag: whether to decode the path based on current parameters
355 | :return:
356 | nll_loss: negative log-likelihood loss
357 | seq_token_pred: seqeunce predictions
358 | """
359 | if batch_first:
360 | # CRF assumes the input size of [seq_len, batch_size, hidden_size]
361 | seq_token_emb = seq_token_emb.transpose(0, 1).contiguous()
362 | if seq_token_label is not None:
363 | seq_token_label = seq_token_label.transpose(0, 1).contiguous()
364 |
365 | seq_emit_score = self.hidden2tag(seq_token_emb) # [seq_len, batch_size, tag_size]
366 | if train_flag:
367 | gold_score = self.get_gold_score(seq_emit_score, seq_token_label) # [batch_size]
368 | log_partition = self.get_log_parition(seq_emit_score) # [batch_size]
369 | nll_loss = log_partition - gold_score
370 | else:
371 | nll_loss = None
372 |
373 | if decode_flag:
374 | # Use viterbi decoding to get the current prediction
375 | # no matter what batch_first is, return size is [batch_size, seq_len]
376 | batch_best_path, batch_best_score = self.viterbi_decode(seq_emit_score)
377 | else:
378 | batch_best_path = None
379 |
380 | return nll_loss, batch_best_path
381 |
382 | # Compute log sum exp in a numerically stable way
383 | def log_sum_exp(batch_logit):
384 | """
385 | Caculate the log-sum-exp operation for the last dimension.
386 | :param batch_logit: Size([*, logit_size]), * should at least be 1
387 | :return: Size([*])
388 | """
389 | batch_max, _ = batch_logit.max(dim=-1)
390 | batch_broadcast = batch_max.unsqueeze(-1)
391 | return batch_max + \
392 | torch.log(torch.sum(torch.exp(batch_logit - batch_broadcast), dim=-1))
393 |
394 | def produce_ner_batch_metrics(seq_logits, gold_labels, masks):
395 | # seq_logits: [batch_size, seq_len, num_entity_labels]
396 | # gold_labels: [batch_size, seq_len]
397 | # masks: [batch_size, seq_len]
398 | batch_size, seq_len, num_entities = seq_logits.size()
399 |
400 | # [batch_size, seq_len, num_entity_labels]
401 | seq_logp = F.log_softmax(seq_logits, dim=-1)
402 | # [batch_size, seq_len]
403 | pred_labels = seq_logp.argmax(dim=-1)
404 | # [batch_size*seq_len, num_entity_labels]
405 | token_logp = seq_logp.view(-1, num_entities)
406 | # [batch_size*seq_len]
407 | token_labels = gold_labels.view(-1)
408 | # [batch_size, seq_len]
409 | seq_token_loss = F.nll_loss(token_logp, token_labels, reduction='none').view(batch_size, seq_len)
410 |
411 | batch_metrics = []
412 | for bid in range(batch_size):
413 | ex_loss = seq_token_loss[bid, masks[bid]].mean().item()
414 | ex_acc = (pred_labels[bid, masks[bid]] == gold_labels[bid, masks[bid]]).float().mean().item()
415 | ex_pred_lids = pred_labels[bid, masks[bid]].tolist()
416 | ex_gold_lids = gold_labels[bid, masks[bid]].tolist()
417 | ner_tp_set, ner_fp_set, ner_fn_set = judge_ner_prediction(ex_pred_lids, ex_gold_lids)
418 | batch_metrics.append([ex_loss, ex_acc, len(ner_tp_set), len(ner_fp_set), len(ner_fn_set)])
419 |
420 | return torch.tensor(batch_metrics, dtype=torch.float, device=seq_logits.device)
421 |
422 | def judge_ner_prediction(pred_label_ids, gold_label_ids):
423 | """Very strong assumption on label_id, 0: others, odd: ner_start, even: ner_mid"""
424 | if isinstance(pred_label_ids, torch.Tensor):
425 | pred_label_ids = pred_label_ids.tolist()
426 | if isinstance(gold_label_ids, torch.Tensor):
427 | gold_label_ids = gold_label_ids.tolist()
428 | # element: (ner_start_index, ner_end_index, ner_type_id)
429 | pred_ner_set = set()
430 | gold_ner_set = set()
431 |
432 | pred_ner_sid = None
433 | for idx, ner in enumerate(pred_label_ids):
434 | if pred_ner_sid is None:
435 | if ner % 2 == 1:
436 | pred_ner_sid = idx
437 | continue
438 | else:
439 | prev_ner = pred_label_ids[pred_ner_sid]
440 | if ner == 0:
441 | pred_ner_set.add((pred_ner_sid, idx, prev_ner))
442 | pred_ner_sid = None
443 | continue
444 | elif ner == prev_ner + 1: # same entity
445 | continue
446 | elif ner % 2 == 1:
447 | pred_ner_set.add((pred_ner_sid, idx, prev_ner))
448 | pred_ner_sid = idx
449 | continue
450 | else: # ignore invalid subsequence ners
451 | pred_ner_set.add((pred_ner_sid, idx, prev_ner))
452 | pred_ner_sid = None
453 | pass
454 | if pred_ner_sid is not None:
455 | prev_ner = pred_label_ids[pred_ner_sid]
456 | pred_ner_set.add((pred_ner_sid, len(pred_label_ids), prev_ner))
457 |
458 | gold_ner_sid = None
459 | for idx, ner in enumerate(gold_label_ids):
460 | if gold_ner_sid is None:
461 | if ner % 2 == 1:
462 | gold_ner_sid = idx
463 | continue
464 | else:
465 | prev_ner = gold_label_ids[gold_ner_sid]
466 | if ner == 0:
467 | gold_ner_set.add((gold_ner_sid, idx, prev_ner))
468 | gold_ner_sid = None
469 | continue
470 | elif ner == prev_ner + 1: # same entity
471 | continue
472 | elif ner % 2 == 1:
473 | gold_ner_set.add((gold_ner_sid, idx, prev_ner))
474 | gold_ner_sid = idx
475 | continue
476 | else: # ignore invalid subsequence ners
477 | gold_ner_set.add((gold_ner_sid, idx, prev_ner))
478 | gold_ner_sid = None
479 | pass
480 | if gold_ner_sid is not None:
481 | prev_ner = gold_label_ids[gold_ner_sid]
482 | gold_ner_set.add((gold_ner_sid, len(gold_label_ids), prev_ner))
483 |
484 | ner_tp_set = pred_ner_set.intersection(gold_ner_set)
485 | ner_fp_set = pred_ner_set - gold_ner_set
486 | ner_fn_set = gold_ner_set - pred_ner_set
487 |
488 | return ner_tp_set, ner_fp_set, ner_fn_set
489 |
--------------------------------------------------------------------------------
/dee/ner_task.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import torch
4 | import logging
5 | import os
6 | import json
7 | from torch.utils.data import TensorDataset
8 | from collections import defaultdict
9 |
10 | from .utils import default_load_json, default_dump_json, EPS, BERTChineseCharacterTokenizer
11 | from .event_type import common_fields, event_type_fields_list
12 | from .ner_model import BertForBasicNER, judge_ner_prediction
13 | from .base_task import TaskSetting, BasePytorchTask
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | class NERExample(object):
18 | basic_entity_label = 'O'
19 |
20 | def __init__(self, guid, text, entity_range_span_types):
21 | self.guid = guid
22 | self.text = text
23 | self.num_chars = len(text)
24 | self.entity_range_span_types = sorted(entity_range_span_types, key=lambda x: x[0])
25 |
26 | def get_char_entity_labels(self):
27 | char_entity_labels = []
28 | char_idx = 0
29 | ent_idx = 0
30 | while ent_idx < len(self.entity_range_span_types):
31 | (ent_cid_s, ent_cid_e), ent_span, ent_type = self.entity_range_span_types[ent_idx]
32 | assert ent_cid_s < ent_cid_e <= self.num_chars
33 |
34 | if ent_cid_s > char_idx:
35 | char_entity_labels.append(NERExample.basic_entity_label)
36 | char_idx += 1
37 | elif ent_cid_s == char_idx:
38 | # tmp_ent_labels = [ent_type] * (ent_cid_e - ent_cid_s)
39 | tmp_ent_labels = ['B-' + ent_type] + ['I-' + ent_type] * (ent_cid_e - ent_cid_s - 1)
40 | char_entity_labels.extend(tmp_ent_labels)
41 | char_idx = ent_cid_e
42 | ent_idx += 1
43 | else:
44 | logger.error('Example GUID {}'.format(self.guid))
45 | logger.error('NER conflicts at char_idx {}, ent_cid_s {}'.format(char_idx, ent_cid_s))
46 | logger.error(self.text[char_idx - 20:char_idx + 20])
47 | logger.error(self.entity_range_span_types[ent_idx - 1:ent_idx + 1])
48 | raise Exception('Unexpected logic error')
49 |
50 | char_entity_labels.extend([NERExample.basic_entity_label] * (self.num_chars - char_idx))
51 | assert len(char_entity_labels) == self.num_chars
52 |
53 | return char_entity_labels
54 |
55 | @staticmethod
56 | def get_entity_label_list():
57 | visit_set = set()
58 | entity_label_list = [NERExample.basic_entity_label]
59 |
60 | for field in common_fields:
61 | if field not in visit_set:
62 | visit_set.add(field)
63 | entity_label_list.extend(['B-' + field, 'I-' + field])
64 |
65 | for event_name, fields in event_type_fields_list:
66 | for field in fields:
67 | if field not in visit_set:
68 | visit_set.add(field)
69 | entity_label_list.extend(['B-' + field, 'I-' + field])
70 |
71 | return entity_label_list
72 |
73 | def __repr__(self):
74 | ex_str = 'NERExample(guid={}, text={}, entity_info={}'.format(
75 | self.guid, self.text, str(self.entity_range_span_types)
76 | )
77 | return ex_str
78 |
79 | def load_ner_dataset(dataset_json_path):
80 | total_ner_examples = []
81 | annguid2detail_align_info = default_load_json(dataset_json_path)
82 | for annguid, detail_align_info in annguid2detail_align_info.items():
83 | sents = detail_align_info['sentences']
84 | ann_valid_mspans = detail_align_info['ann_valid_mspans']
85 | ann_valid_dranges = detail_align_info['ann_valid_dranges']
86 | ann_mspan2guess_field = detail_align_info['ann_mspan2guess_field']
87 | assert len(ann_valid_dranges) == len(ann_valid_mspans)
88 |
89 | sent_idx2mrange_mspan_mfield_tuples = {}
90 | for drange, mspan in zip(ann_valid_dranges, ann_valid_mspans):
91 | sent_idx, char_s, char_e = drange
92 | sent_mrange = (char_s, char_e)
93 |
94 | sent_text = sents[sent_idx]
95 | assert sent_text[char_s: char_e] == mspan
96 |
97 | guess_field = ann_mspan2guess_field[mspan]
98 |
99 | if sent_idx not in sent_idx2mrange_mspan_mfield_tuples:
100 | sent_idx2mrange_mspan_mfield_tuples[sent_idx] = []
101 | sent_idx2mrange_mspan_mfield_tuples[sent_idx].append((sent_mrange, mspan, guess_field))
102 |
103 | for sent_idx in range(len(sents)):
104 | sent_text = sents[sent_idx]
105 | if sent_idx in sent_idx2mrange_mspan_mfield_tuples:
106 | mrange_mspan_mfield_tuples = sent_idx2mrange_mspan_mfield_tuples[sent_idx]
107 | else:
108 | mrange_mspan_mfield_tuples = []
109 |
110 | total_ner_examples.append(
111 | NERExample('{}-{}'.format(annguid, sent_idx),
112 | sent_text,
113 | mrange_mspan_mfield_tuples)
114 | )
115 |
116 | return total_ner_examples
117 |
118 | class NERFeature(object):
119 | def __init__(self, input_ids, input_masks, segment_ids, label_ids, seq_len=None):
120 | self.input_ids = input_ids
121 | self.input_masks = input_masks
122 | self.segment_ids = segment_ids
123 | self.label_ids = label_ids
124 | self.seq_len = seq_len
125 |
126 | def __repr__(self):
127 | fea_strs = ['NERFeature(real_seq_len={}'.format(self.seq_len), ]
128 | info_template = ' {:5} {:9} {:5} {:7} {:7}'
129 | fea_strs.append(info_template.format(
130 | 'index', 'input_ids', 'masks', 'seg_ids', 'lbl_ids'
131 | ))
132 | max_print_len = 10
133 | idx = 0
134 | for tid, mask, segid, lid in zip(
135 | self.input_ids, self.input_masks, self.segment_ids, self.label_ids):
136 | fea_strs.append(info_template.format(
137 | idx, tid, mask, segid, lid
138 | ))
139 | idx += 1
140 | if idx >= max_print_len:
141 | break
142 | fea_strs.append(info_template.format(
143 | '...', '...', '...', '...', '...'
144 | ))
145 | fea_strs.append(')')
146 |
147 | fea_str = '\n'.join(fea_strs)
148 | return fea_str
149 |
150 | class NERFeatureConverter(object):
151 | def __init__(self, entity_label_list, max_seq_len, tokenizer, include_cls=True, include_sep=True):
152 | self.entity_label_list = entity_label_list
153 | self.max_seq_len = max_seq_len # used to normalize sequence length
154 | self.tokenizer = tokenizer
155 | self.entity_label2index = { # for entity label to label index mapping
156 | entity_label: idx for idx, entity_label in enumerate(self.entity_label_list)
157 | }
158 |
159 | self.include_cls = include_cls
160 | self.include_sep = include_sep
161 |
162 | # used to track how many examples have been truncated
163 | self.truncate_count = 0
164 | # used to track the maximum length of input sentences
165 | self.data_max_seq_len = -1
166 |
167 | def convert_example_to_feature(self, ner_example, log_flag=False):
168 | ex_tokens = self.tokenizer.char_tokenize(ner_example.text)
169 | ex_entity_labels = ner_example.get_char_entity_labels()
170 |
171 | assert len(ex_tokens) == len(ex_entity_labels)
172 |
173 | # get valid token sequence length
174 | valid_token_len = self.max_seq_len
175 | if self.include_cls:
176 | valid_token_len -= 1
177 | if self.include_sep:
178 | valid_token_len -= 1
179 |
180 | # truncate according to max_seq_len and record some statistics
181 | self.data_max_seq_len = max(self.data_max_seq_len, len(ex_tokens))
182 | if len(ex_tokens) > valid_token_len:
183 | ex_tokens = ex_tokens[:valid_token_len]
184 | ex_entity_labels = ex_entity_labels[:valid_token_len]
185 |
186 | self.truncate_count += 1
187 |
188 | basic_label_index = self.entity_label2index[NERExample.basic_entity_label]
189 |
190 | # add bert-specific token
191 | if self.include_cls:
192 | fea_tokens = ['[CLS]']
193 | fea_token_labels = [NERExample.basic_entity_label]
194 | fea_label_ids = [basic_label_index]
195 | else:
196 | fea_tokens = []
197 | fea_token_labels = []
198 | fea_label_ids = []
199 |
200 | for token, ent_label in zip(ex_tokens, ex_entity_labels):
201 | fea_tokens.append(token)
202 | fea_token_labels.append(ent_label)
203 |
204 | if ent_label in self.entity_label2index:
205 | fea_label_ids.append(self.entity_label2index[ent_label])
206 | else:
207 | fea_label_ids.append(basic_label_index)
208 |
209 | if self.include_sep:
210 | fea_tokens.append('[SEP]')
211 | fea_token_labels.append(NERExample.basic_entity_label)
212 | fea_label_ids.append(basic_label_index)
213 |
214 | assert len(fea_tokens) == len(fea_token_labels) == len(fea_label_ids) <= self.max_seq_len
215 |
216 | fea_input_ids = self.tokenizer.convert_tokens_to_ids(fea_tokens)
217 | fea_seq_len = len(fea_input_ids)
218 | fea_segment_ids = [0] * fea_seq_len
219 | fea_masks = [1] * fea_seq_len
220 |
221 | # feature is padded to max_seq_len, but fea_seq_len is the real length
222 | while len(fea_input_ids) < self.max_seq_len:
223 | fea_input_ids.append(0)
224 | fea_label_ids.append(0)
225 | fea_masks.append(0)
226 | fea_segment_ids.append(0)
227 |
228 | assert len(fea_input_ids) == len(fea_label_ids) == len(fea_masks) == len(fea_segment_ids) == self.max_seq_len
229 |
230 | if log_flag:
231 | logger.info("*** Example ***")
232 | logger.info("guid: %s" % ner_example.guid)
233 | info_template = '{:8} {:4} {:2} {:2} {:2} {}'
234 | logger.info(info_template.format(
235 | 'TokenId', 'Token', 'Mask', 'SegId', 'LabelId', 'Label'
236 | ))
237 | for tid, token, mask, segid, lid, label in zip(
238 | fea_input_ids, fea_tokens, fea_masks,
239 | fea_segment_ids, fea_label_ids, fea_token_labels):
240 | logger.info(info_template.format(
241 | tid, token, mask, segid, lid, label
242 | ))
243 | if len(fea_input_ids) > len(fea_tokens):
244 | sid = len(fea_tokens)
245 | logger.info(info_template.format(
246 | fea_input_ids[sid], '[PAD]', fea_masks[sid], fea_segment_ids[sid], fea_label_ids[sid], 'O')
247 | + ' x {}'.format(len(fea_input_ids) - len(fea_tokens)))
248 |
249 | return NERFeature(fea_input_ids, fea_masks, fea_segment_ids, fea_label_ids, seq_len=fea_seq_len)
250 |
251 | def __call__(self, ner_examples, log_example_num=0):
252 | """Convert examples to features suitable for ner models"""
253 | self.truncate_count = 0
254 | self.data_max_seq_len = -1
255 | ner_features = []
256 |
257 | for ex_index, ner_example in enumerate(ner_examples):
258 | if ex_index < log_example_num:
259 | ner_feature = self.convert_example_to_feature(ner_example, log_flag=True)
260 | else:
261 | ner_feature = self.convert_example_to_feature(ner_example, log_flag=False)
262 |
263 | ner_features.append(ner_feature)
264 |
265 | logger.info('{} examples in total, {} truncated example, max_sent_len={}'.format(
266 | len(ner_examples), self.truncate_count, self.data_max_seq_len
267 | ))
268 |
269 | return ner_features
270 |
271 | def convert_ner_features_to_dataset(ner_features):
272 | all_input_ids = torch.tensor([f.input_ids for f in ner_features], dtype=torch.long)
273 | # very important to use the mask type of uint8 to support advanced indexing
274 | all_input_masks = torch.tensor([f.input_masks for f in ner_features], dtype=torch.uint8)
275 | all_segment_ids = torch.tensor([f.segment_ids for f in ner_features], dtype=torch.long)
276 | all_label_ids = torch.tensor([f.label_ids for f in ner_features], dtype=torch.long)
277 | all_seq_len = torch.tensor([f.seq_len for f in ner_features], dtype=torch.long)
278 | ner_tensor_dataset = TensorDataset(all_input_ids, all_input_masks, all_segment_ids, all_label_ids, all_seq_len)
279 |
280 | return ner_tensor_dataset
281 |
282 | class NERTaskSetting(TaskSetting):
283 | def __init__(self, **kwargs):
284 | ner_key_attrs = []
285 | ner_attr_default_pairs = [
286 | ('bert_model', 'bert-base-chinese'),
287 | ('train_file_name', 'train.json'),
288 | ('dev_file_name', 'dev.json'),
289 | ('test_file_name', 'test.json'),
290 | ('max_seq_len', 128),
291 | ('train_batch_size', 32),
292 | ('eval_batch_size', 256),
293 | ('learning_rate', 2e-5),
294 | ('num_train_epochs', 3.0),
295 | ('warmup_proportion', 0.1),
296 | ('no_cuda', False),
297 | ('local_rank', -1),
298 | ('seed', 99),
299 | ('gradient_accumulation_steps', 1),
300 | ('optimize_on_cpu', True),
301 | ('fp16', False),
302 | ('loss_scale', 128),
303 | ('cpt_file_name', 'ner_task.cpt'),
304 | ('summary_dir_name', '/tmp/summary'),
305 | ]
306 | super(NERTaskSetting, self).__init__(ner_key_attrs, ner_attr_default_pairs, **kwargs)
307 |
308 | class NERTask(BasePytorchTask):
309 | """Named Entity Recognition Task"""
310 |
311 | def __init__(self, setting,
312 | load_train=True, load_dev=True, load_test=True,
313 | build_model=True, parallel_decorate=True,
314 | resume_model=False, resume_optimizer=False):
315 | super(NERTask, self).__init__(setting)
316 | self.logger = logging.getLogger(self.__class__.__name__)
317 | self.logging('Initializing {}'.format(self.__class__.__name__))
318 |
319 | # initialize entity label list
320 | self.entity_label_list = NERExample.get_entity_label_list()
321 | # initialize tokenizer
322 | self.tokenizer = BERTChineseCharacterTokenizer.from_pretrained(self.setting.bert_model)
323 | # initialize feature converter
324 | self.feature_converter_func = NERFeatureConverter(
325 | self.entity_label_list, self.setting.max_seq_len, self.tokenizer
326 | )
327 |
328 | # load data
329 | self._load_data(
330 | load_ner_dataset, self.feature_converter_func, convert_ner_features_to_dataset,
331 | load_train=load_train, load_dev=load_dev, load_test=load_test
332 | )
333 |
334 | # build model
335 | if build_model:
336 | self.model = BertForBasicNER.from_pretrained(self.setting.bert_model, len(self.entity_label_list))
337 | self.setting.update_by_dict(self.model.config.__dict__) # BertConfig dictionary
338 | self._decorate_model(parallel_decorate=parallel_decorate)
339 |
340 | # prepare optimizer
341 | if build_model and load_train:
342 | self._init_bert_optimizer()
343 |
344 | # resume option
345 | if build_model and (resume_model or resume_optimizer):
346 | self.resume_checkpoint(resume_model=resume_model, resume_optimizer=resume_optimizer)
347 |
348 | self.logging('Successfully initialize {}'.format(self.__class__.__name__))
349 |
350 | def reload_data(self, data_type='return', file_name=None, file_path=None, **kwargs):
351 | """ Either file_name or file_path needs to be provided,
352 | data_type: return (default), return (examples, features, dataset)
353 | train, override self.train_xxx
354 | dev, override self.dev_xxx
355 | test, override self.test_xxx
356 | """
357 | return super(NERTask, self).reload_data(
358 | load_ner_dataset, self.feature_converter_func, convert_ner_features_to_dataset,
359 | data_type=data_type, file_name=file_name, file_path=file_path,
360 | )
361 |
362 | def train(self):
363 | self.logging('='*20 + 'Start Training' + '='*20)
364 | self.base_train(get_ner_loss_on_batch)
365 |
366 | def eval(self, eval_dataset, eval_save_prefix='', pgm_return_flag=False):
367 | self.logging('='*20 + 'Start Evaluation' + '='*20)
368 | # 1. get total prediction info
369 | # pgm denotes (pred_label, gold_label, token_mask)
370 | # size = [num_examples, max_seq_len, 3]
371 | # value = [[(pred_label, gold_label, token_mask), ...], ...]
372 | total_seq_pgm = self.get_total_prediction(eval_dataset)
373 | num_examples, max_seq_len, _ = total_seq_pgm.size()
374 |
375 | # 2. collect per-entity-label tp, fp, fn counts
376 | ent_lid2tp_cnt = defaultdict(lambda: 0)
377 | ent_lid2fp_cnt = defaultdict(lambda: 0)
378 | ent_lid2fn_cnt = defaultdict(lambda: 0)
379 | for bid in range(num_examples):
380 | seq_pgm = total_seq_pgm[bid] # [max_seq_len, 3]
381 | seq_pred = seq_pgm[:, 0] # [max_seq_len]
382 | seq_gold = seq_pgm[:, 1]
383 | seq_mask = seq_pgm[:, 2]
384 |
385 | seq_pred_lid = seq_pred[seq_mask == 1] # [seq_len]
386 | seq_gold_lid = seq_gold[seq_mask == 1]
387 | ner_tp_set, ner_fp_set, ner_fn_set = judge_ner_prediction(seq_pred_lid, seq_gold_lid)
388 | for ent_lid2cnt, ex_ner_set in [
389 | (ent_lid2tp_cnt, ner_tp_set),
390 | (ent_lid2fp_cnt, ner_fp_set),
391 | (ent_lid2fn_cnt, ner_fn_set)
392 | ]:
393 | for ent_idx_s, ent_idx_e, ent_lid in ex_ner_set:
394 | ent_lid2cnt[ent_lid] += 1
395 |
396 | # 3. calculate per-entity-label metrics and collect global counts
397 | ent_label_eval_infos = []
398 | g_ner_tp_cnt = 0
399 | g_ner_fp_cnt = 0
400 | g_ner_fn_cnt = 0
401 | # Entity Label Id, 0 for others, odd for BEGIN-ENTITY, even for INSIDE-ENTITY
402 | # using odd is enough to represent the entity type
403 | for ent_lid in range(1, len(self.entity_label_list), 2):
404 | el_name = self.entity_label_list[ent_lid]
405 | el_tp_cnt, el_fp_cnt, el_fn_cnt = ent_lid2tp_cnt[ent_lid], ent_lid2fp_cnt[ent_lid], ent_lid2fn_cnt[ent_lid]
406 |
407 | el_pred_cnt = el_tp_cnt + el_fp_cnt
408 | el_gold_cnt = el_tp_cnt + el_fn_cnt
409 | el_prec = el_tp_cnt / el_pred_cnt if el_pred_cnt > 0 else 0
410 | el_recall = el_tp_cnt / el_gold_cnt if el_gold_cnt > 0 else 0
411 | el_f1 = 2 / (1 / el_prec + 1 / el_recall) if el_prec > EPS and el_recall > EPS else 0
412 |
413 | # per-entity-label evaluation info
414 | el_eval_info = {
415 | 'entity_label_indexes': (ent_lid, ent_lid + 1),
416 | 'entity_label': el_name[2:], # omit 'B-' prefix
417 | 'ner_tp_cnt': el_tp_cnt,
418 | 'ner_fp_cnt': el_fp_cnt,
419 | 'ner_fn_cnt': el_fn_cnt,
420 | 'ner_prec': el_prec,
421 | 'ner_recall': el_recall,
422 | 'ner_f1': el_f1,
423 | }
424 | ent_label_eval_infos.append(el_eval_info)
425 |
426 | # collect global count info
427 | g_ner_tp_cnt += el_tp_cnt
428 | g_ner_fp_cnt += el_fp_cnt
429 | g_ner_fn_cnt += el_fn_cnt
430 |
431 | # 4. summarize total evaluation info
432 | g_ner_pred_cnt = g_ner_tp_cnt + g_ner_fp_cnt
433 | g_ner_gold_cnt = g_ner_tp_cnt + g_ner_fn_cnt
434 | g_ner_prec = g_ner_tp_cnt / g_ner_pred_cnt if g_ner_pred_cnt > 0 else 0
435 | g_ner_recall = g_ner_tp_cnt / g_ner_gold_cnt if g_ner_gold_cnt > 0 else 0
436 | g_ner_f1 = 2 / (1 / g_ner_prec + 1 / g_ner_recall) if g_ner_prec > EPS and g_ner_recall > EPS else 0
437 |
438 | total_eval_info = {
439 | 'eval_name': eval_save_prefix,
440 | 'num_examples': num_examples,
441 | 'ner_tp_cnt': g_ner_tp_cnt,
442 | 'ner_fp_cnt': g_ner_fp_cnt,
443 | 'ner_fn_cnt': g_ner_fn_cnt,
444 | 'ner_prec': g_ner_prec,
445 | 'ner_recall': g_ner_recall,
446 | 'ner_f1': g_ner_f1,
447 | 'per_ent_label_eval': ent_label_eval_infos
448 | }
449 |
450 | self.logging('Evaluation Results\n{:.300s} ...'.format(json.dumps(total_eval_info, indent=4)))
451 |
452 | if eval_save_prefix:
453 | eval_res_fp = os.path.join(self.setting.output_dir,
454 | '{}.eval'.format(eval_save_prefix))
455 | self.logging('Dump eval results into {}'.format(eval_res_fp))
456 | default_dump_json(total_eval_info, eval_res_fp)
457 |
458 | if pgm_return_flag:
459 | return total_seq_pgm
460 | else:
461 | return total_eval_info
462 |
463 | def get_total_prediction(self, eval_dataset):
464 | self.logging('='*20 + 'Get Total Prediction' + '='*20)
465 | total_pred_gold_mask = self.base_eval(
466 | eval_dataset, get_ner_pred_on_batch, reduce_info_type='none'
467 | )
468 | # torch.Tensor(dtype=torch.long, device='cpu')
469 | # size = [batch_size, seq_len, 3]
470 | # value = [[(pred_label, gold_label, token_mask), ...], ...]
471 | return total_pred_gold_mask
472 |
473 | def normalize_batch_seq_len(input_seq_lens, *batch_seq_tensors):
474 | batch_max_seq_len = input_seq_lens.max().item()
475 | normed_tensors = []
476 | for batch_seq_tensor in batch_seq_tensors:
477 | if batch_seq_tensor.dim() == 2:
478 | normed_tensors.append(batch_seq_tensor[:, :batch_max_seq_len])
479 | elif batch_seq_tensor.dim() == 1:
480 | normed_tensors.append(batch_seq_tensor)
481 | else:
482 | raise Exception('Unsupported batch_seq_tensor dimension {}'.format(batch_seq_tensor.dim()))
483 |
484 | return normed_tensors
485 |
486 | def prepare_ner_batch(batch, resize_len=True):
487 | # prepare batch
488 | input_ids, input_masks, segment_ids, label_ids, input_lens = batch
489 | if resize_len:
490 | input_ids, input_masks, segment_ids, label_ids = normalize_batch_seq_len(
491 | input_lens, input_ids, input_masks, segment_ids, label_ids
492 | )
493 |
494 | return input_ids, input_masks, segment_ids, label_ids
495 |
496 | def get_ner_loss_on_batch(ner_task, batch):
497 | input_ids, input_masks, segment_ids, label_ids = prepare_ner_batch(batch, resize_len=True)
498 | loss, _ = ner_task.model(input_ids, input_masks,
499 | token_type_ids=segment_ids,
500 | label_ids=label_ids)
501 |
502 | return loss
503 |
504 | def get_ner_metrics_on_batch(ner_task, batch):
505 | input_ids, input_masks, segment_ids, label_ids = prepare_ner_batch(batch, resize_len=True)
506 | batch_metrics = ner_task.model(input_ids, input_masks,
507 | token_type_ids=segment_ids,
508 | label_ids=label_ids,
509 | eval_flag=True,
510 | eval_for_metric=True)
511 |
512 | return batch_metrics
513 |
514 | def get_ner_pred_on_batch(ner_task, batch):
515 | # important to set resize_len to False to maintain the same seq len between batches
516 | input_ids, input_masks, segment_ids, label_ids = prepare_ner_batch(batch, resize_len=False)
517 | batch_seq_pred_gold_mask = ner_task.model(input_ids, input_masks,
518 | token_type_ids=segment_ids,
519 | label_ids=label_ids,
520 | eval_flag=True,
521 | eval_for_metric=False)
522 | # size = [batch_size, max_seq_len, 3]
523 | # value = [[(pred_label, gold_label, token_mask), ...], ...]
524 | return batch_seq_pred_gold_mask
525 |
--------------------------------------------------------------------------------
/dee/transformer.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import copy
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import math
9 |
10 |
11 | def clones(module, N):
12 | """Produce N identical layers."""
13 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
14 |
15 | class EncoderDecoder(nn.Module):
16 | """
17 | A standard Encoder-Decoder architecture. Base for this and many
18 | other models.
19 | """
20 |
21 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
22 | super(EncoderDecoder, self).__init__()
23 | self.encoder = encoder
24 | self.decoder = decoder
25 | self.src_embed = src_embed
26 | self.tgt_embed = tgt_embed
27 | self.generator = generator
28 |
29 | def forward(self, src, tgt, src_mask, tgt_mask):
30 | """Take in and process masked src and target sequences."""
31 | return self.decode(self.encode(src, src_mask), src_mask,
32 | tgt, tgt_mask)
33 |
34 | def encode(self, src, src_mask):
35 | return self.encoder(self.src_embed(src), src_mask)
36 |
37 | def decode(self, memory, src_mask, tgt, tgt_mask):
38 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
39 |
40 | class Generator(nn.Module):
41 | """Define standard linear + softmax generation step."""
42 | def __init__(self, d_model, vocab):
43 | super(Generator, self).__init__()
44 | self.proj = nn.Linear(d_model, vocab)
45 |
46 | def forward(self, x):
47 | return F.log_softmax(self.proj(x), dim=-1)
48 |
49 | class LayerNorm(nn.Module):
50 | """Construct a layernorm module (See citation for details)."""
51 |
52 | def __init__(self, features, eps=1e-6):
53 | super(LayerNorm, self).__init__()
54 | # self.a_2 = nn.Parameter(torch.ones(features))
55 | # self.b_2 = nn.Parameter(torch.zeros(features))
56 | # fit for bert optimizer
57 | self.gamma = nn.Parameter(torch.ones(features))
58 | self.beta = nn.Parameter(torch.zeros(features))
59 | self.eps = eps
60 |
61 | def forward(self, x):
62 | mean = x.mean(-1, keepdim=True)
63 | std = x.std(-1, keepdim=True)
64 | # return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
65 | return self.gamma * (x - mean) / (std + self.eps) + self.beta
66 |
67 |
68 | class Encoder(nn.Module):
69 | """"Core encoder is a stack of N layers"""
70 |
71 | def __init__(self, layer, N):
72 | super(Encoder, self).__init__()
73 | self.layers = clones(layer, N)
74 | self.norm = LayerNorm(layer.size)
75 |
76 | def forward(self, x, mask):
77 | """Pass the input (and mask) through each layer in turn."""
78 | for layer in self.layers:
79 | x = layer(x, mask)
80 | return self.norm(x)
81 |
82 | class SublayerConnection(nn.Module):
83 | """
84 | A residual connection followed by a layer norm.
85 | Note for code simplicity the norm is first as opposed to last.
86 | """
87 | def __init__(self, size, dropout):
88 | super(SublayerConnection, self).__init__()
89 | self.norm = LayerNorm(size)
90 | self.dropout = nn.Dropout(dropout)
91 |
92 | def forward(self, x, sublayer):
93 | """Apply residual connection to any sublayer with the same size."""
94 | return x + self.dropout(sublayer(self.norm(x)))
95 |
96 | class EncoderLayer(nn.Module):
97 | """Encoder is made up of self-attn and feed forward (defined below)"""
98 |
99 | def __init__(self, size, self_attn, feed_forward, dropout):
100 | super(EncoderLayer, self).__init__()
101 | self.self_attn = self_attn
102 | self.feed_forward = feed_forward
103 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
104 | self.size = size
105 |
106 | def forward(self, x, mask):
107 | """Follow Figure 1 (left) for connections."""
108 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
109 | return self.sublayer[1](x, self.feed_forward)
110 |
111 | class Decoder(nn.Module):
112 | """Generic N layer decoder with masking."""
113 |
114 | def __init__(self, layer, N):
115 | super(Decoder, self).__init__()
116 | self.layers = clones(layer, N)
117 | self.norm = LayerNorm(layer.size)
118 |
119 | def forward(self, x, memory, src_mask, tgt_mask):
120 | for layer in self.layers:
121 | x = layer(x, memory, src_mask, tgt_mask)
122 | return self.norm(x)
123 |
124 | class DecoderLayer(nn.Module):
125 | """Decoder is made of self-attn, src-attn, and feed forward (defined below)"""
126 |
127 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
128 | super(DecoderLayer, self).__init__()
129 | self.size = size
130 | self.self_attn = self_attn
131 | self.src_attn = src_attn
132 | self.feed_forward = feed_forward
133 | self.sublayer = clones(SublayerConnection(size, dropout), 3)
134 |
135 | def forward(self, x, memory, src_mask, tgt_mask):
136 | """Follow Figure 1 (right) for connections."""
137 | m = memory
138 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
139 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
140 | return self.sublayer[2](x, self.feed_forward)
141 |
142 | def subsequent_mask(size):
143 | """Mask out subsequent positions."""
144 | attn_shape = (1, size, size)
145 | subseq_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
146 | return torch.from_numpy(subseq_mask) == 0
147 |
148 | def attention(query, key, value, mask=None, dropout=None):
149 | """Compute 'Scaled Dot Product Attention'"""
150 | d_k = query.size(-1)
151 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
152 | if mask is not None:
153 | scores = scores.masked_fill(mask == 0, -1e9)
154 | p_attn = F.softmax(scores, dim=-1)
155 | if dropout is not None:
156 | p_attn = dropout(p_attn)
157 | return torch.matmul(p_attn, value), p_attn
158 |
159 | class MultiHeadedAttention(nn.Module):
160 | def __init__(self, h, d_model, dropout=0.1):
161 | """Take in model size and number of heads."""
162 | super(MultiHeadedAttention, self).__init__()
163 | assert d_model % h == 0
164 | # We assume d_v always equals d_k
165 | self.d_k = d_model // h
166 | self.h = h
167 | self.linears = clones(nn.Linear(d_model, d_model), 4)
168 | self.attn = None
169 | self.dropout = nn.Dropout(p=dropout)
170 |
171 | def forward(self, query, key, value, mask=None):
172 | """Implements Figure 2"""
173 | if mask is not None:
174 | # Same mask applied to all h heads.
175 | mask = mask.unsqueeze(1)
176 | nbatches = query.size(0)
177 |
178 | # 1) Do all the linear projections in batch from d_model => h x d_k
179 | query, key, value = [
180 | l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
181 | for l, x in zip(self.linears, (query, key, value))
182 | ]
183 |
184 | # 2) Apply attention on all the projected vectors in batch.
185 | x, self.attn = attention(query, key, value, mask=mask,
186 | dropout=self.dropout)
187 |
188 | # 3) "Concat" using a view and apply a final linear.
189 | x = x.transpose(1, 2).contiguous() \
190 | .view(nbatches, -1, self.h * self.d_k)
191 |
192 | return self.linears[-1](x)
193 |
194 | class PositionwiseFeedForward(nn.Module):
195 | """Implements FFN equation."""
196 |
197 | def __init__(self, d_model, d_ff, dropout=0.1):
198 | super(PositionwiseFeedForward, self).__init__()
199 | self.w_1 = nn.Linear(d_model, d_ff)
200 | self.w_2 = nn.Linear(d_ff, d_model)
201 | self.dropout = nn.Dropout(dropout)
202 |
203 | def forward(self, x):
204 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
205 |
206 | class Embeddings(nn.Module):
207 | def __init__(self, d_model, vocab):
208 | super(Embeddings, self).__init__()
209 | self.lut = nn.Embedding(vocab, d_model)
210 | self.d_model = d_model
211 |
212 | def forward(self, x):
213 | return self.lut(x) * math.sqrt(self.d_model)
214 |
215 | class PositionalEncoding(nn.Module):
216 | """Implement the PE function."""
217 |
218 | def __init__(self, d_model, dropout, max_len=5000):
219 | super(PositionalEncoding, self).__init__()
220 | self.dropout = nn.Dropout(p=dropout)
221 |
222 | # Compute the positional encodings once in log space.
223 | pe = torch.zeros(max_len, d_model)
224 | position = torch.arange(0, max_len).unsqueeze(1)
225 | div_term = torch.exp(torch.arange(0, d_model, 2) *
226 | -(math.log(10000.0) / d_model))
227 | pe[:, 0::2] = torch.sin(position * div_term)
228 | pe[:, 1::2] = torch.cos(position * div_term)
229 | pe = pe.unsqueeze(0)
230 | self.register_buffer('pe', pe)
231 |
232 | def forward(self, x):
233 | x = x + self.pe[:, :x.size(1)].to(device=x.device)
234 | return self.dropout(x)
235 |
236 | def make_model(src_vocab, tgt_vocab, num_layers=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
237 | """Helper: Construct a model from hyperparameters."""
238 | c = copy.deepcopy
239 | attn = MultiHeadedAttention(h, d_model)
240 | ff = PositionwiseFeedForward(d_model, d_ff, dropout)
241 | position = PositionalEncoding(d_model, dropout)
242 | model = EncoderDecoder(
243 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), num_layers),
244 | Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), num_layers),
245 | nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
246 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
247 | Generator(d_model, tgt_vocab))
248 |
249 | # This was important from their code.
250 | # Initialize parameters with Glorot / fan_avg.
251 | for p in model.parameters():
252 | if p.dim() > 1:
253 | nn.init.xavier_uniform(p)
254 | return model
255 |
256 | def make_transformer_encoder(num_layers, hidden_size, ff_size=2048, num_att_heads=8, dropout=0.1):
257 | dcopy = copy.deepcopy
258 | mh_att = MultiHeadedAttention(num_att_heads, hidden_size, dropout=dropout)
259 | pos_ff = PositionwiseFeedForward(hidden_size, ff_size, dropout=dropout)
260 |
261 | tranformer_encoder = Encoder(
262 | EncoderLayer(hidden_size, dcopy(mh_att), dcopy(pos_ff), dropout=dropout),
263 | num_layers
264 | )
265 |
266 | return tranformer_encoder
267 |
--------------------------------------------------------------------------------
/dee/utils.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import json
4 | import logging
5 | import pickle
6 | from pytorch_pretrained_bert import BertTokenizer
7 |
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | EPS = 1e-10
12 |
13 | def default_load_json(json_file_path, encoding='utf-8', **kwargs):
14 | with open(json_file_path, 'r', encoding=encoding) as fin:
15 | tmp_json = json.load(fin, **kwargs)
16 | return tmp_json
17 |
18 | def default_dump_json(obj, json_file_path, encoding='utf-8', ensure_ascii=False, indent=2, **kwargs):
19 | with open(json_file_path, 'w', encoding=encoding) as fout:
20 | json.dump(obj, fout,
21 | ensure_ascii=ensure_ascii,
22 | indent=indent,
23 | **kwargs)
24 |
25 | def default_load_pkl(pkl_file_path, **kwargs):
26 | with open(pkl_file_path, 'rb') as fin:
27 | obj = pickle.load(fin, **kwargs)
28 |
29 | return obj
30 |
31 | def default_dump_pkl(obj, pkl_file_path, **kwargs):
32 | with open(pkl_file_path, 'wb') as fout:
33 | pickle.dump(obj, fout, **kwargs)
34 |
35 | def set_basic_log_config():
36 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
37 | datefmt='%Y-%m-%d %H:%M:%S',
38 | level=logging.INFO)
39 |
40 | class BERTChineseCharacterTokenizer(BertTokenizer):
41 | """Customized tokenizer for Chinese financial announcements"""
42 |
43 | def __init__(self, vocab_file, do_lower_case=True):
44 | super(BERTChineseCharacterTokenizer, self).__init__(vocab_file, do_lower_case)
45 |
46 | def char_tokenize(self, text, unk_token='[UNK]'):
47 | """perform pure character-based tokenization"""
48 | tokens = list(text)
49 | out_tokens = []
50 | for token in tokens:
51 | if token in self.vocab:
52 | out_tokens.append(token)
53 | else:
54 | out_tokens.append(unk_token)
55 |
56 | return out_tokens
57 |
58 | def recursive_print_grad_fn(grad_fn, prefix='', depth=0, max_depth=50):
59 | if depth > max_depth:
60 | return
61 | print(prefix, depth, grad_fn.__class__.__name__)
62 | if hasattr(grad_fn, 'next_functions'):
63 | for nf in grad_fn.next_functions:
64 | ngfn = nf[0]
65 | recursive_print_grad_fn(ngfn, prefix=prefix + ' ', depth=depth+1, max_depth=max_depth)
66 |
67 | def strtobool(str_val):
68 | """Convert a string representation of truth to true (1) or false (0).
69 |
70 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
71 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
72 | 'val' is anything else.
73 | """
74 | str_val = str_val.lower()
75 | if str_val in ('y', 'yes', 't', 'true', 'on', '1'):
76 | return True
77 | elif str_val in ('n', 'no', 'f', 'false', 'off', '0'):
78 | return False
79 | else:
80 | raise ValueError("invalid truth value %r" % (str_val,))
81 |
--------------------------------------------------------------------------------
/figs/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/figs/model.png
--------------------------------------------------------------------------------
/figs/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RunxinXu/GIT/3f91743656ae65c49bbfbe11a7ed8152a8b0bc20/figs/result.png
--------------------------------------------------------------------------------
/run_dee_task.py:
--------------------------------------------------------------------------------
1 | # Code Reference: (https://github.com/dolphin-zs/Doc2EDAG)
2 |
3 | import argparse
4 | import os
5 | import torch.distributed as dist
6 |
7 | from dee.utils import set_basic_log_config, strtobool
8 | from dee.dee_task import DEETask, DEETaskSetting
9 | from dee.dee_helper import aggregate_task_eval_info, print_total_eval_info, print_single_vs_multi_performance
10 |
11 | set_basic_log_config()
12 |
13 |
14 | def parse_args(in_args=None):
15 | arg_parser = argparse.ArgumentParser()
16 | arg_parser.add_argument('--task_name', type=str, required=True,
17 | help='Take Name')
18 | arg_parser.add_argument('--data_dir', type=str, default='./Data',
19 | help='Data directory')
20 | arg_parser.add_argument('--exp_dir', type=str, default='./Exps',
21 | help='Experiment directory')
22 | arg_parser.add_argument('--save_cpt_flag', type=strtobool, default=True,
23 | help='Whether to save cpt for each epoch')
24 | arg_parser.add_argument('--skip_train', type=strtobool, default=False,
25 | help='Whether to skip training')
26 | arg_parser.add_argument('--eval_model_names', type=str, default='GIT',
27 | help="Models to be evaluated")
28 | arg_parser.add_argument('--re_eval_flag', type=strtobool, default=False,
29 | help='Whether to re-evaluate previous predictions')
30 |
31 | # add task setting arguments
32 | for key, val in DEETaskSetting.base_attr_default_pairs:
33 | if isinstance(val, bool):
34 | arg_parser.add_argument('--' + key, type=strtobool, default=val)
35 | else:
36 | arg_parser.add_argument('--'+key, type=type(val), default=val)
37 |
38 | arg_info = arg_parser.parse_args(args=in_args)
39 |
40 | return arg_info
41 |
42 |
43 | if __name__ == '__main__':
44 | in_argv = parse_args()
45 |
46 | task_dir = os.path.join(in_argv.exp_dir, in_argv.task_name)
47 | if not os.path.exists(task_dir):
48 | os.makedirs(task_dir, exist_ok=True)
49 |
50 | in_argv.model_dir = os.path.join(task_dir, "Model")
51 | in_argv.output_dir = os.path.join(task_dir, "Output")
52 |
53 | # in_argv must contain 'data_dir', 'model_dir', 'output_dir'
54 | dee_setting = DEETaskSetting(
55 | **in_argv.__dict__
56 | )
57 | dee_setting.summary_dir_name = os.path.join(task_dir, "Summary")
58 |
59 | # build task
60 | dee_task = DEETask(dee_setting, load_train=not in_argv.skip_train)
61 |
62 | if not in_argv.skip_train:
63 | # dump hyper-parameter settings
64 | if dee_task.is_master_node():
65 | fn = '{}.task_setting.json'.format(dee_setting.cpt_file_name)
66 | dee_setting.dump_to(task_dir, file_name=fn)
67 |
68 | dee_task.train(save_cpt_flag=in_argv.save_cpt_flag)
69 | else:
70 | dee_task.logging('Skip training')
71 |
72 | if dee_task.is_master_node():
73 | if in_argv.re_eval_flag:
74 | data_span_type2model_str2epoch_res_list = dee_task.reevaluate_dee_prediction(dump_flag=True)
75 | else:
76 | data_span_type2model_str2epoch_res_list = aggregate_task_eval_info(in_argv.output_dir, dump_flag=True)
77 | data_type = 'test'
78 | span_type = 'pred_span'
79 | metric_type = 'micro'
80 | mstr_bepoch_list = print_total_eval_info(
81 | data_span_type2model_str2epoch_res_list, metric_type=metric_type, span_type=span_type,
82 | model_str=in_argv.eval_model_names,
83 | target_set=data_type
84 | )
85 | print_single_vs_multi_performance(
86 | mstr_bepoch_list, in_argv.output_dir, dee_task.test_features,
87 | metric_type=metric_type, data_type=data_type, span_type=span_type
88 | )
89 |
90 | # ensure every processes exit at the same time
91 | if dist.is_initialized():
92 | dist.barrier()
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/run_eval.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | GPU=0
4 | DATA_DIR=./Data
5 | EXP_DIR=./Exps
6 | COMMON_TASK_NAME=try
7 | EVAL_BS=2
8 | NUM_GPUS=1
9 | MODEL_STR=GIT
10 |
11 | echo "---> ${MODEL_STR} Run"
12 | CUDA_VISIBLE_DEVICES=${GPU} ./train_multi.sh ${NUM_GPUS} \
13 | --data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} \
14 | --eval_batch_size ${EVAL_BS} \
15 | --cpt_file_name ${MODEL_STR} \
16 | --skip_train True \
17 | --re_eval_flag False
18 |
--------------------------------------------------------------------------------
/run_train.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | GPU=0,1,2,3,4,5,6,7
4 | DATA_DIR=./Data
5 | EXP_DIR=./Exps
6 | COMMON_TASK_NAME=try
7 | RESUME_TRAIN=True
8 | SAVE_CPT=True
9 | N_EPOCH=100
10 | TRAIN_BS=64
11 | EVAL_BS=2
12 | NUM_GPUS=8
13 | GRAD_ACC_STEP=8
14 | MODEL_STR=GIT
15 |
16 | echo "---> ${MODEL_STR} Run"
17 | CUDA_VISIBLE_DEVICES=${GPU} ./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
18 | --data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
19 | --train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
20 | --cpt_file_name ${MODEL_STR}
21 |
--------------------------------------------------------------------------------
/train_multi.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | NUM_GPUS=$1
4 | shift
5 |
6 | python -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_dee_task.py $*
7 |
--------------------------------------------------------------------------------